jaxlib
jax[cuda12]
matplotlib==3.7.2
numpy==1.24.4
pandas==2.0.3
scikit_learn==1.3.0
scipy==1.12.0
setuptools==68.1.2
torch>=1.13.1
pyyaml
pytest

[dev]
