jax>=0.4.23
jaxlib>=0.4.23
flax>=0.8.2
tensorflow>=2.16.1
transformers>=4.40.1
datasets>=2.18.1
accelerate>=0.29.2
wandb>=0.16.6
pydantic>=2.7.0
numpy>=1.26.4
pandas>=2.2.3
scipy>=1.15.2
optax>=0.1.9
chex>=0.1.9
orbax-checkpoint>=0.5.3
tensorboard>=2.19.0
safetensors>=0.4.2
tokenizers>=0.15.2
huggingface-hub>=0.20.3
protobuf>=4.25.3
PyYAML>=6.0.1
requests>=2.31.0
tqdm>=4.66.2
rich>=13.7.0
python-dotenv>=1.0.1
sentry-sdk>=2.8.0
psutil>=5.9.8
scikit-learn>=1.6.1
tensorflow-probability>=0.20.0
keras==3.9.0

[cpu]
torch>=2.2.2
torchaudio>=0.17.2
torchvision>=0.17.2
onnxruntime>=1.17.0

[dev]
pytest<9.0.0,>=8.0.2
black<25.0.0,>=24.2.0
isort<6.0.0,>=5.13.2
mypy<2.0.0,>=1.8.0
flake8<8.0.0,>=7.0.0

[gpu]
torch>=2.2.2
torchaudio>=0.17.2
torchvision>=0.17.2
onnxruntime-gpu>=1.17.0
bitsandbytes>=0.43.0

[tpu]
libtpu<2.0.0,>=1.0.0
cloud-tpu-client<1.0,>=0.10
