jax>=0.6.1
optax
equinox
diffrax
optimistix
lineax

[gpu]
jax[cuda12]>=0.6.1
