jax>=0.4.0
jaxlib>=0.4.0
optax>=0.1.7
numpy>=1.20.0

[benchmarks]
matplotlib
scikit-learn
tensorflow-datasets

[dev]
pytest
build
twine
