Module bbrl.visu.visu_critics

Expand source code
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
#
import random

import matplotlib.pyplot as plt
import numpy as np
import torch as th

from bbrl.visu.common import final_show


def plot_critic(agent, env, directory, env_name, best_reward, plot=False) -> None:
    figure_name = f"critic_{env_name}_{best_reward}.png"
    if agent.is_q_function:
        if "cartpole" in env_name.lower():
            env_string = "CartPole"
            plot_env = plot_cartpole_critic_q
        elif "pendulum" in env_name.lower():
            env_string = "Pendulum"
            plot_env = plot_pendulum_critic_q
        else:
            env_string = env_name
            plot_env = plot_any_env_critic_q
        plot_env(agent, env, env_string, directory, figure_name, plot, action=None)
    else:
        if "cartpole" in env_name.lower():
            env_string = "CartPole"
            plot_env = plot_cartpole_critic_v
        elif "pendulum" in env_name.lower():
            env_string = "Pendulum"
            plot_env = plot_pendulum_critic_v
        else:
            env_string = env_name
            plot_env = plot_any_env_critic_v
        plot_env(agent, env, env_string, directory, figure_name, plot)


def plot_pendulum_critic_v(
    agent,
    env,
    env_string,
    directory,
    figure_name,
    plot=True,
    save_figure=True,
    stochastic=None,
):
    """
    Plot a critic for the Pendulum environment
    :param agent: the critic agent to be plotted
    :param env: the evaluation environment
    :param env_string: the name of the environment
    :param plot: whether the plot should be interactive
    :param directory: the directory where to save the figure
    :param figure_name: the name of the file to save the figure
    :param save_figure: whether the figure should be saved
    :param stochastic: whether we plot the deterministic or stochastic version
    :return: nothing
    """
    if env.observation_space.shape[0] <= 2:
        msg = f"Observation space dim {env.observation_space.shape[0]}, should be > 2"
        raise (ValueError(msg))
    definition = 100
    portrait = np.zeros((definition, definition))
    state_min = env.observation_space.low
    state_max = env.observation_space.high

    for index_t, t in enumerate(np.linspace(-np.pi, np.pi, num=definition)):
        for index_td, td in enumerate(
            np.linspace(state_min[2], state_max[2], num=definition)
        ):
            obs = np.array([[np.cos(t), np.sin(t), td]])
            obs = th.from_numpy(obs.astype(np.float32))

            value = agent.model(obs).squeeze(-1)

            portrait[definition - (1 + index_td), index_t] = value.item()

    plt.figure(figsize=(10, 10))
    plt.imshow(
        portrait,
        cmap="inferno",
        extent=[-180, 180, state_min[2], state_max[2]],
        aspect="auto",
    )

    directory += "/" + env_string + "_critics/"
    title = env_string + " Critic"
    plt.colorbar(label="critic value")

    # Add a point at the center
    plt.scatter([0], [0])
    x_label, y_label = getattr(env.observation_space, "names", ["x", "y"])
    final_show(save_figure, plot, directory, figure_name, x_label, y_label, title)


def plot_cartpole_critic_v(
    agent,
    env,
    env_string,
    directory,
    figure_name,
    plot=True,
    save_figure=True,
    stochastic=None,
):
    """
    Visualization of the critic in a N-dimensional state space
    The N-dimensional state space is projected into its first two dimensions.
    A FeatureInverter wrapper should be used to select which features to put first to plot them
    :param agent: the critic agent to be plotted
    :param env: the environment
    :param env_string: the name of the environment
    :param plot: whether the plot should be interactive
    :param directory: the directory where to save the figure
    :param figure_name: the name of the file where to plot the function
    :param save_figure: whether the plot should be saved into a file
    :param stochastic: whether we plot the deterministic or stochastic version
    :return: nothing
    """
    if env.observation_space.shape[0] <= 2:
        msg = f"Observation space dim {env.observation_space.shape[0]}, should be > 2"
        raise (ValueError(msg))
    definition = 100
    portrait = np.zeros((definition, definition))
    state_min = env.observation_space.low
    state_max = env.observation_space.high

    for index_x, x in enumerate(
        np.linspace(state_min[0], state_max[0], num=definition)
    ):
        for index_y, y in enumerate(
            np.linspace(state_min[2], state_max[2], num=definition)
        ):
            obs = np.array([x])
            z1 = random.random() - 0.5
            z2 = random.random() - 0.5
            obs = np.append(obs, z1)
            obs = np.append(obs, y)
            obs = np.append(obs, z2)
            obs = obs.reshape(1, -1)
            obs = th.from_numpy(obs.astype(np.float32))
            value = agent.model(obs).squeeze(-1)
            portrait[definition - (1 + index_y), index_x] = value.item()

    plt.figure(figsize=(10, 10))
    plt.imshow(
        portrait,
        cmap="inferno",
        extent=[state_min[0], state_max[0], state_min[2], state_max[2]],
        aspect="auto",
    )

    directory += "/" + env_string + "_critics/"
    title = env_string + " Critic"
    plt.colorbar(label="critic value")

    # Add a point at the center
    plt.scatter([0], [0])
    x_label, y_label = getattr(env.observation_space, "names", ["x", "y"])
    final_show(save_figure, plot, directory, figure_name, x_label, y_label, title)


def plot_any_env_critic_v(
    agent,
    env,
    env_string,
    directory,
    figure_name,
    plot=True,
    save_figure=True,
    stochastic=None,
):
    """
    Visualization of the critic in a N-dimensional state space
    The N-dimensional state space is projected into its first two dimensions.
    A FeatureInverter wrapper should be used to select which features to put first to plot them
    :param agent: the critic agent to be plotted
    :param env: the environment
    :param env_string: the name of the environment
    :param plot: whether the plot should be interactive
    :param directory: the directory where to save the figure
    :param figure_name: the name of the file where to plot the function
    :param save_figure: whether the plot should be saved into a file
    :param stochastic: whether we plot the deterministic or stochastic version
    :return: nothing
    """
    if env.observation_space.shape[0] <= 2:
        msg = f"Observation space dim {env.observation_space.shape[0]}, should be > 2"
        raise (ValueError(msg))
    definition = 100
    portrait = np.zeros((definition, definition))
    state_min = env.observation_space.low
    state_max = env.observation_space.high

    for index_x, x in enumerate(
        np.linspace(state_min[0], state_max[0], num=definition)
    ):
        for index_y, y in enumerate(
            np.linspace(state_min[1], state_max[1], num=definition)
        ):
            obs = np.array([x])
            obs = np.append(obs, y)
            for z in range(state_min.size() - 2):
                z1 = random.random() - 0.5
                obs = np.append(obs, z1)
            obs = obs.reshape(1, -1)
            obs = th.from_numpy(obs.astype(np.float32))
            value = agent.model(obs).squeeze(-1)
            portrait[definition - (1 + index_y), index_x] = value.item()

    plt.figure(figsize=(10, 10))
    plt.imshow(
        portrait,
        cmap="inferno",
        extent=[state_min[0], state_max[0], state_min[1], state_max[1]],
        aspect="auto",
    )

    directory += "/" + env_string + "_critics/"
    title = env_string + " Critic"
    plt.colorbar(label="critic value")

    # Add a point at the center
    plt.scatter([0], [0])
    x_label, y_label = getattr(env.observation_space, "names", ["x", "y"])
    final_show(save_figure, plot, directory, figure_name, x_label, y_label, title)


def plot_pendulum_critic_q(
    agent,
    env,
    env_string,
    directory,
    figure_name,
    plot=True,
    save_figure=True,
    action=None,
):
    """
    Plot a critic for the Pendulum environment
    :param agent: the critic agent to be plotted
    :param env: the evaluation environment
    :param env_string: the name of the environment
    :param plot: whether the plot should be interactive
    :param directory: the directory where to save the figure
    :param figure_name: the name of the file to save the figure
    :param save_figure: whether the figure should be saved
    :param action: the action for which we want to plot the value
    :return: nothing
    """
    if env.observation_space.shape[0] <= 2:
        msg = f"Observation space dim {env.observation_space.shape[0]}, should be > 2"
        raise (ValueError(msg))
    definition = 100
    portrait = np.zeros((definition, definition))
    state_min = env.observation_space.low
    state_max = env.observation_space.high

    for index_t, t in enumerate(np.linspace(-np.pi, np.pi, num=definition)):
        for index_td, td in enumerate(
            np.linspace(state_min[2], state_max[2], num=definition)
        ):
            obs = np.array([[np.cos(t), np.sin(t), td]])
            obs = th.from_numpy(obs.astype(np.float32))

            if action is None:
                action = th.Tensor([0])
            value = agent.predict_value(obs[0], action)

            portrait[definition - (1 + index_td), index_t] = value.item()

    plt.figure(figsize=(10, 10))
    plt.imshow(
        portrait,
        cmap="inferno",
        extent=[-180, 180, state_min[2], state_max[2]],
        aspect="auto",
    )

    directory += "/" + env_string + "_critics/"
    title = env_string + " Critic"
    plt.colorbar(label="critic value")

    # Add a point at the center
    plt.scatter([0], [0])
    x_label, y_label = getattr(env.observation_space, "names", ["x", "y"])
    final_show(save_figure, plot, directory, figure_name, x_label, y_label, title)


def plot_cartpole_critic_q(
    agent,
    env,
    env_string,
    directory,
    figure_name,
    plot=True,
    save_figure=True,
    action=None,
):
    """
    Visualization of the critic in a N-dimensional state space
    The N-dimensional state space is projected into its first two dimensions.
    A FeatureInverter wrapper should be used to select which features to put first to plot them
    :param agent: the critic agent to be plotted
    :param env: the environment
    :param env_string: the name of the environment
    :param plot: whether the plot should be interactive
    :param directory: the directory where to save the figure
    :param figure_name: the name of the file where to plot the function
    :param save_figure: whether the plot should be saved into a file
    :param action: the action for which we want to plot the value
    :return: nothing
    """
    if env.observation_space.shape[0] <= 2:
        msg = f"Observation space dim {env.observation_space.shape[0]}, should be > 2"
        raise (ValueError(msg))
    definition = 100
    portrait = np.zeros((definition, definition))
    state_min = env.observation_space.low
    state_max = env.observation_space.high

    for index_x, x in enumerate(
        np.linspace(state_min[0], state_max[0], num=definition)
    ):
        for index_y, y in enumerate(
            np.linspace(state_min[2], state_max[2], num=definition)
        ):
            obs = np.array([x])
            z1 = random.random() - 0.5
            z2 = random.random() - 0.5
            obs = np.append(obs, z1)
            obs = np.append(obs, y)
            obs = np.append(obs, z2)
            obs = obs.reshape(1, -1)
            obs = th.from_numpy(obs.astype(np.float32))

            if action is None:
                action = th.Tensor([0])
            value = agent.predict_value(obs[0], action)

            portrait[definition - (1 + index_y), index_x] = value.item()

    plt.figure(figsize=(10, 10))
    plt.imshow(
        portrait,
        cmap="inferno",
        extent=[state_min[0], state_max[0], state_min[2], state_max[2]],
        aspect="auto",
    )

    directory += "/" + env_string + "_critics/"
    title = env_string + " Critic"
    plt.colorbar(label="critic value")

    # Add a point at the center
    plt.scatter([0], [0])
    x_label, y_label = getattr(env.observation_space, "names", ["x", "y"])
    final_show(save_figure, plot, directory, figure_name, x_label, y_label, title)


def plot_any_env_critic_q(
    agent,
    env,
    env_string,
    directory,
    figure_name,
    plot=True,
    save_figure=True,
    action=None,
):
    """
    Visualization of the critic in a N-dimensional state space
    The N-dimensional state space is projected into its first two dimensions.
    A FeatureInverter wrapper should be used to select which features to put first to plot them
    :param agent: the critic agent to be plotted
    :param env: the environment
    :param env_string: the name of the environment
    :param plot: whether the plot should be interactive
    :param directory: the directory where to save the figure
    :param figure_name: the name of the file where to plot the function
    :param save_figure: whether the plot should be saved into a file
    :param action: the action for which we want to plot the value
    :return: nothing
    """
    if env.observation_space.shape[0] <= 2:
        msg = f"Observation space dim {env.observation_space.shape[0]}, should be > 2"
        raise (ValueError(msg))
    definition = 100
    portrait = np.zeros((definition, definition))
    state_min = env.observation_space.low
    state_max = env.observation_space.high

    for index_x, x in enumerate(
        np.linspace(state_min[0], state_max[0], num=definition)
    ):
        for index_y, y in enumerate(
            np.linspace(state_min[1], state_max[1], num=definition)
        ):
            obs = np.array([x])
            obs = np.append(obs, y)
            for z in range(state_min.size() - 2):
                z1 = random.random() - 0.5
                obs = np.append(obs, z1)
            obs = obs.reshape(1, -1)
            obs = th.from_numpy(obs.astype(np.float32))

            if action is None:
                action = th.Tensor([0])
            value = agent.predict_value(obs[0], action)

            portrait[definition - (1 + index_y), index_x] = value.item()

    plt.figure(figsize=(10, 10))
    plt.imshow(
        portrait,
        cmap="inferno",
        extent=[state_min[0], state_max[0], state_min[1], state_max[1]],
        aspect="auto",
    )

    directory += "/" + env_string + "_critics/"
    title = env_string + " Critic"
    plt.colorbar(label="critic value")

    # Add a point at the center
    plt.scatter([0], [0])
    x_label, y_label = getattr(env.observation_space, "names", ["x", "y"])
    final_show(save_figure, plot, directory, figure_name, x_label, y_label, title)

Functions

def plot_any_env_critic_q(agent, env, env_string, directory, figure_name, plot=True, save_figure=True, action=None)

Visualization of the critic in a N-dimensional state space The N-dimensional state space is projected into its first two dimensions. A FeatureInverter wrapper should be used to select which features to put first to plot them :param agent: the critic agent to be plotted :param env: the environment :param env_string: the name of the environment :param plot: whether the plot should be interactive :param directory: the directory where to save the figure :param figure_name: the name of the file where to plot the function :param save_figure: whether the plot should be saved into a file :param action: the action for which we want to plot the value :return: nothing

Expand source code
def plot_any_env_critic_q(
    agent,
    env,
    env_string,
    directory,
    figure_name,
    plot=True,
    save_figure=True,
    action=None,
):
    """
    Visualization of the critic in a N-dimensional state space
    The N-dimensional state space is projected into its first two dimensions.
    A FeatureInverter wrapper should be used to select which features to put first to plot them
    :param agent: the critic agent to be plotted
    :param env: the environment
    :param env_string: the name of the environment
    :param plot: whether the plot should be interactive
    :param directory: the directory where to save the figure
    :param figure_name: the name of the file where to plot the function
    :param save_figure: whether the plot should be saved into a file
    :param action: the action for which we want to plot the value
    :return: nothing
    """
    if env.observation_space.shape[0] <= 2:
        msg = f"Observation space dim {env.observation_space.shape[0]}, should be > 2"
        raise (ValueError(msg))
    definition = 100
    portrait = np.zeros((definition, definition))
    state_min = env.observation_space.low
    state_max = env.observation_space.high

    for index_x, x in enumerate(
        np.linspace(state_min[0], state_max[0], num=definition)
    ):
        for index_y, y in enumerate(
            np.linspace(state_min[1], state_max[1], num=definition)
        ):
            obs = np.array([x])
            obs = np.append(obs, y)
            for z in range(state_min.size() - 2):
                z1 = random.random() - 0.5
                obs = np.append(obs, z1)
            obs = obs.reshape(1, -1)
            obs = th.from_numpy(obs.astype(np.float32))

            if action is None:
                action = th.Tensor([0])
            value = agent.predict_value(obs[0], action)

            portrait[definition - (1 + index_y), index_x] = value.item()

    plt.figure(figsize=(10, 10))
    plt.imshow(
        portrait,
        cmap="inferno",
        extent=[state_min[0], state_max[0], state_min[1], state_max[1]],
        aspect="auto",
    )

    directory += "/" + env_string + "_critics/"
    title = env_string + " Critic"
    plt.colorbar(label="critic value")

    # Add a point at the center
    plt.scatter([0], [0])
    x_label, y_label = getattr(env.observation_space, "names", ["x", "y"])
    final_show(save_figure, plot, directory, figure_name, x_label, y_label, title)
def plot_any_env_critic_v(agent, env, env_string, directory, figure_name, plot=True, save_figure=True, stochastic=None)

Visualization of the critic in a N-dimensional state space The N-dimensional state space is projected into its first two dimensions. A FeatureInverter wrapper should be used to select which features to put first to plot them :param agent: the critic agent to be plotted :param env: the environment :param env_string: the name of the environment :param plot: whether the plot should be interactive :param directory: the directory where to save the figure :param figure_name: the name of the file where to plot the function :param save_figure: whether the plot should be saved into a file :param stochastic: whether we plot the deterministic or stochastic version :return: nothing

Expand source code
def plot_any_env_critic_v(
    agent,
    env,
    env_string,
    directory,
    figure_name,
    plot=True,
    save_figure=True,
    stochastic=None,
):
    """
    Visualization of the critic in a N-dimensional state space
    The N-dimensional state space is projected into its first two dimensions.
    A FeatureInverter wrapper should be used to select which features to put first to plot them
    :param agent: the critic agent to be plotted
    :param env: the environment
    :param env_string: the name of the environment
    :param plot: whether the plot should be interactive
    :param directory: the directory where to save the figure
    :param figure_name: the name of the file where to plot the function
    :param save_figure: whether the plot should be saved into a file
    :param stochastic: whether we plot the deterministic or stochastic version
    :return: nothing
    """
    if env.observation_space.shape[0] <= 2:
        msg = f"Observation space dim {env.observation_space.shape[0]}, should be > 2"
        raise (ValueError(msg))
    definition = 100
    portrait = np.zeros((definition, definition))
    state_min = env.observation_space.low
    state_max = env.observation_space.high

    for index_x, x in enumerate(
        np.linspace(state_min[0], state_max[0], num=definition)
    ):
        for index_y, y in enumerate(
            np.linspace(state_min[1], state_max[1], num=definition)
        ):
            obs = np.array([x])
            obs = np.append(obs, y)
            for z in range(state_min.size() - 2):
                z1 = random.random() - 0.5
                obs = np.append(obs, z1)
            obs = obs.reshape(1, -1)
            obs = th.from_numpy(obs.astype(np.float32))
            value = agent.model(obs).squeeze(-1)
            portrait[definition - (1 + index_y), index_x] = value.item()

    plt.figure(figsize=(10, 10))
    plt.imshow(
        portrait,
        cmap="inferno",
        extent=[state_min[0], state_max[0], state_min[1], state_max[1]],
        aspect="auto",
    )

    directory += "/" + env_string + "_critics/"
    title = env_string + " Critic"
    plt.colorbar(label="critic value")

    # Add a point at the center
    plt.scatter([0], [0])
    x_label, y_label = getattr(env.observation_space, "names", ["x", "y"])
    final_show(save_figure, plot, directory, figure_name, x_label, y_label, title)
def plot_cartpole_critic_q(agent, env, env_string, directory, figure_name, plot=True, save_figure=True, action=None)

Visualization of the critic in a N-dimensional state space The N-dimensional state space is projected into its first two dimensions. A FeatureInverter wrapper should be used to select which features to put first to plot them :param agent: the critic agent to be plotted :param env: the environment :param env_string: the name of the environment :param plot: whether the plot should be interactive :param directory: the directory where to save the figure :param figure_name: the name of the file where to plot the function :param save_figure: whether the plot should be saved into a file :param action: the action for which we want to plot the value :return: nothing

Expand source code
def plot_cartpole_critic_q(
    agent,
    env,
    env_string,
    directory,
    figure_name,
    plot=True,
    save_figure=True,
    action=None,
):
    """
    Visualization of the critic in a N-dimensional state space
    The N-dimensional state space is projected into its first two dimensions.
    A FeatureInverter wrapper should be used to select which features to put first to plot them
    :param agent: the critic agent to be plotted
    :param env: the environment
    :param env_string: the name of the environment
    :param plot: whether the plot should be interactive
    :param directory: the directory where to save the figure
    :param figure_name: the name of the file where to plot the function
    :param save_figure: whether the plot should be saved into a file
    :param action: the action for which we want to plot the value
    :return: nothing
    """
    if env.observation_space.shape[0] <= 2:
        msg = f"Observation space dim {env.observation_space.shape[0]}, should be > 2"
        raise (ValueError(msg))
    definition = 100
    portrait = np.zeros((definition, definition))
    state_min = env.observation_space.low
    state_max = env.observation_space.high

    for index_x, x in enumerate(
        np.linspace(state_min[0], state_max[0], num=definition)
    ):
        for index_y, y in enumerate(
            np.linspace(state_min[2], state_max[2], num=definition)
        ):
            obs = np.array([x])
            z1 = random.random() - 0.5
            z2 = random.random() - 0.5
            obs = np.append(obs, z1)
            obs = np.append(obs, y)
            obs = np.append(obs, z2)
            obs = obs.reshape(1, -1)
            obs = th.from_numpy(obs.astype(np.float32))

            if action is None:
                action = th.Tensor([0])
            value = agent.predict_value(obs[0], action)

            portrait[definition - (1 + index_y), index_x] = value.item()

    plt.figure(figsize=(10, 10))
    plt.imshow(
        portrait,
        cmap="inferno",
        extent=[state_min[0], state_max[0], state_min[2], state_max[2]],
        aspect="auto",
    )

    directory += "/" + env_string + "_critics/"
    title = env_string + " Critic"
    plt.colorbar(label="critic value")

    # Add a point at the center
    plt.scatter([0], [0])
    x_label, y_label = getattr(env.observation_space, "names", ["x", "y"])
    final_show(save_figure, plot, directory, figure_name, x_label, y_label, title)
def plot_cartpole_critic_v(agent, env, env_string, directory, figure_name, plot=True, save_figure=True, stochastic=None)

Visualization of the critic in a N-dimensional state space The N-dimensional state space is projected into its first two dimensions. A FeatureInverter wrapper should be used to select which features to put first to plot them :param agent: the critic agent to be plotted :param env: the environment :param env_string: the name of the environment :param plot: whether the plot should be interactive :param directory: the directory where to save the figure :param figure_name: the name of the file where to plot the function :param save_figure: whether the plot should be saved into a file :param stochastic: whether we plot the deterministic or stochastic version :return: nothing

Expand source code
def plot_cartpole_critic_v(
    agent,
    env,
    env_string,
    directory,
    figure_name,
    plot=True,
    save_figure=True,
    stochastic=None,
):
    """
    Visualization of the critic in a N-dimensional state space
    The N-dimensional state space is projected into its first two dimensions.
    A FeatureInverter wrapper should be used to select which features to put first to plot them
    :param agent: the critic agent to be plotted
    :param env: the environment
    :param env_string: the name of the environment
    :param plot: whether the plot should be interactive
    :param directory: the directory where to save the figure
    :param figure_name: the name of the file where to plot the function
    :param save_figure: whether the plot should be saved into a file
    :param stochastic: whether we plot the deterministic or stochastic version
    :return: nothing
    """
    if env.observation_space.shape[0] <= 2:
        msg = f"Observation space dim {env.observation_space.shape[0]}, should be > 2"
        raise (ValueError(msg))
    definition = 100
    portrait = np.zeros((definition, definition))
    state_min = env.observation_space.low
    state_max = env.observation_space.high

    for index_x, x in enumerate(
        np.linspace(state_min[0], state_max[0], num=definition)
    ):
        for index_y, y in enumerate(
            np.linspace(state_min[2], state_max[2], num=definition)
        ):
            obs = np.array([x])
            z1 = random.random() - 0.5
            z2 = random.random() - 0.5
            obs = np.append(obs, z1)
            obs = np.append(obs, y)
            obs = np.append(obs, z2)
            obs = obs.reshape(1, -1)
            obs = th.from_numpy(obs.astype(np.float32))
            value = agent.model(obs).squeeze(-1)
            portrait[definition - (1 + index_y), index_x] = value.item()

    plt.figure(figsize=(10, 10))
    plt.imshow(
        portrait,
        cmap="inferno",
        extent=[state_min[0], state_max[0], state_min[2], state_max[2]],
        aspect="auto",
    )

    directory += "/" + env_string + "_critics/"
    title = env_string + " Critic"
    plt.colorbar(label="critic value")

    # Add a point at the center
    plt.scatter([0], [0])
    x_label, y_label = getattr(env.observation_space, "names", ["x", "y"])
    final_show(save_figure, plot, directory, figure_name, x_label, y_label, title)
def plot_critic(agent, env, directory, env_name, best_reward, plot=False) ‑> None
Expand source code
def plot_critic(agent, env, directory, env_name, best_reward, plot=False) -> None:
    figure_name = f"critic_{env_name}_{best_reward}.png"
    if agent.is_q_function:
        if "cartpole" in env_name.lower():
            env_string = "CartPole"
            plot_env = plot_cartpole_critic_q
        elif "pendulum" in env_name.lower():
            env_string = "Pendulum"
            plot_env = plot_pendulum_critic_q
        else:
            env_string = env_name
            plot_env = plot_any_env_critic_q
        plot_env(agent, env, env_string, directory, figure_name, plot, action=None)
    else:
        if "cartpole" in env_name.lower():
            env_string = "CartPole"
            plot_env = plot_cartpole_critic_v
        elif "pendulum" in env_name.lower():
            env_string = "Pendulum"
            plot_env = plot_pendulum_critic_v
        else:
            env_string = env_name
            plot_env = plot_any_env_critic_v
        plot_env(agent, env, env_string, directory, figure_name, plot)
def plot_pendulum_critic_q(agent, env, env_string, directory, figure_name, plot=True, save_figure=True, action=None)

Plot a critic for the Pendulum environment :param agent: the critic agent to be plotted :param env: the evaluation environment :param env_string: the name of the environment :param plot: whether the plot should be interactive :param directory: the directory where to save the figure :param figure_name: the name of the file to save the figure :param save_figure: whether the figure should be saved :param action: the action for which we want to plot the value :return: nothing

Expand source code
def plot_pendulum_critic_q(
    agent,
    env,
    env_string,
    directory,
    figure_name,
    plot=True,
    save_figure=True,
    action=None,
):
    """
    Plot a critic for the Pendulum environment
    :param agent: the critic agent to be plotted
    :param env: the evaluation environment
    :param env_string: the name of the environment
    :param plot: whether the plot should be interactive
    :param directory: the directory where to save the figure
    :param figure_name: the name of the file to save the figure
    :param save_figure: whether the figure should be saved
    :param action: the action for which we want to plot the value
    :return: nothing
    """
    if env.observation_space.shape[0] <= 2:
        msg = f"Observation space dim {env.observation_space.shape[0]}, should be > 2"
        raise (ValueError(msg))
    definition = 100
    portrait = np.zeros((definition, definition))
    state_min = env.observation_space.low
    state_max = env.observation_space.high

    for index_t, t in enumerate(np.linspace(-np.pi, np.pi, num=definition)):
        for index_td, td in enumerate(
            np.linspace(state_min[2], state_max[2], num=definition)
        ):
            obs = np.array([[np.cos(t), np.sin(t), td]])
            obs = th.from_numpy(obs.astype(np.float32))

            if action is None:
                action = th.Tensor([0])
            value = agent.predict_value(obs[0], action)

            portrait[definition - (1 + index_td), index_t] = value.item()

    plt.figure(figsize=(10, 10))
    plt.imshow(
        portrait,
        cmap="inferno",
        extent=[-180, 180, state_min[2], state_max[2]],
        aspect="auto",
    )

    directory += "/" + env_string + "_critics/"
    title = env_string + " Critic"
    plt.colorbar(label="critic value")

    # Add a point at the center
    plt.scatter([0], [0])
    x_label, y_label = getattr(env.observation_space, "names", ["x", "y"])
    final_show(save_figure, plot, directory, figure_name, x_label, y_label, title)
def plot_pendulum_critic_v(agent, env, env_string, directory, figure_name, plot=True, save_figure=True, stochastic=None)

Plot a critic for the Pendulum environment :param agent: the critic agent to be plotted :param env: the evaluation environment :param env_string: the name of the environment :param plot: whether the plot should be interactive :param directory: the directory where to save the figure :param figure_name: the name of the file to save the figure :param save_figure: whether the figure should be saved :param stochastic: whether we plot the deterministic or stochastic version :return: nothing

Expand source code
def plot_pendulum_critic_v(
    agent,
    env,
    env_string,
    directory,
    figure_name,
    plot=True,
    save_figure=True,
    stochastic=None,
):
    """
    Plot a critic for the Pendulum environment
    :param agent: the critic agent to be plotted
    :param env: the evaluation environment
    :param env_string: the name of the environment
    :param plot: whether the plot should be interactive
    :param directory: the directory where to save the figure
    :param figure_name: the name of the file to save the figure
    :param save_figure: whether the figure should be saved
    :param stochastic: whether we plot the deterministic or stochastic version
    :return: nothing
    """
    if env.observation_space.shape[0] <= 2:
        msg = f"Observation space dim {env.observation_space.shape[0]}, should be > 2"
        raise (ValueError(msg))
    definition = 100
    portrait = np.zeros((definition, definition))
    state_min = env.observation_space.low
    state_max = env.observation_space.high

    for index_t, t in enumerate(np.linspace(-np.pi, np.pi, num=definition)):
        for index_td, td in enumerate(
            np.linspace(state_min[2], state_max[2], num=definition)
        ):
            obs = np.array([[np.cos(t), np.sin(t), td]])
            obs = th.from_numpy(obs.astype(np.float32))

            value = agent.model(obs).squeeze(-1)

            portrait[definition - (1 + index_td), index_t] = value.item()

    plt.figure(figsize=(10, 10))
    plt.imshow(
        portrait,
        cmap="inferno",
        extent=[-180, 180, state_min[2], state_max[2]],
        aspect="auto",
    )

    directory += "/" + env_string + "_critics/"
    title = env_string + " Critic"
    plt.colorbar(label="critic value")

    # Add a point at the center
    plt.scatter([0], [0])
    x_label, y_label = getattr(env.observation_space, "names", ["x", "y"])
    final_show(save_figure, plot, directory, figure_name, x_label, y_label, title)