jax<0.10.0,>=0.9.0
jaxlib<0.10.0,>=0.9.0
numpy>=2.0.0
numpyro>=0.20.0
orbax-checkpoint>=0.11.0
flax>=0.12.0
optax>=0.2.0
matplotlib
pyswarms
scipy
tqdm
jaxtyping
seaborn
joblib
