Metadata-Version: 2.4
Name: minimal_mjx
Version: 0.1.4
Summary: Short description
Author-email: Neil Janwani <njanwani@gatech.com>
Requires-Python: >=3.10
Description-Content-Type: text/markdown
Requires-Dist: numpy
Requires-Dist: jax
Requires-Dist: playground==0.0.5
Requires-Dist: pandas
Requires-Dist: h5py
Requires-Dist: wandb
Requires-Dist: opencv-python
Provides-Extra: dev
Requires-Dist: pytest; extra == "dev"
Requires-Dist: black; extra == "dev"

# MINIMAL-MJX

WARNING:jax._src.xla_bridge:An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.

This repository represents starter code for MJX-based RL from the Dynamic Mobility
Lab. The code has been built of off MuJoCo Playground and tailored to make
policy training, saving, and evaluation easy and implements some nice-to-haves.
1. Swappable backend, allowing for fast evaluation using numpy/C++ MuJoCo and fast training using JaX/MJX
2. Base environment class that contains generic reward functions and useful functions
3. Configuration files that allow you to set parameters for your envs + for PPO

There exists a `requirements.txt` in the home directory which includes the packages
needed to run this code.

# Steps to take to 'do RL'!
We will work of off the Cheetah environment from Deepmind's "DM Control". This guide
is not meant to be comprehensive, but is hopefully enough information to fill in 
the gaps as you develop and look around the codebase.

0. **The setup basics**

    Each environment has a `reset` and `step` function, which primarily operate
    off of a Markov Decision Process (MDP) state. In our case, this is represented
    as a class that looks like this
    ```python
        @dataclass
        class MujocoState:
            data: mujoco.MjData
            obs: np.ndarray
            reward: float
            done: bool
            metrics: dict
            info: dict
    ```
    Note that you should treat info as a 'carry' variable between consecutive 
    `step` functions. Store things here you need to use in future timesteps that
    are not ordinarily stored in `data` (i.e. store a history of system states `x`).

1. **Simulate your environment to make sure things look good**
    
    Environments should be simulated and inspected before training. Once you 
    are done developing your env, make a configuration file to specify 
    variables specific to training runs, like PPO or reward weight parameters. 
    
    For the Cheetah, one is already created. To simulate the environment, run
    ```bash
    python3 -m envs.simulate envs/dmcontrol/config/cheetah.yaml
    ```

    Under `visualization`, you should see a metrics plot as well as a video of 
    your environment. Note that the policy here is simulated to output all zeros
    (see `envs/simulate.py`). Inspect these files to make sure things look good.

    For extra good measure, change the `backend` parameter in your config file 
    to `jnp` to make sure your environment is JaX compatible.

2. **Train a policy**

    Once your environment is ready and JaX compatible, you can train by running
    ```bash
    python3 -m learning.training.begin_run envs/dmcontrol/config/cheetah.yaml
    ```

    Note that this will train directly in your terminal session. A bash script
    has been provided at `learning/training/train.sh` that opens a `tmux` terminal
    for this (useful if you want to train for long periods of time). *Make sure to
    change your conda environment name in this script.*

    Within your save directory mentioned in your config file, a new directory will
    be created in which intermediate and final training results will be saved.

3. **Rollout your trained policy**

    Once your policy is finished training, simply run
    ```bash
    python3 -m eval.rollout_policy envs/dmcontrol/config/cheetah.yaml
    ```
    
    This will rollout your policy, plot metrics of your reward function, and
    save a video of your result. Make sure to use an `np` backend for quick
    evaluation!

    If you would like to rollout older policies, check the directory of your past
    training runs. There should be a config there, and you can simply replace 
    the config above with that one (running that *specific* older policy). Note 
    that your code might have changed, so this config holds a git commit that 
    references when the training was run.
