
[cpu]
jax<0.9,>=0.8.1

[cuda12]
jax[cuda12]<0.9,>=0.8.1

[cuda13]
jax[cuda13]<0.9,>=0.8.1
