absl-py>=2.3.0
chex>=0.1.81
einshape
google-benchmark>=1.9.0
jax>=0.8.0
jaxlib>=0.8.0
jaxtyping>=0.3
protobuf
pydantic>=2.11.0
tensorflow
tqdm
typeguard==2.13.3
xprof
hypothesis
jax-triton

[cuda]
jax[cuda12]>=0.8.0
nvidia-cudnn-cu12>=9.0.0
triton>=3.0.0

[test]
pytest-xdist
pytest
flax

[tpu]
jax[tpu]>=0.8.0
