Metadata-Version: 2.4
Name: retrieval-heads
Version: 0.1.0
Summary: Retrieval Head detection in LLMs with vLLM
Author-email: Max Zuo <zuo@brown.edu>
Requires-Python: >=3.12
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: matplotlib>=3.11.0
Requires-Dist: nnsight>=0.7.0
Requires-Dist: pyyaml>=6.0.3
Requires-Dist: rouge-score>=0.1.2
Requires-Dist: seaborn>=0.13.2
Requires-Dist: torch>=2.10.0
Requires-Dist: tqdm>=4.68.2
Requires-Dist: tyro>=1.0.13
Requires-Dist: vllm==0.19.0
Dynamic: license-file

# retrieval-heads

Retrieval head detection in LLMs using vLLM and nnsight activation tracing.

This is my attempt to faithfully reproduce [Retrieval Head Mechanistically Explains Long-Context Factuality](https://arxiv.org/abs/2404.15574), and should work out of the box with any model that uses vLLM's `Attention` or `GatedDeltaNetAttention` implementations.

Two main workflows:

1. **Needle-in-a-haystack (NIAH)** – insert a known fact into a long context at varying depths and lengths, then measure retrieval accuracy (ROUGE-L).
2. **Retrieval head detection** – trace query/key activations through every attention head on NIAH results to identify which heads are responsible for retrieval.

## Example Results

### NIAH Heatmap

![NIAH Heatmap](imgs/heatmap.png)

### Retrieval Head Detection

![Retrieval Head Detection Heatmap](imgs/detect_heatmap.png)

## Setup

Installation:
```bash
git clone https://github.com/maxzuo/retrieval-heads.git
pip install -e .
```
Tested using Python 3.12 and vLLM 0.19.0.

## Usage

### NIAH sweep

```bash
retrieval-heads.niah --config configs/qwen3_5_9b.yaml
```

Runs the needle-in-a-haystack evaluation across a grid of context lengths and
document depths. Results are written to `output_dir` as `results.jsonl` (one
JSON record per cell) alongside the resolved `config.yaml`.

Any config field can be overridden via CLI flags:

```bash
retrieval-heads.niah --config configs/qwen3_5_9b.yaml \
    --model.max-model-len 16384 \
    --output-dir ./results/short
```

### Retrieval head detection

```bash
retrieval-heads.detect --config configs/detect.yaml
```

Takes NIAH result files as input, traces each forward pass with nnsight to
capture per-head query/key matrices, and scores each head on whether it attends
to the needle span. Outputs `detected.json` and `detected-agg.json`.

### Visualization

```bash
retrieval-heads.visualize niah --results results/qwen3_5_9b/results.jsonl
retrieval-heads.visualize detect --results results/detect/detected-agg.json
```

## Configuration

Configs are YAML files with the following sections:

```yaml
model:
  model: Qwen/Qwen3.5-9B
  max_model_len: 32768
  dtype: bfloat16
  chat_template: path/to/template.jinja
  language_model_only: true

haystack:
  haystack_dir: ./PaulGrahamEssays
  needle: "\nThe best thing to do in San Francisco is eat a sandwich...\n"
  retrieval_question: "What is the best thing to do in San Francisco?"

sweep:
  context_lengths: {min: 1000, max: 32000, intervals: 31}
  document_depths: {min: 0, max: 100, intervals: 10}

output_dir: ./results/qwen3_5_9b
```

Sweep dimensions accept either a `{min, max, intervals}` shorthand or an
explicit list of values.
