jax
jaxlib
tensorflow-probability>=0.19.0
numpy
scipy
optax
jaxtyping>=0.2.15
tqdm
datasets
