Metadata-Version: 2.4
Name: lerax
Version: 0.0.4
Summary: Deep Reinforcement Learning with JAX and Equinox.
Project-URL: documentation, https://docs.tedpinkerton.ca/lerax
Project-URL: homepage, https://github.com/RunnersNum40/lerax
Project-URL: issues, https://github.com/RunnersNum40/lerax/issues
Project-URL: source, https://github.com/RunnersNum40/lerax.git
Author-email: Theodore Pinkerton <ted@tedpinkerton.ca>
License-Expression: GPL-3.0-or-later
License-File: LICENSE
Requires-Python: >=3.13
Requires-Dist: diffrax
Requires-Dist: distreqx
Requires-Dist: equinox
Requires-Dist: jax
Requires-Dist: jaxtyping
Requires-Dist: optax
Requires-Dist: rich
Requires-Dist: tensorboard
Requires-Dist: tensorboardx
Provides-Extra: compatibility
Requires-Dist: flax; extra == 'compatibility'
Requires-Dist: gymnasium; extra == 'compatibility'
Requires-Dist: gymnax; extra == 'compatibility'
Requires-Dist: stable-baselines3; extra == 'compatibility'
Provides-Extra: docs
Requires-Dist: mkdocstrings-python; extra == 'docs'
Requires-Dist: zensical; extra == 'docs'
Provides-Extra: mujoco
Requires-Dist: mujoco-mjx; extra == 'mujoco'
Provides-Extra: render
Requires-Dist: imageio; extra == 'render'
Requires-Dist: imageio-ffmpeg; extra == 'render'
Requires-Dist: pygame; extra == 'render'
Description-Content-Type: text/markdown

# Lerax: Fully JITable reinforcement learning with Jax.

Lerax is a reinforcement learning library built on top of Jax, designed to facilitate the creation, training, and evaluation of RL agents in a fully JITable manner.
It provides modular components for building custom environments, policies, and training algorithms.

Built on top of [Jax](https://docs.jax.dev/en/latest/index.html), [Equinox](https://docs.kidger.site/equinox/), and [Diffrax](https://docs.kidger.site/diffrax/).

## Installation

```bash
pip install lerax
```

## Training Example

```py
from jax import random as jr

from lerax.algorithm import PPO
from lerax.env import CartPole
from lerax.policy import MLPActorCriticPolicy

env = CartPole()
policy = MLPActorCriticPolicy(env=env, key=jr.key(0))
algo = PPO()

policy = algo.learn(env, policy, total_timesteps=2**16, key=jr.key(1))
```

## Documentation

Check out: [lerax.tedpinkerton.ca](https://lerax.tedpinkerton.ca)
