Module bbrl.visu.visu_policies

Expand source code
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
#
import random

import matplotlib.pyplot as plt
import numpy as np
import torch as th

from bbrl.visu.common import final_show


def plot_policy(
    agent, env, directory, env_name, best_reward, plot=False, stochastic=False
):
    if "cartpole" in env_name.lower():
        plot_env = plot_cartpole_policy
    elif "pendulum" in env_name.lower():
        plot_env = plot_pendulum_policy
    else:
        print("Environment not supported for plot. Please use CartPole or Pendulum")
        return
    save_figure = True
    figname = f"policy_{env_name}_{best_reward}.png"
    plot_env(agent, env, directory, figname, plot, save_figure, stochastic)


def plot_pendulum_policy(
    agent, env, directory, figname, plot=True, save_figure=True, stochastic=None
):
    """
    Plot an agent for the Pendulum environment
    :param agent: the policy specifying the action to be plotted
    :param env: the evaluation environment
    :param figname: the name of the file to save the figure
    :param directory: the path to the file to save the figure
    :param plot: whether the plot should be interactive
    :param save_figure: whether the figure should be saved
    :param stochastic: whether one wants to plot a deterministic or stochastic policy
    :return: nothing
    """
    if env.observation_space.shape[0] <= 2:
        msg = f"Observation space dim {env.observation_space.shape[0]}, should be > 2"
        raise (ValueError(msg))
    definition = 100
    portrait = np.zeros((definition, definition))
    state_min = env.observation_space.low
    state_max = env.observation_space.high

    for index_t, t in enumerate(np.linspace(-np.pi, np.pi, num=definition)):
        for index_td, td in enumerate(
            np.linspace(state_min[2], state_max[2], num=definition)
        ):
            obs = np.array([[np.cos(t), np.sin(t), td]])
            obs = th.from_numpy(obs.astype(np.float32))
            action = agent.predict_action(obs, stochastic)

            portrait[definition - (1 + index_td), index_t] = action.item()

    plt.figure(figsize=(10, 10))
    plt.imshow(
        portrait,
        cmap="inferno",
        extent=[-np.pi, np.pi, state_min[2], state_max[2]],
        aspect="auto",
    )

    title = "Pendulum Actor"
    plt.colorbar(label="action")
    directory += "/pendulum_policies/"
    # Add a point at the center
    plt.scatter([0], [0])
    x_label, y_label = getattr(env.observation_space, "names", ["x", "y"])
    final_show(save_figure, plot, directory, figname, x_label, y_label, title)


def plot_cartpole_policy(
    agent, env, directory, figname, plot=True, save_figure=True, stochastic=None
):
    """
    Visualization of a policy in a N-dimensional state space
    The N-dimensional state space is projected into its first two dimensions.
    A FeatureInverter wrapper should be used to select which features to put first to plot them
    :param agent: the policy agent to be plotted
    :param env: the environment
    :param figname: the name of the file to save the figure
    :param directory: the path to the file to save the figure
    :param plot: whether the plot should be interactive
    :param save_figure: whether the figure should be saved
    :param stochastic: whether one wants to plot a deterministic or stochastic policy
    :return: nothing
    """
    if env.observation_space.shape[0] <= 2:
        msg = f"Observation space dim {env.observation_space.shape[0]}, should be > 2"
        raise (ValueError(msg))
    definition = 100
    portrait = np.zeros((definition, definition))
    state_min = env.observation_space.low
    state_max = env.observation_space.high

    for index_x, x in enumerate(
        np.linspace(state_min[0], state_max[0], num=definition)
    ):
        for index_y, y in enumerate(
            np.linspace(state_min[2], state_max[2], num=definition)
        ):
            obs = np.array([x])
            z1 = random.random() - 0.5
            z2 = random.random() - 0.5
            obs = np.append(obs, z1)
            obs = np.append(obs, y)
            obs = np.append(obs, z2)
            obs = th.from_numpy(obs.astype(np.float32))
            action = agent.predict_action(obs, stochastic)

            portrait[definition - (1 + index_y), index_x] = action.item()

    plt.figure(figsize=(10, 10))
    plt.imshow(
        portrait,
        cmap="inferno",
        extent=[state_min[0], state_max[0], state_min[2], state_max[2]],
        aspect="auto",
    )

    title = "Cartpole Actor"
    plt.colorbar(label="action")
    directory += "/cartpole_policies/"
    # Add a point at the center
    plt.scatter([0], [0])
    x_label, y_label = getattr(env.observation_space, "names", ["x", "y"])
    final_show(save_figure, plot, directory, figname, x_label, y_label, title)

Functions

def plot_cartpole_policy(agent, env, directory, figname, plot=True, save_figure=True, stochastic=None)

Visualization of a policy in a N-dimensional state space The N-dimensional state space is projected into its first two dimensions. A FeatureInverter wrapper should be used to select which features to put first to plot them :param agent: the policy agent to be plotted :param env: the environment :param figname: the name of the file to save the figure :param directory: the path to the file to save the figure :param plot: whether the plot should be interactive :param save_figure: whether the figure should be saved :param stochastic: whether one wants to plot a deterministic or stochastic policy :return: nothing

Expand source code
def plot_cartpole_policy(
    agent, env, directory, figname, plot=True, save_figure=True, stochastic=None
):
    """
    Visualization of a policy in a N-dimensional state space
    The N-dimensional state space is projected into its first two dimensions.
    A FeatureInverter wrapper should be used to select which features to put first to plot them
    :param agent: the policy agent to be plotted
    :param env: the environment
    :param figname: the name of the file to save the figure
    :param directory: the path to the file to save the figure
    :param plot: whether the plot should be interactive
    :param save_figure: whether the figure should be saved
    :param stochastic: whether one wants to plot a deterministic or stochastic policy
    :return: nothing
    """
    if env.observation_space.shape[0] <= 2:
        msg = f"Observation space dim {env.observation_space.shape[0]}, should be > 2"
        raise (ValueError(msg))
    definition = 100
    portrait = np.zeros((definition, definition))
    state_min = env.observation_space.low
    state_max = env.observation_space.high

    for index_x, x in enumerate(
        np.linspace(state_min[0], state_max[0], num=definition)
    ):
        for index_y, y in enumerate(
            np.linspace(state_min[2], state_max[2], num=definition)
        ):
            obs = np.array([x])
            z1 = random.random() - 0.5
            z2 = random.random() - 0.5
            obs = np.append(obs, z1)
            obs = np.append(obs, y)
            obs = np.append(obs, z2)
            obs = th.from_numpy(obs.astype(np.float32))
            action = agent.predict_action(obs, stochastic)

            portrait[definition - (1 + index_y), index_x] = action.item()

    plt.figure(figsize=(10, 10))
    plt.imshow(
        portrait,
        cmap="inferno",
        extent=[state_min[0], state_max[0], state_min[2], state_max[2]],
        aspect="auto",
    )

    title = "Cartpole Actor"
    plt.colorbar(label="action")
    directory += "/cartpole_policies/"
    # Add a point at the center
    plt.scatter([0], [0])
    x_label, y_label = getattr(env.observation_space, "names", ["x", "y"])
    final_show(save_figure, plot, directory, figname, x_label, y_label, title)
def plot_pendulum_policy(agent, env, directory, figname, plot=True, save_figure=True, stochastic=None)

Plot an agent for the Pendulum environment :param agent: the policy specifying the action to be plotted :param env: the evaluation environment :param figname: the name of the file to save the figure :param directory: the path to the file to save the figure :param plot: whether the plot should be interactive :param save_figure: whether the figure should be saved :param stochastic: whether one wants to plot a deterministic or stochastic policy :return: nothing

Expand source code
def plot_pendulum_policy(
    agent, env, directory, figname, plot=True, save_figure=True, stochastic=None
):
    """
    Plot an agent for the Pendulum environment
    :param agent: the policy specifying the action to be plotted
    :param env: the evaluation environment
    :param figname: the name of the file to save the figure
    :param directory: the path to the file to save the figure
    :param plot: whether the plot should be interactive
    :param save_figure: whether the figure should be saved
    :param stochastic: whether one wants to plot a deterministic or stochastic policy
    :return: nothing
    """
    if env.observation_space.shape[0] <= 2:
        msg = f"Observation space dim {env.observation_space.shape[0]}, should be > 2"
        raise (ValueError(msg))
    definition = 100
    portrait = np.zeros((definition, definition))
    state_min = env.observation_space.low
    state_max = env.observation_space.high

    for index_t, t in enumerate(np.linspace(-np.pi, np.pi, num=definition)):
        for index_td, td in enumerate(
            np.linspace(state_min[2], state_max[2], num=definition)
        ):
            obs = np.array([[np.cos(t), np.sin(t), td]])
            obs = th.from_numpy(obs.astype(np.float32))
            action = agent.predict_action(obs, stochastic)

            portrait[definition - (1 + index_td), index_t] = action.item()

    plt.figure(figsize=(10, 10))
    plt.imshow(
        portrait,
        cmap="inferno",
        extent=[-np.pi, np.pi, state_min[2], state_max[2]],
        aspect="auto",
    )

    title = "Pendulum Actor"
    plt.colorbar(label="action")
    directory += "/pendulum_policies/"
    # Add a point at the center
    plt.scatter([0], [0])
    x_label, y_label = getattr(env.observation_space, "names", ["x", "y"])
    final_show(save_figure, plot, directory, figname, x_label, y_label, title)
def plot_policy(agent, env, directory, env_name, best_reward, plot=False, stochastic=False)
Expand source code
def plot_policy(
    agent, env, directory, env_name, best_reward, plot=False, stochastic=False
):
    if "cartpole" in env_name.lower():
        plot_env = plot_cartpole_policy
    elif "pendulum" in env_name.lower():
        plot_env = plot_pendulum_policy
    else:
        print("Environment not supported for plot. Please use CartPole or Pendulum")
        return
    save_figure = True
    figname = f"policy_{env_name}_{best_reward}.png"
    plot_env(agent, env, directory, figname, plot, save_figure, stochastic)