Metadata-Version: 2.4
Name: flash-sparse-attn
Version: 2.0.1
Summary: Flash Sparse Attention: Fast and Memory-Efficient Trainable Sparse Attention
Author-email: Jingze Shi <losercheems@gmail.com>, Yifan Wu <ywu012@connect.hkust-gz.edu.cn>, Bingheng Wu <wubingheng52136@gmail.com>, Yiran Peng <amagipeng@gmail.com>, Liangdong Wang <wangliangdong@baai.ac.cn>, Guang Li <liuguang@baai.ac.cn>, Yuyu Luo <yuyuluo@hkust-gz.edu.cn>
Maintainer-email: Jingze Shi <losercheems@gmail.com>
License: BSD 3-Clause License
        
        Copyright (c) 2025, the respective contributors, as shown by the AUTHORS file.
        All rights reserved.
        
        Redistribution and use in source and binary forms, with or without
        modification, are permitted provided that the following conditions are met:
        
        * Redistributions of source code must retain the above copyright notice, this
          list of conditions and the following disclaimer.
        
        * Redistributions in binary form must reproduce the above copyright notice,
          this list of conditions and the following disclaimer in the documentation
          and/or other materials provided with the distribution.
        
        * Neither the name of the copyright holder nor the names of its
          contributors may be used to endorse or promote products derived from
          this software without specific prior written permission.
        
        THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
        AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
        IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
        DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
        FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
        DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
        SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
        CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
        OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
        OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
        
Project-URL: Homepage, https://github.com/HKUSTDial/flash-sparse-attention
Project-URL: Source, https://github.com/HKUSTDial/flash-sparse-attention
Project-URL: Issues, https://github.com/HKUSTDial/flash-sparse-attention/issues
Classifier: Programming Language :: Python :: 3
Classifier: License :: OSI Approved :: BSD License
Classifier: Operating System :: Unix
Requires-Python: >=3.10
Description-Content-Type: text/markdown
License-File: LICENSE
License-File: AUTHORS
Requires-Dist: triton>=2.0.0
Provides-Extra: cute
Requires-Dist: nvidia-cutlass-dsl>=4.4.2; extra == "cute"
Requires-Dist: einops; extra == "cute"
Requires-Dist: typing_extensions; extra == "cute"
Requires-Dist: apache-tvm-ffi<0.2,>=0.1.5; extra == "cute"
Requires-Dist: torch-c-dlpack-ext; extra == "cute"
Requires-Dist: quack-kernels>=0.3.3; extra == "cute"
Provides-Extra: dev
Requires-Dist: transformers>=4.38.0; extra == "dev"
Requires-Dist: pytest>=6.0; extra == "dev"
Requires-Dist: pytest-benchmark; extra == "dev"
Requires-Dist: numpy; extra == "dev"
Requires-Dist: tabulate>=0.9.0; extra == "dev"
Provides-Extra: docs
Requires-Dist: mkdocs>=1.6; extra == "docs"
Requires-Dist: mkdocs-material>=9.5; extra == "docs"
Requires-Dist: mkdocstrings[python]>=0.28; extra == "docs"
Requires-Dist: pymdown-extensions>=10.0; extra == "docs"
Dynamic: license-file

<div align="center">
  <img src="./assets/logo.png" alt="flash-algo" width="100%">
</div>

<div align="center">


**English** | [简体中文](./README_zh.md)

</div>


Flash-Sparse-Attention is a high-performance trainable sparse attention implementation that combines Flash Attention's memory efficiency with sparse computation for handling extremely long sequences in Transformer models.


## Key Features

> [!NOTE]
> Support for arbitrary mask and bias shapes is available in [this branch](https://github.com/HKUSTDial/flash-sparse-attention/tree/final_mask_version). The current main branch no longer maintains that feature set.

### Supported Features

- Forward and backward passes for dense attention, sparse attention, and gated attention
- Regular batched inputs and varlen inputs
- Causal attention and local window attention
- Arbitrary combinations of Q and KV sequence lengths, with head dimensions up to 256
- Grouped Query Attention and Multi Query Attention
- Sparse softmax threshold control
- Gated attention with gate inputs and configurable gating sparsity
- Split-KV path optimization for decoding workloads

### Features We Aim to Support

- Paged Attention
- TMA, WGMMA, and FP8 low precision
- Sequence parallelism


## Installation

### Requirements

- **Linux**: Ubuntu 22.04 or later
- **NVIDIA GPU**: Compute Capability 8.0 or higher
- **Runtime**: NVIDIA driver and runtime compatible with your PyTorch and Triton installation
- **Python**: 3.9 or later
- **PyTorch**: 2.5.1 or later
- **Triton**: Installed automatically as a default dependency

### Install

Install from PyPI:

```bash
pip install flash-sparse-attn
```

To install from source:

```bash
git clone https://github.com/flash-algo/flash-sparse-attn.git
cd flash-sparse-attn
pip install .
```


### Install via HuggingFace Kernel

You can also load the kernels directly from [HuggingFace Kernel](https://github.com/huggingface/kernels) without installing the package:

```python
from kernels import get_kernel

fsa = get_kernel("JingzeShi/flash-sparse-attention", version=1)

out = fsa.flash_dense_attn_func(q, k, v, is_causal=True)
out = fsa.flash_sparse_attn_func(q, k, v, is_causal=True, softmax_threshold=0.01)
out = fsa.flash_gated_attn_func(q, k, v, alpha, delta, is_causal=True)
```

Requires `pip install kernels`.


## Quick Start

### Basic Usage

Below are examples for the three common attention variants:

```python
import torch
from flash_sparse_attn.ops.triton.interface import (
    flash_dense_attn_func,
    flash_sparse_attn_func,
    flash_gated_attn_func,
)

dtype = torch.bfloat16
device = torch.device("cuda")
batch_size, seqlen_q, seqlen_k, num_heads, num_kv_heads, head_dim = 2, 1024, 1024, 8, 2, 64

query = torch.randn(batch_size, seqlen_q, num_heads, head_dim, dtype=dtype, device=device)
key = torch.randn(batch_size, seqlen_k, num_kv_heads, head_dim, dtype=dtype, device=device)
value = torch.randn(batch_size, seqlen_k, num_kv_heads, head_dim, dtype=dtype, device=device)
```

### Dense Attention

Use this when you do not need explicit sparsification but still want an efficient attention kernel.

```python
output_dense = flash_dense_attn_func(
    query=query,
    key=key,
    value=value,
    is_causal=True,
)

print(output_dense.shape)
```

### Sparse Attention

Use this when you want to skip low-contribution attention weights through `softmax_threshold` and reduce effective compute on long sequences.

```python
output_sparse = flash_sparse_attn_func(
    query=query,
    key=key,
    value=value,
    is_causal=True,
    softmax_threshold=1.0,
)

print(output_sparse.shape)
```

### Gated Attention

Use this when you need explicit gating signals for sparse attention. `alpha` controls query-side gating and `delta` controls key-side gating.

```python
alpha = torch.randn(batch_size, num_heads, seqlen_q, device=device, dtype=dtype)
delta = torch.randn(batch_size, num_kv_heads, seqlen_k, device=device, dtype=dtype)

output_gated = flash_gated_attn_func(
    query=query,
    key=key,
    value=value,
    alpha=alpha,
    delta=delta,
    is_causal=True,
    softmax_threshold=1.0,
    gate_threshold=1.0,
)

print(output_gated.shape)
```


## Performance

The following benchmarks were collected on SM120 and cover forward, backward, and decoding workloads. They include Dense, Sparse, and Gated implementations, with FlashAttention as a baseline.

### Forward Performance

![Attention forward speed, head dim 128](assets/sm120_forward_benchmark.png)

### Backward Performance

![Attention backward speed, head dim 128](assets/sm120_backward_benchmark.png)

### Decode Performance

![Attention decode speed, head dim 128](assets/sm120_decode_benchmark.png)


## Benchmarking

Benchmark scripts are located under [tests](tests/), covering forward, backward, and decoding performance.

By default, these scripts use the attention projection layers from the Qwen model family to generate Q, K, and V states with distributions closer to real LLM workloads, and they build input sequences from the Needle-in-a-Haystack dataset.

### Forward Performance

```bash
python tests/benchmark_forward.py
```

### Backward Performance

```bash
python tests/benchmark_backward.py
```

### Decode Performance

```bash
python tests/benchmark_decode.py
```


## Citation

If you use FSA in your research, please cite:

```bibtex
@misc{shi2025trainabledynamicmasksparse,
      title={Trainable Dynamic Mask Sparse Attention},
      author={Jingze Shi and Yifan Wu and Bingheng Wu and Yiran Peng and Liangdong Wang and Guang Liu and Yuyu Luo},
      year={2025},
      eprint={2508.02124},
      archivePrefix={arXiv},
      primaryClass={cs.AI},
      url={https://arxiv.org/abs/2508.02124},
}
```


## Acknowledgments

This project builds upon and integrates several excellent works:

- **[OpenSeek](https://github.com/FlagAI-Open/OpenSeek)** - Kernel development support
- **[Flash-Attention](https://github.com/Dao-AILab/flash-attention)** - Memory-efficient attention computation
- **[NVIDIA CUTLASS](https://github.com/NVIDIA/cutlass)** - High-performance matrix operations library

We thank the open-source community for its contributions to efficient Transformer implementations.
