jax[cuda12]==0.4.37
flax>=0.10.2
optax>=0.2.4
numpy
tqdm
mergedeep
orbax-checkpoint>=0.5.20
