jax==0.5.3
jaxdecomp==0.2.7
tensorflow
orbax-checkpoint==0.11.18
optax==0.2.5
numpy
matplotlib
setuptools>=70.1.1

[cuda12]
jax[cuda12]==0.5.3
