datasets>=4.0.0
pillow>=11.3.0
rich>=14.1.0
safetensors>=0.6.2
tokenizers>=0.21.2
transformers<=5.3.0,>=5.0.0
typer>=0.17.4
peft
hf_transfer
cloudpathlib>=0.23.0

[aws]
cloudpathlib[s3]

[azure]
cloudpathlib[azure]

[dev]
mkdocs
mkdocs-material
mkdocstrings[python]>=0.24.0
pymdown-extensions>=10.7
pytest
pytest-forked
pytest-asyncio
pre-commit
litellm
torch
ty
cloudpathlib[s3]
alembic
griffe2md

[flashrl]
skyrl[skyrl-train]

[flashrl:sys_platform == "linux"]
flash-attn==2.8.3
torch==2.7.0
flashinfer-python
torchvision

[fsdp]
skyrl[skyrl-train]

[fsdp:sys_platform == "linux"]
vllm==0.19.0
vllm-router
nixl
flash-linear-attention
causal-conv1d
flash-attn==2.8.3
torch==2.10.0
torchvision

[fsdp:sys_platform == "linux" and platform_machine == "x86_64"]
flashinfer-python==0.6.6
flashinfer-jit-cache==0.6.6

[gcp]
cloudpathlib[gs]

[gpu]

[gpu:sys_platform == "linux"]
jax[cuda12]>=0.7.2

[harbor]

[harbor:python_version >= "3.12"]
harbor[daytona,modal]

[jax]
jax<1.0,>=0.8
flax>=0.12.2
optax>=0.2.5

[mamba]

[mamba:sys_platform == "linux"]
mamba-ssm>=2.3.0

[megatron]
skyrl[skyrl-train]

[megatron:sys_platform == "linux"]
transformer-engine[pytorch]==2.10.0
flash-attn==2.8.3
flash-linear-attention
causal-conv1d
vllm==0.19.0
vllm-router
nixl
torch==2.10.0
torchvision
nvidia-modelopt

[megatron:sys_platform == "linux" and platform_machine == "x86_64"]
flashinfer-python==0.6.6
flashinfer-jit-cache==0.6.6

[megatron:sys_platform == "linux" and python_version >= "3.12"]
megatron-bridge
megatron-core

[miniswe]
skyrl[skyrl-train]
mini-swe-agent>=1.12.0
litellm

[skyrl-train]
loguru
tqdm
ninja
tensorboard
func_timeout
hydra-core==1.3.2
accelerate
torchdata
omegaconf
ray==2.51.1
peft
debugpy==1.8.0
hf_transfer
wandb
datasets>=4.0.0
tensordict
jaxtyping
skyrl-gym==0.3.0
polars
s3fs
fastapi
uvicorn
pybind11
setuptools

[skyrl-train:sys_platform == "linux"]
flash-attn==2.8.3
vllm-router
nixl

[tinker]
tinker>=0.3.0
fastapi[standard]
sqlmodel
sqlalchemy[asyncio]
aiosqlite
asyncpg
psycopg2-binary

[tpu]

[tpu:sys_platform == "linux"]
jax[tpu]>=0.7.2
