Module bbrl.utils.replay_buffer
Expand source code
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
#
import copy
import torch
from bbrl.workspace import Workspace
class ReplayBuffer:
def __init__(self, max_size, device=torch.device("cpu")):
self.max_size = int(max_size)
self.variables = None
self.position = 0
self.is_full = False
self.device = device
def init_workspace(self, all_tensors):
"""
Create an array to stores workspace based on the given all_tensors keys.
shape of stores tensors : [key] => [self.max_size][time_size][key_dim]
Makes a copy of the input content
"""
if self.variables is None:
self.variables = {}
for k, v in all_tensors.items():
s = list(v.size())
s[1] = self.max_size
_s = copy.deepcopy(s)
s[0] = _s[1]
s[1] = _s[0]
tensor = torch.zeros(*s, dtype=v.dtype, device=self.device)
self.variables[k] = tensor
self.is_full = False
self.position = 0
def _insert(self, k, indexes, v):
self.variables[k][indexes] = v.detach().moveaxis((0, 1), (1, 0))
def put(self, workspace):
"""
Add a the content of a workspace to the replay buffer.
The given workspace must have keys of shape : [time_size][batch_size][key_dim]
"""
new_data = {
k: workspace.get_full(k).detach().to(self.device) for k in workspace.keys()
}
self.init_workspace(new_data)
batch_size = None
arange = None
indexes = None
for k, v in new_data.items():
if batch_size is None:
batch_size = v.size()[1]
# print(f"{k}: batch size : {batch_size}")
# print("pos", self.position)
if self.position + batch_size < self.max_size:
# The case where the batch can be inserted before the end of the replay buffer
if indexes is None:
indexes = torch.arange(batch_size) + self.position
arange = torch.arange(batch_size)
self.position = self.position + batch_size
indexes = indexes.to(dtype=torch.long, device=v.device)
arange = arange.to(dtype=torch.long, device=v.device)
# print("insertion standard:", indexes)
# # print("v shape", v.detach().shape)
self._insert(k, indexes, v)
else:
# The case where the batch cannot be inserted before the end of the replay buffer
# A part is at the end, the other part is in the beginning
self.is_full = True
# the number of data at the end of the RB
batch_end_size = self.max_size - self.position
# the number of data at the beginning of the RB
batch_begin_size = batch_size - batch_end_size
if indexes is None:
# print(f"{k}: batch size : {batch_size}")
# print("pos", self.position)
# the part of the indexes at the end of the RB
indexes = torch.arange(batch_end_size) + self.position
arange = torch.arange(batch_end_size)
# the part of the indexes at the beginning of the RB
# print("insertion intermediate computed:", indexes)
indexes = torch.cat((indexes, torch.arange(batch_begin_size)), 0)
arange = torch.cat((arange, torch.arange(batch_begin_size)), 0)
# print("insertion full:", indexes)
self.position = batch_begin_size
indexes = indexes.to(dtype=torch.long, device=v.device)
arange = arange.to(dtype=torch.long, device=v.device)
self._insert(k, indexes, v)
def size(self):
if self.is_full:
return self.max_size
else:
return self.position
def print_obs(self):
print(f"position: {self.position}")
print(self.variables["env/env_obs"])
def get_shuffled(self, batch_size):
who = torch.randint(
low=0, high=self.size(), size=(batch_size,), device=self.device
)
workspace = Workspace()
for k in self.variables:
workspace.set_full(k, self.variables[k][who].transpose(0, 1))
return workspace
def to(self, device):
n_vars = {k: v.to(device) for k, v in self.variables.items()}
self.variables = n_vars
Classes
class ReplayBuffer (max_size, device=device(type='cpu'))
-
Expand source code
class ReplayBuffer: def __init__(self, max_size, device=torch.device("cpu")): self.max_size = int(max_size) self.variables = None self.position = 0 self.is_full = False self.device = device def init_workspace(self, all_tensors): """ Create an array to stores workspace based on the given all_tensors keys. shape of stores tensors : [key] => [self.max_size][time_size][key_dim] Makes a copy of the input content """ if self.variables is None: self.variables = {} for k, v in all_tensors.items(): s = list(v.size()) s[1] = self.max_size _s = copy.deepcopy(s) s[0] = _s[1] s[1] = _s[0] tensor = torch.zeros(*s, dtype=v.dtype, device=self.device) self.variables[k] = tensor self.is_full = False self.position = 0 def _insert(self, k, indexes, v): self.variables[k][indexes] = v.detach().moveaxis((0, 1), (1, 0)) def put(self, workspace): """ Add a the content of a workspace to the replay buffer. The given workspace must have keys of shape : [time_size][batch_size][key_dim] """ new_data = { k: workspace.get_full(k).detach().to(self.device) for k in workspace.keys() } self.init_workspace(new_data) batch_size = None arange = None indexes = None for k, v in new_data.items(): if batch_size is None: batch_size = v.size()[1] # print(f"{k}: batch size : {batch_size}") # print("pos", self.position) if self.position + batch_size < self.max_size: # The case where the batch can be inserted before the end of the replay buffer if indexes is None: indexes = torch.arange(batch_size) + self.position arange = torch.arange(batch_size) self.position = self.position + batch_size indexes = indexes.to(dtype=torch.long, device=v.device) arange = arange.to(dtype=torch.long, device=v.device) # print("insertion standard:", indexes) # # print("v shape", v.detach().shape) self._insert(k, indexes, v) else: # The case where the batch cannot be inserted before the end of the replay buffer # A part is at the end, the other part is in the beginning self.is_full = True # the number of data at the end of the RB batch_end_size = self.max_size - self.position # the number of data at the beginning of the RB batch_begin_size = batch_size - batch_end_size if indexes is None: # print(f"{k}: batch size : {batch_size}") # print("pos", self.position) # the part of the indexes at the end of the RB indexes = torch.arange(batch_end_size) + self.position arange = torch.arange(batch_end_size) # the part of the indexes at the beginning of the RB # print("insertion intermediate computed:", indexes) indexes = torch.cat((indexes, torch.arange(batch_begin_size)), 0) arange = torch.cat((arange, torch.arange(batch_begin_size)), 0) # print("insertion full:", indexes) self.position = batch_begin_size indexes = indexes.to(dtype=torch.long, device=v.device) arange = arange.to(dtype=torch.long, device=v.device) self._insert(k, indexes, v) def size(self): if self.is_full: return self.max_size else: return self.position def print_obs(self): print(f"position: {self.position}") print(self.variables["env/env_obs"]) def get_shuffled(self, batch_size): who = torch.randint( low=0, high=self.size(), size=(batch_size,), device=self.device ) workspace = Workspace() for k in self.variables: workspace.set_full(k, self.variables[k][who].transpose(0, 1)) return workspace def to(self, device): n_vars = {k: v.to(device) for k, v in self.variables.items()} self.variables = n_vars
Methods
def get_shuffled(self, batch_size)
-
Expand source code
def get_shuffled(self, batch_size): who = torch.randint( low=0, high=self.size(), size=(batch_size,), device=self.device ) workspace = Workspace() for k in self.variables: workspace.set_full(k, self.variables[k][who].transpose(0, 1)) return workspace
def init_workspace(self, all_tensors)
-
Create an array to stores workspace based on the given all_tensors keys. shape of stores tensors : [key] => [self.max_size][time_size][key_dim] Makes a copy of the input content
Expand source code
def init_workspace(self, all_tensors): """ Create an array to stores workspace based on the given all_tensors keys. shape of stores tensors : [key] => [self.max_size][time_size][key_dim] Makes a copy of the input content """ if self.variables is None: self.variables = {} for k, v in all_tensors.items(): s = list(v.size()) s[1] = self.max_size _s = copy.deepcopy(s) s[0] = _s[1] s[1] = _s[0] tensor = torch.zeros(*s, dtype=v.dtype, device=self.device) self.variables[k] = tensor self.is_full = False self.position = 0
def print_obs(self)
-
Expand source code
def print_obs(self): print(f"position: {self.position}") print(self.variables["env/env_obs"])
def put(self, workspace)
-
Add a the content of a workspace to the replay buffer. The given workspace must have keys of shape : [time_size][batch_size][key_dim]
Expand source code
def put(self, workspace): """ Add a the content of a workspace to the replay buffer. The given workspace must have keys of shape : [time_size][batch_size][key_dim] """ new_data = { k: workspace.get_full(k).detach().to(self.device) for k in workspace.keys() } self.init_workspace(new_data) batch_size = None arange = None indexes = None for k, v in new_data.items(): if batch_size is None: batch_size = v.size()[1] # print(f"{k}: batch size : {batch_size}") # print("pos", self.position) if self.position + batch_size < self.max_size: # The case where the batch can be inserted before the end of the replay buffer if indexes is None: indexes = torch.arange(batch_size) + self.position arange = torch.arange(batch_size) self.position = self.position + batch_size indexes = indexes.to(dtype=torch.long, device=v.device) arange = arange.to(dtype=torch.long, device=v.device) # print("insertion standard:", indexes) # # print("v shape", v.detach().shape) self._insert(k, indexes, v) else: # The case where the batch cannot be inserted before the end of the replay buffer # A part is at the end, the other part is in the beginning self.is_full = True # the number of data at the end of the RB batch_end_size = self.max_size - self.position # the number of data at the beginning of the RB batch_begin_size = batch_size - batch_end_size if indexes is None: # print(f"{k}: batch size : {batch_size}") # print("pos", self.position) # the part of the indexes at the end of the RB indexes = torch.arange(batch_end_size) + self.position arange = torch.arange(batch_end_size) # the part of the indexes at the beginning of the RB # print("insertion intermediate computed:", indexes) indexes = torch.cat((indexes, torch.arange(batch_begin_size)), 0) arange = torch.cat((arange, torch.arange(batch_begin_size)), 0) # print("insertion full:", indexes) self.position = batch_begin_size indexes = indexes.to(dtype=torch.long, device=v.device) arange = arange.to(dtype=torch.long, device=v.device) self._insert(k, indexes, v)
def size(self)
-
Expand source code
def size(self): if self.is_full: return self.max_size else: return self.position
def to(self, device)
-
Expand source code
def to(self, device): n_vars = {k: v.to(device) for k, v in self.variables.items()} self.variables = n_vars