numpy>=1.20.0
tqdm>=4.60.0
jax>=0.4.0
jaxlib>=0.4.0
h5py>=3.0.0

[dev]
pytest>=6.0.0
black>=21.0.0
flake8>=3.8.0

[gpu]
jax[cuda12]

[tpu]
jax[tpu]
