rich>=13.9.4

[all]
torch>=2.2
jax>=0.4.30
flax>=0.8.0

[flax]
flax>=0.8.0
jax>=0.4.30

[jax]
jax>=0.4.30

[pytorch]
torch>=2.2
