Metadata-Version: 2.4
Name: flash-colreduce
Version: 0.2.1
Summary: Fast, memory-efficient attention column reduction (e.g., sum, mean, max)
Author: Z Lab
License: MIT License
Project-URL: Homepage, https://github.com/z-lab/flash-colreduce
Project-URL: Issues, https://github.com/z-lab/flash-colreduce/issues
Keywords: triton,pytorch,cuda,attention,transformers,gpu,kernels
Classifier: License :: OSI Approved :: MIT License
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3 :: Only
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Classifier: Topic :: Scientific/Engineering :: Mathematics
Requires-Python: >=3.10
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: torch>=2.1
Requires-Dist: triton>=3.0.0
Provides-Extra: test
Requires-Dist: pytest>=7; extra == "test"
Requires-Dist: pytest-cov>=4; extra == "test"
Provides-Extra: dev
Requires-Dist: black>=24.0; extra == "dev"
Provides-Extra: bench
Requires-Dist: rich>=13.7; extra == "bench"
Requires-Dist: tqdm>=4.67; extra == "bench"
Provides-Extra: all
Requires-Dist: flash-colreduce[test]; extra == "all"
Requires-Dist: flash-colreduce[dev]; extra == "all"
Requires-Dist: flash-colreduce[bench]; extra == "all"
Dynamic: license-file

# Flash-ColReduce

[![PyPI](https://img.shields.io/pypi/v/flash-colreduce)](https://pypi.org/project/flash-colreduce/)
[![License](https://img.shields.io/badge/license-MIT-yellow)](LICENSE)
[![Python 3.10+](https://img.shields.io/badge/python-3.10+-blue.svg)](https://www.python.org/downloads/)

**Flash-ColReduce** provides highly optimized Triton kernels for computing column-wise reductions of the attention matrix such as sum, mean, or max without materializing the full $O(N^2)$ attention weights.

This primitive is essential for KV-cache pruning, token importance estimation, and attention analysis in Large Language Models (LLMs) and Vision-Language Models (VLMs). It powers the visual token pruning in [SparseVILA](https://arxiv.org/abs/2510.17777).

## Highlights

- **🚀 Efficient**: Fused kernels compute column reductions in **$O(N)$ memory**.
- **🧩 Flexible**: Supports **causal** and **non-causal** attention with irregular shapes ($M \neq N$).
- **✅ Exact**: Uses online softmax for numerical precision and correct causal masking.

## Prerequisites

- **Python**: 3.10+
- **PyTorch**: 2.1+ (with CUDA support)
- **Triton**: 3.0.0+
- **GPU**: NVIDIA GPU with Compute Capability 8.0+ (Ampere or newer recommended)

## Installation

Install from PyPI:
```bash
pip install flash-colreduce
```

Or build from source:
```bash
git clone https://github.com/z-lab/flash-colreduce.git
cd flash-colreduce
pip install -e .
```

## Usage

### 1. Non-Causal Attention

Compute a column-wise reduction of the attention matrix over the query dimension.

```python
import torch
from flash_colreduce import flash_colreduce

q = torch.randn(8, 16, 512, 64, device="cuda", dtype=torch.float16)
k = torch.randn(8, 16, 512, 64, device="cuda", dtype=torch.float16)

flash_colreduce(q, k, reduction="sum")  # Shape: (8, 16, 512)
flash_colreduce(q, k, reduction="mean")  # Shape: (8, 16, 512)
flash_colreduce(q, k, reduction="max")  # Shape: (8, 16, 512)
```

### 2. Causal Attention

Handle autoregressive attention where $M \neq N$. The kernel applies a right-aligned causal mask matching KV-cached decoding behavior.

```python
import torch
from flash_colreduce import flash_colreduce

q = torch.randn(1, 32, 128, 128, device="cuda", dtype=torch.float16)
k = torch.randn(1, 32, 4096, 128, device="cuda", dtype=torch.float16)

flash_colreduce(q, k, reduction="sum", is_causal=True)  # Shape: (1, 32, 4096)
flash_colreduce(q, k, reduction="mean", is_causal=True)  # Shape: (1, 32, 4096)
flash_colreduce(q, k, reduction="max", is_causal=True)  # Shape: (1, 32, 4096)
```

## Performance

Flash-ColReduce achieves significant speedups and memory savings over naïve implementations. By fusing softmax and reduction into a single kernel, it avoids writing the $B \times H \times M \times N$ attention matrix to GPU memory.

![Benchmark Results on NVIDIA RTX Pro 6000 Blackwell](benchmarks/results/rtx-pro-6000-blackwell.png)
*Benchmarked on NVIDIA RTX Pro 6000 Blackwell with FP16 precision*

## Development

### Running Tests
```bash
pip install -e ".[test]"
pytest -v
```

### Running Benchmarks
```bash
pip install -e ".[bench]"
python benchmarks/run.py
```

## Citation

If you use Flash-ColReduce in your research, please cite the SparseVILA paper:

```bibtex
@inproceedings{khaki2025sparsevila,
  title = {{SparseVILA: Decoupling Visual Sparsity for Efficient VLM Inference}},
  author = {Khaki, Samir and Guo, Junxian and Tang, Jiaming and Yang, Shang and Chen, Yukang and Plataniotis, Konstantinos N and Lu, Yao and Han, Song and Liu, Zhijian},
  booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
  year = {2025}
}
```

## License

[MIT License](LICENSE)

## Acknowledgments

- **[FlashAttention](https://github.com/Dao-AILab/flash-attention)**: The tiling and online softmax approach is heavily inspired by FlashAttention.
- **[SparseVILA](https://arxiv.org/abs/2510.17777)**: The original project that motivated this primitive.
