numpy
jax>=0.5.0
optax>=0.2.5
lineax
scipy
matplotlib
