jax[cuda11_pip]
git+https://github.com/deepmind/graphcast.git@e80f4d4
