jax>=0.4.17
numpy>=1.24.4
chex>=0.1.85
einops>=0.8.0
