numpy>=2.0.0
jax>=0.8.1
optax>=0.2.6
equinox>=0.13.2
matplotlib
jaxtyping
pyright>=1.1.407
pytest>=9.0.2

[cpu]
jax>=0.8.1

[dev]
pre-commit

[docs]
hippogriffe==0.2.2
griffe==1.7.3
mkdocs==1.6.1
mkdocs-include-exclude-files==0.1.0
mkdocs-ipynb==0.1.1
mkdocs-material==9.6.7
mkdocstrings==0.28.3
mkdocstrings-python==1.16.8
pymdown-extensions==10.14.3
mknotebooks

[gpu_cuda12]
jax[cuda12]>=0.8.1

[gpu_cuda13]
jax[cuda13]>=0.8.1

[notebook]
jupyter
seaborn
pandas
diffrax
