jax==0.6.0
jaxdecomp==0.2.9
tensorflow
orbax-checkpoint==0.11.25
optax==0.2.6
numpy
matplotlib
setuptools>=70.1.1
chex==0.1.90
