jax!=0.5.3,>=0.5.0
mujoco>=3.3.0
mujoco-mjx>=3.3.0
gymnasium
imageio
einops
flax
ml_collections
casadi
numpy

[benchmark]
fire

[gpu]
jax[cuda12]

[test]
pytest>=8.0.0
pytest-cov
pytest-timeout
