absl-py>=2.3.0
einshape
jax>=0.9.2
jaxlib>=0.9.2
jaxtyping>=0.3
pydantic>=2.11.0
qwix>=0.1.2
tqdm
typing_extensions>=4.5.0
typeguard==2.13.3
immutabledict
tensorboardx

[bench]
libtpu>=0.0.35
google-benchmark>=1.9.0
xprof

[cuda]
jax[cuda12]>=0.8.0
nvidia-cudnn-cu12>=9.0.0

[test]
chex>=0.1.91
flatbuffers
flax
pytest-xdist
pytest
xprof

[tpu]
jax[tpu]>=0.8.0
hypothesis
