jax>=0.4.0
jaxlib>=0.4.0

[cuda12]

[cuda12:platform_system == "Linux" and (platform_machine == "x86_64" or platform_machine == "aarch64")]
jax[cuda12]>=0.4.0

[cuda13]

[cuda13:platform_system == "Linux" and (platform_machine == "x86_64" or platform_machine == "aarch64")]
jax[cuda13]>=0.4.0

[dev]
pytest
ruff
numpy
scipy
matplotlib
pandas
joblib
tqdm
cvxpy
optax
pydecomp
pyyaml
equinox
ipython

[docs]
mkdocs-material>=9.5
mkdocstrings[python]>=0.24
pymdown-extensions>=10.0
