Metadata-Version: 2.4
Name: fpwap
Version: 0.1.2
Summary: Forward Pass Weight Amortization Protocol — invert the inference loop for large transformer models.
Author: Michael Klear
License: MIT
Requires-Python: >=3.11
Requires-Dist: accelerate>=1.0
Requires-Dist: numpy>=1.26
Requires-Dist: pyarrow>=15.0
Requires-Dist: safetensors>=0.4
Requires-Dist: torch>=2.1
Requires-Dist: tqdm>=4.66
Requires-Dist: transformers>=4.40
Provides-Extra: dev
Requires-Dist: mypy>=1.10; extra == 'dev'
Requires-Dist: pytest-cov>=5.0; extra == 'dev'
Requires-Dist: pytest>=8.0; extra == 'dev'
Requires-Dist: ruff>=0.6; extra == 'dev'
Description-Content-Type: text/markdown

# fpwap — Forward Pass Weight Amortization Protocol

A single-purpose library for running activation extraction over large transformer models **whose weights don't fit in your GPU**, across datasets of **thousands of prompts**, on **consumer hardware**, at **full precision**.

## The regime

You're a mech-interp researcher. Your model is bigger than your VRAM. Your dataset is thousands of prompts. Adjacent tools each fail in a way that changes what you're studying:

- **Quantization** (bitsandbytes, GPTQ) changes the activations you're trying to read.
- **Inference servers** (vLLM, TGI) optimize next-token throughput, not residual-stream extraction.
- **`accelerate.cpu_offload`** streams weights once per batch — 10k prompts × 80 layers on a 70B model is hundreds of TB of weight I/O, hours of wall-clock per dataset pass.
- **Cloud GPUs** break your interactive iteration loop and cost hundreds per experiment.

fpwap inverts the inference loop: **load each layer once, stream the whole dataset through it**, spill intermediates to disk, move on. Total weight I/O drops from `O(N_batches × N_layers)` to `O(N_layers)`. A 10k-sample Llama-3.1-70B extraction on a 32 GB consumer GPU runs in roughly the wall-clock of a single batch under the naive approach — with the same weights, no quantization, no cloud.

## Aspirational performance

Targets, not measurements. These are the numbers fpwap is being built to; each row unlocks only after its milestone lands (70B gates on the bit-perfect test; 405B gates on the mmap-from-HF-cache path). Replaced by measured benchmarks as they come in.

### Reference machine

| Component | Spec |
| --------- | ---- |
| GPU       | NVIDIA RTX 5090, 32 GB VRAM |
| CPU       | Modern desktop-class, 16+ cores |
| RAM       | 128 GB DDR5 |
| Storage   | NVMe SSD (Gen 4+), ≥ 1 TB free |
| Interconnect | PCIe 5.0 x16 |
| Network   | None — fully local, no cloud |

### Dataset-scale activation extraction (10,000 prompts × 256 tokens = 2.56M tokens)

Residual stream (`residual_post`) captured at every layer, pooled to last token, persisted to disk. `RawActivations(layers="all")`.

| Model | Weights (bf16) | Loading strategy | Wall-clock target | Throughput target | vs. naive `accelerate.cpu_offload` |
| ----- | -------------- | ---------------- | ----------------- | ----------------- | ----------------------------------- |
| Llama-3.1-8B   | 16 GB   | `cpu_offload`      | ≤ 8 min  | ≥ 5,000 tok/s | ≥ 4× faster |
| Llama-3.1-70B  | 140 GB  | `disk_offload`     | ≤ 45 min | ≥ 950 tok/s   | ≥ 4× faster (naive ≈ 3 h) |
| Llama-3.1-405B | 810 GB  | `mmap_from_cache`  | ≤ 4 h    | ≥ 180 tok/s   | naive infeasible (OOM in RAM) |

Throughput is end-to-end tokens per second — total tokens processed (samples × seq_len) divided by wall-clock from `fpwap(...).run()` entry to return, including weight I/O, forward, callbacks, and buffer write.

### Single-pass cost per layer (Llama-3.1-70B, 1.75 GB weights per layer)

The inner loop that fpwap is optimizing. On the reference machine, per layer, per full sweep of 10k × 256-token samples:

| Phase | Budget | Notes |
| ----- | ------ | ----- |
| Weight load  | ≤ 1.0 s    | NVMe → CPU → GPU, `disk_offload` path; once per layer, not once per batch |
| Forward      | ≤ 15 s     | 10k samples, bf16, batched at engine's discretion |
| Callback     | ≤ 1.0 s    | Aggregate across all registered callbacks for this layer |
| Buffer write | ≤ 1.0 s    | Pooled activations to memmap; raw `[N, S, H]` budget is higher |
| **Per-layer total** | **≤ 18 s** | × 80 layers ≈ 24 min (leaves headroom vs. 45 min end-to-end target) |

### Overhead budgets

| Surface | Budget | Why |
| ------- | ------ | --- |
| Profile + progress, combined | < 1% wall-clock | Has to stay on by default — see the [Observability](#observability) section |
| `verify=True` (vs. naive `cpu_offload` at every layer) | 2–3× slower | Correctness debugging only; not for production runs |
| Preflight | < 5 s | Rejects infeasible configurations before GPU contact |

## The API

One verb. One callback class. One result.

```python
from fpwap import Sweep
from fpwap.callbacks.common import RawActivations, IncrementalPCA, DiffOfMeans

run = Sweep(
    model="meta-llama/Llama-3.1-70B",
    dataset=my_dataset,                # iterable of {"input_ids": ..., "label": ...}
    seq_len=256,
    callbacks=[
        RawActivations(layers=[40, 45, 50]),               # pooled by default
        IncrementalPCA(layers="all", n_components=64),
        DiffOfMeans(layers="all", label_fn=lambda s: s["label"]),
    ],
)

plan = run.preflight()
print(plan.summary())                   # check feasibility before GPU contact

result = run.run()
acts  = result.activations(layer=45, hook="residual_post")   # [N, H]
basis = result.artifact("pca_basis", layer=45)
```

That is the entire user-facing surface for read-only workflows. No backend objects to construct. No `batch_size` knob to foot-gun. No `loader` / `accumulator` triple to wire up. Construction is cheap; `.preflight()` inspects the plan and rejects infeasible configurations with actionable messages; `.run()` executes.

### Layer indexing

Hook names follow the HF `hidden_states` convention:

| Hook | Equals |
| ---- | ------ |
| `residual_pre` at layer `L`  | `hidden_states[L]`   (input to block `L`) |
| `residual_post` at layer `L` | `hidden_states[L+1]` (output of block `L`) |
| `attn_out` at layer `L`      | attention sub-layer output at block `L` |
| `mlp_out` at layer `L`       | MLP sub-layer output at block `L` |

No off-by-one translation at the call site.

### Writing your own callback

Subclass `Callback`. Declare which layers and hooks you want; implement `on_batch`. Return an `Emit` to persist a tensor, a `WriteBack` to modify the residual before the next layer, or `None` to no-op.

```python
from fpwap import Callback, Emit

class LastTokenLogNorm(Callback):
    target_layers = [32]
    target_hooks = ("residual_post",)
    phase = "read"

    def on_batch(self, layer_idx, hook, acts, sample_ids):
        return Emit(acts[:, -1, :].norm(dim=-1).log())
```

### Write-backs and multi-pass workflows

The same entry point handles steering. A callback with `phase = "write"` modifies the residual stream between layers; artifacts from one run feed the next.

```python
from fpwap.callbacks.common import SteerInBasis

# Pass 2: steer in the basis fit during pass 1
steer = Sweep(
    model="meta-llama/Llama-3.1-70B",
    dataset=my_dataset,
    seq_len=256,
    callbacks=[
        SteerInBasis(
            basis_artifact=result.artifact("pca_basis", layer=45),
            direction_idx=0,
            alpha=2.0,
            layers=[45],
        ),
    ],
)
steered = steer.run()
```

## Observability

Performance is the product, so every run is profiled by default with a measurement overhead small enough (target: under 1% wall-clock) that you never have to opt in. When a run is slower than you want, the answer is already in `result.profile` — no re-running with `profile=True`.

```python
result = run.run()

result.profile.summary()            # human-readable breakdown per layer
result.profile.by_phase()           # load / forward / callback / write
result.profile.slowest_layer()      # where the time went
result.profile.bytes_moved()        # weight I/O, buffer I/O
```

Interactive progress is on by default — a tqdm-style bar across layers × batches, because a run on the workstation under your desk should not sit silent for 40 minutes. Disable with `progress=False`; pass a callable (`progress=my_reporter`) to stream events into wandb, rich, or any other backend.

### Known cliff: CUDA allocator fragmentation on K-sweep configs

If a K-sweep run on tight VRAM (K-packed sweeps, `chunk_size=1`, large per-K residual buffer) shows episodic multi-minute pauses every few sweeps with the process stuck in D-state but NVMe mostly idle in `iostat`, the cause is almost certainly CUDA caching-allocator fragmentation, not host I/O. Set `PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True` before launch — segments grow contiguously on demand and the cliff disappears (verified on a 70B / K=30 / 14k repro: 2:20 → 0:55 wall, no other change). See [#72](https://github.com/AlliedToasters/fpwap/issues/72) for the diagnostic walk.

## Reference callbacks

Four callbacks ship with the library as examples and integration tests:

- **`RawActivations`** — persist per-sample activations, pooled (`last_token_only=True`) by default to avoid an `[N, S, H]` memory landmine.
- **`IncrementalPCA`** — fit a PCA basis over the entire dataset in a single pass.
- **`DiffOfMeans`** — compute per-class activation means for binary-labeled data.
- **`SteerInBasis`** — additive intervention in a pre-computed basis; `phase = "write"`.

Anything beyond these four is a consumer's problem.

## Integrating fpwap into a research codebase

The recommended shape is a single classmethod on your codebase's activation-source type, inserted **above** any per-batch sharding your framework does:

```python
class Activations:
    @classmethod
    def from_fpwap(cls, model_id, prompts, layers, pool="last_token"):
        run = Sweep(
            model=model_id,
            dataset=_as_dataset(prompts),
            seq_len=...,
            callbacks=[
                RawActivations(
                    layers=layers,
                    last_token_only=(pool == "last_token"),
                ),
            ],
        )
        return cls.from_result(run.run())
```

Branch `use_fpwap` at your dispatch layer — the same place you'd branch between `from_model`, `from_goodfire`, etc. — not inside a per-batch loop. fpwap's value (amortizing layer loads across the whole dataset) only materializes if it sees the dataset; if your framework shards externally and calls an extractor per shard, lift the dispatch up one level before integrating.

## Scope

fpwap is a plumbing layer. It produces activations and accepts transforms. It does not know what a probe is. Linear probe fitting, SAE training, attribution analysis, and any other statistical modeling of activations belong in consumer libraries. If it requires knowing what a probe is, it's out of scope.

## Related work

The loop inversion at the heart of fpwap — load each layer once, stream the dataset through it — was explored independently by [FlexGen](https://arxiv.org/abs/2303.06865) (Sheng et al., ICML 2023) for high-throughput generative inference on a single GPU. FlexGen calls this a "zig-zag block schedule" and proves it is within 2× of I/O-optimal (Theorem 4.1) — a result that applies directly to fpwap's loop, since our schedule is the same modulo KV cache. FlexGen solves a harder scheduling problem (KV cache placement across GPU/CPU/disk, multi-step autoregressive decoding, CPU compute delegation) and applies 4-bit group-wise quantization to further compress weights. fpwap targets a narrower regime — forward-pass activation extraction for mechanistic interpretability — where full precision is non-negotiable and generation is not needed, so the implementation is much simpler. The absence of KV cache and autoregressive decoding also means fpwap's cost model has fewer free variables, making strategy selection tractable without an LP solver.

## Status

**Llama-3.1-405B on a single RTX 5090 (32 GB VRAM), streaming 803 GB of
bf16 weights from NVMe — 45.7 tok/s in under 12 minutes.** 70B at
10,000 prompts × 128 tokens hits 1,221 tok/s. That's the regime fpwap
exists for: the model doesn't fit in VRAM (or even RAM), the dataset is
thousands of prompts, and no quantization is involved. Measured on the
reference machine (RTX 5090, 128 GB DDR5, PCIe 5.0 NVMe):

| Model | Path | Samples × seq_len | Throughput (bf16) | SPEC target |
| ----- | ---- | ------------------ | ----------------- | ----------- |
| Llama-3.1-405B-Instruct | streaming, prefetch | 256 × 128    | **45.7 tok/s** | ≥ 180 |
| Llama-3.3-70B-Instruct | streaming, prefetch | 10,000 × 128 | **1,221 tok/s** | ≥ 950 |
| Llama-3.3-70B-Instruct | streaming            | 1,024 × 128  |  1,026 tok/s | ≥ 950 |
| Llama-3.1-8B-Instruct  | streaming            | 1,024 × 128  | 10,442 tok/s | ≥ 5,000 |
| Llama-3.1-8B-Instruct  | preloaded            | 256 × 128    | 11,894 tok/s | ≥ 5,000 |

The 405B number is end-to-end across 126 layers streaming 803 GB of
weights from NVMe SSD at 1.12 GB/s sustained, with prefetch fully hiding
disk reads behind compute (0.000s load per layer at steady state). The
70B hero number is end-to-end across 80 layers with a pinned-CPU
residual buffer (21 GB), async D2H, and a worker-thread weight prefetch
that overlaps layer L+1's safetensors read with layer L's compute.

Baseline sanity: an 8B streaming-vs-naive head-to-head shows a **7.25×
speedup** at 1024 × 128 (SPEC §17 ratio target ≥ 4×). The naive baseline
is `accelerate.cpu_offload` at 1,440 tok/s, reproducible via
`scripts/benchmark.py --mode naive`. 70B can't ratio-test on this machine
(141 GB bf16 > 128 GB RAM for `cpu_offload`); the 70B claim is absolute
throughput.

Correctness: `tests/gpu/test_real_llama_bit_exact.py` runs Llama-3.2-1B in bf16
on CUDA and compares every layer's `residual_post` against a naive HF forward —
bit-exact (`torch.equal`) at every real token position. When microbatch_size
equals the naive batch size, bf16 is deterministic; at different microbatch
sizes, outputs diverge by LSB accumulation noise (see the memory note on
`bf16_microbatch_determinism`).

What's wired: pre-loaded and streaming model paths, `Sweep` + `Callback` +
`Result` API, padded-batch + attention-mask propagation, RoPE-aware Llama
plumbing, GPT-2 plumbing, all four hooks (`residual_pre`, `attn_out`,
`mlp_out`, `residual_post`) with fast-path block forward when no sub-layer
hook is wanted and WriteBack at every hook (sub-layer WriteBack is
threaded through the block mid-forward so the modified tensor actually
affects downstream compute), all four reference callbacks shipped
(`RawActivations`, `IncrementalPCA`, `DiffOfMeans`, `SteerInBasis`),
`result.activations(...)`, tqdm progress plus callable `progress=reporter`
emitting `ProgressEvent`s for wandb/rich sinks, pinned-CPU
`buffer_device="cpu"` with async D2H copy (so oversized residual buffers
don't block compute), worker-thread **concurrent weight prefetch** on the
streaming path (layer L+1's safetensors read + H2D overlap with layer L's
compute), `MemmapBackend` for disk-backed emits,
`ProfileReport.throughput_tok_per_s()` / `weight_bandwidth_gb_per_s()`,
`verify=True` fail-fast against a naive-forward baseline (pre-loaded
models), per-layer `on_layer_end` artifacts collected into
`result.artifacts`.

Model families covered by the structural matcher: Llama, Mistral, Qwen2,
Gemma, DeepSeek-V2, and any future HF causal LM exposing the same
`model.{embed_tokens, layers, rotary_emb}` layout. GPT-2 covered by its
own plumbing.

What's not yet: checkpoint/resume, NVMe-backed ResidualBuffer,
`verify=True` on the streaming path (pre-loaded only).

See `SPEC.md` for the full design.
