Metadata-Version: 2.4
Name: pascal-attn
Version: 0.1.0
Summary: Memory-efficient tiled online-softmax attention with fused GQA KV expansion, tuned for Pascal and later NVIDIA GPUs
Project-URL: Repository, https://github.com/hraisikai/pascal-attn
Author: Kai Robbins
License: MIT
License-File: LICENSE
Keywords: GQA,attention,memory-efficient,pytorch,transformer
Classifier: Development Status :: 3 - Alpha
Classifier: Intended Audience :: Science/Research
Classifier: License :: OSI Approved :: MIT License
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.9
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Requires-Python: >=3.9
Requires-Dist: torch>=2.0
Provides-Extra: dev
Requires-Dist: pytest; extra == 'dev'
Requires-Dist: pytest-benchmark; extra == 'dev'
Provides-Extra: transformers
Requires-Dist: transformers>=4.36; extra == 'transformers'
Description-Content-Type: text/markdown

# pascal-attn

Memory-efficient tiled online-softmax attention with fused GQA KV expansion, tuned for Pascal and later NVIDIA GPUs.

## What it is

`pascal-attn` implements the **tiled online-softmax attention** algorithm (a pure-PyTorch variant of FlashAttention) with two improvements over a naive chunked implementation:

1. **Fused GQA KV expansion** — in Grouped-Query Attention (GQA), the key/value tensors have fewer heads than the query (`n_kv < n_h`). A naive implementation expands `[B, n_kv, N, d_h]` → `[B, n_h, N, d_h]` upfront. `pascal-attn` slices the unexpanded KV per tile and expands only that small `[B, n_h, tile, d_h]` slice — peak KV memory is constant in sequence length.

2. **Tile size tuned to GPU L2 cache** — the dominant cost in the inner loop is the QK matmul. Keeping the tile small enough to fit in L2 eliminates cache thrashing. The library includes auto-detection via `recommended_tile_size()`.

## Installation

```bash
pip install pascal-attn
```

With HuggingFace Transformers integration:

```bash
pip install "pascal-attn[transformers]"
```

From source:

```bash
git clone https://github.com/Hraisikai/pascal-attn
cd pascal-attn
pip install -e .
```

## Quick start

### Functional API

```python
from pascal_attn import tiled_attention, recommended_tile_size

tile = recommended_tile_size()   # auto-detect from GPU L2

# MHA
out = tiled_attention(q, k, v, tile_size=tile)

# GQA — n_kv < n_h, inferred from tensor shapes
out = tiled_attention(q, k_gqa, v_gqa, tile_size=tile)

# With causal mask
import torch
N = q.shape[2]
causal = torch.zeros(1, 1, N, N).masked_fill(
    torch.ones(N, N, dtype=torch.bool).triu(1), -1e9
)
out = tiled_attention(q, k, v, mask=causal, tile_size=tile)
```

Input/output shapes:
```
query:  [B, n_h,  N_q, d_h]
key:    [B, n_kv, N_k, d_h]   # n_kv <= n_h, n_h % n_kv == 0
value:  [B, n_kv, N_k, d_h]
output: [B, N_q,  n_h, d_h]
```

### nn.Module API

```python
from pascal_attn import TiledAttention

attn = TiledAttention(
    n_heads=32,
    n_kv_heads=8,       # GQA with 4 groups
    head_dim=64,
    tile_size='auto',   # detect from GPU at first forward call
)

# Packed inputs [B, N, H]
q = torch.randn(4, 2048, 32 * 64)
k = torch.randn(4, 2048,  8 * 64)
v = torch.randn(4, 2048,  8 * 64)
out = attn(q, k, v)   # [4, 2048, 32 * 64]

# Or unpacked inputs [B, N, n_h, d_h]
q = torch.randn(4, 2048, 32, 64)
k = torch.randn(4, 2048,  8, 64)
v = torch.randn(4, 2048,  8, 64)
out = attn(q, k, v)   # [4, 2048, 32, 64]
```

### HuggingFace Transformers integration

```python
from pascal_attn.hf import register_with_transformers

# Register once before loading your model
register_with_transformers(tile_size=64, name="pascal_chunked")

# Now any model config can use it
config._attn_implementation = "pascal_chunked"
model = AutoModelForCausalLM.from_pretrained("...", config=config)
```

Or assign per-layer:

```python
from pascal_attn.hf import make_hf_attention_fn

fn = make_hf_attention_fn(tile_size=64)
for layer in model.model.layers:
    layer.self_attn._attention_forward_fn = fn
```

## Memory savings

Configuration: N=2048, n_h=32, n_kv=8 (GQA 4:1), d_h=64, B=4, fp16.

| Implementation       | KV peak per layer | Attn scores peak | Total peak (approx) |
|----------------------|-------------------|------------------|---------------------|
| Naive (expanded KV)  | 84 MB             | 536 MB           | ~620 MB             |
| tiled tile=256       | 84 MB\*           | 33 MB            | ~117 MB             |
| tiled tile=64        | 2.6 MB            | 2.1 MB           | ~5 MB               |

\* `tile=256` without fused GQA expansion still expands KV upfront.  `pascal-attn` with `tile=64` avoids both the upfront expansion and the large scores buffer.

Across a 28-layer 3B model: **~2.4 GB saved from fused GQA alone** (84 MB × 28), before accounting for the scores buffer reduction.

## GPU tile size guide

| GPU family                        | L2 cache  | Recommended tile_size |
|-----------------------------------|-----------|-----------------------|
| Pascal (GTX 1080 Ti, P40)         | 3 MB      | **64**                |
| Volta (V100) / Turing (T4, 20xx)  | 6–8 MB    | **128**               |
| Ampere (A100, 30xx) / Ada (40xx)  | 20–80 MB  | **256**               |

Use `recommended_tile_size()` to detect automatically:

```python
from pascal_attn import recommended_tile_size
tile = recommended_tile_size()
print(tile)  # e.g. 64 on a P40
```

The heuristic queries `torch.cuda.get_device_properties().l2_cache_size` and
returns a conservative tile that fits within L2.

## How it works

The algorithm is a pure-PyTorch implementation of the online-softmax tiling
used in FlashAttention, extended with fused per-tile GQA expansion:

```
For each query tile q_i  [B, n_h, tile, d_h]:
    m = -1e9, l = 0, o = 0          # running max, denominator, output

    For each KV tile (k_raw, v_raw)  [B, n_kv, tile, d_h]:
        k_tile = expand(k_raw, n_groups)   # [B, n_h, tile, d_h]  ← fused
        v_tile = expand(v_raw, n_groups)

        s     = q_i @ k_tile.T * scale    # [B, n_h, tile, tile]
        s    += mask[q_i, k_j]            # optional
        m_new = max(m, rowmax(s))
        o     = exp(m - m_new) * o  +  exp(s - m_new) @ v_tile
        l     = exp(m - m_new) * l  +  rowsum(exp(s - m_new))
        m     = m_new

    output[q_i] = o / (l + 1e-8)        # normalise once per query tile
```

Key numerical choices:
- Accumulators `m`, `l`, `o` are kept in **float32** even for fp16/bf16 inputs.
- `m` is initialised to `-1e9` (not `-inf`) to avoid `nan` in `exp(-inf - (-inf))` when every position in a tile is masked.
- `l + 1e-8` guards normalisation against fully-masked rows.

## Running tests

```bash
pip install -e ".[dev]"
pytest tests/ -v
```

Correctness tests run on CPU; VRAM tests are skipped automatically if CUDA is not available.

## Running benchmarks

```bash
# Default: CUDA, fp16, tile=auto, N=[512,1024,2048,4096]
python benchmarks/benchmark.py

# CPU run (fp32)
python benchmarks/benchmark.py --device cpu --dtype fp32 --seq-lens 128 256 512

# Custom config
python benchmarks/benchmark.py --tile 128 --batch 2 --heads 16 --kv-heads 4
```

## License

MIT
