Metadata-Version: 2.4
Name: lerax
Version: 0.0.5
Summary: Deep Reinforcement Learning with JAX and Equinox.
Author-email: Theodore Pinkerton <ted@tedpinkerton.ca>
License-Expression: GPL-3.0-or-later
Project-URL: documentation, https://lerax.tedpinkerton.ca/
Project-URL: homepage, https://lerax.tedpinkerton.ca/
Project-URL: issues, https://github.com/RunnersNum40/lerax/issues
Project-URL: source, https://github.com/RunnersNum40/lerax.git
Requires-Python: >=3.13
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: diffrax
Requires-Dist: distreqx>=0.0.2
Requires-Dist: equinox
Requires-Dist: jax
Requires-Dist: jaxtyping
Requires-Dist: optax
Requires-Dist: rich
Requires-Dist: tensorboard
Requires-Dist: tensorboardX
Provides-Extra: compatibility
Requires-Dist: gymnasium; extra == "compatibility"
Requires-Dist: stable-baselines3; extra == "compatibility"
Requires-Dist: gymnax; extra == "compatibility"
Requires-Dist: flax; extra == "compatibility"
Provides-Extra: docs
Requires-Dist: zensical; extra == "docs"
Requires-Dist: mkdocstrings-python; extra == "docs"
Provides-Extra: mujoco
Requires-Dist: mujoco-mjx; extra == "mujoco"
Provides-Extra: render
Requires-Dist: pygame; extra == "render"
Requires-Dist: imageio; extra == "render"
Requires-Dist: imageio-ffmpeg; extra == "render"
Dynamic: license-file

# Lerax: Fully JITable reinforcement learning with Jax.

Lerax is a reinforcement learning library built on top of Jax, designed to facilitate the creation, training, and evaluation of RL agents in a fully JITable manner.
It provides modular components for building custom environments, policies, and training algorithms.

Built on top of [Jax](https://docs.jax.dev/en/latest/index.html), [Equinox](https://docs.kidger.site/equinox/), and [Diffrax](https://docs.kidger.site/diffrax/).

## Installation

```bash
pip install lerax
```

## Training Example

```py
from jax import random as jr

from lerax.algorithm import PPO
from lerax.env import CartPole
from lerax.policy import MLPActorCriticPolicy

env = CartPole()
policy = MLPActorCriticPolicy(env=env, key=jr.key(0))
algo = PPO()

policy = algo.learn(env, policy, total_timesteps=2**16, key=jr.key(1))
```

## Documentation

Check out: [lerax.tedpinkerton.ca](https://lerax.tedpinkerton.ca)
