Metadata-Version: 2.4
Name: siglip-kernel
Version: 0.1.0
Summary: Fused Triton kernels for memory-efficient SigLIP training
License: Apache-2.0
Project-URL: Repository, https://github.com/avocardio/siglip-kernel
Requires-Python: >=3.10
Description-Content-Type: text/markdown
Requires-Dist: torch>=2.1.0
Requires-Dist: triton>=3.0.0
Provides-Extra: dev
Requires-Dist: pytest>=7.0; extra == "dev"

# siglip-kernel

Fused Triton kernels for memory-efficient SigLIP training. The first public fused GPU kernel for the sigmoid contrastive loss: tile-by-tile computation that never materializes the `B × B` logit matrix, reducing peak loss memory from `O(B²)` to `O(B · D)` and enabling per-device batch sizes that the standard implementation cannot allocate.

On H100 80 GB at `B=65,536` in BF16 the kernel uses 736 MB versus the cuBLAS reference's 33 GB (45× less), runs ~1.40× faster, and reaches `B=131,072` where the reference OOMs. On B200 it reaches `B=262,144` on 2.9 GB. See the [paper](https://github.com/avocardio/fusedsiglip_paper) for full benchmarks and validation (3-epoch CC12M pre-training, 500-step LiT, parity tests).

## Install

```bash
git clone https://github.com/avocardio/siglip-kernel.git
cd siglip-kernel
pip install -e .
```

Requires PyTorch ≥ 2.1 and Triton ≥ 3.0.

## Usage

```python
import torch
from siglip_kernel import fused_siglip_loss

img = torch.randn(8192, 768, device="cuda", dtype=torch.bfloat16)
txt = torch.randn(8192, 768, device="cuda", dtype=torch.bfloat16)
img = torch.nn.functional.normalize(img, dim=-1)
txt = torch.nn.functional.normalize(txt, dim=-1)
log_temp = torch.tensor(2.302, device="cuda", requires_grad=True)  # ln(10)
bias     = torch.tensor(-10.0, device="cuda", requires_grad=True)

loss = fused_siglip_loss(img, txt, log_temp, bias)
loss.backward()
```

A dtype-aware router selects the right backend automatically: chunked-BF16 (cuBLAS GEMM + Triton fused BCE) for `bfloat16`/`float16`, fully-fused Triton for `float32`. Pass `backend=` to override (`"chunked_bf16"`, `"fused"`, `"hopper"`, or `"reference"`).

## GPU support

| Architecture | SM | Status |
|--------------|----|--------|
| Ampere (A100) | SM80 | Tested, autotuned tiles |
| Ada (RTX 4090, L40) | SM89 | Tested, FP8 path available |
| Hopper (H100, H200) | SM90 | Tested, persistent kernel + TMA |
| Blackwell (B200) | SM100 | Tested via Triton 3.7, no code changes |

## OpenCLIP integration

For users on OpenCLIP, a no-dependency pure-PyTorch chunked variant of `SigLipLoss` is proposed upstream as [PR #1145](https://github.com/mlfoundations/open_clip/pull/1145). That PR addresses the memory side only; for the speed-up too, install this package and use `fused_siglip_loss` directly.

## Tests

```bash
pip install -e .[dev]
pytest tests/
```

86 tests cover correctness (forward + backward parity vs FP32 reference) and gradient stability across `B ∈ [32, 768]`, `D ∈ [64, 768]`, dtypes, chunk sizes, and backends.

## Citation

Paper source at [avocardio/fusedsiglip_paper](https://github.com/avocardio/fusedsiglip_paper); preprint forthcoming.

## License

Apache-2.0.
