chex>=0.1.90
flax==0.10.3
imageio>=2.37.0
jax==0.4.38
jaxlib==0.4.38
matplotlib>=3.8.0
numpy>=1.26.0
pillow>=10.0.0
scipy>=1.12.0
seaborn>=0.13.0

[baselines-llm]
hydra-core>=1.3.0
omegaconf>=2.3.0
wandb>=0.16.0
tqdm>=4.66.0
openai>=1.0.0
anthropic>=0.40.0
google-genai>=0.3.0

[baselines-rl]
jaxmarl>=0.0.3
distrax>=0.1.5
optax>=0.2.0
orbax-checkpoint>=0.5.0
hydra-core>=1.3.0
omegaconf>=2.3.0
wandb>=0.16.0
pyyaml>=6.0

[cuda12]
jax[cuda12]==0.4.38

[dev]
pytest>=8.0.0
pytest-cov>=5.0.0
jaxtyping>=0.2.34
beartype>=0.18.0
ruff>=0.4.0

[gpu]
jax[cuda12]==0.4.38

[llm]
openai>=1.0.0

[play]
pygame>=2.6.0
