jax
jaxlib
tensorflow-probability>=0.19.0
numpy
scipy
optax
jaxtyping>=0.2.15
tqdm
datasets

[dev]
autoflake
black
flake8
isort>=5.1
mypy>=0.920
pytest-cov
pytest-random-order
pytest-xdist
pytest>=3.5.0
ipykernel
ipython
jupyter_client
notebook
jupytext
nbconvert
nbformat
matplotlib>=3.7.1
