jax>=0.3.18
jax_dataclasses>=1.4.4
numpy
typing_extensions>=4.0.0
tyro

[testing]
mypy
jax!=0.3.19
flax
hypothesis[numpy]
pytest
pytest-xdist[psutil]
pytest-cov
