flax==0.8.5
optax==0.2.4
chex==0.1.90
matplotlib>=3.8

[:python_version < "3.13"]
jax==0.4.30
jaxlib==0.4.30
numpy==1.26.4
scipy==1.13.1

[:python_version >= "3.13"]
jax==0.4.35
jaxlib==0.4.35
numpy==2.1.3
scipy==1.14.1

[cuda12]

[cuda12:python_version < "3.13"]
jax[cuda12]==0.4.30

[cuda12:python_version >= "3.13"]
jax[cuda12]==0.4.35
