Coverage for pattern_lens / activations.py: 94%
223 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-22 18:15 -0700
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-22 18:15 -0700
1"""computing and saving activations given a model and prompts
3# Usage:
5from the command line:
7```bash
8python -m pattern_lens.activations --model <model_name> --prompts <prompts_path> --save-path <save_path> --min-chars <min_chars> --max-chars <max_chars> --n-samples <n_samples>
9```
11from a script:
13```python
14from pattern_lens.activations import activations_main
15activations_main(
16 model_name="gpt2",
17 save_path="demo/"
18 prompts_path="data/pile_1k.jsonl",
19)
20```
22"""
24import argparse
25import json
26import re
27from collections.abc import Callable
28from dataclasses import asdict
29from pathlib import Path
30from typing import Literal, overload
32import numpy as np
33import torch
34import tqdm
35from jaxtyping import Float
36from muutils.json_serialize import json_serialize
37from muutils.misc.numerical import shorten_numerical_to_str
39# custom utils
40from muutils.spinner import SpinnerContext
41from transformer_lens import ( # type: ignore[import-untyped]
42 ActivationCache,
43 HookedTransformer,
44 HookedTransformerConfig,
45)
47# pattern_lens
48from pattern_lens.consts import (
49 ATTN_PATTERN_REGEX,
50 DATA_DIR,
51 DIVIDER_S1,
52 DIVIDER_S2,
53 SPINNER_KWARGS,
54 ActivationCacheNp,
55 ReturnCache,
56)
57from pattern_lens.indexes import (
58 generate_models_jsonl,
59 generate_prompts_jsonl,
60 write_html_index,
61)
62from pattern_lens.load_activations import (
63 ActivationsMissingError,
64 activations_exist,
65 augment_prompt_with_hash,
66 load_activations,
67)
68from pattern_lens.prompts import load_text_data
71def _rel_path(p: Path) -> str:
72 """Return path relative to cwd if possible, otherwise absolute."""
73 try:
74 return p.relative_to(Path.cwd()).as_posix()
75 except ValueError:
76 return p.as_posix()
79# return nothing, but `stack_heads` still affects how we save the activations
80@overload
81def compute_activations(
82 prompt: dict,
83 model: HookedTransformer | None = None,
84 save_path: Path = Path(DATA_DIR),
85 names_filter: Callable[[str], bool] | re.Pattern = ATTN_PATTERN_REGEX,
86 return_cache: None = None,
87 stack_heads: bool = False,
88) -> tuple[Path, None]: ...
89# return stacked heads in numpy or torch form
90@overload
91def compute_activations(
92 prompt: dict,
93 model: HookedTransformer | None = None,
94 save_path: Path = Path(DATA_DIR),
95 names_filter: Callable[[str], bool] | re.Pattern = ATTN_PATTERN_REGEX,
96 return_cache: Literal["torch"] = "torch",
97 stack_heads: Literal[True] = True,
98) -> tuple[Path, Float[torch.Tensor, "n_layers n_heads n_ctx n_ctx"]]: ...
99@overload
100def compute_activations(
101 prompt: dict,
102 model: HookedTransformer | None = None,
103 save_path: Path = Path(DATA_DIR),
104 names_filter: Callable[[str], bool] | re.Pattern = ATTN_PATTERN_REGEX,
105 return_cache: Literal["numpy"] = "numpy",
106 stack_heads: Literal[True] = True,
107) -> tuple[Path, Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"]]: ...
108# return dicts in numpy or torch form
109@overload
110def compute_activations(
111 prompt: dict,
112 model: HookedTransformer | None = None,
113 save_path: Path = Path(DATA_DIR),
114 names_filter: Callable[[str], bool] | re.Pattern = ATTN_PATTERN_REGEX,
115 return_cache: Literal["numpy"] = "numpy",
116 stack_heads: Literal[False] = False,
117) -> tuple[Path, ActivationCacheNp]: ...
118@overload
119def compute_activations(
120 prompt: dict,
121 model: HookedTransformer | None = None,
122 save_path: Path = Path(DATA_DIR),
123 names_filter: Callable[[str], bool] | re.Pattern = ATTN_PATTERN_REGEX,
124 return_cache: Literal["torch"] = "torch",
125 stack_heads: Literal[False] = False,
126) -> tuple[Path, ActivationCache]: ...
127# actual function body
128def compute_activations( # noqa: PLR0915
129 prompt: dict,
130 model: HookedTransformer | None = None,
131 save_path: Path = Path(DATA_DIR),
132 names_filter: Callable[[str], bool] | re.Pattern = ATTN_PATTERN_REGEX,
133 return_cache: ReturnCache = "torch",
134 stack_heads: bool = False,
135) -> tuple[
136 Path,
137 ActivationCacheNp
138 | ActivationCache
139 | Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"]
140 | Float[torch.Tensor, "n_layers n_heads n_ctx n_ctx"]
141 | None,
142]:
143 """compute activations for a single prompt and save to disk
145 always runs a forward pass -- does NOT load from disk cache.
146 for cache-aware loading, use `get_activations` which tries disk first.
148 # Parameters:
149 - `prompt : dict | None`
150 (defaults to `None`)
151 - `model : HookedTransformer`
152 - `save_path : Path`
153 (defaults to `Path(DATA_DIR)`)
154 - `names_filter : Callable[[str], bool]|re.Pattern`
155 a filter for the names of the activations to return. if an `re.Pattern`, will use `lambda key: names_filter.match(key) is not None`
156 (defaults to `ATTN_PATTERN_REGEX`)
157 - `return_cache : Literal[None, "numpy", "torch"]`
158 will return `None` as the second element if `None`, otherwise will return the cache in the specified tensor format. `stack_heads` still affects whether it will be a dict (False) or a single tensor (True)
159 (defaults to `None`)
160 - `stack_heads : bool`
161 whether the heads should be stacked in the output. this causes a number of changes:
162 - `npy` file with a single `(n_layers, n_heads, n_ctx, n_ctx)` tensor saved for each prompt instead of `npz` file with dict by layer
163 - `cache` will be a single `(n_layers, n_heads, n_ctx, n_ctx)` tensor instead of a dict by layer if `return_cache` is `True`
164 will assert that everything in the activation cache is only attention patterns, and is all of the attention patterns. raises an exception if not.
166 # Returns:
167 ```
168 tuple[
169 Path,
170 Union[
171 None,
172 ActivationCacheNp, ActivationCache,
173 Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"], Float[torch.Tensor, "n_layers n_heads n_ctx n_ctx"],
174 ]
175 ]
176 ```
177 """
178 # check inputs
179 assert model is not None, "model must be passed"
180 assert "text" in prompt, "prompt must contain 'text' key"
181 prompt_str: str = prompt["text"]
183 # compute or get prompt metadata
184 assert model.tokenizer is not None
185 prompt_tokenized: list[str] = prompt.get(
186 "tokens",
187 model.tokenizer.tokenize(prompt_str),
188 )
189 # n_tokens counts subword tokens (no BOS); attention patterns include BOS
190 # so have dim n_tokens+1. see also compute_activations_batched Phase B.
191 prompt.update(
192 dict(
193 n_tokens=len(prompt_tokenized),
194 tokens=prompt_tokenized,
195 ),
196 )
198 # save metadata
199 prompt_dir: Path = save_path / model.cfg.model_name / "prompts" / prompt["hash"]
200 prompt_dir.mkdir(parents=True, exist_ok=True)
201 with open(prompt_dir / "prompt.json", "w") as f:
202 json.dump(prompt, f)
204 # set up names filter
205 names_filter_fn: Callable[[str], bool]
206 if isinstance(names_filter, re.Pattern):
207 names_filter_fn = lambda key: names_filter.match(key) is not None # noqa: E731
208 else:
209 names_filter_fn = names_filter
211 # compute activations
212 # NOTE: no padding_side kwarg here -- it's only meaningful for multi-sequence
213 # batches where padding is needed. single-string input has no padding.
214 # see compute_activations_batched for the batched path that passes padding_side="right".
215 cache_torch: ActivationCache
216 with torch.no_grad():
217 model.eval()
218 _, cache_torch = model.run_with_cache(
219 prompt_str,
220 names_filter=names_filter_fn,
221 return_type=None,
222 )
224 activations_path: Path
225 # saving and returning
226 if stack_heads:
227 n_layers: int = model.cfg.n_layers
228 key_pattern: str = "blocks.{i}.attn.hook_pattern"
229 # NOTE: this only works for stacking heads at the moment
230 # activations_specifier: str = key_pattern.format(i=f'0-{n_layers}')
231 activations_specifier: str = key_pattern.format(i="-")
232 activations_path = prompt_dir / f"activations-{activations_specifier}.npy"
234 # check the keys are only attention heads
235 head_keys: list[str] = [key_pattern.format(i=i) for i in range(n_layers)]
236 cache_torch_keys_set: set[str] = set(cache_torch.keys())
237 assert cache_torch_keys_set == set(head_keys), (
238 f"unexpected keys!\n{set(head_keys).symmetric_difference(cache_torch_keys_set) = }\n{cache_torch_keys_set} != {set(head_keys)}"
239 )
241 # stack heads
242 patterns_stacked: Float[torch.Tensor, "n_layers n_heads n_ctx n_ctx"] = (
243 torch.stack([cache_torch[k] for k in head_keys], dim=1)
244 )
245 # check shape
246 pattern_shape_no_ctx: tuple[int, ...] = tuple(patterns_stacked.shape[:3])
247 assert pattern_shape_no_ctx == (1, n_layers, model.cfg.n_heads), (
248 f"unexpected shape: {patterns_stacked.shape[:3] = } ({pattern_shape_no_ctx = }), expected {(1, n_layers, model.cfg.n_heads) = }"
249 )
251 patterns_stacked_np: Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"] = (
252 patterns_stacked.cpu().numpy()
253 )
255 # save
256 np.save(activations_path, patterns_stacked_np)
258 # return
259 match return_cache:
260 case "numpy":
261 return activations_path, patterns_stacked_np
262 case "torch":
263 return activations_path, patterns_stacked
264 case None:
265 return activations_path, None
266 case _:
267 msg = f"invalid return_cache: {return_cache = }"
268 raise ValueError(msg)
269 else:
270 activations_path = prompt_dir / "activations.npz"
272 # save
273 cache_np: ActivationCacheNp = {
274 k: v.detach().cpu().numpy() for k, v in cache_torch.items()
275 }
277 np.savez_compressed(
278 activations_path,
279 **cache_np, # type: ignore[arg-type]
280 )
282 # return
283 match return_cache:
284 case "numpy":
285 return activations_path, cache_np
286 case "torch":
287 return activations_path, cache_torch
288 case None:
289 return activations_path, None
290 case _:
291 msg = f"invalid return_cache: {return_cache = }"
292 raise ValueError(msg)
295def compute_activations_batched(
296 prompts: list[dict],
297 model: HookedTransformer,
298 save_path: Path = Path(DATA_DIR),
299 names_filter: Callable[[str], bool] | re.Pattern = ATTN_PATTERN_REGEX,
300 seq_lens: list[int] | None = None,
301) -> list[Path]:
302 """compute and save activations for a batch of prompts in a single forward pass
304 Batched companion to `compute_activations` -- instead of one forward pass per
305 prompt, this runs a single `model.run_with_cache(list_of_strings)` call for the
306 whole batch. TransformerLens tokenizes and right-pads automatically. Each prompt's
307 attention patterns are then trimmed to their actual (unpadded) size and saved
308 individually, producing files identical to the single-prompt path.
310 Does not support `stack_heads` or `return_cache` -- this function is intended for
311 the bulk processing path in `activations_main`, not for interactive use. Use
312 `compute_activations` directly for single-prompt use cases that need those features.
314 ## Why right-padding makes trimming correct without an explicit attention mask
316 With right-padding, pad tokens sit at positions seq_len, seq_len+1, ...,
317 max_seq_len-1 (higher than any real token). The causal attention mask prevents
318 position i from attending to any j > i. So for real tokens at positions
319 0..seq_len-1, they can only attend to 0..i -- all real tokens. The softmax is computed over the same set of positions
320 as in single-prompt inference, producing identical attention patterns.
322 We explicitly pass `padding_side="right"` to `run_with_cache` to guarantee this
323 regardless of the model's default padding side.
325 # Parameters:
326 - `prompts : list[dict]`
327 each prompt must contain 'text' and 'hash' keys. call
328 `augment_prompt_with_hash` on each prompt before passing them here.
329 - `model : HookedTransformer`
330 the model to compute activations with
331 - `save_path : Path`
332 path to save the activations to
333 (defaults to `Path(DATA_DIR)`)
334 - `names_filter : Callable[[str], bool] | re.Pattern`
335 filter for which activations to save. must only match activations with
336 4D shape `[batch, n_heads, seq, seq]` (e.g. attention patterns).
337 non-attention activations will cause incorrect trimming.
338 (defaults to `ATTN_PATTERN_REGEX`)
339 - `seq_lens : list[int] | None`
340 pre-computed model sequence lengths per prompt (from `model.to_tokens`).
341 if `None`, will be computed internally. pass this to avoid redundant
342 tokenization when lengths are already known (e.g. from length-sorting).
343 **important**: these must be from `model.to_tokens()` (includes BOS),
344 NOT from `model.tokenizer.tokenize()` (excludes BOS).
345 (defaults to `None`)
347 # Returns:
348 - `list[Path]`
349 paths to the saved activations files, one per prompt
351 # Modifies:
352 each prompt dict in `prompts` -- adds/overwrites `n_tokens` and `tokens` keys
353 with tokenization metadata (same mutation as `compute_activations`).
354 """
355 assert model is not None, "model must be passed"
356 assert len(prompts) > 0, "prompts must not be empty"
357 assert "text" in prompts[0], f"prompt must contain 'text' key: {prompts[0].keys()}"
358 assert "hash" in prompts[0], (
359 f"prompt must contain 'hash' key (call augment_prompt_with_hash first): {prompts[0].keys()}"
360 )
362 # --- Phase A: get actual model sequence lengths ---
363 # model.to_tokens() includes BOS if applicable, matching the attention pattern dims
364 # model.tokenizer.tokenize() gives subword strings WITHOUT BOS, used for metadata
365 # these differ by 1 when BOS is prepended -- using the wrong one for trimming
366 # would silently truncate or include garbage
367 if seq_lens is None:
368 seq_lens = [model.to_tokens(p["text"]).shape[1] for p in prompts]
369 assert len(seq_lens) == len(prompts), (
370 f"seq_lens length mismatch: {len(seq_lens)} != {len(prompts)}"
371 )
373 # --- Phase B: save prompt metadata (mirrors compute_activations's metadata logic) ---
374 assert model.tokenizer is not None
375 for p in prompts:
376 prompt_str: str = p["text"]
377 prompt_tokenized: list[str] = p.get(
378 "tokens",
379 model.tokenizer.tokenize(prompt_str),
380 )
381 # n_tokens counts subword tokens (no BOS); attention patterns include BOS so have dim n_tokens+1
382 p.update(
383 dict(
384 n_tokens=len(prompt_tokenized),
385 tokens=prompt_tokenized,
386 ),
387 )
388 prompt_dir: Path = save_path / model.cfg.model_name / "prompts" / p["hash"]
389 prompt_dir.mkdir(parents=True, exist_ok=True)
390 with open(prompt_dir / "prompt.json", "w") as f:
391 json.dump(p, f)
393 # --- Phase C: batched forward pass ---
394 names_filter_fn: Callable[[str], bool]
395 if isinstance(names_filter, re.Pattern):
396 names_filter_fn = lambda key: names_filter.match(key) is not None # noqa: E731
397 else:
398 names_filter_fn = names_filter
400 texts: list[str] = [p["text"] for p in prompts]
401 cache_torch: ActivationCache
402 with torch.no_grad():
403 model.eval()
404 _, cache_torch = model.run_with_cache(
405 texts,
406 names_filter=names_filter_fn,
407 return_type=None,
408 padding_side="right",
409 )
411 # --- Phase D: split, trim padding, and save per-prompt ---
412 # For each prompt i with actual sequence length seq_len_i:
413 # v[i : i+1, :, :seq_len_i, :seq_len_i]
414 # ^^^^^^^ i:i+1 not i -- keeps batch dim [1,...] for
415 # format compatibility with compute_activations
416 # ^^ all attention heads
417 # ^^^^^^^^^^ ^^^^^^^^^^ trim both query and key dims to actual length,
418 # discarding meaningless padding positions
419 paths: list[Path] = []
420 for i, (prompt, seq_len) in enumerate(zip(prompts, seq_lens, strict=True)):
421 prompt_dir = save_path / model.cfg.model_name / "prompts" / prompt["hash"]
422 activations_path: Path = prompt_dir / "activations.npz"
423 cache_np: ActivationCacheNp = {}
424 for k, v in cache_torch.items():
425 assert v.ndim == 4, ( # noqa: PLR2004
426 f"expected 4D attention pattern tensor for {k!r}, "
427 f"got shape {v.shape}. names_filter must only match "
428 f"attention pattern activations [batch, n_heads, seq, seq]"
429 )
430 cache_np[k] = v[i : i + 1, :, :seq_len, :seq_len].detach().cpu().numpy()
432 np.savez_compressed(
433 activations_path,
434 **cache_np, # type: ignore[arg-type]
435 )
436 paths.append(activations_path)
438 return paths
441@overload
442def get_activations(
443 prompt: dict,
444 model: HookedTransformer | str,
445 save_path: Path = Path(DATA_DIR),
446 allow_disk_cache: bool = True,
447 return_cache: None = None,
448) -> tuple[Path, None]: ...
449@overload
450def get_activations(
451 prompt: dict,
452 model: HookedTransformer | str,
453 save_path: Path = Path(DATA_DIR),
454 allow_disk_cache: bool = True,
455 return_cache: Literal["torch"] = "torch",
456) -> tuple[Path, ActivationCache]: ...
457@overload
458def get_activations(
459 prompt: dict,
460 model: HookedTransformer | str,
461 save_path: Path = Path(DATA_DIR),
462 allow_disk_cache: bool = True,
463 return_cache: Literal["numpy"] = "numpy",
464) -> tuple[Path, ActivationCacheNp]: ...
465def get_activations(
466 prompt: dict,
467 model: HookedTransformer | str,
468 save_path: Path = Path(DATA_DIR),
469 allow_disk_cache: bool = True,
470 return_cache: ReturnCache = "numpy",
471) -> tuple[Path, ActivationCacheNp | ActivationCache | None]:
472 """given a prompt and a model, save or load activations
474 # Parameters:
475 - `prompt : dict`
476 expected to contain the 'text' key
477 - `model : HookedTransformer | str`
478 either a `HookedTransformer` or a string model name, to be loaded with `HookedTransformer.from_pretrained`
479 - `save_path : Path`
480 path to save the activations to (and load from)
481 (defaults to `Path(DATA_DIR)`)
482 - `allow_disk_cache : bool`
483 whether to allow loading from disk cache
484 (defaults to `True`)
485 - `return_cache : Literal[None, "numpy", "torch"]`
486 whether to return the cache, and in what format
487 (defaults to `"numpy"`)
489 # Returns:
490 - `tuple[Path, ActivationCacheNp | ActivationCache | None]`
491 the path to the activations and the cache if `return_cache is not None`
493 """
494 # add hash to prompt
495 augment_prompt_with_hash(prompt)
497 # get the model
498 model_name: str = (
499 model.cfg.model_name if isinstance(model, HookedTransformer) else model
500 )
502 # from cache
503 if allow_disk_cache:
504 if return_cache is None:
505 # fast path: check file existence without loading data into memory.
506 # activations_exist just calls .exists() on two paths, whereas
507 # load_activations would decompress the full .npz into numpy arrays
508 # only for us to discard them immediately.
509 if activations_exist(model_name, prompt, save_path):
510 prompt_dir: Path = save_path / model_name / "prompts" / prompt["hash"]
511 return prompt_dir / "activations.npz", None
512 else:
513 try:
514 path, cache = load_activations(
515 model_name=model_name,
516 prompt=prompt,
517 save_path=save_path,
518 )
519 except ActivationsMissingError:
520 pass
521 else:
522 return path, cache
524 # compute them
525 if isinstance(model, str):
526 model = HookedTransformer.from_pretrained(model_name)
528 return compute_activations( # type: ignore[return-value]
529 prompt=prompt,
530 model=model,
531 save_path=save_path,
532 return_cache=return_cache,
533 )
536DEFAULT_DEVICE: torch.device = torch.device(
537 "cuda" if torch.cuda.is_available() else "cpu",
538)
541def activations_main( # noqa: C901, PLR0912, PLR0915
542 model_name: str,
543 save_path: str | Path,
544 prompts_path: str,
545 raw_prompts: bool,
546 min_chars: int,
547 max_chars: int,
548 force: bool,
549 n_samples: int,
550 no_index_html: bool,
551 shuffle: bool = False,
552 stacked_heads: bool = False,
553 device: str | torch.device = DEFAULT_DEVICE,
554 batch_size: int = 32,
555) -> None:
556 """main function for computing activations
558 # Parameters:
559 - `model_name : str`
560 name of a model to load with `HookedTransformer.from_pretrained`
561 - `save_path : str | Path`
562 path to save the activations to
563 - `prompts_path : str`
564 path to the prompts file
565 - `raw_prompts : bool`
566 whether the prompts are raw, not filtered by length. `load_text_data` will be called if `True`, otherwise just load the "text" field from each line in `prompts_path`
567 - `min_chars : int`
568 minimum number of characters for a prompt
569 - `max_chars : int`
570 maximum number of characters for a prompt
571 - `force : bool`
572 whether to overwrite existing files
573 - `n_samples : int`
574 maximum number of samples to process
575 - `no_index_html : bool`
576 whether to write an index.html file
577 - `shuffle : bool`
578 whether to shuffle the prompts
579 (defaults to `False`)
580 - `stacked_heads : bool`
581 whether to stack the heads in the output tensor. will save as `.npy` instead of `.npz` if `True`
582 (defaults to `False`)
583 - `device : str | torch.device`
584 the device to use. if a string, will be passed to `torch.device`
585 - `batch_size : int`
586 number of prompts per forward pass. prompts are sorted by token length
587 (longest first) and grouped so that similar-length prompts share a batch,
588 minimizing padding waste. use `batch_size=1` for one prompt per forward
589 pass (largely equivalent to the old sequential behavior, but note: prompts
590 are still sorted by length and cache checking uses file-existence only,
591 unlike the old path which processed prompts in order and validated cache
592 contents via `load_activations`).
593 the single-prompt functions `compute_activations` and `get_activations`
594 are still available for programmatic use outside of `activations_main`.
595 (defaults to `32`)
596 """
597 # figure out the device to use
598 device_: torch.device
599 if isinstance(device, torch.device):
600 device_ = device
601 elif isinstance(device, str):
602 device_ = torch.device(device)
603 else:
604 msg = f"invalid device: {device}"
605 raise TypeError(msg)
607 print(f"using device: {device_}")
609 with SpinnerContext(message="loading model", **SPINNER_KWARGS):
610 model: HookedTransformer = HookedTransformer.from_pretrained(
611 model_name,
612 device=device_,
613 )
614 model.model_name = model_name # type: ignore[unresolved-attribute]
615 model.cfg.model_name = model_name
616 n_params: int = sum(p.numel() for p in model.parameters())
617 print(
618 f"loaded {model_name} with {shorten_numerical_to_str(n_params)} ({n_params}) parameters",
619 )
620 print(f"\tmodel devices: { {p.device for p in model.parameters()} }")
622 save_path_p: Path = Path(save_path)
623 save_path_p.mkdir(parents=True, exist_ok=True)
624 model_path: Path = save_path_p / model_name
625 with SpinnerContext(
626 message=f"saving model info to {_rel_path(model_path)}",
627 **SPINNER_KWARGS,
628 ):
629 model_cfg: HookedTransformerConfig
630 model_cfg = model.cfg
631 model_path.mkdir(parents=True, exist_ok=True)
632 with open(model_path / "model_cfg.json", "w") as f:
633 json.dump(json_serialize(asdict(model_cfg)), f)
635 # load prompts
636 with SpinnerContext(
637 message=f"loading prompts from {Path(prompts_path).as_posix()}",
638 **SPINNER_KWARGS,
639 ):
640 prompts: list[dict]
641 if raw_prompts:
642 prompts = load_text_data(
643 Path(prompts_path),
644 min_chars=min_chars,
645 max_chars=max_chars,
646 shuffle=shuffle,
647 )
648 else:
649 with open(model_path / "prompts.jsonl", "r") as f:
650 prompts = [json.loads(line) for line in f.readlines()]
651 # truncate to n_samples
652 prompts = prompts[:n_samples]
654 print(f" {len(prompts)} prompts loaded")
656 # write index.html
657 with SpinnerContext(
658 message=f"writing {_rel_path(save_path_p / 'index.html')}",
659 **SPINNER_KWARGS,
660 ):
661 if not no_index_html:
662 write_html_index(save_path_p)
664 # TODO: not implemented yet
665 if stacked_heads:
666 raise NotImplementedError("stacked_heads not implemented yet")
668 # augment all prompts with hashes
669 for prompt in prompts:
670 augment_prompt_with_hash(prompt)
672 # filter out cached prompts
673 if not force:
674 uncached: list[dict] = [
675 p for p in prompts if not activations_exist(model_name, p, save_path_p)
676 ]
677 n_cached: int = len(prompts) - len(uncached)
678 if n_cached > 0:
679 print(f" {n_cached} prompts already cached, {len(uncached)} to compute")
680 else:
681 uncached = list(prompts)
683 if uncached:
684 # sort by token length descending so that:
685 # 1. the longest (slowest, most memory-hungry) batches run first --
686 # OOM errors surface immediately rather than after all the cheap work,
687 # and tqdm's ETA stabilizes early for better progress estimation
688 # 2. similar-length prompts are grouped together, minimizing padding waste
689 #
690 # pre-tokenization is a separate step from compute_activations_batched because
691 # we need token lengths *before* batching to sort and group. the resulting
692 # seq_lens are then passed through so compute_activations_batched can skip
693 # re-tokenizing each prompt internally.
694 with SpinnerContext(
695 message="pre-tokenizing prompts for length sorting",
696 **SPINNER_KWARGS,
697 ):
698 uncached_with_lens: list[tuple[dict, int]] = [
699 (p, model.to_tokens(p["text"]).shape[1]) for p in uncached
700 ]
701 uncached_with_lens.sort(key=lambda x: x[1], reverse=True)
702 sorted_uncached: list[dict] = [p for p, _ in uncached_with_lens]
703 sorted_seq_lens: list[int] = [sl for _, sl in uncached_with_lens]
705 # process in batches
706 n_prompts: int = len(sorted_uncached)
707 with tqdm.tqdm(
708 total=n_prompts,
709 desc="Computing activations",
710 unit="prompt",
711 ) as pbar:
712 for batch_start in range(0, n_prompts, batch_size):
713 batch_end: int = min(batch_start + batch_size, n_prompts)
714 batch: list[dict] = sorted_uncached[batch_start:batch_end]
715 batch_seq_lens: list[int] = sorted_seq_lens[batch_start:batch_end]
716 pbar.set_postfix(
717 n_ctx=batch_seq_lens[0],
718 ) # longest in batch (sorted descending)
719 compute_activations_batched(
720 prompts=batch,
721 model=model,
722 save_path=save_path_p,
723 seq_lens=batch_seq_lens,
724 )
725 pbar.update(len(batch))
726 else:
727 print(" all prompts cached, nothing to compute")
729 with SpinnerContext(
730 message="updating jsonl metadata for models and prompts",
731 **SPINNER_KWARGS,
732 ):
733 generate_models_jsonl(save_path_p)
734 generate_prompts_jsonl(save_path_p / model_name)
737def main() -> None:
738 "generate attention pattern activations for a model and prompts"
739 print(DIVIDER_S1)
740 with SpinnerContext(message="parsing args", **SPINNER_KWARGS):
741 arg_parser: argparse.ArgumentParser = argparse.ArgumentParser()
742 # input and output
743 arg_parser.add_argument(
744 "--model",
745 "-m",
746 type=str,
747 required=True,
748 help="The model name(s) to use. comma separated with no whitespace if multiple",
749 )
751 arg_parser.add_argument(
752 "--prompts",
753 "-p",
754 type=str,
755 required=False,
756 help="The path to the prompts file (jsonl with 'text' key on each line). If `None`, expects that `--figures` is passed and will generate figures for all prompts in the model directory",
757 default=None,
758 )
760 arg_parser.add_argument(
761 "--save-path",
762 "-s",
763 type=str,
764 required=False,
765 help="The path to save the attention patterns",
766 default=DATA_DIR,
767 )
769 # min and max prompt lengths
770 arg_parser.add_argument(
771 "--min-chars",
772 type=int,
773 required=False,
774 help="The minimum number of characters for a prompt",
775 default=100,
776 )
777 arg_parser.add_argument(
778 "--max-chars",
779 type=int,
780 required=False,
781 help="The maximum number of characters for a prompt",
782 default=1000,
783 )
785 # number of samples
786 arg_parser.add_argument(
787 "--n-samples",
788 "-n",
789 type=int,
790 required=False,
791 help="The max number of samples to process, do all in the file if None",
792 default=None,
793 )
795 # batch size
796 arg_parser.add_argument(
797 "--batch-size",
798 "-b",
799 type=int,
800 required=False,
801 help="Batch size for computing activations (number of prompts per forward pass)",
802 default=32,
803 )
805 # force overwrite
806 arg_parser.add_argument(
807 "--force",
808 "-f",
809 action="store_true",
810 help="If passed, will overwrite existing files",
811 )
813 # no index html
814 arg_parser.add_argument(
815 "--no-index-html",
816 action="store_true",
817 help="If passed, will not write an index.html file for the model",
818 )
820 # raw prompts
821 arg_parser.add_argument(
822 "--raw-prompts",
823 "-r",
824 action="store_true",
825 help="pass if the prompts have not been split and tokenized (still needs keys 'text' and 'meta' for each item)",
826 )
828 # shuffle
829 arg_parser.add_argument(
830 "--shuffle",
831 action="store_true",
832 help="If passed, will shuffle the prompts",
833 )
835 # stack heads
836 arg_parser.add_argument(
837 "--stacked-heads",
838 action="store_true",
839 help="If passed, will stack the heads in the output tensor",
840 )
842 # device
843 arg_parser.add_argument(
844 "--device",
845 type=str,
846 required=False,
847 help="The device to use for the model",
848 default="cuda" if torch.cuda.is_available() else "cpu",
849 )
851 args: argparse.Namespace = arg_parser.parse_args()
853 print(f"args parsed: {args}")
855 models: list[str]
856 if "," in args.model:
857 models = args.model.split(",")
858 else:
859 models = [args.model]
861 n_models: int = len(models)
862 for idx, model in enumerate(models):
863 print(DIVIDER_S2)
864 print(f"processing model {idx + 1} / {n_models}: {model}")
865 print(DIVIDER_S2)
867 activations_main(
868 model_name=model,
869 save_path=args.save_path,
870 prompts_path=args.prompts,
871 raw_prompts=args.raw_prompts,
872 min_chars=args.min_chars,
873 max_chars=args.max_chars,
874 force=args.force,
875 n_samples=args.n_samples,
876 no_index_html=args.no_index_html,
877 shuffle=args.shuffle,
878 stacked_heads=args.stacked_heads,
879 device=args.device,
880 batch_size=args.batch_size,
881 )
882 del model
884 print(DIVIDER_S1)
887if __name__ == "__main__":
888 main()