torch>=2.2
numpy>=1.23
scipy>=1.10
tqdm>=4.65
pyyaml>=6.0
omegaconf>=2.3
typer>=0.9
matplotlib>=3.7
scikit-learn>=1.2
pandas>=1.5
pytorch-lightning>=2.3
torchmetrics>=1.4
benchmark-mi>=0.1.3
jsonlines>=4.0.0
seaborn>=0.13.2
jax==0.4.26
jaxlib==0.4.26
optuna>=4.6.0
huggingface-hub==0.36.0

[jax]
equinox>=0.10
jaxtyping>=0.3
optax>=0.1.9
