jax
optax

[dev]
jax[cuda12]
matplotlib
torch
torchvision
datasets
flax
huggingface_hub
pytest
flash-attn-jax
einops
