chex==0.1.90
grain==0.2.12
jax==0.7.0
flax==0.11.1
ml_dtypes==0.5.3
optax==0.2.5
orbax-checkpoint==0.11.23
orbax-export==0.0.6

[dev]
pytest
pytest-xdist

[grain]

[tfds]
tensorflow==2.20.0
tensorflow_datasets==4.9.9
