numpy
jax>=0.4.16
jaxlib>=0.4.16

[mindep]
jax==0.4.16
jaxlib==0.4.16
numpy<2

[test]
coverage[toml]
pytest
pytest-cov
pytest-xdist
