torch>=2.1.0
numpy>=1.26.0
einops>=0.7.0
click>=8.1.0
rich>=13.0.0
transformers>=4.40.0
datasets>=2.19.0
scikit-learn>=1.4.0

[dev]
pytest>=8.0
ruff>=0.4.0

[jax]
jax>=0.4.30
flax>=0.8.4

[mamba]
mamba-ssm>=1.2
