docs for pattern_lens v0.6.0
View Source on GitHub

pattern_lens.activations

computing and saving activations given a model and prompts

Usage:

from the command line:

python -m pattern_lens.activations --model <model_name> --prompts <prompts_path> --save-path <save_path> --min-chars <min_chars> --max-chars <max_chars> --n-samples <n_samples>

from a script:

from pattern_lens.activations import activations_main
activations_main(
        model_name="gpt2",
        save_path="demo/"
        prompts_path="data/pile_1k.jsonl",
)

  1"""computing and saving activations given a model and prompts
  2
  3# Usage:
  4
  5from the command line:
  6
  7```bash
  8python -m pattern_lens.activations --model <model_name> --prompts <prompts_path> --save-path <save_path> --min-chars <min_chars> --max-chars <max_chars> --n-samples <n_samples>
  9```
 10
 11from a script:
 12
 13```python
 14from pattern_lens.activations import activations_main
 15activations_main(
 16	model_name="gpt2",
 17	save_path="demo/"
 18	prompts_path="data/pile_1k.jsonl",
 19)
 20```
 21
 22"""
 23
 24import argparse
 25import json
 26import re
 27from collections.abc import Callable
 28from dataclasses import asdict
 29from pathlib import Path
 30from typing import Literal, overload
 31
 32import numpy as np
 33import torch
 34import tqdm
 35from jaxtyping import Float
 36from muutils.json_serialize import json_serialize
 37from muutils.misc.numerical import shorten_numerical_to_str
 38
 39# custom utils
 40from muutils.spinner import SpinnerContext
 41from transformer_lens import (  # type: ignore[import-untyped]
 42	ActivationCache,
 43	HookedTransformer,
 44	HookedTransformerConfig,
 45)
 46
 47# pattern_lens
 48from pattern_lens.consts import (
 49	ATTN_PATTERN_REGEX,
 50	DATA_DIR,
 51	DIVIDER_S1,
 52	DIVIDER_S2,
 53	SPINNER_KWARGS,
 54	ActivationCacheNp,
 55	ReturnCache,
 56)
 57from pattern_lens.indexes import (
 58	generate_models_jsonl,
 59	generate_prompts_jsonl,
 60	write_html_index,
 61)
 62from pattern_lens.load_activations import (
 63	ActivationsMissingError,
 64	activations_exist,
 65	augment_prompt_with_hash,
 66	load_activations,
 67)
 68from pattern_lens.prompts import load_text_data
 69
 70
 71def _rel_path(p: Path) -> str:
 72	"""Return path relative to cwd if possible, otherwise absolute."""
 73	try:
 74		return p.relative_to(Path.cwd()).as_posix()
 75	except ValueError:
 76		return p.as_posix()
 77
 78
 79# return nothing, but `stack_heads` still affects how we save the activations
 80@overload
 81def compute_activations(
 82	prompt: dict,
 83	model: HookedTransformer | None = None,
 84	save_path: Path = Path(DATA_DIR),
 85	names_filter: Callable[[str], bool] | re.Pattern = ATTN_PATTERN_REGEX,
 86	return_cache: None = None,
 87	stack_heads: bool = False,
 88) -> tuple[Path, None]: ...
 89# return stacked heads in numpy or torch form
 90@overload
 91def compute_activations(
 92	prompt: dict,
 93	model: HookedTransformer | None = None,
 94	save_path: Path = Path(DATA_DIR),
 95	names_filter: Callable[[str], bool] | re.Pattern = ATTN_PATTERN_REGEX,
 96	return_cache: Literal["torch"] = "torch",
 97	stack_heads: Literal[True] = True,
 98) -> tuple[Path, Float[torch.Tensor, "n_layers n_heads n_ctx n_ctx"]]: ...
 99@overload
100def compute_activations(
101	prompt: dict,
102	model: HookedTransformer | None = None,
103	save_path: Path = Path(DATA_DIR),
104	names_filter: Callable[[str], bool] | re.Pattern = ATTN_PATTERN_REGEX,
105	return_cache: Literal["numpy"] = "numpy",
106	stack_heads: Literal[True] = True,
107) -> tuple[Path, Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"]]: ...
108# return dicts in numpy or torch form
109@overload
110def compute_activations(
111	prompt: dict,
112	model: HookedTransformer | None = None,
113	save_path: Path = Path(DATA_DIR),
114	names_filter: Callable[[str], bool] | re.Pattern = ATTN_PATTERN_REGEX,
115	return_cache: Literal["numpy"] = "numpy",
116	stack_heads: Literal[False] = False,
117) -> tuple[Path, ActivationCacheNp]: ...
118@overload
119def compute_activations(
120	prompt: dict,
121	model: HookedTransformer | None = None,
122	save_path: Path = Path(DATA_DIR),
123	names_filter: Callable[[str], bool] | re.Pattern = ATTN_PATTERN_REGEX,
124	return_cache: Literal["torch"] = "torch",
125	stack_heads: Literal[False] = False,
126) -> tuple[Path, ActivationCache]: ...
127# actual function body
128def compute_activations(  # noqa: PLR0915
129	prompt: dict,
130	model: HookedTransformer | None = None,
131	save_path: Path = Path(DATA_DIR),
132	names_filter: Callable[[str], bool] | re.Pattern = ATTN_PATTERN_REGEX,
133	return_cache: ReturnCache = "torch",
134	stack_heads: bool = False,
135) -> tuple[
136	Path,
137	ActivationCacheNp
138	| ActivationCache
139	| Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"]
140	| Float[torch.Tensor, "n_layers n_heads n_ctx n_ctx"]
141	| None,
142]:
143	"""compute activations for a single prompt and save to disk
144
145	always runs a forward pass -- does NOT load from disk cache.
146	for cache-aware loading, use `get_activations` which tries disk first.
147
148	# Parameters:
149	- `prompt : dict | None`
150		(defaults to `None`)
151	- `model : HookedTransformer`
152	- `save_path : Path`
153		(defaults to `Path(DATA_DIR)`)
154	- `names_filter : Callable[[str], bool]|re.Pattern`
155		a filter for the names of the activations to return. if an `re.Pattern`, will use `lambda key: names_filter.match(key) is not None`
156		(defaults to `ATTN_PATTERN_REGEX`)
157	- `return_cache : Literal[None, "numpy", "torch"]`
158		will return `None` as the second element if `None`, otherwise will return the cache in the specified tensor format. `stack_heads` still affects whether it will be a dict (False) or a single tensor (True)
159		(defaults to `None`)
160	- `stack_heads : bool`
161		whether the heads should be stacked in the output. this causes a number of changes:
162	- `npy` file with a single `(n_layers, n_heads, n_ctx, n_ctx)` tensor saved for each prompt instead of `npz` file with dict by layer
163	- `cache` will be a single `(n_layers, n_heads, n_ctx, n_ctx)` tensor instead of a dict by layer if `return_cache` is `True`
164		will assert that everything in the activation cache is only attention patterns, and is all of the attention patterns. raises an exception if not.
165
166	# Returns:
167	```
168	tuple[
169		Path,
170		Union[
171			None,
172			ActivationCacheNp, ActivationCache,
173			Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"], Float[torch.Tensor, "n_layers n_heads n_ctx n_ctx"],
174		]
175	]
176	```
177	"""
178	# check inputs
179	assert model is not None, "model must be passed"
180	assert "text" in prompt, "prompt must contain 'text' key"
181	prompt_str: str = prompt["text"]
182
183	# compute or get prompt metadata
184	assert model.tokenizer is not None
185	prompt_tokenized: list[str] = prompt.get(
186		"tokens",
187		model.tokenizer.tokenize(prompt_str),
188	)
189	# n_tokens counts subword tokens (no BOS); attention patterns include BOS
190	# so have dim n_tokens+1. see also compute_activations_batched Phase B.
191	prompt.update(
192		dict(
193			n_tokens=len(prompt_tokenized),
194			tokens=prompt_tokenized,
195		),
196	)
197
198	# save metadata
199	prompt_dir: Path = save_path / model.cfg.model_name / "prompts" / prompt["hash"]
200	prompt_dir.mkdir(parents=True, exist_ok=True)
201	with open(prompt_dir / "prompt.json", "w") as f:
202		json.dump(prompt, f)
203
204	# set up names filter
205	names_filter_fn: Callable[[str], bool]
206	if isinstance(names_filter, re.Pattern):
207		names_filter_fn = lambda key: names_filter.match(key) is not None  # noqa: E731
208	else:
209		names_filter_fn = names_filter
210
211	# compute activations
212	# NOTE: no padding_side kwarg here -- it's only meaningful for multi-sequence
213	# batches where padding is needed. single-string input has no padding.
214	# see compute_activations_batched for the batched path that passes padding_side="right".
215	cache_torch: ActivationCache
216	with torch.no_grad():
217		model.eval()
218		_, cache_torch = model.run_with_cache(
219			prompt_str,
220			names_filter=names_filter_fn,
221			return_type=None,
222		)
223
224	activations_path: Path
225	# saving and returning
226	if stack_heads:
227		n_layers: int = model.cfg.n_layers
228		key_pattern: str = "blocks.{i}.attn.hook_pattern"
229		# NOTE: this only works for stacking heads at the moment
230		# activations_specifier: str = key_pattern.format(i=f'0-{n_layers}')
231		activations_specifier: str = key_pattern.format(i="-")
232		activations_path = prompt_dir / f"activations-{activations_specifier}.npy"
233
234		# check the keys are only attention heads
235		head_keys: list[str] = [key_pattern.format(i=i) for i in range(n_layers)]
236		cache_torch_keys_set: set[str] = set(cache_torch.keys())
237		assert cache_torch_keys_set == set(head_keys), (
238			f"unexpected keys!\n{set(head_keys).symmetric_difference(cache_torch_keys_set) = }\n{cache_torch_keys_set} != {set(head_keys)}"
239		)
240
241		# stack heads
242		patterns_stacked: Float[torch.Tensor, "n_layers n_heads n_ctx n_ctx"] = (
243			torch.stack([cache_torch[k] for k in head_keys], dim=1)
244		)
245		# check shape
246		pattern_shape_no_ctx: tuple[int, ...] = tuple(patterns_stacked.shape[:3])
247		assert pattern_shape_no_ctx == (1, n_layers, model.cfg.n_heads), (
248			f"unexpected shape: {patterns_stacked.shape[:3] = } ({pattern_shape_no_ctx = }), expected {(1, n_layers, model.cfg.n_heads) = }"
249		)
250
251		patterns_stacked_np: Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"] = (
252			patterns_stacked.cpu().numpy()
253		)
254
255		# save
256		np.save(activations_path, patterns_stacked_np)
257
258		# return
259		match return_cache:
260			case "numpy":
261				return activations_path, patterns_stacked_np
262			case "torch":
263				return activations_path, patterns_stacked
264			case None:
265				return activations_path, None
266			case _:
267				msg = f"invalid return_cache: {return_cache = }"
268				raise ValueError(msg)
269	else:
270		activations_path = prompt_dir / "activations.npz"
271
272		# save
273		cache_np: ActivationCacheNp = {
274			k: v.detach().cpu().numpy() for k, v in cache_torch.items()
275		}
276
277		np.savez_compressed(
278			activations_path,
279			**cache_np,  # type: ignore[arg-type]
280		)
281
282		# return
283		match return_cache:
284			case "numpy":
285				return activations_path, cache_np
286			case "torch":
287				return activations_path, cache_torch
288			case None:
289				return activations_path, None
290			case _:
291				msg = f"invalid return_cache: {return_cache = }"
292				raise ValueError(msg)
293
294
295def compute_activations_batched(
296	prompts: list[dict],
297	model: HookedTransformer,
298	save_path: Path = Path(DATA_DIR),
299	names_filter: Callable[[str], bool] | re.Pattern = ATTN_PATTERN_REGEX,
300	seq_lens: list[int] | None = None,
301) -> list[Path]:
302	"""compute and save activations for a batch of prompts in a single forward pass
303
304	Batched companion to `compute_activations` -- instead of one forward pass per
305	prompt, this runs a single `model.run_with_cache(list_of_strings)` call for the
306	whole batch. TransformerLens tokenizes and right-pads automatically. Each prompt's
307	attention patterns are then trimmed to their actual (unpadded) size and saved
308	individually, producing files identical to the single-prompt path.
309
310	Does not support `stack_heads` or `return_cache` -- this function is intended for
311	the bulk processing path in `activations_main`, not for interactive use. Use
312	`compute_activations` directly for single-prompt use cases that need those features.
313
314	## Why right-padding makes trimming correct without an explicit attention mask
315
316	With right-padding, pad tokens sit at positions seq_len, seq_len+1, ...,
317	max_seq_len-1 (higher than any real token). The causal attention mask prevents
318	position i from attending to any j > i. So for real tokens at positions
319	0..seq_len-1, they can only attend to 0..i -- all real tokens. The softmax is computed over the same set of positions
320	as in single-prompt inference, producing identical attention patterns.
321
322	We explicitly pass `padding_side="right"` to `run_with_cache` to guarantee this
323	regardless of the model's default padding side.
324
325	# Parameters:
326	- `prompts : list[dict]`
327		each prompt must contain 'text' and 'hash' keys. call
328		`augment_prompt_with_hash` on each prompt before passing them here.
329	- `model : HookedTransformer`
330		the model to compute activations with
331	- `save_path : Path`
332		path to save the activations to
333		(defaults to `Path(DATA_DIR)`)
334	- `names_filter : Callable[[str], bool] | re.Pattern`
335		filter for which activations to save. must only match activations with
336		4D shape `[batch, n_heads, seq, seq]` (e.g. attention patterns).
337		non-attention activations will cause incorrect trimming.
338		(defaults to `ATTN_PATTERN_REGEX`)
339	- `seq_lens : list[int] | None`
340		pre-computed model sequence lengths per prompt (from `model.to_tokens`).
341		if `None`, will be computed internally. pass this to avoid redundant
342		tokenization when lengths are already known (e.g. from length-sorting).
343		**important**: these must be from `model.to_tokens()` (includes BOS),
344		NOT from `model.tokenizer.tokenize()` (excludes BOS).
345		(defaults to `None`)
346
347	# Returns:
348	- `list[Path]`
349		paths to the saved activations files, one per prompt
350
351	# Modifies:
352	each prompt dict in `prompts` -- adds/overwrites `n_tokens` and `tokens` keys
353	with tokenization metadata (same mutation as `compute_activations`).
354	"""
355	assert model is not None, "model must be passed"
356	assert len(prompts) > 0, "prompts must not be empty"
357	assert "text" in prompts[0], f"prompt must contain 'text' key: {prompts[0].keys()}"
358	assert "hash" in prompts[0], (
359		f"prompt must contain 'hash' key (call augment_prompt_with_hash first): {prompts[0].keys()}"
360	)
361
362	# --- Phase A: get actual model sequence lengths ---
363	# model.to_tokens() includes BOS if applicable, matching the attention pattern dims
364	# model.tokenizer.tokenize() gives subword strings WITHOUT BOS, used for metadata
365	# these differ by 1 when BOS is prepended -- using the wrong one for trimming
366	# would silently truncate or include garbage
367	if seq_lens is None:
368		seq_lens = [model.to_tokens(p["text"]).shape[1] for p in prompts]
369	assert len(seq_lens) == len(prompts), (
370		f"seq_lens length mismatch: {len(seq_lens)} != {len(prompts)}"
371	)
372
373	# --- Phase B: save prompt metadata (mirrors compute_activations's metadata logic) ---
374	assert model.tokenizer is not None
375	for p in prompts:
376		prompt_str: str = p["text"]
377		prompt_tokenized: list[str] = p.get(
378			"tokens",
379			model.tokenizer.tokenize(prompt_str),
380		)
381		# n_tokens counts subword tokens (no BOS); attention patterns include BOS so have dim n_tokens+1
382		p.update(
383			dict(
384				n_tokens=len(prompt_tokenized),
385				tokens=prompt_tokenized,
386			),
387		)
388		prompt_dir: Path = save_path / model.cfg.model_name / "prompts" / p["hash"]
389		prompt_dir.mkdir(parents=True, exist_ok=True)
390		with open(prompt_dir / "prompt.json", "w") as f:
391			json.dump(p, f)
392
393	# --- Phase C: batched forward pass ---
394	names_filter_fn: Callable[[str], bool]
395	if isinstance(names_filter, re.Pattern):
396		names_filter_fn = lambda key: names_filter.match(key) is not None  # noqa: E731
397	else:
398		names_filter_fn = names_filter
399
400	texts: list[str] = [p["text"] for p in prompts]
401	cache_torch: ActivationCache
402	with torch.no_grad():
403		model.eval()
404		_, cache_torch = model.run_with_cache(
405			texts,
406			names_filter=names_filter_fn,
407			return_type=None,
408			padding_side="right",
409		)
410
411	# --- Phase D: split, trim padding, and save per-prompt ---
412	# For each prompt i with actual sequence length seq_len_i:
413	#   v[i : i+1, :, :seq_len_i, :seq_len_i]
414	#     ^^^^^^^                               i:i+1 not i -- keeps batch dim [1,...] for
415	#                                           format compatibility with compute_activations
416	#              ^^                           all attention heads
417	#                  ^^^^^^^^^^  ^^^^^^^^^^   trim both query and key dims to actual length,
418	#                                           discarding meaningless padding positions
419	paths: list[Path] = []
420	for i, (prompt, seq_len) in enumerate(zip(prompts, seq_lens, strict=True)):
421		prompt_dir = save_path / model.cfg.model_name / "prompts" / prompt["hash"]
422		activations_path: Path = prompt_dir / "activations.npz"
423		cache_np: ActivationCacheNp = {}
424		for k, v in cache_torch.items():
425			assert v.ndim == 4, (  # noqa: PLR2004
426				f"expected 4D attention pattern tensor for {k!r}, "
427				f"got shape {v.shape}. names_filter must only match "
428				f"attention pattern activations [batch, n_heads, seq, seq]"
429			)
430			cache_np[k] = v[i : i + 1, :, :seq_len, :seq_len].detach().cpu().numpy()
431
432		np.savez_compressed(
433			activations_path,
434			**cache_np,  # type: ignore[arg-type]
435		)
436		paths.append(activations_path)
437
438	return paths
439
440
441@overload
442def get_activations(
443	prompt: dict,
444	model: HookedTransformer | str,
445	save_path: Path = Path(DATA_DIR),
446	allow_disk_cache: bool = True,
447	return_cache: None = None,
448) -> tuple[Path, None]: ...
449@overload
450def get_activations(
451	prompt: dict,
452	model: HookedTransformer | str,
453	save_path: Path = Path(DATA_DIR),
454	allow_disk_cache: bool = True,
455	return_cache: Literal["torch"] = "torch",
456) -> tuple[Path, ActivationCache]: ...
457@overload
458def get_activations(
459	prompt: dict,
460	model: HookedTransformer | str,
461	save_path: Path = Path(DATA_DIR),
462	allow_disk_cache: bool = True,
463	return_cache: Literal["numpy"] = "numpy",
464) -> tuple[Path, ActivationCacheNp]: ...
465def get_activations(
466	prompt: dict,
467	model: HookedTransformer | str,
468	save_path: Path = Path(DATA_DIR),
469	allow_disk_cache: bool = True,
470	return_cache: ReturnCache = "numpy",
471) -> tuple[Path, ActivationCacheNp | ActivationCache | None]:
472	"""given a prompt and a model, save or load activations
473
474	# Parameters:
475	- `prompt : dict`
476		expected to contain the 'text' key
477	- `model : HookedTransformer | str`
478		either a `HookedTransformer` or a string model name, to be loaded with `HookedTransformer.from_pretrained`
479	- `save_path : Path`
480		path to save the activations to (and load from)
481		(defaults to `Path(DATA_DIR)`)
482	- `allow_disk_cache : bool`
483		whether to allow loading from disk cache
484		(defaults to `True`)
485	- `return_cache : Literal[None, "numpy", "torch"]`
486		whether to return the cache, and in what format
487		(defaults to `"numpy"`)
488
489	# Returns:
490	- `tuple[Path, ActivationCacheNp | ActivationCache | None]`
491		the path to the activations and the cache if `return_cache is not None`
492
493	"""
494	# add hash to prompt
495	augment_prompt_with_hash(prompt)
496
497	# get the model
498	model_name: str = (
499		model.cfg.model_name if isinstance(model, HookedTransformer) else model
500	)
501
502	# from cache
503	if allow_disk_cache:
504		if return_cache is None:
505			# fast path: check file existence without loading data into memory.
506			# activations_exist just calls .exists() on two paths, whereas
507			# load_activations would decompress the full .npz into numpy arrays
508			# only for us to discard them immediately.
509			if activations_exist(model_name, prompt, save_path):
510				prompt_dir: Path = save_path / model_name / "prompts" / prompt["hash"]
511				return prompt_dir / "activations.npz", None
512		else:
513			try:
514				path, cache = load_activations(
515					model_name=model_name,
516					prompt=prompt,
517					save_path=save_path,
518				)
519			except ActivationsMissingError:
520				pass
521			else:
522				return path, cache
523
524	# compute them
525	if isinstance(model, str):
526		model = HookedTransformer.from_pretrained(model_name)
527
528	return compute_activations(  # type: ignore[return-value]
529		prompt=prompt,
530		model=model,
531		save_path=save_path,
532		return_cache=return_cache,
533	)
534
535
536DEFAULT_DEVICE: torch.device = torch.device(
537	"cuda" if torch.cuda.is_available() else "cpu",
538)
539
540
541def activations_main(  # noqa: C901, PLR0912, PLR0915
542	model_name: str,
543	save_path: str | Path,
544	prompts_path: str,
545	raw_prompts: bool,
546	min_chars: int,
547	max_chars: int,
548	force: bool,
549	n_samples: int,
550	no_index_html: bool,
551	shuffle: bool = False,
552	stacked_heads: bool = False,
553	device: str | torch.device = DEFAULT_DEVICE,
554	batch_size: int = 32,
555) -> None:
556	"""main function for computing activations
557
558	# Parameters:
559	- `model_name : str`
560		name of a model to load with `HookedTransformer.from_pretrained`
561	- `save_path : str | Path`
562		path to save the activations to
563	- `prompts_path : str`
564		path to the prompts file
565	- `raw_prompts : bool`
566		whether the prompts are raw, not filtered by length. `load_text_data` will be called if `True`, otherwise just load the "text" field from each line in `prompts_path`
567	- `min_chars : int`
568		minimum number of characters for a prompt
569	- `max_chars : int`
570		maximum number of characters for a prompt
571	- `force : bool`
572		whether to overwrite existing files
573	- `n_samples : int`
574		maximum number of samples to process
575	- `no_index_html : bool`
576		whether to write an index.html file
577	- `shuffle : bool`
578		whether to shuffle the prompts
579		(defaults to `False`)
580	- `stacked_heads : bool`
581		whether	to stack the heads in the output tensor. will save as `.npy` instead of `.npz` if `True`
582		(defaults to `False`)
583	- `device : str | torch.device`
584		the device to use. if a string, will be passed to `torch.device`
585	- `batch_size : int`
586		number of prompts per forward pass. prompts are sorted by token length
587		(longest first) and grouped so that similar-length prompts share a batch,
588		minimizing padding waste. use `batch_size=1` for one prompt per forward
589		pass (largely equivalent to the old sequential behavior, but note: prompts
590		are still sorted by length and cache checking uses file-existence only,
591		unlike the old path which processed prompts in order and validated cache
592		contents via `load_activations`).
593		the single-prompt functions `compute_activations` and `get_activations`
594		are still available for programmatic use outside of `activations_main`.
595		(defaults to `32`)
596	"""
597	# figure out the device to use
598	device_: torch.device
599	if isinstance(device, torch.device):
600		device_ = device
601	elif isinstance(device, str):
602		device_ = torch.device(device)
603	else:
604		msg = f"invalid device: {device}"
605		raise TypeError(msg)
606
607	print(f"using device: {device_}")
608
609	with SpinnerContext(message="loading model", **SPINNER_KWARGS):
610		model: HookedTransformer = HookedTransformer.from_pretrained(
611			model_name,
612			device=device_,
613		)
614		model.model_name = model_name  # type: ignore[unresolved-attribute]
615		model.cfg.model_name = model_name
616		n_params: int = sum(p.numel() for p in model.parameters())
617	print(
618		f"loaded {model_name} with {shorten_numerical_to_str(n_params)} ({n_params}) parameters",
619	)
620	print(f"\tmodel devices: { {p.device for p in model.parameters()} }")
621
622	save_path_p: Path = Path(save_path)
623	save_path_p.mkdir(parents=True, exist_ok=True)
624	model_path: Path = save_path_p / model_name
625	with SpinnerContext(
626		message=f"saving model info to {_rel_path(model_path)}",
627		**SPINNER_KWARGS,
628	):
629		model_cfg: HookedTransformerConfig
630		model_cfg = model.cfg
631		model_path.mkdir(parents=True, exist_ok=True)
632		with open(model_path / "model_cfg.json", "w") as f:
633			json.dump(json_serialize(asdict(model_cfg)), f)
634
635	# load prompts
636	with SpinnerContext(
637		message=f"loading prompts from {Path(prompts_path).as_posix()}",
638		**SPINNER_KWARGS,
639	):
640		prompts: list[dict]
641		if raw_prompts:
642			prompts = load_text_data(
643				Path(prompts_path),
644				min_chars=min_chars,
645				max_chars=max_chars,
646				shuffle=shuffle,
647			)
648		else:
649			with open(model_path / "prompts.jsonl", "r") as f:
650				prompts = [json.loads(line) for line in f.readlines()]
651		# truncate to n_samples
652		prompts = prompts[:n_samples]
653
654	print(f"  {len(prompts)} prompts loaded")
655
656	# write index.html
657	with SpinnerContext(
658		message=f"writing {_rel_path(save_path_p / 'index.html')}",
659		**SPINNER_KWARGS,
660	):
661		if not no_index_html:
662			write_html_index(save_path_p)
663
664	# TODO: not implemented yet
665	if stacked_heads:
666		raise NotImplementedError("stacked_heads not implemented yet")
667
668	# augment all prompts with hashes
669	for prompt in prompts:
670		augment_prompt_with_hash(prompt)
671
672	# filter out cached prompts
673	if not force:
674		uncached: list[dict] = [
675			p for p in prompts if not activations_exist(model_name, p, save_path_p)
676		]
677		n_cached: int = len(prompts) - len(uncached)
678		if n_cached > 0:
679			print(f"  {n_cached} prompts already cached, {len(uncached)} to compute")
680	else:
681		uncached = list(prompts)
682
683	if uncached:
684		# sort by token length descending so that:
685		# 1. the longest (slowest, most memory-hungry) batches run first --
686		#    OOM errors surface immediately rather than after all the cheap work,
687		#    and tqdm's ETA stabilizes early for better progress estimation
688		# 2. similar-length prompts are grouped together, minimizing padding waste
689		#
690		# pre-tokenization is a separate step from compute_activations_batched because
691		# we need token lengths *before* batching to sort and group. the resulting
692		# seq_lens are then passed through so compute_activations_batched can skip
693		# re-tokenizing each prompt internally.
694		with SpinnerContext(
695			message="pre-tokenizing prompts for length sorting",
696			**SPINNER_KWARGS,
697		):
698			uncached_with_lens: list[tuple[dict, int]] = [
699				(p, model.to_tokens(p["text"]).shape[1]) for p in uncached
700			]
701			uncached_with_lens.sort(key=lambda x: x[1], reverse=True)
702			sorted_uncached: list[dict] = [p for p, _ in uncached_with_lens]
703			sorted_seq_lens: list[int] = [sl for _, sl in uncached_with_lens]
704
705		# process in batches
706		n_prompts: int = len(sorted_uncached)
707		with tqdm.tqdm(
708			total=n_prompts,
709			desc="Computing activations",
710			unit="prompt",
711		) as pbar:
712			for batch_start in range(0, n_prompts, batch_size):
713				batch_end: int = min(batch_start + batch_size, n_prompts)
714				batch: list[dict] = sorted_uncached[batch_start:batch_end]
715				batch_seq_lens: list[int] = sorted_seq_lens[batch_start:batch_end]
716				pbar.set_postfix(
717					n_ctx=batch_seq_lens[0],
718				)  # longest in batch (sorted descending)
719				compute_activations_batched(
720					prompts=batch,
721					model=model,
722					save_path=save_path_p,
723					seq_lens=batch_seq_lens,
724				)
725				pbar.update(len(batch))
726	else:
727		print("  all prompts cached, nothing to compute")
728
729	with SpinnerContext(
730		message="updating jsonl metadata for models and prompts",
731		**SPINNER_KWARGS,
732	):
733		generate_models_jsonl(save_path_p)
734		generate_prompts_jsonl(save_path_p / model_name)
735
736
737def main() -> None:
738	"generate attention pattern activations for a model and prompts"
739	print(DIVIDER_S1)
740	with SpinnerContext(message="parsing args", **SPINNER_KWARGS):
741		arg_parser: argparse.ArgumentParser = argparse.ArgumentParser()
742		# input and output
743		arg_parser.add_argument(
744			"--model",
745			"-m",
746			type=str,
747			required=True,
748			help="The model name(s) to use. comma separated with no whitespace if multiple",
749		)
750
751		arg_parser.add_argument(
752			"--prompts",
753			"-p",
754			type=str,
755			required=False,
756			help="The path to the prompts file (jsonl with 'text' key on each line). If `None`, expects that `--figures` is passed and will generate figures for all prompts in the model directory",
757			default=None,
758		)
759
760		arg_parser.add_argument(
761			"--save-path",
762			"-s",
763			type=str,
764			required=False,
765			help="The path to save the attention patterns",
766			default=DATA_DIR,
767		)
768
769		# min and max prompt lengths
770		arg_parser.add_argument(
771			"--min-chars",
772			type=int,
773			required=False,
774			help="The minimum number of characters for a prompt",
775			default=100,
776		)
777		arg_parser.add_argument(
778			"--max-chars",
779			type=int,
780			required=False,
781			help="The maximum number of characters for a prompt",
782			default=1000,
783		)
784
785		# number of samples
786		arg_parser.add_argument(
787			"--n-samples",
788			"-n",
789			type=int,
790			required=False,
791			help="The max number of samples to process, do all in the file if None",
792			default=None,
793		)
794
795		# batch size
796		arg_parser.add_argument(
797			"--batch-size",
798			"-b",
799			type=int,
800			required=False,
801			help="Batch size for computing activations (number of prompts per forward pass)",
802			default=32,
803		)
804
805		# force overwrite
806		arg_parser.add_argument(
807			"--force",
808			"-f",
809			action="store_true",
810			help="If passed, will overwrite existing files",
811		)
812
813		# no index html
814		arg_parser.add_argument(
815			"--no-index-html",
816			action="store_true",
817			help="If passed, will not write an index.html file for the model",
818		)
819
820		# raw prompts
821		arg_parser.add_argument(
822			"--raw-prompts",
823			"-r",
824			action="store_true",
825			help="pass if the prompts have not been split and tokenized (still needs keys 'text' and 'meta' for each item)",
826		)
827
828		# shuffle
829		arg_parser.add_argument(
830			"--shuffle",
831			action="store_true",
832			help="If passed, will shuffle the prompts",
833		)
834
835		# stack heads
836		arg_parser.add_argument(
837			"--stacked-heads",
838			action="store_true",
839			help="If passed, will stack the heads in the output tensor",
840		)
841
842		# device
843		arg_parser.add_argument(
844			"--device",
845			type=str,
846			required=False,
847			help="The device to use for the model",
848			default="cuda" if torch.cuda.is_available() else "cpu",
849		)
850
851		args: argparse.Namespace = arg_parser.parse_args()
852
853	print(f"args parsed: {args}")
854
855	models: list[str]
856	if "," in args.model:
857		models = args.model.split(",")
858	else:
859		models = [args.model]
860
861	n_models: int = len(models)
862	for idx, model in enumerate(models):
863		print(DIVIDER_S2)
864		print(f"processing model {idx + 1} / {n_models}: {model}")
865		print(DIVIDER_S2)
866
867		activations_main(
868			model_name=model,
869			save_path=args.save_path,
870			prompts_path=args.prompts,
871			raw_prompts=args.raw_prompts,
872			min_chars=args.min_chars,
873			max_chars=args.max_chars,
874			force=args.force,
875			n_samples=args.n_samples,
876			no_index_html=args.no_index_html,
877			shuffle=args.shuffle,
878			stacked_heads=args.stacked_heads,
879			device=args.device,
880			batch_size=args.batch_size,
881		)
882		del model
883
884	print(DIVIDER_S1)
885
886
887if __name__ == "__main__":
888	main()

def compute_activations( prompt: dict, model: transformer_lens.HookedTransformer.HookedTransformer | None = None, save_path: pathlib._local.Path = PosixPath('attn_data'), names_filter: Callable[[str], bool] | re.Pattern = re.compile('blocks\\.(\\d+)\\.attn\\.hook_pattern'), return_cache: Optional[Literal['numpy', 'torch']] = 'torch', stack_heads: bool = False) -> tuple[pathlib._local.Path, dict[str, numpy.ndarray] | transformer_lens.ActivationCache.ActivationCache | jaxtyping.Float[ndarray, 'n_layers n_heads n_ctx n_ctx'] | jaxtyping.Float[Tensor, 'n_layers n_heads n_ctx n_ctx'] | None]:
129def compute_activations(  # noqa: PLR0915
130	prompt: dict,
131	model: HookedTransformer | None = None,
132	save_path: Path = Path(DATA_DIR),
133	names_filter: Callable[[str], bool] | re.Pattern = ATTN_PATTERN_REGEX,
134	return_cache: ReturnCache = "torch",
135	stack_heads: bool = False,
136) -> tuple[
137	Path,
138	ActivationCacheNp
139	| ActivationCache
140	| Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"]
141	| Float[torch.Tensor, "n_layers n_heads n_ctx n_ctx"]
142	| None,
143]:
144	"""compute activations for a single prompt and save to disk
145
146	always runs a forward pass -- does NOT load from disk cache.
147	for cache-aware loading, use `get_activations` which tries disk first.
148
149	# Parameters:
150	- `prompt : dict | None`
151		(defaults to `None`)
152	- `model : HookedTransformer`
153	- `save_path : Path`
154		(defaults to `Path(DATA_DIR)`)
155	- `names_filter : Callable[[str], bool]|re.Pattern`
156		a filter for the names of the activations to return. if an `re.Pattern`, will use `lambda key: names_filter.match(key) is not None`
157		(defaults to `ATTN_PATTERN_REGEX`)
158	- `return_cache : Literal[None, "numpy", "torch"]`
159		will return `None` as the second element if `None`, otherwise will return the cache in the specified tensor format. `stack_heads` still affects whether it will be a dict (False) or a single tensor (True)
160		(defaults to `None`)
161	- `stack_heads : bool`
162		whether the heads should be stacked in the output. this causes a number of changes:
163	- `npy` file with a single `(n_layers, n_heads, n_ctx, n_ctx)` tensor saved for each prompt instead of `npz` file with dict by layer
164	- `cache` will be a single `(n_layers, n_heads, n_ctx, n_ctx)` tensor instead of a dict by layer if `return_cache` is `True`
165		will assert that everything in the activation cache is only attention patterns, and is all of the attention patterns. raises an exception if not.
166
167	# Returns:
168	```
169	tuple[
170		Path,
171		Union[
172			None,
173			ActivationCacheNp, ActivationCache,
174			Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"], Float[torch.Tensor, "n_layers n_heads n_ctx n_ctx"],
175		]
176	]
177	```
178	"""
179	# check inputs
180	assert model is not None, "model must be passed"
181	assert "text" in prompt, "prompt must contain 'text' key"
182	prompt_str: str = prompt["text"]
183
184	# compute or get prompt metadata
185	assert model.tokenizer is not None
186	prompt_tokenized: list[str] = prompt.get(
187		"tokens",
188		model.tokenizer.tokenize(prompt_str),
189	)
190	# n_tokens counts subword tokens (no BOS); attention patterns include BOS
191	# so have dim n_tokens+1. see also compute_activations_batched Phase B.
192	prompt.update(
193		dict(
194			n_tokens=len(prompt_tokenized),
195			tokens=prompt_tokenized,
196		),
197	)
198
199	# save metadata
200	prompt_dir: Path = save_path / model.cfg.model_name / "prompts" / prompt["hash"]
201	prompt_dir.mkdir(parents=True, exist_ok=True)
202	with open(prompt_dir / "prompt.json", "w") as f:
203		json.dump(prompt, f)
204
205	# set up names filter
206	names_filter_fn: Callable[[str], bool]
207	if isinstance(names_filter, re.Pattern):
208		names_filter_fn = lambda key: names_filter.match(key) is not None  # noqa: E731
209	else:
210		names_filter_fn = names_filter
211
212	# compute activations
213	# NOTE: no padding_side kwarg here -- it's only meaningful for multi-sequence
214	# batches where padding is needed. single-string input has no padding.
215	# see compute_activations_batched for the batched path that passes padding_side="right".
216	cache_torch: ActivationCache
217	with torch.no_grad():
218		model.eval()
219		_, cache_torch = model.run_with_cache(
220			prompt_str,
221			names_filter=names_filter_fn,
222			return_type=None,
223		)
224
225	activations_path: Path
226	# saving and returning
227	if stack_heads:
228		n_layers: int = model.cfg.n_layers
229		key_pattern: str = "blocks.{i}.attn.hook_pattern"
230		# NOTE: this only works for stacking heads at the moment
231		# activations_specifier: str = key_pattern.format(i=f'0-{n_layers}')
232		activations_specifier: str = key_pattern.format(i="-")
233		activations_path = prompt_dir / f"activations-{activations_specifier}.npy"
234
235		# check the keys are only attention heads
236		head_keys: list[str] = [key_pattern.format(i=i) for i in range(n_layers)]
237		cache_torch_keys_set: set[str] = set(cache_torch.keys())
238		assert cache_torch_keys_set == set(head_keys), (
239			f"unexpected keys!\n{set(head_keys).symmetric_difference(cache_torch_keys_set) = }\n{cache_torch_keys_set} != {set(head_keys)}"
240		)
241
242		# stack heads
243		patterns_stacked: Float[torch.Tensor, "n_layers n_heads n_ctx n_ctx"] = (
244			torch.stack([cache_torch[k] for k in head_keys], dim=1)
245		)
246		# check shape
247		pattern_shape_no_ctx: tuple[int, ...] = tuple(patterns_stacked.shape[:3])
248		assert pattern_shape_no_ctx == (1, n_layers, model.cfg.n_heads), (
249			f"unexpected shape: {patterns_stacked.shape[:3] = } ({pattern_shape_no_ctx = }), expected {(1, n_layers, model.cfg.n_heads) = }"
250		)
251
252		patterns_stacked_np: Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"] = (
253			patterns_stacked.cpu().numpy()
254		)
255
256		# save
257		np.save(activations_path, patterns_stacked_np)
258
259		# return
260		match return_cache:
261			case "numpy":
262				return activations_path, patterns_stacked_np
263			case "torch":
264				return activations_path, patterns_stacked
265			case None:
266				return activations_path, None
267			case _:
268				msg = f"invalid return_cache: {return_cache = }"
269				raise ValueError(msg)
270	else:
271		activations_path = prompt_dir / "activations.npz"
272
273		# save
274		cache_np: ActivationCacheNp = {
275			k: v.detach().cpu().numpy() for k, v in cache_torch.items()
276		}
277
278		np.savez_compressed(
279			activations_path,
280			**cache_np,  # type: ignore[arg-type]
281		)
282
283		# return
284		match return_cache:
285			case "numpy":
286				return activations_path, cache_np
287			case "torch":
288				return activations_path, cache_torch
289			case None:
290				return activations_path, None
291			case _:
292				msg = f"invalid return_cache: {return_cache = }"
293				raise ValueError(msg)

compute activations for a single prompt and save to disk

always runs a forward pass -- does NOT load from disk cache. for cache-aware loading, use get_activations which tries disk first.

Parameters:

  • prompt : dict | None (defaults to None)
  • model : HookedTransformer
  • save_path : Path (defaults to Path(DATA_DIR))
  • names_filter : Callable[[str], bool]|re.Pattern a filter for the names of the activations to return. if an re.Pattern, will use lambda key: names_filter.match(key) is not None (defaults to ATTN_PATTERN_REGEX)
  • return_cache : Literal[None, "numpy", "torch"] will return None as the second element if None, otherwise will return the cache in the specified tensor format. stack_heads still affects whether it will be a dict (False) or a single tensor (True) (defaults to None)
  • stack_heads : bool whether the heads should be stacked in the output. this causes a number of changes:
  • npy file with a single (n_layers, n_heads, n_ctx, n_ctx) tensor saved for each prompt instead of npz file with dict by layer
  • cache will be a single (n_layers, n_heads, n_ctx, n_ctx) tensor instead of a dict by layer if return_cache is True will assert that everything in the activation cache is only attention patterns, and is all of the attention patterns. raises an exception if not.

Returns:

tuple[
        Path,
        Union[
                None,
                ActivationCacheNp, ActivationCache,
                Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"], Float[torch.Tensor, "n_layers n_heads n_ctx n_ctx"],
        ]
]
def compute_activations_batched( prompts: list[dict], model: transformer_lens.HookedTransformer.HookedTransformer, save_path: pathlib._local.Path = PosixPath('attn_data'), names_filter: Callable[[str], bool] | re.Pattern = re.compile('blocks\\.(\\d+)\\.attn\\.hook_pattern'), seq_lens: list[int] | None = None) -> list[pathlib._local.Path]:
296def compute_activations_batched(
297	prompts: list[dict],
298	model: HookedTransformer,
299	save_path: Path = Path(DATA_DIR),
300	names_filter: Callable[[str], bool] | re.Pattern = ATTN_PATTERN_REGEX,
301	seq_lens: list[int] | None = None,
302) -> list[Path]:
303	"""compute and save activations for a batch of prompts in a single forward pass
304
305	Batched companion to `compute_activations` -- instead of one forward pass per
306	prompt, this runs a single `model.run_with_cache(list_of_strings)` call for the
307	whole batch. TransformerLens tokenizes and right-pads automatically. Each prompt's
308	attention patterns are then trimmed to their actual (unpadded) size and saved
309	individually, producing files identical to the single-prompt path.
310
311	Does not support `stack_heads` or `return_cache` -- this function is intended for
312	the bulk processing path in `activations_main`, not for interactive use. Use
313	`compute_activations` directly for single-prompt use cases that need those features.
314
315	## Why right-padding makes trimming correct without an explicit attention mask
316
317	With right-padding, pad tokens sit at positions seq_len, seq_len+1, ...,
318	max_seq_len-1 (higher than any real token). The causal attention mask prevents
319	position i from attending to any j > i. So for real tokens at positions
320	0..seq_len-1, they can only attend to 0..i -- all real tokens. The softmax is computed over the same set of positions
321	as in single-prompt inference, producing identical attention patterns.
322
323	We explicitly pass `padding_side="right"` to `run_with_cache` to guarantee this
324	regardless of the model's default padding side.
325
326	# Parameters:
327	- `prompts : list[dict]`
328		each prompt must contain 'text' and 'hash' keys. call
329		`augment_prompt_with_hash` on each prompt before passing them here.
330	- `model : HookedTransformer`
331		the model to compute activations with
332	- `save_path : Path`
333		path to save the activations to
334		(defaults to `Path(DATA_DIR)`)
335	- `names_filter : Callable[[str], bool] | re.Pattern`
336		filter for which activations to save. must only match activations with
337		4D shape `[batch, n_heads, seq, seq]` (e.g. attention patterns).
338		non-attention activations will cause incorrect trimming.
339		(defaults to `ATTN_PATTERN_REGEX`)
340	- `seq_lens : list[int] | None`
341		pre-computed model sequence lengths per prompt (from `model.to_tokens`).
342		if `None`, will be computed internally. pass this to avoid redundant
343		tokenization when lengths are already known (e.g. from length-sorting).
344		**important**: these must be from `model.to_tokens()` (includes BOS),
345		NOT from `model.tokenizer.tokenize()` (excludes BOS).
346		(defaults to `None`)
347
348	# Returns:
349	- `list[Path]`
350		paths to the saved activations files, one per prompt
351
352	# Modifies:
353	each prompt dict in `prompts` -- adds/overwrites `n_tokens` and `tokens` keys
354	with tokenization metadata (same mutation as `compute_activations`).
355	"""
356	assert model is not None, "model must be passed"
357	assert len(prompts) > 0, "prompts must not be empty"
358	assert "text" in prompts[0], f"prompt must contain 'text' key: {prompts[0].keys()}"
359	assert "hash" in prompts[0], (
360		f"prompt must contain 'hash' key (call augment_prompt_with_hash first): {prompts[0].keys()}"
361	)
362
363	# --- Phase A: get actual model sequence lengths ---
364	# model.to_tokens() includes BOS if applicable, matching the attention pattern dims
365	# model.tokenizer.tokenize() gives subword strings WITHOUT BOS, used for metadata
366	# these differ by 1 when BOS is prepended -- using the wrong one for trimming
367	# would silently truncate or include garbage
368	if seq_lens is None:
369		seq_lens = [model.to_tokens(p["text"]).shape[1] for p in prompts]
370	assert len(seq_lens) == len(prompts), (
371		f"seq_lens length mismatch: {len(seq_lens)} != {len(prompts)}"
372	)
373
374	# --- Phase B: save prompt metadata (mirrors compute_activations's metadata logic) ---
375	assert model.tokenizer is not None
376	for p in prompts:
377		prompt_str: str = p["text"]
378		prompt_tokenized: list[str] = p.get(
379			"tokens",
380			model.tokenizer.tokenize(prompt_str),
381		)
382		# n_tokens counts subword tokens (no BOS); attention patterns include BOS so have dim n_tokens+1
383		p.update(
384			dict(
385				n_tokens=len(prompt_tokenized),
386				tokens=prompt_tokenized,
387			),
388		)
389		prompt_dir: Path = save_path / model.cfg.model_name / "prompts" / p["hash"]
390		prompt_dir.mkdir(parents=True, exist_ok=True)
391		with open(prompt_dir / "prompt.json", "w") as f:
392			json.dump(p, f)
393
394	# --- Phase C: batched forward pass ---
395	names_filter_fn: Callable[[str], bool]
396	if isinstance(names_filter, re.Pattern):
397		names_filter_fn = lambda key: names_filter.match(key) is not None  # noqa: E731
398	else:
399		names_filter_fn = names_filter
400
401	texts: list[str] = [p["text"] for p in prompts]
402	cache_torch: ActivationCache
403	with torch.no_grad():
404		model.eval()
405		_, cache_torch = model.run_with_cache(
406			texts,
407			names_filter=names_filter_fn,
408			return_type=None,
409			padding_side="right",
410		)
411
412	# --- Phase D: split, trim padding, and save per-prompt ---
413	# For each prompt i with actual sequence length seq_len_i:
414	#   v[i : i+1, :, :seq_len_i, :seq_len_i]
415	#     ^^^^^^^                               i:i+1 not i -- keeps batch dim [1,...] for
416	#                                           format compatibility with compute_activations
417	#              ^^                           all attention heads
418	#                  ^^^^^^^^^^  ^^^^^^^^^^   trim both query and key dims to actual length,
419	#                                           discarding meaningless padding positions
420	paths: list[Path] = []
421	for i, (prompt, seq_len) in enumerate(zip(prompts, seq_lens, strict=True)):
422		prompt_dir = save_path / model.cfg.model_name / "prompts" / prompt["hash"]
423		activations_path: Path = prompt_dir / "activations.npz"
424		cache_np: ActivationCacheNp = {}
425		for k, v in cache_torch.items():
426			assert v.ndim == 4, (  # noqa: PLR2004
427				f"expected 4D attention pattern tensor for {k!r}, "
428				f"got shape {v.shape}. names_filter must only match "
429				f"attention pattern activations [batch, n_heads, seq, seq]"
430			)
431			cache_np[k] = v[i : i + 1, :, :seq_len, :seq_len].detach().cpu().numpy()
432
433		np.savez_compressed(
434			activations_path,
435			**cache_np,  # type: ignore[arg-type]
436		)
437		paths.append(activations_path)
438
439	return paths

compute and save activations for a batch of prompts in a single forward pass

Batched companion to compute_activations -- instead of one forward pass per prompt, this runs a single model.run_with_cache(list_of_strings) call for the whole batch. TransformerLens tokenizes and right-pads automatically. Each prompt's attention patterns are then trimmed to their actual (unpadded) size and saved individually, producing files identical to the single-prompt path.

Does not support stack_heads or return_cache -- this function is intended for the bulk processing path in activations_main, not for interactive use. Use compute_activations directly for single-prompt use cases that need those features.

Why right-padding makes trimming correct without an explicit attention mask

With right-padding, pad tokens sit at positions seq_len, seq_len+1, ..., max_seq_len-1 (higher than any real token). The causal attention mask prevents position i from attending to any j > i. So for real tokens at positions 0..seq_len-1, they can only attend to 0..i -- all real tokens. The softmax is computed over the same set of positions as in single-prompt inference, producing identical attention patterns.

We explicitly pass padding_side="right" to run_with_cache to guarantee this regardless of the model's default padding side.

Parameters:

  • prompts : list[dict] each prompt must contain 'text' and 'hash' keys. call augment_prompt_with_hash on each prompt before passing them here.
  • model : HookedTransformer the model to compute activations with
  • save_path : Path path to save the activations to (defaults to Path(DATA_DIR))
  • names_filter : Callable[[str], bool] | re.Pattern filter for which activations to save. must only match activations with 4D shape [batch, n_heads, seq, seq] (e.g. attention patterns). non-attention activations will cause incorrect trimming. (defaults to ATTN_PATTERN_REGEX)
  • seq_lens : list[int] | None pre-computed model sequence lengths per prompt (from model.to_tokens). if None, will be computed internally. pass this to avoid redundant tokenization when lengths are already known (e.g. from length-sorting). important: these must be from model.to_tokens() (includes BOS), NOT from model.tokenizer.tokenize() (excludes BOS). (defaults to None)

Returns:

  • list[Path] paths to the saved activations files, one per prompt

Modifies:

each prompt dict in prompts -- adds/overwrites n_tokens and tokens keys with tokenization metadata (same mutation as compute_activations).

def get_activations( prompt: dict, model: transformer_lens.HookedTransformer.HookedTransformer | str, save_path: pathlib._local.Path = PosixPath('attn_data'), allow_disk_cache: bool = True, return_cache: Optional[Literal['numpy', 'torch']] = 'numpy') -> tuple[pathlib._local.Path, dict[str, numpy.ndarray] | transformer_lens.ActivationCache.ActivationCache | None]:
466def get_activations(
467	prompt: dict,
468	model: HookedTransformer | str,
469	save_path: Path = Path(DATA_DIR),
470	allow_disk_cache: bool = True,
471	return_cache: ReturnCache = "numpy",
472) -> tuple[Path, ActivationCacheNp | ActivationCache | None]:
473	"""given a prompt and a model, save or load activations
474
475	# Parameters:
476	- `prompt : dict`
477		expected to contain the 'text' key
478	- `model : HookedTransformer | str`
479		either a `HookedTransformer` or a string model name, to be loaded with `HookedTransformer.from_pretrained`
480	- `save_path : Path`
481		path to save the activations to (and load from)
482		(defaults to `Path(DATA_DIR)`)
483	- `allow_disk_cache : bool`
484		whether to allow loading from disk cache
485		(defaults to `True`)
486	- `return_cache : Literal[None, "numpy", "torch"]`
487		whether to return the cache, and in what format
488		(defaults to `"numpy"`)
489
490	# Returns:
491	- `tuple[Path, ActivationCacheNp | ActivationCache | None]`
492		the path to the activations and the cache if `return_cache is not None`
493
494	"""
495	# add hash to prompt
496	augment_prompt_with_hash(prompt)
497
498	# get the model
499	model_name: str = (
500		model.cfg.model_name if isinstance(model, HookedTransformer) else model
501	)
502
503	# from cache
504	if allow_disk_cache:
505		if return_cache is None:
506			# fast path: check file existence without loading data into memory.
507			# activations_exist just calls .exists() on two paths, whereas
508			# load_activations would decompress the full .npz into numpy arrays
509			# only for us to discard them immediately.
510			if activations_exist(model_name, prompt, save_path):
511				prompt_dir: Path = save_path / model_name / "prompts" / prompt["hash"]
512				return prompt_dir / "activations.npz", None
513		else:
514			try:
515				path, cache = load_activations(
516					model_name=model_name,
517					prompt=prompt,
518					save_path=save_path,
519				)
520			except ActivationsMissingError:
521				pass
522			else:
523				return path, cache
524
525	# compute them
526	if isinstance(model, str):
527		model = HookedTransformer.from_pretrained(model_name)
528
529	return compute_activations(  # type: ignore[return-value]
530		prompt=prompt,
531		model=model,
532		save_path=save_path,
533		return_cache=return_cache,
534	)

given a prompt and a model, save or load activations

Parameters:

  • prompt : dict expected to contain the 'text' key
  • model : HookedTransformer | str either a HookedTransformer or a string model name, to be loaded with HookedTransformer.from_pretrained
  • save_path : Path path to save the activations to (and load from) (defaults to Path(DATA_DIR))
  • allow_disk_cache : bool whether to allow loading from disk cache (defaults to True)
  • return_cache : Literal[None, "numpy", "torch"] whether to return the cache, and in what format (defaults to "numpy")

Returns:

  • tuple[Path, ActivationCacheNp | ActivationCache | None] the path to the activations and the cache if return_cache is not None
DEFAULT_DEVICE: torch.device = device(type='cuda')
def activations_main( model_name: str, save_path: str | pathlib._local.Path, prompts_path: str, raw_prompts: bool, min_chars: int, max_chars: int, force: bool, n_samples: int, no_index_html: bool, shuffle: bool = False, stacked_heads: bool = False, device: str | torch.device = device(type='cuda'), batch_size: int = 32) -> None:
542def activations_main(  # noqa: C901, PLR0912, PLR0915
543	model_name: str,
544	save_path: str | Path,
545	prompts_path: str,
546	raw_prompts: bool,
547	min_chars: int,
548	max_chars: int,
549	force: bool,
550	n_samples: int,
551	no_index_html: bool,
552	shuffle: bool = False,
553	stacked_heads: bool = False,
554	device: str | torch.device = DEFAULT_DEVICE,
555	batch_size: int = 32,
556) -> None:
557	"""main function for computing activations
558
559	# Parameters:
560	- `model_name : str`
561		name of a model to load with `HookedTransformer.from_pretrained`
562	- `save_path : str | Path`
563		path to save the activations to
564	- `prompts_path : str`
565		path to the prompts file
566	- `raw_prompts : bool`
567		whether the prompts are raw, not filtered by length. `load_text_data` will be called if `True`, otherwise just load the "text" field from each line in `prompts_path`
568	- `min_chars : int`
569		minimum number of characters for a prompt
570	- `max_chars : int`
571		maximum number of characters for a prompt
572	- `force : bool`
573		whether to overwrite existing files
574	- `n_samples : int`
575		maximum number of samples to process
576	- `no_index_html : bool`
577		whether to write an index.html file
578	- `shuffle : bool`
579		whether to shuffle the prompts
580		(defaults to `False`)
581	- `stacked_heads : bool`
582		whether	to stack the heads in the output tensor. will save as `.npy` instead of `.npz` if `True`
583		(defaults to `False`)
584	- `device : str | torch.device`
585		the device to use. if a string, will be passed to `torch.device`
586	- `batch_size : int`
587		number of prompts per forward pass. prompts are sorted by token length
588		(longest first) and grouped so that similar-length prompts share a batch,
589		minimizing padding waste. use `batch_size=1` for one prompt per forward
590		pass (largely equivalent to the old sequential behavior, but note: prompts
591		are still sorted by length and cache checking uses file-existence only,
592		unlike the old path which processed prompts in order and validated cache
593		contents via `load_activations`).
594		the single-prompt functions `compute_activations` and `get_activations`
595		are still available for programmatic use outside of `activations_main`.
596		(defaults to `32`)
597	"""
598	# figure out the device to use
599	device_: torch.device
600	if isinstance(device, torch.device):
601		device_ = device
602	elif isinstance(device, str):
603		device_ = torch.device(device)
604	else:
605		msg = f"invalid device: {device}"
606		raise TypeError(msg)
607
608	print(f"using device: {device_}")
609
610	with SpinnerContext(message="loading model", **SPINNER_KWARGS):
611		model: HookedTransformer = HookedTransformer.from_pretrained(
612			model_name,
613			device=device_,
614		)
615		model.model_name = model_name  # type: ignore[unresolved-attribute]
616		model.cfg.model_name = model_name
617		n_params: int = sum(p.numel() for p in model.parameters())
618	print(
619		f"loaded {model_name} with {shorten_numerical_to_str(n_params)} ({n_params}) parameters",
620	)
621	print(f"\tmodel devices: { {p.device for p in model.parameters()} }")
622
623	save_path_p: Path = Path(save_path)
624	save_path_p.mkdir(parents=True, exist_ok=True)
625	model_path: Path = save_path_p / model_name
626	with SpinnerContext(
627		message=f"saving model info to {_rel_path(model_path)}",
628		**SPINNER_KWARGS,
629	):
630		model_cfg: HookedTransformerConfig
631		model_cfg = model.cfg
632		model_path.mkdir(parents=True, exist_ok=True)
633		with open(model_path / "model_cfg.json", "w") as f:
634			json.dump(json_serialize(asdict(model_cfg)), f)
635
636	# load prompts
637	with SpinnerContext(
638		message=f"loading prompts from {Path(prompts_path).as_posix()}",
639		**SPINNER_KWARGS,
640	):
641		prompts: list[dict]
642		if raw_prompts:
643			prompts = load_text_data(
644				Path(prompts_path),
645				min_chars=min_chars,
646				max_chars=max_chars,
647				shuffle=shuffle,
648			)
649		else:
650			with open(model_path / "prompts.jsonl", "r") as f:
651				prompts = [json.loads(line) for line in f.readlines()]
652		# truncate to n_samples
653		prompts = prompts[:n_samples]
654
655	print(f"  {len(prompts)} prompts loaded")
656
657	# write index.html
658	with SpinnerContext(
659		message=f"writing {_rel_path(save_path_p / 'index.html')}",
660		**SPINNER_KWARGS,
661	):
662		if not no_index_html:
663			write_html_index(save_path_p)
664
665	# TODO: not implemented yet
666	if stacked_heads:
667		raise NotImplementedError("stacked_heads not implemented yet")
668
669	# augment all prompts with hashes
670	for prompt in prompts:
671		augment_prompt_with_hash(prompt)
672
673	# filter out cached prompts
674	if not force:
675		uncached: list[dict] = [
676			p for p in prompts if not activations_exist(model_name, p, save_path_p)
677		]
678		n_cached: int = len(prompts) - len(uncached)
679		if n_cached > 0:
680			print(f"  {n_cached} prompts already cached, {len(uncached)} to compute")
681	else:
682		uncached = list(prompts)
683
684	if uncached:
685		# sort by token length descending so that:
686		# 1. the longest (slowest, most memory-hungry) batches run first --
687		#    OOM errors surface immediately rather than after all the cheap work,
688		#    and tqdm's ETA stabilizes early for better progress estimation
689		# 2. similar-length prompts are grouped together, minimizing padding waste
690		#
691		# pre-tokenization is a separate step from compute_activations_batched because
692		# we need token lengths *before* batching to sort and group. the resulting
693		# seq_lens are then passed through so compute_activations_batched can skip
694		# re-tokenizing each prompt internally.
695		with SpinnerContext(
696			message="pre-tokenizing prompts for length sorting",
697			**SPINNER_KWARGS,
698		):
699			uncached_with_lens: list[tuple[dict, int]] = [
700				(p, model.to_tokens(p["text"]).shape[1]) for p in uncached
701			]
702			uncached_with_lens.sort(key=lambda x: x[1], reverse=True)
703			sorted_uncached: list[dict] = [p for p, _ in uncached_with_lens]
704			sorted_seq_lens: list[int] = [sl for _, sl in uncached_with_lens]
705
706		# process in batches
707		n_prompts: int = len(sorted_uncached)
708		with tqdm.tqdm(
709			total=n_prompts,
710			desc="Computing activations",
711			unit="prompt",
712		) as pbar:
713			for batch_start in range(0, n_prompts, batch_size):
714				batch_end: int = min(batch_start + batch_size, n_prompts)
715				batch: list[dict] = sorted_uncached[batch_start:batch_end]
716				batch_seq_lens: list[int] = sorted_seq_lens[batch_start:batch_end]
717				pbar.set_postfix(
718					n_ctx=batch_seq_lens[0],
719				)  # longest in batch (sorted descending)
720				compute_activations_batched(
721					prompts=batch,
722					model=model,
723					save_path=save_path_p,
724					seq_lens=batch_seq_lens,
725				)
726				pbar.update(len(batch))
727	else:
728		print("  all prompts cached, nothing to compute")
729
730	with SpinnerContext(
731		message="updating jsonl metadata for models and prompts",
732		**SPINNER_KWARGS,
733	):
734		generate_models_jsonl(save_path_p)
735		generate_prompts_jsonl(save_path_p / model_name)

main function for computing activations

Parameters:

  • model_name : str name of a model to load with HookedTransformer.from_pretrained
  • save_path : str | Path path to save the activations to
  • prompts_path : str path to the prompts file
  • raw_prompts : bool whether the prompts are raw, not filtered by length. load_text_data will be called if True, otherwise just load the "text" field from each line in prompts_path
  • min_chars : int minimum number of characters for a prompt
  • max_chars : int maximum number of characters for a prompt
  • force : bool whether to overwrite existing files
  • n_samples : int maximum number of samples to process
  • no_index_html : bool whether to write an index.html file
  • shuffle : bool whether to shuffle the prompts (defaults to False)
  • stacked_heads : bool whether to stack the heads in the output tensor. will save as .npy instead of .npz if True (defaults to False)
  • device : str | torch.device the device to use. if a string, will be passed to torch.device
  • batch_size : int number of prompts per forward pass. prompts are sorted by token length (longest first) and grouped so that similar-length prompts share a batch, minimizing padding waste. use batch_size=1 for one prompt per forward pass (largely equivalent to the old sequential behavior, but note: prompts are still sorted by length and cache checking uses file-existence only, unlike the old path which processed prompts in order and validated cache contents via load_activations). the single-prompt functions compute_activations and get_activations are still available for programmatic use outside of activations_main. (defaults to 32)
def main() -> None:
738def main() -> None:
739	"generate attention pattern activations for a model and prompts"
740	print(DIVIDER_S1)
741	with SpinnerContext(message="parsing args", **SPINNER_KWARGS):
742		arg_parser: argparse.ArgumentParser = argparse.ArgumentParser()
743		# input and output
744		arg_parser.add_argument(
745			"--model",
746			"-m",
747			type=str,
748			required=True,
749			help="The model name(s) to use. comma separated with no whitespace if multiple",
750		)
751
752		arg_parser.add_argument(
753			"--prompts",
754			"-p",
755			type=str,
756			required=False,
757			help="The path to the prompts file (jsonl with 'text' key on each line). If `None`, expects that `--figures` is passed and will generate figures for all prompts in the model directory",
758			default=None,
759		)
760
761		arg_parser.add_argument(
762			"--save-path",
763			"-s",
764			type=str,
765			required=False,
766			help="The path to save the attention patterns",
767			default=DATA_DIR,
768		)
769
770		# min and max prompt lengths
771		arg_parser.add_argument(
772			"--min-chars",
773			type=int,
774			required=False,
775			help="The minimum number of characters for a prompt",
776			default=100,
777		)
778		arg_parser.add_argument(
779			"--max-chars",
780			type=int,
781			required=False,
782			help="The maximum number of characters for a prompt",
783			default=1000,
784		)
785
786		# number of samples
787		arg_parser.add_argument(
788			"--n-samples",
789			"-n",
790			type=int,
791			required=False,
792			help="The max number of samples to process, do all in the file if None",
793			default=None,
794		)
795
796		# batch size
797		arg_parser.add_argument(
798			"--batch-size",
799			"-b",
800			type=int,
801			required=False,
802			help="Batch size for computing activations (number of prompts per forward pass)",
803			default=32,
804		)
805
806		# force overwrite
807		arg_parser.add_argument(
808			"--force",
809			"-f",
810			action="store_true",
811			help="If passed, will overwrite existing files",
812		)
813
814		# no index html
815		arg_parser.add_argument(
816			"--no-index-html",
817			action="store_true",
818			help="If passed, will not write an index.html file for the model",
819		)
820
821		# raw prompts
822		arg_parser.add_argument(
823			"--raw-prompts",
824			"-r",
825			action="store_true",
826			help="pass if the prompts have not been split and tokenized (still needs keys 'text' and 'meta' for each item)",
827		)
828
829		# shuffle
830		arg_parser.add_argument(
831			"--shuffle",
832			action="store_true",
833			help="If passed, will shuffle the prompts",
834		)
835
836		# stack heads
837		arg_parser.add_argument(
838			"--stacked-heads",
839			action="store_true",
840			help="If passed, will stack the heads in the output tensor",
841		)
842
843		# device
844		arg_parser.add_argument(
845			"--device",
846			type=str,
847			required=False,
848			help="The device to use for the model",
849			default="cuda" if torch.cuda.is_available() else "cpu",
850		)
851
852		args: argparse.Namespace = arg_parser.parse_args()
853
854	print(f"args parsed: {args}")
855
856	models: list[str]
857	if "," in args.model:
858		models = args.model.split(",")
859	else:
860		models = [args.model]
861
862	n_models: int = len(models)
863	for idx, model in enumerate(models):
864		print(DIVIDER_S2)
865		print(f"processing model {idx + 1} / {n_models}: {model}")
866		print(DIVIDER_S2)
867
868		activations_main(
869			model_name=model,
870			save_path=args.save_path,
871			prompts_path=args.prompts,
872			raw_prompts=args.raw_prompts,
873			min_chars=args.min_chars,
874			max_chars=args.max_chars,
875			force=args.force,
876			n_samples=args.n_samples,
877			no_index_html=args.no_index_html,
878			shuffle=args.shuffle,
879			stacked_heads=args.stacked_heads,
880			device=args.device,
881			batch_size=args.batch_size,
882		)
883		del model
884
885	print(DIVIDER_S1)

generate attention pattern activations for a model and prompts