torch>=2.1.0

[all]
nexustrain[distributed,full,logging,speed,triton,tuning]

[dev]
pytest>=7.4.0
pytest-cov>=4.1.0
black>=23.0.0
isort>=5.12.0
mypy>=1.5.0

[distributed]
deepspeed>=0.14.0
accelerate>=0.27.0

[full]
transformers>=4.40.0
accelerate>=0.27.0
peft>=0.10.0
trl>=0.8.0
bitsandbytes>=0.43.0
datasets>=2.14.0

[logging]
wandb>=0.16.0
tensorboard>=2.14.0

[speed]
flash-attn>=2.5.0
triton>=2.1.0

[triton]
triton>=2.1.0

[tuning]
optuna>=3.5.0
