numpy
jax
brainstate>=0.2.9
brainunit
braintools

[cpu]
jax

[cuda12]
jax[cuda12]

[cuda13]
jax[cuda13]

[testing]
pytest

[tpu]
jax[tpu]
