pattern_lens.activations
computing and saving activations given a model and prompts
Usage:
from the command line:
python -m pattern_lens.activations --model <model_name> --prompts <prompts_path> --save-path <save_path> --min-chars <min_chars> --max-chars <max_chars> --n-samples <n_samples>
from a script:
from pattern_lens.activations import activations_main
activations_main(
model_name="gpt2",
save_path="demo/"
prompts_path="data/pile_1k.jsonl",
)
1"""computing and saving activations given a model and prompts 2 3# Usage: 4 5from the command line: 6 7```bash 8python -m pattern_lens.activations --model <model_name> --prompts <prompts_path> --save-path <save_path> --min-chars <min_chars> --max-chars <max_chars> --n-samples <n_samples> 9``` 10 11from a script: 12 13```python 14from pattern_lens.activations import activations_main 15activations_main( 16 model_name="gpt2", 17 save_path="demo/" 18 prompts_path="data/pile_1k.jsonl", 19) 20``` 21 22""" 23 24import argparse 25import json 26import re 27from collections.abc import Callable 28from dataclasses import asdict 29from pathlib import Path 30from typing import Literal, overload 31 32import numpy as np 33import torch 34import tqdm 35from jaxtyping import Float 36from muutils.json_serialize import json_serialize 37from muutils.misc.numerical import shorten_numerical_to_str 38 39# custom utils 40from muutils.spinner import SpinnerContext 41from transformer_lens import ( # type: ignore[import-untyped] 42 ActivationCache, 43 HookedTransformer, 44 HookedTransformerConfig, 45) 46 47# pattern_lens 48from pattern_lens.consts import ( 49 ATTN_PATTERN_REGEX, 50 DATA_DIR, 51 DIVIDER_S1, 52 DIVIDER_S2, 53 SPINNER_KWARGS, 54 ActivationCacheNp, 55 ReturnCache, 56) 57from pattern_lens.indexes import ( 58 generate_models_jsonl, 59 generate_prompts_jsonl, 60 write_html_index, 61) 62from pattern_lens.load_activations import ( 63 ActivationsMissingError, 64 activations_exist, 65 augment_prompt_with_hash, 66 load_activations, 67) 68from pattern_lens.prompts import load_text_data 69 70 71def _rel_path(p: Path) -> str: 72 """Return path relative to cwd if possible, otherwise absolute.""" 73 try: 74 return p.relative_to(Path.cwd()).as_posix() 75 except ValueError: 76 return p.as_posix() 77 78 79# return nothing, but `stack_heads` still affects how we save the activations 80@overload 81def compute_activations( 82 prompt: dict, 83 model: HookedTransformer | None = None, 84 save_path: Path = Path(DATA_DIR), 85 names_filter: Callable[[str], bool] | re.Pattern = ATTN_PATTERN_REGEX, 86 return_cache: None = None, 87 stack_heads: bool = False, 88) -> tuple[Path, None]: ... 89# return stacked heads in numpy or torch form 90@overload 91def compute_activations( 92 prompt: dict, 93 model: HookedTransformer | None = None, 94 save_path: Path = Path(DATA_DIR), 95 names_filter: Callable[[str], bool] | re.Pattern = ATTN_PATTERN_REGEX, 96 return_cache: Literal["torch"] = "torch", 97 stack_heads: Literal[True] = True, 98) -> tuple[Path, Float[torch.Tensor, "n_layers n_heads n_ctx n_ctx"]]: ... 99@overload 100def compute_activations( 101 prompt: dict, 102 model: HookedTransformer | None = None, 103 save_path: Path = Path(DATA_DIR), 104 names_filter: Callable[[str], bool] | re.Pattern = ATTN_PATTERN_REGEX, 105 return_cache: Literal["numpy"] = "numpy", 106 stack_heads: Literal[True] = True, 107) -> tuple[Path, Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"]]: ... 108# return dicts in numpy or torch form 109@overload 110def compute_activations( 111 prompt: dict, 112 model: HookedTransformer | None = None, 113 save_path: Path = Path(DATA_DIR), 114 names_filter: Callable[[str], bool] | re.Pattern = ATTN_PATTERN_REGEX, 115 return_cache: Literal["numpy"] = "numpy", 116 stack_heads: Literal[False] = False, 117) -> tuple[Path, ActivationCacheNp]: ... 118@overload 119def compute_activations( 120 prompt: dict, 121 model: HookedTransformer | None = None, 122 save_path: Path = Path(DATA_DIR), 123 names_filter: Callable[[str], bool] | re.Pattern = ATTN_PATTERN_REGEX, 124 return_cache: Literal["torch"] = "torch", 125 stack_heads: Literal[False] = False, 126) -> tuple[Path, ActivationCache]: ... 127# actual function body 128def compute_activations( # noqa: PLR0915 129 prompt: dict, 130 model: HookedTransformer | None = None, 131 save_path: Path = Path(DATA_DIR), 132 names_filter: Callable[[str], bool] | re.Pattern = ATTN_PATTERN_REGEX, 133 return_cache: ReturnCache = "torch", 134 stack_heads: bool = False, 135) -> tuple[ 136 Path, 137 ActivationCacheNp 138 | ActivationCache 139 | Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"] 140 | Float[torch.Tensor, "n_layers n_heads n_ctx n_ctx"] 141 | None, 142]: 143 """compute activations for a single prompt and save to disk 144 145 always runs a forward pass -- does NOT load from disk cache. 146 for cache-aware loading, use `get_activations` which tries disk first. 147 148 # Parameters: 149 - `prompt : dict | None` 150 (defaults to `None`) 151 - `model : HookedTransformer` 152 - `save_path : Path` 153 (defaults to `Path(DATA_DIR)`) 154 - `names_filter : Callable[[str], bool]|re.Pattern` 155 a filter for the names of the activations to return. if an `re.Pattern`, will use `lambda key: names_filter.match(key) is not None` 156 (defaults to `ATTN_PATTERN_REGEX`) 157 - `return_cache : Literal[None, "numpy", "torch"]` 158 will return `None` as the second element if `None`, otherwise will return the cache in the specified tensor format. `stack_heads` still affects whether it will be a dict (False) or a single tensor (True) 159 (defaults to `None`) 160 - `stack_heads : bool` 161 whether the heads should be stacked in the output. this causes a number of changes: 162 - `npy` file with a single `(n_layers, n_heads, n_ctx, n_ctx)` tensor saved for each prompt instead of `npz` file with dict by layer 163 - `cache` will be a single `(n_layers, n_heads, n_ctx, n_ctx)` tensor instead of a dict by layer if `return_cache` is `True` 164 will assert that everything in the activation cache is only attention patterns, and is all of the attention patterns. raises an exception if not. 165 166 # Returns: 167 ``` 168 tuple[ 169 Path, 170 Union[ 171 None, 172 ActivationCacheNp, ActivationCache, 173 Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"], Float[torch.Tensor, "n_layers n_heads n_ctx n_ctx"], 174 ] 175 ] 176 ``` 177 """ 178 # check inputs 179 assert model is not None, "model must be passed" 180 assert "text" in prompt, "prompt must contain 'text' key" 181 prompt_str: str = prompt["text"] 182 183 # compute or get prompt metadata 184 assert model.tokenizer is not None 185 prompt_tokenized: list[str] = prompt.get( 186 "tokens", 187 model.tokenizer.tokenize(prompt_str), 188 ) 189 # n_tokens counts subword tokens (no BOS); attention patterns include BOS 190 # so have dim n_tokens+1. see also compute_activations_batched Phase B. 191 prompt.update( 192 dict( 193 n_tokens=len(prompt_tokenized), 194 tokens=prompt_tokenized, 195 ), 196 ) 197 198 # save metadata 199 prompt_dir: Path = save_path / model.cfg.model_name / "prompts" / prompt["hash"] 200 prompt_dir.mkdir(parents=True, exist_ok=True) 201 with open(prompt_dir / "prompt.json", "w") as f: 202 json.dump(prompt, f) 203 204 # set up names filter 205 names_filter_fn: Callable[[str], bool] 206 if isinstance(names_filter, re.Pattern): 207 names_filter_fn = lambda key: names_filter.match(key) is not None # noqa: E731 208 else: 209 names_filter_fn = names_filter 210 211 # compute activations 212 # NOTE: no padding_side kwarg here -- it's only meaningful for multi-sequence 213 # batches where padding is needed. single-string input has no padding. 214 # see compute_activations_batched for the batched path that passes padding_side="right". 215 cache_torch: ActivationCache 216 with torch.no_grad(): 217 model.eval() 218 _, cache_torch = model.run_with_cache( 219 prompt_str, 220 names_filter=names_filter_fn, 221 return_type=None, 222 ) 223 224 activations_path: Path 225 # saving and returning 226 if stack_heads: 227 n_layers: int = model.cfg.n_layers 228 key_pattern: str = "blocks.{i}.attn.hook_pattern" 229 # NOTE: this only works for stacking heads at the moment 230 # activations_specifier: str = key_pattern.format(i=f'0-{n_layers}') 231 activations_specifier: str = key_pattern.format(i="-") 232 activations_path = prompt_dir / f"activations-{activations_specifier}.npy" 233 234 # check the keys are only attention heads 235 head_keys: list[str] = [key_pattern.format(i=i) for i in range(n_layers)] 236 cache_torch_keys_set: set[str] = set(cache_torch.keys()) 237 assert cache_torch_keys_set == set(head_keys), ( 238 f"unexpected keys!\n{set(head_keys).symmetric_difference(cache_torch_keys_set) = }\n{cache_torch_keys_set} != {set(head_keys)}" 239 ) 240 241 # stack heads 242 patterns_stacked: Float[torch.Tensor, "n_layers n_heads n_ctx n_ctx"] = ( 243 torch.stack([cache_torch[k] for k in head_keys], dim=1) 244 ) 245 # check shape 246 pattern_shape_no_ctx: tuple[int, ...] = tuple(patterns_stacked.shape[:3]) 247 assert pattern_shape_no_ctx == (1, n_layers, model.cfg.n_heads), ( 248 f"unexpected shape: {patterns_stacked.shape[:3] = } ({pattern_shape_no_ctx = }), expected {(1, n_layers, model.cfg.n_heads) = }" 249 ) 250 251 patterns_stacked_np: Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"] = ( 252 patterns_stacked.cpu().numpy() 253 ) 254 255 # save 256 np.save(activations_path, patterns_stacked_np) 257 258 # return 259 match return_cache: 260 case "numpy": 261 return activations_path, patterns_stacked_np 262 case "torch": 263 return activations_path, patterns_stacked 264 case None: 265 return activations_path, None 266 case _: 267 msg = f"invalid return_cache: {return_cache = }" 268 raise ValueError(msg) 269 else: 270 activations_path = prompt_dir / "activations.npz" 271 272 # save 273 cache_np: ActivationCacheNp = { 274 k: v.detach().cpu().numpy() for k, v in cache_torch.items() 275 } 276 277 np.savez_compressed( 278 activations_path, 279 **cache_np, # type: ignore[arg-type] 280 ) 281 282 # return 283 match return_cache: 284 case "numpy": 285 return activations_path, cache_np 286 case "torch": 287 return activations_path, cache_torch 288 case None: 289 return activations_path, None 290 case _: 291 msg = f"invalid return_cache: {return_cache = }" 292 raise ValueError(msg) 293 294 295def compute_activations_batched( 296 prompts: list[dict], 297 model: HookedTransformer, 298 save_path: Path = Path(DATA_DIR), 299 names_filter: Callable[[str], bool] | re.Pattern = ATTN_PATTERN_REGEX, 300 seq_lens: list[int] | None = None, 301) -> list[Path]: 302 """compute and save activations for a batch of prompts in a single forward pass 303 304 Batched companion to `compute_activations` -- instead of one forward pass per 305 prompt, this runs a single `model.run_with_cache(list_of_strings)` call for the 306 whole batch. TransformerLens tokenizes and right-pads automatically. Each prompt's 307 attention patterns are then trimmed to their actual (unpadded) size and saved 308 individually, producing files identical to the single-prompt path. 309 310 Does not support `stack_heads` or `return_cache` -- this function is intended for 311 the bulk processing path in `activations_main`, not for interactive use. Use 312 `compute_activations` directly for single-prompt use cases that need those features. 313 314 ## Why right-padding makes trimming correct without an explicit attention mask 315 316 With right-padding, pad tokens sit at positions seq_len, seq_len+1, ..., 317 max_seq_len-1 (higher than any real token). The causal attention mask prevents 318 position i from attending to any j > i. So for real tokens at positions 319 0..seq_len-1, they can only attend to 0..i -- all real tokens. The softmax is computed over the same set of positions 320 as in single-prompt inference, producing identical attention patterns. 321 322 We explicitly pass `padding_side="right"` to `run_with_cache` to guarantee this 323 regardless of the model's default padding side. 324 325 # Parameters: 326 - `prompts : list[dict]` 327 each prompt must contain 'text' and 'hash' keys. call 328 `augment_prompt_with_hash` on each prompt before passing them here. 329 - `model : HookedTransformer` 330 the model to compute activations with 331 - `save_path : Path` 332 path to save the activations to 333 (defaults to `Path(DATA_DIR)`) 334 - `names_filter : Callable[[str], bool] | re.Pattern` 335 filter for which activations to save. must only match activations with 336 4D shape `[batch, n_heads, seq, seq]` (e.g. attention patterns). 337 non-attention activations will cause incorrect trimming. 338 (defaults to `ATTN_PATTERN_REGEX`) 339 - `seq_lens : list[int] | None` 340 pre-computed model sequence lengths per prompt (from `model.to_tokens`). 341 if `None`, will be computed internally. pass this to avoid redundant 342 tokenization when lengths are already known (e.g. from length-sorting). 343 **important**: these must be from `model.to_tokens()` (includes BOS), 344 NOT from `model.tokenizer.tokenize()` (excludes BOS). 345 (defaults to `None`) 346 347 # Returns: 348 - `list[Path]` 349 paths to the saved activations files, one per prompt 350 351 # Modifies: 352 each prompt dict in `prompts` -- adds/overwrites `n_tokens` and `tokens` keys 353 with tokenization metadata (same mutation as `compute_activations`). 354 """ 355 assert model is not None, "model must be passed" 356 assert len(prompts) > 0, "prompts must not be empty" 357 assert "text" in prompts[0], f"prompt must contain 'text' key: {prompts[0].keys()}" 358 assert "hash" in prompts[0], ( 359 f"prompt must contain 'hash' key (call augment_prompt_with_hash first): {prompts[0].keys()}" 360 ) 361 362 # --- Phase A: get actual model sequence lengths --- 363 # model.to_tokens() includes BOS if applicable, matching the attention pattern dims 364 # model.tokenizer.tokenize() gives subword strings WITHOUT BOS, used for metadata 365 # these differ by 1 when BOS is prepended -- using the wrong one for trimming 366 # would silently truncate or include garbage 367 if seq_lens is None: 368 seq_lens = [model.to_tokens(p["text"]).shape[1] for p in prompts] 369 assert len(seq_lens) == len(prompts), ( 370 f"seq_lens length mismatch: {len(seq_lens)} != {len(prompts)}" 371 ) 372 373 # --- Phase B: save prompt metadata (mirrors compute_activations's metadata logic) --- 374 assert model.tokenizer is not None 375 for p in prompts: 376 prompt_str: str = p["text"] 377 prompt_tokenized: list[str] = p.get( 378 "tokens", 379 model.tokenizer.tokenize(prompt_str), 380 ) 381 # n_tokens counts subword tokens (no BOS); attention patterns include BOS so have dim n_tokens+1 382 p.update( 383 dict( 384 n_tokens=len(prompt_tokenized), 385 tokens=prompt_tokenized, 386 ), 387 ) 388 prompt_dir: Path = save_path / model.cfg.model_name / "prompts" / p["hash"] 389 prompt_dir.mkdir(parents=True, exist_ok=True) 390 with open(prompt_dir / "prompt.json", "w") as f: 391 json.dump(p, f) 392 393 # --- Phase C: batched forward pass --- 394 names_filter_fn: Callable[[str], bool] 395 if isinstance(names_filter, re.Pattern): 396 names_filter_fn = lambda key: names_filter.match(key) is not None # noqa: E731 397 else: 398 names_filter_fn = names_filter 399 400 texts: list[str] = [p["text"] for p in prompts] 401 cache_torch: ActivationCache 402 with torch.no_grad(): 403 model.eval() 404 _, cache_torch = model.run_with_cache( 405 texts, 406 names_filter=names_filter_fn, 407 return_type=None, 408 padding_side="right", 409 ) 410 411 # --- Phase D: split, trim padding, and save per-prompt --- 412 # For each prompt i with actual sequence length seq_len_i: 413 # v[i : i+1, :, :seq_len_i, :seq_len_i] 414 # ^^^^^^^ i:i+1 not i -- keeps batch dim [1,...] for 415 # format compatibility with compute_activations 416 # ^^ all attention heads 417 # ^^^^^^^^^^ ^^^^^^^^^^ trim both query and key dims to actual length, 418 # discarding meaningless padding positions 419 paths: list[Path] = [] 420 for i, (prompt, seq_len) in enumerate(zip(prompts, seq_lens, strict=True)): 421 prompt_dir = save_path / model.cfg.model_name / "prompts" / prompt["hash"] 422 activations_path: Path = prompt_dir / "activations.npz" 423 cache_np: ActivationCacheNp = {} 424 for k, v in cache_torch.items(): 425 assert v.ndim == 4, ( # noqa: PLR2004 426 f"expected 4D attention pattern tensor for {k!r}, " 427 f"got shape {v.shape}. names_filter must only match " 428 f"attention pattern activations [batch, n_heads, seq, seq]" 429 ) 430 cache_np[k] = v[i : i + 1, :, :seq_len, :seq_len].detach().cpu().numpy() 431 432 np.savez_compressed( 433 activations_path, 434 **cache_np, # type: ignore[arg-type] 435 ) 436 paths.append(activations_path) 437 438 return paths 439 440 441@overload 442def get_activations( 443 prompt: dict, 444 model: HookedTransformer | str, 445 save_path: Path = Path(DATA_DIR), 446 allow_disk_cache: bool = True, 447 return_cache: None = None, 448) -> tuple[Path, None]: ... 449@overload 450def get_activations( 451 prompt: dict, 452 model: HookedTransformer | str, 453 save_path: Path = Path(DATA_DIR), 454 allow_disk_cache: bool = True, 455 return_cache: Literal["torch"] = "torch", 456) -> tuple[Path, ActivationCache]: ... 457@overload 458def get_activations( 459 prompt: dict, 460 model: HookedTransformer | str, 461 save_path: Path = Path(DATA_DIR), 462 allow_disk_cache: bool = True, 463 return_cache: Literal["numpy"] = "numpy", 464) -> tuple[Path, ActivationCacheNp]: ... 465def get_activations( 466 prompt: dict, 467 model: HookedTransformer | str, 468 save_path: Path = Path(DATA_DIR), 469 allow_disk_cache: bool = True, 470 return_cache: ReturnCache = "numpy", 471) -> tuple[Path, ActivationCacheNp | ActivationCache | None]: 472 """given a prompt and a model, save or load activations 473 474 # Parameters: 475 - `prompt : dict` 476 expected to contain the 'text' key 477 - `model : HookedTransformer | str` 478 either a `HookedTransformer` or a string model name, to be loaded with `HookedTransformer.from_pretrained` 479 - `save_path : Path` 480 path to save the activations to (and load from) 481 (defaults to `Path(DATA_DIR)`) 482 - `allow_disk_cache : bool` 483 whether to allow loading from disk cache 484 (defaults to `True`) 485 - `return_cache : Literal[None, "numpy", "torch"]` 486 whether to return the cache, and in what format 487 (defaults to `"numpy"`) 488 489 # Returns: 490 - `tuple[Path, ActivationCacheNp | ActivationCache | None]` 491 the path to the activations and the cache if `return_cache is not None` 492 493 """ 494 # add hash to prompt 495 augment_prompt_with_hash(prompt) 496 497 # get the model 498 model_name: str = ( 499 model.cfg.model_name if isinstance(model, HookedTransformer) else model 500 ) 501 502 # from cache 503 if allow_disk_cache: 504 if return_cache is None: 505 # fast path: check file existence without loading data into memory. 506 # activations_exist just calls .exists() on two paths, whereas 507 # load_activations would decompress the full .npz into numpy arrays 508 # only for us to discard them immediately. 509 if activations_exist(model_name, prompt, save_path): 510 prompt_dir: Path = save_path / model_name / "prompts" / prompt["hash"] 511 return prompt_dir / "activations.npz", None 512 else: 513 try: 514 path, cache = load_activations( 515 model_name=model_name, 516 prompt=prompt, 517 save_path=save_path, 518 ) 519 except ActivationsMissingError: 520 pass 521 else: 522 return path, cache 523 524 # compute them 525 if isinstance(model, str): 526 model = HookedTransformer.from_pretrained(model_name) 527 528 return compute_activations( # type: ignore[return-value] 529 prompt=prompt, 530 model=model, 531 save_path=save_path, 532 return_cache=return_cache, 533 ) 534 535 536DEFAULT_DEVICE: torch.device = torch.device( 537 "cuda" if torch.cuda.is_available() else "cpu", 538) 539 540 541def activations_main( # noqa: C901, PLR0912, PLR0915 542 model_name: str, 543 save_path: str | Path, 544 prompts_path: str, 545 raw_prompts: bool, 546 min_chars: int, 547 max_chars: int, 548 force: bool, 549 n_samples: int, 550 no_index_html: bool, 551 shuffle: bool = False, 552 stacked_heads: bool = False, 553 device: str | torch.device = DEFAULT_DEVICE, 554 batch_size: int = 32, 555) -> None: 556 """main function for computing activations 557 558 # Parameters: 559 - `model_name : str` 560 name of a model to load with `HookedTransformer.from_pretrained` 561 - `save_path : str | Path` 562 path to save the activations to 563 - `prompts_path : str` 564 path to the prompts file 565 - `raw_prompts : bool` 566 whether the prompts are raw, not filtered by length. `load_text_data` will be called if `True`, otherwise just load the "text" field from each line in `prompts_path` 567 - `min_chars : int` 568 minimum number of characters for a prompt 569 - `max_chars : int` 570 maximum number of characters for a prompt 571 - `force : bool` 572 whether to overwrite existing files 573 - `n_samples : int` 574 maximum number of samples to process 575 - `no_index_html : bool` 576 whether to write an index.html file 577 - `shuffle : bool` 578 whether to shuffle the prompts 579 (defaults to `False`) 580 - `stacked_heads : bool` 581 whether to stack the heads in the output tensor. will save as `.npy` instead of `.npz` if `True` 582 (defaults to `False`) 583 - `device : str | torch.device` 584 the device to use. if a string, will be passed to `torch.device` 585 - `batch_size : int` 586 number of prompts per forward pass. prompts are sorted by token length 587 (longest first) and grouped so that similar-length prompts share a batch, 588 minimizing padding waste. use `batch_size=1` for one prompt per forward 589 pass (largely equivalent to the old sequential behavior, but note: prompts 590 are still sorted by length and cache checking uses file-existence only, 591 unlike the old path which processed prompts in order and validated cache 592 contents via `load_activations`). 593 the single-prompt functions `compute_activations` and `get_activations` 594 are still available for programmatic use outside of `activations_main`. 595 (defaults to `32`) 596 """ 597 # figure out the device to use 598 device_: torch.device 599 if isinstance(device, torch.device): 600 device_ = device 601 elif isinstance(device, str): 602 device_ = torch.device(device) 603 else: 604 msg = f"invalid device: {device}" 605 raise TypeError(msg) 606 607 print(f"using device: {device_}") 608 609 with SpinnerContext(message="loading model", **SPINNER_KWARGS): 610 model: HookedTransformer = HookedTransformer.from_pretrained( 611 model_name, 612 device=device_, 613 ) 614 model.model_name = model_name # type: ignore[unresolved-attribute] 615 model.cfg.model_name = model_name 616 n_params: int = sum(p.numel() for p in model.parameters()) 617 print( 618 f"loaded {model_name} with {shorten_numerical_to_str(n_params)} ({n_params}) parameters", 619 ) 620 print(f"\tmodel devices: { {p.device for p in model.parameters()} }") 621 622 save_path_p: Path = Path(save_path) 623 save_path_p.mkdir(parents=True, exist_ok=True) 624 model_path: Path = save_path_p / model_name 625 with SpinnerContext( 626 message=f"saving model info to {_rel_path(model_path)}", 627 **SPINNER_KWARGS, 628 ): 629 model_cfg: HookedTransformerConfig 630 model_cfg = model.cfg 631 model_path.mkdir(parents=True, exist_ok=True) 632 with open(model_path / "model_cfg.json", "w") as f: 633 json.dump(json_serialize(asdict(model_cfg)), f) 634 635 # load prompts 636 with SpinnerContext( 637 message=f"loading prompts from {Path(prompts_path).as_posix()}", 638 **SPINNER_KWARGS, 639 ): 640 prompts: list[dict] 641 if raw_prompts: 642 prompts = load_text_data( 643 Path(prompts_path), 644 min_chars=min_chars, 645 max_chars=max_chars, 646 shuffle=shuffle, 647 ) 648 else: 649 with open(model_path / "prompts.jsonl", "r") as f: 650 prompts = [json.loads(line) for line in f.readlines()] 651 # truncate to n_samples 652 prompts = prompts[:n_samples] 653 654 print(f" {len(prompts)} prompts loaded") 655 656 # write index.html 657 with SpinnerContext( 658 message=f"writing {_rel_path(save_path_p / 'index.html')}", 659 **SPINNER_KWARGS, 660 ): 661 if not no_index_html: 662 write_html_index(save_path_p) 663 664 # TODO: not implemented yet 665 if stacked_heads: 666 raise NotImplementedError("stacked_heads not implemented yet") 667 668 # augment all prompts with hashes 669 for prompt in prompts: 670 augment_prompt_with_hash(prompt) 671 672 # filter out cached prompts 673 if not force: 674 uncached: list[dict] = [ 675 p for p in prompts if not activations_exist(model_name, p, save_path_p) 676 ] 677 n_cached: int = len(prompts) - len(uncached) 678 if n_cached > 0: 679 print(f" {n_cached} prompts already cached, {len(uncached)} to compute") 680 else: 681 uncached = list(prompts) 682 683 if uncached: 684 # sort by token length descending so that: 685 # 1. the longest (slowest, most memory-hungry) batches run first -- 686 # OOM errors surface immediately rather than after all the cheap work, 687 # and tqdm's ETA stabilizes early for better progress estimation 688 # 2. similar-length prompts are grouped together, minimizing padding waste 689 # 690 # pre-tokenization is a separate step from compute_activations_batched because 691 # we need token lengths *before* batching to sort and group. the resulting 692 # seq_lens are then passed through so compute_activations_batched can skip 693 # re-tokenizing each prompt internally. 694 with SpinnerContext( 695 message="pre-tokenizing prompts for length sorting", 696 **SPINNER_KWARGS, 697 ): 698 uncached_with_lens: list[tuple[dict, int]] = [ 699 (p, model.to_tokens(p["text"]).shape[1]) for p in uncached 700 ] 701 uncached_with_lens.sort(key=lambda x: x[1], reverse=True) 702 sorted_uncached: list[dict] = [p for p, _ in uncached_with_lens] 703 sorted_seq_lens: list[int] = [sl for _, sl in uncached_with_lens] 704 705 # process in batches 706 n_prompts: int = len(sorted_uncached) 707 with tqdm.tqdm( 708 total=n_prompts, 709 desc="Computing activations", 710 unit="prompt", 711 ) as pbar: 712 for batch_start in range(0, n_prompts, batch_size): 713 batch_end: int = min(batch_start + batch_size, n_prompts) 714 batch: list[dict] = sorted_uncached[batch_start:batch_end] 715 batch_seq_lens: list[int] = sorted_seq_lens[batch_start:batch_end] 716 pbar.set_postfix( 717 n_ctx=batch_seq_lens[0], 718 ) # longest in batch (sorted descending) 719 compute_activations_batched( 720 prompts=batch, 721 model=model, 722 save_path=save_path_p, 723 seq_lens=batch_seq_lens, 724 ) 725 pbar.update(len(batch)) 726 else: 727 print(" all prompts cached, nothing to compute") 728 729 with SpinnerContext( 730 message="updating jsonl metadata for models and prompts", 731 **SPINNER_KWARGS, 732 ): 733 generate_models_jsonl(save_path_p) 734 generate_prompts_jsonl(save_path_p / model_name) 735 736 737def main() -> None: 738 "generate attention pattern activations for a model and prompts" 739 print(DIVIDER_S1) 740 with SpinnerContext(message="parsing args", **SPINNER_KWARGS): 741 arg_parser: argparse.ArgumentParser = argparse.ArgumentParser() 742 # input and output 743 arg_parser.add_argument( 744 "--model", 745 "-m", 746 type=str, 747 required=True, 748 help="The model name(s) to use. comma separated with no whitespace if multiple", 749 ) 750 751 arg_parser.add_argument( 752 "--prompts", 753 "-p", 754 type=str, 755 required=False, 756 help="The path to the prompts file (jsonl with 'text' key on each line). If `None`, expects that `--figures` is passed and will generate figures for all prompts in the model directory", 757 default=None, 758 ) 759 760 arg_parser.add_argument( 761 "--save-path", 762 "-s", 763 type=str, 764 required=False, 765 help="The path to save the attention patterns", 766 default=DATA_DIR, 767 ) 768 769 # min and max prompt lengths 770 arg_parser.add_argument( 771 "--min-chars", 772 type=int, 773 required=False, 774 help="The minimum number of characters for a prompt", 775 default=100, 776 ) 777 arg_parser.add_argument( 778 "--max-chars", 779 type=int, 780 required=False, 781 help="The maximum number of characters for a prompt", 782 default=1000, 783 ) 784 785 # number of samples 786 arg_parser.add_argument( 787 "--n-samples", 788 "-n", 789 type=int, 790 required=False, 791 help="The max number of samples to process, do all in the file if None", 792 default=None, 793 ) 794 795 # batch size 796 arg_parser.add_argument( 797 "--batch-size", 798 "-b", 799 type=int, 800 required=False, 801 help="Batch size for computing activations (number of prompts per forward pass)", 802 default=32, 803 ) 804 805 # force overwrite 806 arg_parser.add_argument( 807 "--force", 808 "-f", 809 action="store_true", 810 help="If passed, will overwrite existing files", 811 ) 812 813 # no index html 814 arg_parser.add_argument( 815 "--no-index-html", 816 action="store_true", 817 help="If passed, will not write an index.html file for the model", 818 ) 819 820 # raw prompts 821 arg_parser.add_argument( 822 "--raw-prompts", 823 "-r", 824 action="store_true", 825 help="pass if the prompts have not been split and tokenized (still needs keys 'text' and 'meta' for each item)", 826 ) 827 828 # shuffle 829 arg_parser.add_argument( 830 "--shuffle", 831 action="store_true", 832 help="If passed, will shuffle the prompts", 833 ) 834 835 # stack heads 836 arg_parser.add_argument( 837 "--stacked-heads", 838 action="store_true", 839 help="If passed, will stack the heads in the output tensor", 840 ) 841 842 # device 843 arg_parser.add_argument( 844 "--device", 845 type=str, 846 required=False, 847 help="The device to use for the model", 848 default="cuda" if torch.cuda.is_available() else "cpu", 849 ) 850 851 args: argparse.Namespace = arg_parser.parse_args() 852 853 print(f"args parsed: {args}") 854 855 models: list[str] 856 if "," in args.model: 857 models = args.model.split(",") 858 else: 859 models = [args.model] 860 861 n_models: int = len(models) 862 for idx, model in enumerate(models): 863 print(DIVIDER_S2) 864 print(f"processing model {idx + 1} / {n_models}: {model}") 865 print(DIVIDER_S2) 866 867 activations_main( 868 model_name=model, 869 save_path=args.save_path, 870 prompts_path=args.prompts, 871 raw_prompts=args.raw_prompts, 872 min_chars=args.min_chars, 873 max_chars=args.max_chars, 874 force=args.force, 875 n_samples=args.n_samples, 876 no_index_html=args.no_index_html, 877 shuffle=args.shuffle, 878 stacked_heads=args.stacked_heads, 879 device=args.device, 880 batch_size=args.batch_size, 881 ) 882 del model 883 884 print(DIVIDER_S1) 885 886 887if __name__ == "__main__": 888 main()
129def compute_activations( # noqa: PLR0915 130 prompt: dict, 131 model: HookedTransformer | None = None, 132 save_path: Path = Path(DATA_DIR), 133 names_filter: Callable[[str], bool] | re.Pattern = ATTN_PATTERN_REGEX, 134 return_cache: ReturnCache = "torch", 135 stack_heads: bool = False, 136) -> tuple[ 137 Path, 138 ActivationCacheNp 139 | ActivationCache 140 | Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"] 141 | Float[torch.Tensor, "n_layers n_heads n_ctx n_ctx"] 142 | None, 143]: 144 """compute activations for a single prompt and save to disk 145 146 always runs a forward pass -- does NOT load from disk cache. 147 for cache-aware loading, use `get_activations` which tries disk first. 148 149 # Parameters: 150 - `prompt : dict | None` 151 (defaults to `None`) 152 - `model : HookedTransformer` 153 - `save_path : Path` 154 (defaults to `Path(DATA_DIR)`) 155 - `names_filter : Callable[[str], bool]|re.Pattern` 156 a filter for the names of the activations to return. if an `re.Pattern`, will use `lambda key: names_filter.match(key) is not None` 157 (defaults to `ATTN_PATTERN_REGEX`) 158 - `return_cache : Literal[None, "numpy", "torch"]` 159 will return `None` as the second element if `None`, otherwise will return the cache in the specified tensor format. `stack_heads` still affects whether it will be a dict (False) or a single tensor (True) 160 (defaults to `None`) 161 - `stack_heads : bool` 162 whether the heads should be stacked in the output. this causes a number of changes: 163 - `npy` file with a single `(n_layers, n_heads, n_ctx, n_ctx)` tensor saved for each prompt instead of `npz` file with dict by layer 164 - `cache` will be a single `(n_layers, n_heads, n_ctx, n_ctx)` tensor instead of a dict by layer if `return_cache` is `True` 165 will assert that everything in the activation cache is only attention patterns, and is all of the attention patterns. raises an exception if not. 166 167 # Returns: 168 ``` 169 tuple[ 170 Path, 171 Union[ 172 None, 173 ActivationCacheNp, ActivationCache, 174 Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"], Float[torch.Tensor, "n_layers n_heads n_ctx n_ctx"], 175 ] 176 ] 177 ``` 178 """ 179 # check inputs 180 assert model is not None, "model must be passed" 181 assert "text" in prompt, "prompt must contain 'text' key" 182 prompt_str: str = prompt["text"] 183 184 # compute or get prompt metadata 185 assert model.tokenizer is not None 186 prompt_tokenized: list[str] = prompt.get( 187 "tokens", 188 model.tokenizer.tokenize(prompt_str), 189 ) 190 # n_tokens counts subword tokens (no BOS); attention patterns include BOS 191 # so have dim n_tokens+1. see also compute_activations_batched Phase B. 192 prompt.update( 193 dict( 194 n_tokens=len(prompt_tokenized), 195 tokens=prompt_tokenized, 196 ), 197 ) 198 199 # save metadata 200 prompt_dir: Path = save_path / model.cfg.model_name / "prompts" / prompt["hash"] 201 prompt_dir.mkdir(parents=True, exist_ok=True) 202 with open(prompt_dir / "prompt.json", "w") as f: 203 json.dump(prompt, f) 204 205 # set up names filter 206 names_filter_fn: Callable[[str], bool] 207 if isinstance(names_filter, re.Pattern): 208 names_filter_fn = lambda key: names_filter.match(key) is not None # noqa: E731 209 else: 210 names_filter_fn = names_filter 211 212 # compute activations 213 # NOTE: no padding_side kwarg here -- it's only meaningful for multi-sequence 214 # batches where padding is needed. single-string input has no padding. 215 # see compute_activations_batched for the batched path that passes padding_side="right". 216 cache_torch: ActivationCache 217 with torch.no_grad(): 218 model.eval() 219 _, cache_torch = model.run_with_cache( 220 prompt_str, 221 names_filter=names_filter_fn, 222 return_type=None, 223 ) 224 225 activations_path: Path 226 # saving and returning 227 if stack_heads: 228 n_layers: int = model.cfg.n_layers 229 key_pattern: str = "blocks.{i}.attn.hook_pattern" 230 # NOTE: this only works for stacking heads at the moment 231 # activations_specifier: str = key_pattern.format(i=f'0-{n_layers}') 232 activations_specifier: str = key_pattern.format(i="-") 233 activations_path = prompt_dir / f"activations-{activations_specifier}.npy" 234 235 # check the keys are only attention heads 236 head_keys: list[str] = [key_pattern.format(i=i) for i in range(n_layers)] 237 cache_torch_keys_set: set[str] = set(cache_torch.keys()) 238 assert cache_torch_keys_set == set(head_keys), ( 239 f"unexpected keys!\n{set(head_keys).symmetric_difference(cache_torch_keys_set) = }\n{cache_torch_keys_set} != {set(head_keys)}" 240 ) 241 242 # stack heads 243 patterns_stacked: Float[torch.Tensor, "n_layers n_heads n_ctx n_ctx"] = ( 244 torch.stack([cache_torch[k] for k in head_keys], dim=1) 245 ) 246 # check shape 247 pattern_shape_no_ctx: tuple[int, ...] = tuple(patterns_stacked.shape[:3]) 248 assert pattern_shape_no_ctx == (1, n_layers, model.cfg.n_heads), ( 249 f"unexpected shape: {patterns_stacked.shape[:3] = } ({pattern_shape_no_ctx = }), expected {(1, n_layers, model.cfg.n_heads) = }" 250 ) 251 252 patterns_stacked_np: Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"] = ( 253 patterns_stacked.cpu().numpy() 254 ) 255 256 # save 257 np.save(activations_path, patterns_stacked_np) 258 259 # return 260 match return_cache: 261 case "numpy": 262 return activations_path, patterns_stacked_np 263 case "torch": 264 return activations_path, patterns_stacked 265 case None: 266 return activations_path, None 267 case _: 268 msg = f"invalid return_cache: {return_cache = }" 269 raise ValueError(msg) 270 else: 271 activations_path = prompt_dir / "activations.npz" 272 273 # save 274 cache_np: ActivationCacheNp = { 275 k: v.detach().cpu().numpy() for k, v in cache_torch.items() 276 } 277 278 np.savez_compressed( 279 activations_path, 280 **cache_np, # type: ignore[arg-type] 281 ) 282 283 # return 284 match return_cache: 285 case "numpy": 286 return activations_path, cache_np 287 case "torch": 288 return activations_path, cache_torch 289 case None: 290 return activations_path, None 291 case _: 292 msg = f"invalid return_cache: {return_cache = }" 293 raise ValueError(msg)
compute activations for a single prompt and save to disk
always runs a forward pass -- does NOT load from disk cache.
for cache-aware loading, use get_activations which tries disk first.
Parameters:
prompt : dict | None(defaults toNone)model : HookedTransformersave_path : Path(defaults toPath(DATA_DIR))names_filter : Callable[[str], bool]|re.Patterna filter for the names of the activations to return. if anre.Pattern, will uselambda key: names_filter.match(key) is not None(defaults toATTN_PATTERN_REGEX)return_cache : Literal[None, "numpy", "torch"]will returnNoneas the second element ifNone, otherwise will return the cache in the specified tensor format.stack_headsstill affects whether it will be a dict (False) or a single tensor (True) (defaults toNone)stack_heads : boolwhether the heads should be stacked in the output. this causes a number of changes:npyfile with a single(n_layers, n_heads, n_ctx, n_ctx)tensor saved for each prompt instead ofnpzfile with dict by layercachewill be a single(n_layers, n_heads, n_ctx, n_ctx)tensor instead of a dict by layer ifreturn_cacheisTruewill assert that everything in the activation cache is only attention patterns, and is all of the attention patterns. raises an exception if not.
Returns:
tuple[
Path,
Union[
None,
ActivationCacheNp, ActivationCache,
Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"], Float[torch.Tensor, "n_layers n_heads n_ctx n_ctx"],
]
]
296def compute_activations_batched( 297 prompts: list[dict], 298 model: HookedTransformer, 299 save_path: Path = Path(DATA_DIR), 300 names_filter: Callable[[str], bool] | re.Pattern = ATTN_PATTERN_REGEX, 301 seq_lens: list[int] | None = None, 302) -> list[Path]: 303 """compute and save activations for a batch of prompts in a single forward pass 304 305 Batched companion to `compute_activations` -- instead of one forward pass per 306 prompt, this runs a single `model.run_with_cache(list_of_strings)` call for the 307 whole batch. TransformerLens tokenizes and right-pads automatically. Each prompt's 308 attention patterns are then trimmed to their actual (unpadded) size and saved 309 individually, producing files identical to the single-prompt path. 310 311 Does not support `stack_heads` or `return_cache` -- this function is intended for 312 the bulk processing path in `activations_main`, not for interactive use. Use 313 `compute_activations` directly for single-prompt use cases that need those features. 314 315 ## Why right-padding makes trimming correct without an explicit attention mask 316 317 With right-padding, pad tokens sit at positions seq_len, seq_len+1, ..., 318 max_seq_len-1 (higher than any real token). The causal attention mask prevents 319 position i from attending to any j > i. So for real tokens at positions 320 0..seq_len-1, they can only attend to 0..i -- all real tokens. The softmax is computed over the same set of positions 321 as in single-prompt inference, producing identical attention patterns. 322 323 We explicitly pass `padding_side="right"` to `run_with_cache` to guarantee this 324 regardless of the model's default padding side. 325 326 # Parameters: 327 - `prompts : list[dict]` 328 each prompt must contain 'text' and 'hash' keys. call 329 `augment_prompt_with_hash` on each prompt before passing them here. 330 - `model : HookedTransformer` 331 the model to compute activations with 332 - `save_path : Path` 333 path to save the activations to 334 (defaults to `Path(DATA_DIR)`) 335 - `names_filter : Callable[[str], bool] | re.Pattern` 336 filter for which activations to save. must only match activations with 337 4D shape `[batch, n_heads, seq, seq]` (e.g. attention patterns). 338 non-attention activations will cause incorrect trimming. 339 (defaults to `ATTN_PATTERN_REGEX`) 340 - `seq_lens : list[int] | None` 341 pre-computed model sequence lengths per prompt (from `model.to_tokens`). 342 if `None`, will be computed internally. pass this to avoid redundant 343 tokenization when lengths are already known (e.g. from length-sorting). 344 **important**: these must be from `model.to_tokens()` (includes BOS), 345 NOT from `model.tokenizer.tokenize()` (excludes BOS). 346 (defaults to `None`) 347 348 # Returns: 349 - `list[Path]` 350 paths to the saved activations files, one per prompt 351 352 # Modifies: 353 each prompt dict in `prompts` -- adds/overwrites `n_tokens` and `tokens` keys 354 with tokenization metadata (same mutation as `compute_activations`). 355 """ 356 assert model is not None, "model must be passed" 357 assert len(prompts) > 0, "prompts must not be empty" 358 assert "text" in prompts[0], f"prompt must contain 'text' key: {prompts[0].keys()}" 359 assert "hash" in prompts[0], ( 360 f"prompt must contain 'hash' key (call augment_prompt_with_hash first): {prompts[0].keys()}" 361 ) 362 363 # --- Phase A: get actual model sequence lengths --- 364 # model.to_tokens() includes BOS if applicable, matching the attention pattern dims 365 # model.tokenizer.tokenize() gives subword strings WITHOUT BOS, used for metadata 366 # these differ by 1 when BOS is prepended -- using the wrong one for trimming 367 # would silently truncate or include garbage 368 if seq_lens is None: 369 seq_lens = [model.to_tokens(p["text"]).shape[1] for p in prompts] 370 assert len(seq_lens) == len(prompts), ( 371 f"seq_lens length mismatch: {len(seq_lens)} != {len(prompts)}" 372 ) 373 374 # --- Phase B: save prompt metadata (mirrors compute_activations's metadata logic) --- 375 assert model.tokenizer is not None 376 for p in prompts: 377 prompt_str: str = p["text"] 378 prompt_tokenized: list[str] = p.get( 379 "tokens", 380 model.tokenizer.tokenize(prompt_str), 381 ) 382 # n_tokens counts subword tokens (no BOS); attention patterns include BOS so have dim n_tokens+1 383 p.update( 384 dict( 385 n_tokens=len(prompt_tokenized), 386 tokens=prompt_tokenized, 387 ), 388 ) 389 prompt_dir: Path = save_path / model.cfg.model_name / "prompts" / p["hash"] 390 prompt_dir.mkdir(parents=True, exist_ok=True) 391 with open(prompt_dir / "prompt.json", "w") as f: 392 json.dump(p, f) 393 394 # --- Phase C: batched forward pass --- 395 names_filter_fn: Callable[[str], bool] 396 if isinstance(names_filter, re.Pattern): 397 names_filter_fn = lambda key: names_filter.match(key) is not None # noqa: E731 398 else: 399 names_filter_fn = names_filter 400 401 texts: list[str] = [p["text"] for p in prompts] 402 cache_torch: ActivationCache 403 with torch.no_grad(): 404 model.eval() 405 _, cache_torch = model.run_with_cache( 406 texts, 407 names_filter=names_filter_fn, 408 return_type=None, 409 padding_side="right", 410 ) 411 412 # --- Phase D: split, trim padding, and save per-prompt --- 413 # For each prompt i with actual sequence length seq_len_i: 414 # v[i : i+1, :, :seq_len_i, :seq_len_i] 415 # ^^^^^^^ i:i+1 not i -- keeps batch dim [1,...] for 416 # format compatibility with compute_activations 417 # ^^ all attention heads 418 # ^^^^^^^^^^ ^^^^^^^^^^ trim both query and key dims to actual length, 419 # discarding meaningless padding positions 420 paths: list[Path] = [] 421 for i, (prompt, seq_len) in enumerate(zip(prompts, seq_lens, strict=True)): 422 prompt_dir = save_path / model.cfg.model_name / "prompts" / prompt["hash"] 423 activations_path: Path = prompt_dir / "activations.npz" 424 cache_np: ActivationCacheNp = {} 425 for k, v in cache_torch.items(): 426 assert v.ndim == 4, ( # noqa: PLR2004 427 f"expected 4D attention pattern tensor for {k!r}, " 428 f"got shape {v.shape}. names_filter must only match " 429 f"attention pattern activations [batch, n_heads, seq, seq]" 430 ) 431 cache_np[k] = v[i : i + 1, :, :seq_len, :seq_len].detach().cpu().numpy() 432 433 np.savez_compressed( 434 activations_path, 435 **cache_np, # type: ignore[arg-type] 436 ) 437 paths.append(activations_path) 438 439 return paths
compute and save activations for a batch of prompts in a single forward pass
Batched companion to compute_activations -- instead of one forward pass per
prompt, this runs a single model.run_with_cache(list_of_strings) call for the
whole batch. TransformerLens tokenizes and right-pads automatically. Each prompt's
attention patterns are then trimmed to their actual (unpadded) size and saved
individually, producing files identical to the single-prompt path.
Does not support stack_heads or return_cache -- this function is intended for
the bulk processing path in activations_main, not for interactive use. Use
compute_activations directly for single-prompt use cases that need those features.
Why right-padding makes trimming correct without an explicit attention mask
With right-padding, pad tokens sit at positions seq_len, seq_len+1, ..., max_seq_len-1 (higher than any real token). The causal attention mask prevents position i from attending to any j > i. So for real tokens at positions 0..seq_len-1, they can only attend to 0..i -- all real tokens. The softmax is computed over the same set of positions as in single-prompt inference, producing identical attention patterns.
We explicitly pass padding_side="right" to run_with_cache to guarantee this
regardless of the model's default padding side.
Parameters:
prompts : list[dict]each prompt must contain 'text' and 'hash' keys. callaugment_prompt_with_hashon each prompt before passing them here.model : HookedTransformerthe model to compute activations withsave_path : Pathpath to save the activations to (defaults toPath(DATA_DIR))names_filter : Callable[[str], bool] | re.Patternfilter for which activations to save. must only match activations with 4D shape[batch, n_heads, seq, seq](e.g. attention patterns). non-attention activations will cause incorrect trimming. (defaults toATTN_PATTERN_REGEX)seq_lens : list[int] | Nonepre-computed model sequence lengths per prompt (frommodel.to_tokens). ifNone, will be computed internally. pass this to avoid redundant tokenization when lengths are already known (e.g. from length-sorting). important: these must be frommodel.to_tokens()(includes BOS), NOT frommodel.tokenizer.tokenize()(excludes BOS). (defaults toNone)
Returns:
list[Path]paths to the saved activations files, one per prompt
Modifies:
each prompt dict in prompts -- adds/overwrites n_tokens and tokens keys
with tokenization metadata (same mutation as compute_activations).
466def get_activations( 467 prompt: dict, 468 model: HookedTransformer | str, 469 save_path: Path = Path(DATA_DIR), 470 allow_disk_cache: bool = True, 471 return_cache: ReturnCache = "numpy", 472) -> tuple[Path, ActivationCacheNp | ActivationCache | None]: 473 """given a prompt and a model, save or load activations 474 475 # Parameters: 476 - `prompt : dict` 477 expected to contain the 'text' key 478 - `model : HookedTransformer | str` 479 either a `HookedTransformer` or a string model name, to be loaded with `HookedTransformer.from_pretrained` 480 - `save_path : Path` 481 path to save the activations to (and load from) 482 (defaults to `Path(DATA_DIR)`) 483 - `allow_disk_cache : bool` 484 whether to allow loading from disk cache 485 (defaults to `True`) 486 - `return_cache : Literal[None, "numpy", "torch"]` 487 whether to return the cache, and in what format 488 (defaults to `"numpy"`) 489 490 # Returns: 491 - `tuple[Path, ActivationCacheNp | ActivationCache | None]` 492 the path to the activations and the cache if `return_cache is not None` 493 494 """ 495 # add hash to prompt 496 augment_prompt_with_hash(prompt) 497 498 # get the model 499 model_name: str = ( 500 model.cfg.model_name if isinstance(model, HookedTransformer) else model 501 ) 502 503 # from cache 504 if allow_disk_cache: 505 if return_cache is None: 506 # fast path: check file existence without loading data into memory. 507 # activations_exist just calls .exists() on two paths, whereas 508 # load_activations would decompress the full .npz into numpy arrays 509 # only for us to discard them immediately. 510 if activations_exist(model_name, prompt, save_path): 511 prompt_dir: Path = save_path / model_name / "prompts" / prompt["hash"] 512 return prompt_dir / "activations.npz", None 513 else: 514 try: 515 path, cache = load_activations( 516 model_name=model_name, 517 prompt=prompt, 518 save_path=save_path, 519 ) 520 except ActivationsMissingError: 521 pass 522 else: 523 return path, cache 524 525 # compute them 526 if isinstance(model, str): 527 model = HookedTransformer.from_pretrained(model_name) 528 529 return compute_activations( # type: ignore[return-value] 530 prompt=prompt, 531 model=model, 532 save_path=save_path, 533 return_cache=return_cache, 534 )
given a prompt and a model, save or load activations
Parameters:
prompt : dictexpected to contain the 'text' keymodel : HookedTransformer | streither aHookedTransformeror a string model name, to be loaded withHookedTransformer.from_pretrainedsave_path : Pathpath to save the activations to (and load from) (defaults toPath(DATA_DIR))allow_disk_cache : boolwhether to allow loading from disk cache (defaults toTrue)return_cache : Literal[None, "numpy", "torch"]whether to return the cache, and in what format (defaults to"numpy")
Returns:
tuple[Path, ActivationCacheNp | ActivationCache | None]the path to the activations and the cache ifreturn_cache is not None
542def activations_main( # noqa: C901, PLR0912, PLR0915 543 model_name: str, 544 save_path: str | Path, 545 prompts_path: str, 546 raw_prompts: bool, 547 min_chars: int, 548 max_chars: int, 549 force: bool, 550 n_samples: int, 551 no_index_html: bool, 552 shuffle: bool = False, 553 stacked_heads: bool = False, 554 device: str | torch.device = DEFAULT_DEVICE, 555 batch_size: int = 32, 556) -> None: 557 """main function for computing activations 558 559 # Parameters: 560 - `model_name : str` 561 name of a model to load with `HookedTransformer.from_pretrained` 562 - `save_path : str | Path` 563 path to save the activations to 564 - `prompts_path : str` 565 path to the prompts file 566 - `raw_prompts : bool` 567 whether the prompts are raw, not filtered by length. `load_text_data` will be called if `True`, otherwise just load the "text" field from each line in `prompts_path` 568 - `min_chars : int` 569 minimum number of characters for a prompt 570 - `max_chars : int` 571 maximum number of characters for a prompt 572 - `force : bool` 573 whether to overwrite existing files 574 - `n_samples : int` 575 maximum number of samples to process 576 - `no_index_html : bool` 577 whether to write an index.html file 578 - `shuffle : bool` 579 whether to shuffle the prompts 580 (defaults to `False`) 581 - `stacked_heads : bool` 582 whether to stack the heads in the output tensor. will save as `.npy` instead of `.npz` if `True` 583 (defaults to `False`) 584 - `device : str | torch.device` 585 the device to use. if a string, will be passed to `torch.device` 586 - `batch_size : int` 587 number of prompts per forward pass. prompts are sorted by token length 588 (longest first) and grouped so that similar-length prompts share a batch, 589 minimizing padding waste. use `batch_size=1` for one prompt per forward 590 pass (largely equivalent to the old sequential behavior, but note: prompts 591 are still sorted by length and cache checking uses file-existence only, 592 unlike the old path which processed prompts in order and validated cache 593 contents via `load_activations`). 594 the single-prompt functions `compute_activations` and `get_activations` 595 are still available for programmatic use outside of `activations_main`. 596 (defaults to `32`) 597 """ 598 # figure out the device to use 599 device_: torch.device 600 if isinstance(device, torch.device): 601 device_ = device 602 elif isinstance(device, str): 603 device_ = torch.device(device) 604 else: 605 msg = f"invalid device: {device}" 606 raise TypeError(msg) 607 608 print(f"using device: {device_}") 609 610 with SpinnerContext(message="loading model", **SPINNER_KWARGS): 611 model: HookedTransformer = HookedTransformer.from_pretrained( 612 model_name, 613 device=device_, 614 ) 615 model.model_name = model_name # type: ignore[unresolved-attribute] 616 model.cfg.model_name = model_name 617 n_params: int = sum(p.numel() for p in model.parameters()) 618 print( 619 f"loaded {model_name} with {shorten_numerical_to_str(n_params)} ({n_params}) parameters", 620 ) 621 print(f"\tmodel devices: { {p.device for p in model.parameters()} }") 622 623 save_path_p: Path = Path(save_path) 624 save_path_p.mkdir(parents=True, exist_ok=True) 625 model_path: Path = save_path_p / model_name 626 with SpinnerContext( 627 message=f"saving model info to {_rel_path(model_path)}", 628 **SPINNER_KWARGS, 629 ): 630 model_cfg: HookedTransformerConfig 631 model_cfg = model.cfg 632 model_path.mkdir(parents=True, exist_ok=True) 633 with open(model_path / "model_cfg.json", "w") as f: 634 json.dump(json_serialize(asdict(model_cfg)), f) 635 636 # load prompts 637 with SpinnerContext( 638 message=f"loading prompts from {Path(prompts_path).as_posix()}", 639 **SPINNER_KWARGS, 640 ): 641 prompts: list[dict] 642 if raw_prompts: 643 prompts = load_text_data( 644 Path(prompts_path), 645 min_chars=min_chars, 646 max_chars=max_chars, 647 shuffle=shuffle, 648 ) 649 else: 650 with open(model_path / "prompts.jsonl", "r") as f: 651 prompts = [json.loads(line) for line in f.readlines()] 652 # truncate to n_samples 653 prompts = prompts[:n_samples] 654 655 print(f" {len(prompts)} prompts loaded") 656 657 # write index.html 658 with SpinnerContext( 659 message=f"writing {_rel_path(save_path_p / 'index.html')}", 660 **SPINNER_KWARGS, 661 ): 662 if not no_index_html: 663 write_html_index(save_path_p) 664 665 # TODO: not implemented yet 666 if stacked_heads: 667 raise NotImplementedError("stacked_heads not implemented yet") 668 669 # augment all prompts with hashes 670 for prompt in prompts: 671 augment_prompt_with_hash(prompt) 672 673 # filter out cached prompts 674 if not force: 675 uncached: list[dict] = [ 676 p for p in prompts if not activations_exist(model_name, p, save_path_p) 677 ] 678 n_cached: int = len(prompts) - len(uncached) 679 if n_cached > 0: 680 print(f" {n_cached} prompts already cached, {len(uncached)} to compute") 681 else: 682 uncached = list(prompts) 683 684 if uncached: 685 # sort by token length descending so that: 686 # 1. the longest (slowest, most memory-hungry) batches run first -- 687 # OOM errors surface immediately rather than after all the cheap work, 688 # and tqdm's ETA stabilizes early for better progress estimation 689 # 2. similar-length prompts are grouped together, minimizing padding waste 690 # 691 # pre-tokenization is a separate step from compute_activations_batched because 692 # we need token lengths *before* batching to sort and group. the resulting 693 # seq_lens are then passed through so compute_activations_batched can skip 694 # re-tokenizing each prompt internally. 695 with SpinnerContext( 696 message="pre-tokenizing prompts for length sorting", 697 **SPINNER_KWARGS, 698 ): 699 uncached_with_lens: list[tuple[dict, int]] = [ 700 (p, model.to_tokens(p["text"]).shape[1]) for p in uncached 701 ] 702 uncached_with_lens.sort(key=lambda x: x[1], reverse=True) 703 sorted_uncached: list[dict] = [p for p, _ in uncached_with_lens] 704 sorted_seq_lens: list[int] = [sl for _, sl in uncached_with_lens] 705 706 # process in batches 707 n_prompts: int = len(sorted_uncached) 708 with tqdm.tqdm( 709 total=n_prompts, 710 desc="Computing activations", 711 unit="prompt", 712 ) as pbar: 713 for batch_start in range(0, n_prompts, batch_size): 714 batch_end: int = min(batch_start + batch_size, n_prompts) 715 batch: list[dict] = sorted_uncached[batch_start:batch_end] 716 batch_seq_lens: list[int] = sorted_seq_lens[batch_start:batch_end] 717 pbar.set_postfix( 718 n_ctx=batch_seq_lens[0], 719 ) # longest in batch (sorted descending) 720 compute_activations_batched( 721 prompts=batch, 722 model=model, 723 save_path=save_path_p, 724 seq_lens=batch_seq_lens, 725 ) 726 pbar.update(len(batch)) 727 else: 728 print(" all prompts cached, nothing to compute") 729 730 with SpinnerContext( 731 message="updating jsonl metadata for models and prompts", 732 **SPINNER_KWARGS, 733 ): 734 generate_models_jsonl(save_path_p) 735 generate_prompts_jsonl(save_path_p / model_name)
main function for computing activations
Parameters:
model_name : strname of a model to load withHookedTransformer.from_pretrainedsave_path : str | Pathpath to save the activations toprompts_path : strpath to the prompts fileraw_prompts : boolwhether the prompts are raw, not filtered by length.load_text_datawill be called ifTrue, otherwise just load the "text" field from each line inprompts_pathmin_chars : intminimum number of characters for a promptmax_chars : intmaximum number of characters for a promptforce : boolwhether to overwrite existing filesn_samples : intmaximum number of samples to processno_index_html : boolwhether to write an index.html fileshuffle : boolwhether to shuffle the prompts (defaults toFalse)stacked_heads : boolwhether to stack the heads in the output tensor. will save as.npyinstead of.npzifTrue(defaults toFalse)device : str | torch.devicethe device to use. if a string, will be passed totorch.devicebatch_size : intnumber of prompts per forward pass. prompts are sorted by token length (longest first) and grouped so that similar-length prompts share a batch, minimizing padding waste. usebatch_size=1for one prompt per forward pass (largely equivalent to the old sequential behavior, but note: prompts are still sorted by length and cache checking uses file-existence only, unlike the old path which processed prompts in order and validated cache contents viaload_activations). the single-prompt functionscompute_activationsandget_activationsare still available for programmatic use outside ofactivations_main. (defaults to32)
738def main() -> None: 739 "generate attention pattern activations for a model and prompts" 740 print(DIVIDER_S1) 741 with SpinnerContext(message="parsing args", **SPINNER_KWARGS): 742 arg_parser: argparse.ArgumentParser = argparse.ArgumentParser() 743 # input and output 744 arg_parser.add_argument( 745 "--model", 746 "-m", 747 type=str, 748 required=True, 749 help="The model name(s) to use. comma separated with no whitespace if multiple", 750 ) 751 752 arg_parser.add_argument( 753 "--prompts", 754 "-p", 755 type=str, 756 required=False, 757 help="The path to the prompts file (jsonl with 'text' key on each line). If `None`, expects that `--figures` is passed and will generate figures for all prompts in the model directory", 758 default=None, 759 ) 760 761 arg_parser.add_argument( 762 "--save-path", 763 "-s", 764 type=str, 765 required=False, 766 help="The path to save the attention patterns", 767 default=DATA_DIR, 768 ) 769 770 # min and max prompt lengths 771 arg_parser.add_argument( 772 "--min-chars", 773 type=int, 774 required=False, 775 help="The minimum number of characters for a prompt", 776 default=100, 777 ) 778 arg_parser.add_argument( 779 "--max-chars", 780 type=int, 781 required=False, 782 help="The maximum number of characters for a prompt", 783 default=1000, 784 ) 785 786 # number of samples 787 arg_parser.add_argument( 788 "--n-samples", 789 "-n", 790 type=int, 791 required=False, 792 help="The max number of samples to process, do all in the file if None", 793 default=None, 794 ) 795 796 # batch size 797 arg_parser.add_argument( 798 "--batch-size", 799 "-b", 800 type=int, 801 required=False, 802 help="Batch size for computing activations (number of prompts per forward pass)", 803 default=32, 804 ) 805 806 # force overwrite 807 arg_parser.add_argument( 808 "--force", 809 "-f", 810 action="store_true", 811 help="If passed, will overwrite existing files", 812 ) 813 814 # no index html 815 arg_parser.add_argument( 816 "--no-index-html", 817 action="store_true", 818 help="If passed, will not write an index.html file for the model", 819 ) 820 821 # raw prompts 822 arg_parser.add_argument( 823 "--raw-prompts", 824 "-r", 825 action="store_true", 826 help="pass if the prompts have not been split and tokenized (still needs keys 'text' and 'meta' for each item)", 827 ) 828 829 # shuffle 830 arg_parser.add_argument( 831 "--shuffle", 832 action="store_true", 833 help="If passed, will shuffle the prompts", 834 ) 835 836 # stack heads 837 arg_parser.add_argument( 838 "--stacked-heads", 839 action="store_true", 840 help="If passed, will stack the heads in the output tensor", 841 ) 842 843 # device 844 arg_parser.add_argument( 845 "--device", 846 type=str, 847 required=False, 848 help="The device to use for the model", 849 default="cuda" if torch.cuda.is_available() else "cpu", 850 ) 851 852 args: argparse.Namespace = arg_parser.parse_args() 853 854 print(f"args parsed: {args}") 855 856 models: list[str] 857 if "," in args.model: 858 models = args.model.split(",") 859 else: 860 models = [args.model] 861 862 n_models: int = len(models) 863 for idx, model in enumerate(models): 864 print(DIVIDER_S2) 865 print(f"processing model {idx + 1} / {n_models}: {model}") 866 print(DIVIDER_S2) 867 868 activations_main( 869 model_name=model, 870 save_path=args.save_path, 871 prompts_path=args.prompts, 872 raw_prompts=args.raw_prompts, 873 min_chars=args.min_chars, 874 max_chars=args.max_chars, 875 force=args.force, 876 n_samples=args.n_samples, 877 no_index_html=args.no_index_html, 878 shuffle=args.shuffle, 879 stacked_heads=args.stacked_heads, 880 device=args.device, 881 batch_size=args.batch_size, 882 ) 883 del model 884 885 print(DIVIDER_S1)
generate attention pattern activations for a model and prompts