numpy>=1.24.0
scipy>=1.10.0
jax>=0.5.0

[all]
jaxamg[cuda12,dev,mpi]

[cuda12]
jax[cuda12]>=0.5.0

[cuda13]
jax[cuda13]>=0.5.0

[dev]
pytest>=7.0.0
pytest-mpi>=0.6
pytest-mypy
black
mypy
ruff
pre-commit
zensical
mkdocstrings[python]
lineax

[mpi]
mpi4py>=3.1.0
mpi4jax>=0.9.0
