jax>=0.4.20
optax>=0.2.0
numpy
matplotlib

[cuda12]
jax[cuda12]==0.6.0
