Module bbrl.agents.utils
Expand source code
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
#
import torch
import torch.nn as nn
from bbrl.agents.agent import Agent
class Agents(Agent):
"""An agent that contains multiple agents that will be executed sequentially
Args:
Agent ([bbrl.Agent]): The agents
"""
def __init__(self, *agents, name=None):
"""Creates the agent from multiple agents
Args:
name ([str], optional): [name of the resulting agent]. Defaults to None.
"""
super().__init__(name=name)
for a in agents:
assert isinstance(a, Agent)
self.agents = nn.ModuleList(agents)
def __call__(self, workspace, **kwargs):
for a in self.agents:
a(workspace, **kwargs)
def forward(self, **kwargs):
raise NotImplementedError
def seed(self, seed):
for a in self.agents:
a.seed(seed)
def __getitem__(self, k):
return self.agents[k]
def get_by_name(self, n):
r = []
for a in self.agents:
r = r + a.get_by_name(n)
if n == self._name:
r = r + [self]
return r
class TemporalAgent(Agent):
"""Execute one Agent over multiple timesteps"""
def __init__(self, agent, name=None):
"""The agent to transform to a temporal agent
Args:
agent ([bbrl.Agent]): The agent to encapsulate
name ([str], optional): Name of the agent
"""
super().__init__(name=name)
self.agent = agent
def __call__(self, workspace, t=0, n_steps=None, stop_variable=None, **kwargs):
"""Execute the agent starting at time t, for n_steps
Args:
workspace ([bbrl.Workspace]):
t (int, optional): The starting timestep. Defaults to 0.
n_steps ([type], optional): The number of steps. Defaults to None.
stop_variable ([type], optional): if True everywhere (at time t), execution is stopped.
Defaults to None = not used.
"""
assert n_steps is not None or stop_variable is not None
_t = t
while True:
self.agent(workspace, t=_t, **kwargs)
if stop_variable is not None:
s = workspace.get(stop_variable, _t)
if s.all():
break
_t += 1
if n_steps is not None:
if _t >= t + n_steps:
break
def forward(self, **kwargs):
raise NotImplementedError
def seed(self, seed):
self.agent.seed(seed)
def get_by_name(self, n):
r = self.agent.get_by_name(n)
if n == self._name:
r = r + [self]
return r
class CopyTAgent(Agent):
"""An agent that copies a variable"""
def __init__(self, input_name, output_name, detach=False, name=None):
"""
Args:
input_name ([str]): The variable to copy from
output_name ([str]): The variable to copy to
detach ([bool]): copy with detach if True
"""
super().__init__(name=name)
self.input_name = input_name
self.output_name = output_name
self.detach = detach
def forward(self, t=None, **kwargs):
"""
Args:
t ([type], optional): if not None, copy at time t. Defaults to None.
"""
if t is None:
x = self.get(self.input_name)
if not self.detach:
self.set(self.output_name, x)
else:
self.set((self.output_name, t), x.detach())
else:
x = self.get((self.input_name, t))
if not self.detach:
self.set((self.output_name, t), x)
else:
self.set((self.output_name, t), x.detach())
class PrintAgent(Agent):
"""An agent to generate print in the console (mainly for debugging)
It can be passed a list of strings corresponding to the variables to print
or if nothing is passed, it prints all the existing variables in the workspace
"""
def __init__(self, *names, name=None):
"""
Args:
names ([str], optional): The variables to print
"""
super().__init__(name=name)
self.names = names
def reset(self):
self.names = ()
def forward(self, t, **kwargs):
if self.names == ():
self.names = self.workspace.keys()
for n in self.names:
print(n, " = ", self.get((n, t)))
class EpisodesDone(Agent):
"""
If done is encountered at time t, then done=True for all timeteps t'>=t
It allows to simulate a single episode agent based on an autoreset agent
"""
def __init__(self, in_var="env/done", out_var="env/done"):
super().__init__()
self.in_var = in_var
self.out_var = out_var
def forward(self, t, **kwargs):
d = self.get((self.in_var, t))
if t == 0:
self.state = torch.zeros_like(d).bool()
self.state = torch.logical_or(self.state, d)
self.set((self.out_var, t), self.state)
Classes
class Agents (*agents, name=None)
-
An agent that contains multiple agents that will be executed sequentially
Args
Agent
:[bbrl.Agent]
- The agents
Creates the agent from multiple agents
Args
name
:[str]
, optional- [name of the resulting agent]. Defaults to None.
Expand source code
class Agents(Agent): """An agent that contains multiple agents that will be executed sequentially Args: Agent ([bbrl.Agent]): The agents """ def __init__(self, *agents, name=None): """Creates the agent from multiple agents Args: name ([str], optional): [name of the resulting agent]. Defaults to None. """ super().__init__(name=name) for a in agents: assert isinstance(a, Agent) self.agents = nn.ModuleList(agents) def __call__(self, workspace, **kwargs): for a in self.agents: a(workspace, **kwargs) def forward(self, **kwargs): raise NotImplementedError def seed(self, seed): for a in self.agents: a.seed(seed) def __getitem__(self, k): return self.agents[k] def get_by_name(self, n): r = [] for a in self.agents: r = r + a.get_by_name(n) if n == self._name: r = r + [self] return r
Ancestors
- Agent
- torch.nn.modules.module.Module
Class variables
var dump_patches : bool
var training : bool
Inherited members
class CopyTAgent (input_name, output_name, detach=False, name=None)
-
An agent that copies a variable
Args: input_name ([str]): The variable to copy from output_name ([str]): The variable to copy to detach ([bool]): copy with detach if True
Expand source code
class CopyTAgent(Agent): """An agent that copies a variable""" def __init__(self, input_name, output_name, detach=False, name=None): """ Args: input_name ([str]): The variable to copy from output_name ([str]): The variable to copy to detach ([bool]): copy with detach if True """ super().__init__(name=name) self.input_name = input_name self.output_name = output_name self.detach = detach def forward(self, t=None, **kwargs): """ Args: t ([type], optional): if not None, copy at time t. Defaults to None. """ if t is None: x = self.get(self.input_name) if not self.detach: self.set(self.output_name, x) else: self.set((self.output_name, t), x.detach()) else: x = self.get((self.input_name, t)) if not self.detach: self.set((self.output_name, t), x) else: self.set((self.output_name, t), x.detach())
Ancestors
- Agent
- torch.nn.modules.module.Module
Class variables
var dump_patches : bool
var training : bool
Methods
def forward(self, t=None, **kwargs) ‑> Callable[..., Any]
-
Args
t
:[type]
, optional- if not None, copy at time t. Defaults to None.
Expand source code
def forward(self, t=None, **kwargs): """ Args: t ([type], optional): if not None, copy at time t. Defaults to None. """ if t is None: x = self.get(self.input_name) if not self.detach: self.set(self.output_name, x) else: self.set((self.output_name, t), x.detach()) else: x = self.get((self.input_name, t)) if not self.detach: self.set((self.output_name, t), x) else: self.set((self.output_name, t), x.detach())
Inherited members
class EpisodesDone (in_var='env/done', out_var='env/done')
-
If done is encountered at time t, then done=True for all timeteps t'>=t It allows to simulate a single episode agent based on an autoreset agent
To create a new Agent
Args
name
:[type]
, optional- An agent can have a name that will allow to perform operations
on agents that are composed into more complex agents.
Expand source code
class EpisodesDone(Agent): """ If done is encountered at time t, then done=True for all timeteps t'>=t It allows to simulate a single episode agent based on an autoreset agent """ def __init__(self, in_var="env/done", out_var="env/done"): super().__init__() self.in_var = in_var self.out_var = out_var def forward(self, t, **kwargs): d = self.get((self.in_var, t)) if t == 0: self.state = torch.zeros_like(d).bool() self.state = torch.logical_or(self.state, d) self.set((self.out_var, t), self.state)
Ancestors
- Agent
- torch.nn.modules.module.Module
Class variables
var dump_patches : bool
var training : bool
Inherited members
class PrintAgent (*names, name=None)
-
An agent to generate print in the console (mainly for debugging) It can be passed a list of strings corresponding to the variables to print or if nothing is passed, it prints all the existing variables in the workspace
Args
names
:[str]
, optional- The variables to print
Expand source code
class PrintAgent(Agent): """An agent to generate print in the console (mainly for debugging) It can be passed a list of strings corresponding to the variables to print or if nothing is passed, it prints all the existing variables in the workspace """ def __init__(self, *names, name=None): """ Args: names ([str], optional): The variables to print """ super().__init__(name=name) self.names = names def reset(self): self.names = () def forward(self, t, **kwargs): if self.names == (): self.names = self.workspace.keys() for n in self.names: print(n, " = ", self.get((n, t)))
Ancestors
- Agent
- torch.nn.modules.module.Module
Class variables
var dump_patches : bool
var training : bool
Methods
def reset(self)
-
Expand source code
def reset(self): self.names = ()
Inherited members
class TemporalAgent (agent, name=None)
-
Execute one Agent over multiple timesteps
The agent to transform to a temporal agent
Args
agent
:[bbrl.Agent]
- The agent to encapsulate
name
:[str]
, optional- Name of the agent
Expand source code
class TemporalAgent(Agent): """Execute one Agent over multiple timesteps""" def __init__(self, agent, name=None): """The agent to transform to a temporal agent Args: agent ([bbrl.Agent]): The agent to encapsulate name ([str], optional): Name of the agent """ super().__init__(name=name) self.agent = agent def __call__(self, workspace, t=0, n_steps=None, stop_variable=None, **kwargs): """Execute the agent starting at time t, for n_steps Args: workspace ([bbrl.Workspace]): t (int, optional): The starting timestep. Defaults to 0. n_steps ([type], optional): The number of steps. Defaults to None. stop_variable ([type], optional): if True everywhere (at time t), execution is stopped. Defaults to None = not used. """ assert n_steps is not None or stop_variable is not None _t = t while True: self.agent(workspace, t=_t, **kwargs) if stop_variable is not None: s = workspace.get(stop_variable, _t) if s.all(): break _t += 1 if n_steps is not None: if _t >= t + n_steps: break def forward(self, **kwargs): raise NotImplementedError def seed(self, seed): self.agent.seed(seed) def get_by_name(self, n): r = self.agent.get_by_name(n) if n == self._name: r = r + [self] return r
Ancestors
- Agent
- torch.nn.modules.module.Module
Class variables
var dump_patches : bool
var training : bool
Inherited members