torch>=2.1.0
transformers>=4.36.0
accelerate>=0.25.0
datasets>=2.16.0
numpy>=1.24.0
matplotlib>=3.7.0
seaborn>=0.13.0
tqdm>=4.65.0
einops>=0.7.0
jaxtyping>=0.2.25
safetensors>=0.4.0

[all]
reward-lens[dev,sae]

[dev]
pytest>=7.4.0
pytest-cov>=4.1.0
ruff>=0.1.0
mypy>=1.7.0

[sae]
wandb>=0.16.0
