torch>=2.0
triton>=3.6
numpy
numba
nvidia-cutlass-dsl
tqdm

[bench]
scikit-learn<1.8,>=1.5
tqdm
matplotlib

[dev]
pytest>=7
scikit-learn<1.8,>=1.5
tqdm
