jax
jaxlib
flax>=0.12.0
numpy
opt_einsum
absl-py

[dev]
pytest

[docs]
myst-nb
sphinx
sphinx-book-theme
sphinx-tabs
jax
flax
