jax>=0.4.20
jaxlib>=0.4.20
numpy>=1.24.0
scipy>=1.10.0
networkx>=3.0
matplotlib>=3.7.0
tqdm>=4.65.0

[dev]
pytest>=7.0.0
pytest-cov>=4.0.0
black>=23.0.0
ruff>=0.1.0
mypy>=1.0.0
pre-commit>=3.0.0

[docs]
sphinx>=6.0.0
sphinx-rtd-theme>=1.3.0
sphinx-autodoc-typehints>=1.22.0
nbsphinx>=0.9.0
myst-parser>=1.0.0

[gpu]
jax[cuda12_pip]>=0.4.20

[tpu]
jax[tpu]>=0.4.20
