datasets>=4.6.1
flax>=0.12.5
grain>=0.2.16
jax>=0.9.1
optax>=0.2.6
orbax>=0.1.9
tokenizers>=0.22.2

[dev]
pytest>=9.0.2
twine>=6.0.0
build>=1.2.0
