jax[cuda12]==0.4.37
flax==0.10.2
optax==0.2.4
numpy
tqdm
mergedeep
orbax-checkpoint==0.7.0
