flax>=0.7.1
optax
scikit-learn
matplotlib
jax
