chex==0.1.91
grain==0.2.13
jax==0.8.0
flax==0.12.0
ml_dtypes==0.5.3
optax==0.2.6
orbax-checkpoint==0.11.26
orbax-export==0.0.8

[dev]
pytest
pytest-xdist

[grain]

[tfds]
tensorflow==2.20.0
tensorflow_datasets==4.9.9
