numpy>=2.1
jax>=0.6.0
jaxlib>=0.6.0
flax>=0.10.4
optax>=0.2.5
orbax-checkpoint>=0.11.5
huggingface_hub>=1.0.0
pyyaml>=6.0
safetensors>=0.5.0

[blackjax]
blackjax>=1.2.5

[viz]
matplotlib>=3.9
