absl-py>=2.3.0
einshape
jax>=0.8.1
jaxlib>=0.8.1
jaxtyping>=0.3
pydantic>=2.11.0
qwix>=0.1.2
tensorflow
tqdm
typing_extensions>=4.5.0
typeguard==2.13.3
xprof
immutabledict

[bench]
google-benchmark>=1.9.0

[cuda]
jax[cuda12]>=0.8.0
nvidia-cudnn-cu12>=9.0.0

[test]
chex>=0.1.91
pytest-xdist
pytest
flax

[tpu]
jax[tpu]>=0.8.0
hypothesis
