jax>=0.9.0
jaxlib>=0.9.0
numpy>=1.25
optax>=0.2.3
scipy>=1.13
typing-extensions>=4.4.0

[progress]
fastprogress<1.1,>=1.0.0
