absl-py>=2.3.0
chex>=0.1.81
einshape
flax
google-benchmark>=1.9.0
jax-triton
jax[cuda12,tpu]>=0.7.2
jaxlib>=0.7.2
jaxtyping>=0.3
nvidia-cudnn-cu12>=9.0.0
protobuf
pydantic>=2.11.0
pytest-xdist
pytest
tensorflow
tqdm
triton>=3.0.0
typeguard==2.13.3
xprof
hypothesis
