Metadata-Version: 2.4
Name: hiera-optim
Version: 0.1.0
Summary: Drop-in throughput and memory optimisations for FAIR Hiera (4D-SDPA, gather/scatter, Triton kernels).
Author: Maxi Kalcher
License: MIT
Project-URL: Homepage, https://github.com/avocardio/hiera-optim
Project-URL: Repository, https://github.com/avocardio/hiera-optim
Project-URL: Issues, https://github.com/avocardio/hiera-optim/issues
Keywords: pytorch,transformer,vision,hiera,mae,flash-attention,triton,hopper,h100,gh200
Classifier: Development Status :: 4 - Beta
Classifier: License :: OSI Approved :: MIT License
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Classifier: Operating System :: POSIX :: Linux
Requires-Python: >=3.10
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: torch>=2.5.0
Requires-Dist: triton>=2.3.0
Provides-Extra: hiera
Requires-Dist: hiera-transformer>=0.1.4; extra == "hiera"
Provides-Extra: test
Requires-Dist: pytest>=7.0; extra == "test"
Provides-Extra: dev
Requires-Dist: pytest>=7.0; extra == "dev"
Requires-Dist: ruff; extra == "dev"
Requires-Dist: build; extra == "dev"
Requires-Dist: twine; extra == "dev"
Dynamic: license-file

# hiera-optim

Drop-in throughput and memory optimisations for [FAIR's Hiera](https://github.com/facebookresearch/hiera) and its MAE variant. Two lines:

```python
from hiera_optim import optimize
optimize(model)
```

restore the model's silent math-fallback attention to FlashAttention / cuDNN-attn, replace boolean mask indexing with `torch.gather` / `scatter_`, and unblock `torch.compile`. Numerically equivalent within bf16 noise.

## Results

H100 (GH200), bf16, full forward + backward.

### Production config: Hiera-Base, 224x224, 8 in-chans, B=128

| | ms / step | samples / s | peak mem |
|---|---|---|---|
| FAIR baseline + `torch.compile` | 131.7 | 972 | 14.0 GB |
| **hiera-optim + `torch.compile`** | **70.3** | **1820** | **9.4 GB** |
| speedup / saving | 1.88x | 1.87x | 33% |

### Across the variant matrix (444 GH200 cells)

| | median | mean | best | worst |
|---|---|---|---|---|
| speedup | 1.35x | 1.42x | 2.10x | 1.10x |
| memory ratio | 74% | 73% | 29% | 99% |

RTX 4090, Hiera-Base, 8 in-chans, B=32: 1.81x eager, **2.86x with `torch.compile`**.

Full matrix and per-cell numbers: [`MATRIX_RESULTS.md`](MATRIX_RESULTS.md).

## Install

```bash
pip install hiera-optim
```

From source:

```bash
git clone https://github.com/avocardio/hiera-optim.git
cd hiera-optim
pip install -e .
```

Requires PyTorch >= 2.5 and Triton >= 2.3. Recognises FAIR Hiera in-tree (`models.hiera`) or via PyPI (`hiera-transformer`).

## Usage

```python
import torch
from hiera_optim import optimize
from hiera import mae_hiera_base_224

model = mae_hiera_base_224(pretrained=False, in_chans=3, input_size=(224, 224))
optimize(model)
model = torch.compile(model, mode="default", dynamic=False)

x = torch.randn(128, 3, 224, 224, device="cuda", dtype=torch.bfloat16)
loss, *_ = model(x, mask_ratio=0.6)
loss.backward()
```

`optimize(model)` does two things, in place, weights preserved:

1. Swap every `MaskUnitAttention` for a 4D-reshape variant so PyTorch SDPA dispatches to FlashAttention / cuDNN-attn / mem-efficient instead of math. FAIR's original feeds SDPA a 5-D tensor which the fused kernels reject, costing ~13x per call on Ada, ~6x on Hopper.
2. Swap `x[mask.tile(...)]` and `x_dec[mask] = ...` for explicit `torch.gather` / `scatter_`. Removes a slow `indexing_backward_kernel` and the `aten::nonzero` graph break that stops `torch.compile`.

## Optional

```python
from hiera_optim import optimize, enable_stage_checkpointing

optimize(model, sdpa_backend="auto")           # per-block SDPA hint
enable_stage_checkpointing(model, stages=(2,)) # OOM lever
```

## GPU support

| Architecture | SM | Status |
|---|---|---|
| Ada (RTX 4090, L40) | SM89 | Tested |
| Hopper (H100, GH200) | SM90 | Tested |
| Ampere (A100) | SM80 | Should work |
| Blackwell (B200) | SM100 | Should work |

## Tests

```bash
pip install -e .[test]
pytest
```

112 tests cover all 5 Hiera variants x q_pool {1, 2, 3} x mask ratios x bf16/fp16/fp32 x 1D/2D/3D inputs x classification + MAE.

## License

MIT.
