Metadata-Version: 2.4
Name: trifast
Version: 0.1.2
Summary: Fast kernel for triangle self attetion.
Author-email: Liam Atkinson <liamatkinson@gmail.com>
Requires-Python: >=3.11
Requires-Dist: einops>=0.8.0
Requires-Dist: jaxtyping>=0.2.36
Requires-Dist: numpy>=2.1.3
Requires-Dist: pyyaml>=6.0.2
Requires-Dist: setuptools>=75.6.0
Requires-Dist: torch>=2.5.1
Requires-Dist: triton>=3.1.0
Provides-Extra: benchmark
Requires-Dist: deepspeed>=0.16.0; extra == 'benchmark'
Provides-Extra: test
Requires-Dist: pytest>=8.3.4; extra == 'test'
Description-Content-Type: text/markdown

Fused Triangle Self Attention kernel, written in triton. Basically flash attention, but for triangle self attention.
Implementation heavily inspired by [FlagAttention](https://github.com/FlagOpen/FlagAttention/tree/main) and the [triton fused attention tutorial](https://triton-lang.org/main/getting-started/tutorials/06-fused-attention.html#sphx-glr-getting-started-tutorials-06-fused-attention-py).

- n^2 memory complexity (vs n^3 for pure pytorch).
- Faster (~2x) backward pass than next fastest implementation I could find (DS4S evoformer kernel).
- Faster (~4x) forward pass than next fastest implementation I could find (DS4S evoformer kernel).
- As far as I can tell, faster than naieve implementation.

## Plots
All done on a 3090 in bfloat16.
### Forward
![TSA forward runtime](benchmark_plots/tri_attn_fwd.png "TSA forward runtime")
![TSA forward memory](benchmark_plots/peak_memory_fwd.png "TSA forward memory")

Backward
![TSA backward runtime](benchmark_plots/tri_attn_bwd.png "TSA backward runtime")
![TSA backward memory](benchmark_plots/peak_memory_bwd.png "TSA backward memory")


Todos:
- [] Try to train a model with it.
- [] Can we perform and of dq/db/dkv transposed?
- [] Rewrite autotuner
