jaxtyping>=0.2.21
chex>=0.1.85
optax>=0.1.8
pytest>=8.0.0
beartype
twine>=6.0.1
ipython>=8.31.0
jupyter>=1.1.1
numpy>=2.2.1
scipy>=1.14.1
matplotlib>=3.10.0
jupyter-black>=0.4.0
ipykernel>=6.29.5
pandas>=2.3.0
isort>=6.0.1
myst-parser>=4.0.0

[:sys_platform == "linux"]
jax[cuda12]>=0.4.38

[:sys_platform == "win64"]
jax>=0.4.38

[docs]
sphinx>=8.1.3
sphinx-rtd-theme>=3.0.2
myst-parser>=4.0.0
nbsphinx>=0.9.6
ipykernel>=6.29.5
pydata-sphinx-theme>=0.16.1
sphinxcontrib-napoleon>=0.7

[test]
pytest>=7.0
pytest-cov
absl-py
