Module bbrl.utils.functional

Expand source code
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
#

import torch


def _index(tensor_3d, tensor_2d):
    x, y, z = tensor_3d.size()
    t = tensor_3d.reshape(x * y, z)
    tt = tensor_2d.reshape(x * y)
    v = t[torch.arange(x * y), tt]
    v = v.reshape(x, y)
    return v


def cumulated_reward(reward, done):
    T, B = done.size()
    done = done.detach().clone()

    v_done, index_done = done.float().max(0)
    assert v_done.eq(
        1.0
    ).all(), "[agents.rl.functional.cumulated_reward] Computing cumulated reward over unfinished trajectories"
    arange = torch.arange(T, device=done.device).unsqueeze(-1).repeat(1, B)
    index_done = index_done.unsqueeze(0).repeat(T, 1)

    mask = arange.le(index_done)
    reward = (reward * mask.float()).sum(0)
    return reward.mean().item()


def temporal_difference(critic, reward, must_bootstrap, discount_factor):
    target = discount_factor * critic[1:].detach() * must_bootstrap.float() + reward[1:]
    td = target - critic[:-1]
    to_add = torch.zeros(1, td.size()[1]).to(td.device)
    td = torch.cat([td, to_add], dim=0)
    return td


def doubleqlearning_temporal_difference(
    q, action, q_target, reward, must_bootstrap, discount_factor
):
    action_max = q.max(-1)[1]
    q_target_max = _index(q_target, action_max).detach()[1:]

    mb = must_bootstrap.float()
    target = reward[1:] + discount_factor * q_target_max * mb

    q = _index(q, action)[:-1]
    td = target - q
    to_add = torch.zeros(1, td.size()[1], device=td.device)
    td = torch.cat([td, to_add], dim=0)
    return td


def gae(critic, reward, must_bootstrap, discount_factor, gae_coef):
    mb = must_bootstrap.float()
    td = reward[1:] + discount_factor * mb * critic[1:].detach() - critic[:-1]
    # handling td0 case
    if gae_coef == 0.0:
        return td

    td_shape = td.shape[0]
    gae_val = td[-1]
    gaes = [gae_val]
    for t in range(td_shape - 2, -1, -1):
        gae_val = td[t] + discount_factor * gae_coef * mb[:-1][t] * gae_val
        gaes.append(gae_val)
    gaes = list([g.unsqueeze(0) for g in reversed(gaes)])
    gaes = torch.cat(gaes, dim=0)
    return gaes


def compute_reinforce_loss(
    reward, action_probabilities, baseline, action, done, discount_factor
):

    batch_size = reward.size()[1]

    # Find the first occurrence of done for each element in the batch
    v_done, trajectories_length = done.float().max(0)
    trajectories_length += 1
    assert v_done.eq(1.0).all()
    max_trajectories_length = trajectories_length.max().item()
    # Shorten trajectories for faster computation
    reward = reward[:max_trajectories_length]
    action_probabilities = action_probabilities[:max_trajectories_length]
    baseline = baseline[:max_trajectories_length]
    action = action[:max_trajectories_length]

    # Create a binary mask to mask useless values (of size max_trajectories_length x batch_size)
    arange = (
        torch.arange(max_trajectories_length, device=done.device)
        .unsqueeze(-1)
        .repeat(1, batch_size)
    )
    mask = arange.lt(
        trajectories_length.unsqueeze(0).repeat(max_trajectories_length, 1)
    )
    reward = reward * mask

    # Compute discounted cumulated reward
    cum_reward = [torch.zeros_like(reward[-1])]
    for t in range(max_trajectories_length - 1, 0, -1):
        cum_reward.append(discount_factor + cum_reward[-1] + reward[t])
    cum_reward.reverse()
    cum_reward = torch.cat([c.unsqueeze(0) for c in cum_reward])

    # baseline loss
    g = baseline - cum_reward
    baseline_loss = (g) ** 2
    baseline_loss = (baseline_loss * mask).mean()

    # policy loss
    log_probabilities = _index(action_probabilities, action).log()
    policy_loss = log_probabilities * -g.detach()
    policy_loss = policy_loss * mask
    policy_loss = policy_loss.mean()

    # entropy loss
    entropy = torch.distributions.Categorical(action_probabilities).entropy() * mask
    entropy_loss = entropy.mean()

    return {
        "baseline_loss": baseline_loss,
        "policy_loss": policy_loss,
        "entropy_loss": entropy_loss,
    }

Functions

def compute_reinforce_loss(reward, action_probabilities, baseline, action, done, discount_factor)
Expand source code
def compute_reinforce_loss(
    reward, action_probabilities, baseline, action, done, discount_factor
):

    batch_size = reward.size()[1]

    # Find the first occurrence of done for each element in the batch
    v_done, trajectories_length = done.float().max(0)
    trajectories_length += 1
    assert v_done.eq(1.0).all()
    max_trajectories_length = trajectories_length.max().item()
    # Shorten trajectories for faster computation
    reward = reward[:max_trajectories_length]
    action_probabilities = action_probabilities[:max_trajectories_length]
    baseline = baseline[:max_trajectories_length]
    action = action[:max_trajectories_length]

    # Create a binary mask to mask useless values (of size max_trajectories_length x batch_size)
    arange = (
        torch.arange(max_trajectories_length, device=done.device)
        .unsqueeze(-1)
        .repeat(1, batch_size)
    )
    mask = arange.lt(
        trajectories_length.unsqueeze(0).repeat(max_trajectories_length, 1)
    )
    reward = reward * mask

    # Compute discounted cumulated reward
    cum_reward = [torch.zeros_like(reward[-1])]
    for t in range(max_trajectories_length - 1, 0, -1):
        cum_reward.append(discount_factor + cum_reward[-1] + reward[t])
    cum_reward.reverse()
    cum_reward = torch.cat([c.unsqueeze(0) for c in cum_reward])

    # baseline loss
    g = baseline - cum_reward
    baseline_loss = (g) ** 2
    baseline_loss = (baseline_loss * mask).mean()

    # policy loss
    log_probabilities = _index(action_probabilities, action).log()
    policy_loss = log_probabilities * -g.detach()
    policy_loss = policy_loss * mask
    policy_loss = policy_loss.mean()

    # entropy loss
    entropy = torch.distributions.Categorical(action_probabilities).entropy() * mask
    entropy_loss = entropy.mean()

    return {
        "baseline_loss": baseline_loss,
        "policy_loss": policy_loss,
        "entropy_loss": entropy_loss,
    }
def cumulated_reward(reward, done)
Expand source code
def cumulated_reward(reward, done):
    T, B = done.size()
    done = done.detach().clone()

    v_done, index_done = done.float().max(0)
    assert v_done.eq(
        1.0
    ).all(), "[agents.rl.functional.cumulated_reward] Computing cumulated reward over unfinished trajectories"
    arange = torch.arange(T, device=done.device).unsqueeze(-1).repeat(1, B)
    index_done = index_done.unsqueeze(0).repeat(T, 1)

    mask = arange.le(index_done)
    reward = (reward * mask.float()).sum(0)
    return reward.mean().item()
def doubleqlearning_temporal_difference(q, action, q_target, reward, must_bootstrap, discount_factor)
Expand source code
def doubleqlearning_temporal_difference(
    q, action, q_target, reward, must_bootstrap, discount_factor
):
    action_max = q.max(-1)[1]
    q_target_max = _index(q_target, action_max).detach()[1:]

    mb = must_bootstrap.float()
    target = reward[1:] + discount_factor * q_target_max * mb

    q = _index(q, action)[:-1]
    td = target - q
    to_add = torch.zeros(1, td.size()[1], device=td.device)
    td = torch.cat([td, to_add], dim=0)
    return td
def gae(critic, reward, must_bootstrap, discount_factor, gae_coef)
Expand source code
def gae(critic, reward, must_bootstrap, discount_factor, gae_coef):
    mb = must_bootstrap.float()
    td = reward[1:] + discount_factor * mb * critic[1:].detach() - critic[:-1]
    # handling td0 case
    if gae_coef == 0.0:
        return td

    td_shape = td.shape[0]
    gae_val = td[-1]
    gaes = [gae_val]
    for t in range(td_shape - 2, -1, -1):
        gae_val = td[t] + discount_factor * gae_coef * mb[:-1][t] * gae_val
        gaes.append(gae_val)
    gaes = list([g.unsqueeze(0) for g in reversed(gaes)])
    gaes = torch.cat(gaes, dim=0)
    return gaes
def temporal_difference(critic, reward, must_bootstrap, discount_factor)
Expand source code
def temporal_difference(critic, reward, must_bootstrap, discount_factor):
    target = discount_factor * critic[1:].detach() * must_bootstrap.float() + reward[1:]
    td = target - critic[:-1]
    to_add = torch.zeros(1, td.size()[1]).to(td.device)
    td = torch.cat([td, to_add], dim=0)
    return td