jax>=0.4.34
optax>=0.2.4
numpy>=1.24.1
