einops>=0.8.1
flax>=0.12.1
jax>=0.8.1
optax>=0.2.6
orbax-checkpoint>=0.11.28

[dev]
pytest>=9.0.1
