jax
jaxlib
numpy
typing_extensions

[cpu]
jax[cpu]

[cuda12]
jax[cuda12]

[testing]
pytest

[tpu]
jax[tpu]
