jax>=0.4
diffrax>=0.4
jaxtyping

[dev]
pytest
mypy
ruff
pre-commit

[test]
pytest
mypy
ruff
pre-commit
