jax>=0.4.20
jaxlib>=0.4.20
flax>=0.8.0
optax>=0.1.7
numpy>=1.24.0
tqdm>=4.65.0

[dev]
pytest>=7.0.0
