jax>=0.4.0
flax>=0.8.0
optax>=0.1.0

[test]
pytest>=7.0
