Metadata-Version: 2.4
Name: flash-attn-res
Version: 0.1.12
Summary: Attention Residuals (AttnRes) kernels
Project-URL: Homepage, https://github.com/catswe/Flash-Attention-Residuals
Project-URL: Issues, https://github.com/catswe/Flash-Attention-Residuals/issues
Author: William Bui
License: Apache-2.0
License-File: LICENSE
Classifier: Intended Audience :: Science/Research
Classifier: License :: OSI Approved :: Apache Software License
Classifier: Programming Language :: Python :: 3
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Requires-Dist: torch
Requires-Dist: triton
Description-Content-Type: text/markdown

## Flash Attention Residuals

> **4x faster inference/training** vs. torch.compile naive attention residuals implementation

> **20% reduction in training memory** (without activation checkpointing)*

*Benchmarked on H100. Dependent on problem size and setup.

Reference: https://arxiv.org/abs/2603.15031 (Kimi Team, MoonshotAI, 2026)

## Credits:
Thanks to Mohamed Osman (https://github.com/spaghettiSystems) and Cartesia (https://github.com/cartesia-ai) for advising on and supporting the development of this project.

## Install

```
pip install flash-attn-res
```

## Usage
This package contains Triton kernels, `triton_op` wrappers compatible with torch.compile, and an experimental high-performance Block AttenRes autograd implementation.
See `src` and `benchmarks` folders.


<!-- TODO: -->
<!-- - Figure out first block phase 1 special case redundant computation output -->
<!-- - Determine redundant store -->
<!-- - Consider "phase_2_online_softmax_merge_intrablock_backward_kernel probably does not need atomic_add" -->
<!-- - Consider two-phase reduction -->

## Roadmap:
- Better autotuning set up
- Better benchmarks
- More robust autograd impl.
- Precision tuning
- Mixed FP16 and BF16 and store quantization scale
- Stochastic rounding
- CuTE, CUDA, and other DSLs implementation

## Development Notes:
- Normalizing in phase 1 keeps outputs bounded (convex combination of values) so bf16 error doesn't scale with softmax flatness. Phase 2 computes in fp32, and the reduction algebra matches split-KV Flash Attention.
- Certain dimensions, especially NUM_QUERIES_PER_BLOCK, are small so semi-elementwise (B, T) kernel with static_range is better than doing tl.dot
- Kernel is memory bound and doing semi-elementwise allows for kernel fusion
- NUM_SOURCE_BLOCKS and NUM_QUERIES_PER_BLOCK should be autotuning keys, unlike with torch.compile, which allows for faster kernels
- Small NUM_QUERIES_PER_BLOCK so eviction_policy should be "evict_last"
