jax==0.4.30
jaxlib==0.4.30
flax==0.8.5
optax==0.2.4
chex==0.1.90
numpy==1.26.4
scipy==1.13.1
pandas==2.2.3
matplotlib==3.9.4

[cuda12]
jax[cuda12]==0.4.30
jaxlib==0.4.30
