jax>=0.4.20
optax>=0.2.0
chex
numpy
matplotlib
