jax
optax

[examples]
matplotlib
scipy
scikit-learn
scipy-stubs
seaborn

[gpu]
jax[cuda12]

[test]
pytest>=7.0
pytest-cov>=4.0
