datasets>=4.6.1
flax>=0.12.5
grain>=0.2.16
jax>=0.9.2
optax>=0.2.6
orbax>=0.1.9
safetensors>=0.7.0

[cpu]

[cuda]
jax[cuda]>=0.9.2

[cuda12]
jax[cuda12]>=0.9.2

[cuda12-local]
jax[cuda12-local]>=0.9.2

[cuda13]
jax[cuda13]>=0.9.2

[cuda13-local]
jax[cuda13-local]>=0.9.2

[dev]
pytest>=9.0.2
twine>=6.0.0
build>=1.2.0

[tpu]
jax[tpu]>=0.9.2
