Module bbrl.agents.dataloader

Expand source code
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
#

# import gym
import torch
from gym.utils import seeding

# from torch.utils.data import DataLoader

from bbrl.agents.agent import Agent


class ShuffledDatasetAgent(Agent):
    """An agent that read a dataset in a shuffle order, in an infinite way."""

    def __init__(self, dataset, batch_size, output_names=("x", "y")):
        """Create the agent

        Args:
            dataset ([torch.utils.data.Dataset]): the Dataset
            batch_size ([int]): The number of datapoints to write at each call
            output_names (tuple, optional): The name of the variables. Defaults to ("x", "y").
        """
        super().__init__()
        self.output_names = output_names
        self.dataset = dataset
        self.batch_size = batch_size
        self.ghost_params = torch.nn.Parameter(torch.randn(()))

    def seed(self, seed=None):
        self.np_random, seed = seeding.np_random(seed)
        return [seed]

    def forward(self, **kwargs):
        """Write a batch of data at timestep==0 in the workspace"""
        vs = []
        for k in range(self.batch_size):
            idx = self.np_random.randint(len(self.dataset))
            x = self.dataset[idx]
            xs = []
            for xx in x:
                if isinstance(xx, torch.Tensor):
                    xs.append(xx.unsqueeze(0))
                else:
                    xs.append(torch.tensor(xx).unsqueeze(0))
            vs.append(xs)

        vals = []
        for k in range(len(vs[0])):
            val = [v[k] for v in vs]
            val = torch.cat(val, dim=0)
            vals.append(val)

        for name, value in zip(self.output_names, vals):
            self.set((name, 0), value.to(self.ghost_params.device))


class DataLoaderAgent(Agent):
    """An agent based on a DataLoader that read a single dataset
    Usage is: agent.forward(), then one has to check if agent.finished() is True or Not. If True, then no data have been written in the workspace since the reading of the daaset is terminated
    """

    def __init__(self, dataloader, output_names=("x", "y")):
        """Create the agent based on a dataloader

        Args:
            dataloader ([DataLader]): The underlying pytoch daaloader object
            output_names (tuple, optional): Names of the variable to write in the workspace. Defaults to ("x", "y").
        """
        super().__init__()
        self.dataloader = dataloader
        self.iter = iter(self.dataloader)
        self.output_names = output_names
        self._finished = False
        self.ghost_params = torch.nn.Parameter(torch.randn(()))

    def reset(self):
        self.iter = iter(self.dataloader)
        self._finished = False

    def finished(self):
        return self._finished

    def forward(self, **kwargs):
        try:
            output_values = next(self.iter)
        except StopIteration:
            self.iter = None
            self._finished = True
        else:
            for name, value in zip(self.output_names, output_values):
                self.set((name, 0), value.to(self.ghost_params.device))

Classes

class DataLoaderAgent (dataloader, output_names=('x', 'y'))

An agent based on a DataLoader that read a single dataset Usage is: agent.forward(), then one has to check if agent.finished() is True or Not. If True, then no data have been written in the workspace since the reading of the daaset is terminated

Create the agent based on a dataloader

Args

dataloader : [DataLader]
The underlying pytoch daaloader object
output_names : tuple, optional
Names of the variable to write in the workspace. Defaults to ("x", "y").
Expand source code
class DataLoaderAgent(Agent):
    """An agent based on a DataLoader that read a single dataset
    Usage is: agent.forward(), then one has to check if agent.finished() is True or Not. If True, then no data have been written in the workspace since the reading of the daaset is terminated
    """

    def __init__(self, dataloader, output_names=("x", "y")):
        """Create the agent based on a dataloader

        Args:
            dataloader ([DataLader]): The underlying pytoch daaloader object
            output_names (tuple, optional): Names of the variable to write in the workspace. Defaults to ("x", "y").
        """
        super().__init__()
        self.dataloader = dataloader
        self.iter = iter(self.dataloader)
        self.output_names = output_names
        self._finished = False
        self.ghost_params = torch.nn.Parameter(torch.randn(()))

    def reset(self):
        self.iter = iter(self.dataloader)
        self._finished = False

    def finished(self):
        return self._finished

    def forward(self, **kwargs):
        try:
            output_values = next(self.iter)
        except StopIteration:
            self.iter = None
            self._finished = True
        else:
            for name, value in zip(self.output_names, output_values):
                self.set((name, 0), value.to(self.ghost_params.device))

Ancestors

  • Agent
  • torch.nn.modules.module.Module

Class variables

var dump_patches : bool
var training : bool

Methods

def finished(self)
Expand source code
def finished(self):
    return self._finished
def reset(self)
Expand source code
def reset(self):
    self.iter = iter(self.dataloader)
    self._finished = False

Inherited members

class ShuffledDatasetAgent (dataset, batch_size, output_names=('x', 'y'))

An agent that read a dataset in a shuffle order, in an infinite way.

Create the agent

Args

dataset : [torch.utils.data.Dataset]
the Dataset
batch_size : [int]
The number of datapoints to write at each call
output_names : tuple, optional
The name of the variables. Defaults to ("x", "y").
Expand source code
class ShuffledDatasetAgent(Agent):
    """An agent that read a dataset in a shuffle order, in an infinite way."""

    def __init__(self, dataset, batch_size, output_names=("x", "y")):
        """Create the agent

        Args:
            dataset ([torch.utils.data.Dataset]): the Dataset
            batch_size ([int]): The number of datapoints to write at each call
            output_names (tuple, optional): The name of the variables. Defaults to ("x", "y").
        """
        super().__init__()
        self.output_names = output_names
        self.dataset = dataset
        self.batch_size = batch_size
        self.ghost_params = torch.nn.Parameter(torch.randn(()))

    def seed(self, seed=None):
        self.np_random, seed = seeding.np_random(seed)
        return [seed]

    def forward(self, **kwargs):
        """Write a batch of data at timestep==0 in the workspace"""
        vs = []
        for k in range(self.batch_size):
            idx = self.np_random.randint(len(self.dataset))
            x = self.dataset[idx]
            xs = []
            for xx in x:
                if isinstance(xx, torch.Tensor):
                    xs.append(xx.unsqueeze(0))
                else:
                    xs.append(torch.tensor(xx).unsqueeze(0))
            vs.append(xs)

        vals = []
        for k in range(len(vs[0])):
            val = [v[k] for v in vs]
            val = torch.cat(val, dim=0)
            vals.append(val)

        for name, value in zip(self.output_names, vals):
            self.set((name, 0), value.to(self.ghost_params.device))

Ancestors

  • Agent
  • torch.nn.modules.module.Module

Class variables

var dump_patches : bool
var training : bool

Methods

def forward(self, **kwargs) ‑> Callable[..., Any]

Write a batch of data at timestep==0 in the workspace

Expand source code
def forward(self, **kwargs):
    """Write a batch of data at timestep==0 in the workspace"""
    vs = []
    for k in range(self.batch_size):
        idx = self.np_random.randint(len(self.dataset))
        x = self.dataset[idx]
        xs = []
        for xx in x:
            if isinstance(xx, torch.Tensor):
                xs.append(xx.unsqueeze(0))
            else:
                xs.append(torch.tensor(xx).unsqueeze(0))
        vs.append(xs)

    vals = []
    for k in range(len(vs[0])):
        val = [v[k] for v in vs]
        val = torch.cat(val, dim=0)
        vals.append(val)

    for name, value in zip(self.output_names, vals):
        self.set((name, 0), value.to(self.ghost_params.device))

Inherited members