numpy>=1.20

[dev]
pytest>=7.0
pytest-benchmark>=4.0

[jax]
jax>=0.4
jaxlib>=0.4

[torch]
torch>=2.0
