numpy>=1.15
brainunit>=0.0.8
absl-py
jax>=0.5.0

[cpu]
jax[cpu]>=0.7.2
numba

[cuda12]
jax[cuda12]>=0.7.2

[cuda13]
jax[cuda13]>=0.7.2

[testing]
pytest
pytest-cov

[tpu]
jax[tpu]>=0.7.2
