jax>=0.4.0
flax>=0.7.0
numpy

[all]
jaxflow[dev,viz]

[dev]
pytest>=7.0
black
isort
mypy
mkdocs

[viz]
matplotlib
seaborn
