numpy
scipy
jax>=0.7.0
optax>=0.2.5

[dev]
pytest>=7.0
pytest-cov>=4.0.0
black>=23.0
isort
