jax>=0.4.20
optax>=0.2.0
numpy
matplotlib
