jax>=0.5

[test]
pytest>=8.3
