jax>=0.4.31
jaxlib>=0.4.31
optax>=0.2.2
matplotlib>=3.8.2
numpy>=1.24.0
einops>=0.8.0
tqdm>=4.66.1
jax-tqdm>=0.1.2
flax>=0.8.5
chex>=0.1.86
jaxtyping>=0.2.24

[dev]
pytest>=6.0
pytest-cov

[docs]
sphinx
sphinx-rtd-theme
