jax>=0.4.1
flax>=0.6.0
optax>=0.1.3
numpy>=1.21.0

[dev]
pytest>=7.0.0
black>=22.0.0
mypy>=0.950

[tpu]
jax[tpu]>=0.4.1
