docs for pattern_lens v0.4.0
View Source on GitHub

pattern_lens.activations

computing and saving activations given a model and prompts

Usage:

from the command line:

python -m pattern_lens.activations --model <model_name> --prompts <prompts_path> --save-path <save_path> --min-chars <min_chars> --max-chars <max_chars> --n-samples <n_samples>

from a script:

from pattern_lens.activations import activations_main
activations_main(
        model_name="gpt2",
        save_path="demo/"
        prompts_path="data/pile_1k.jsonl",
)

  1"""computing and saving activations given a model and prompts
  2
  3# Usage:
  4
  5from the command line:
  6
  7```bash
  8python -m pattern_lens.activations --model <model_name> --prompts <prompts_path> --save-path <save_path> --min-chars <min_chars> --max-chars <max_chars> --n-samples <n_samples>
  9```
 10
 11from a script:
 12
 13```python
 14from pattern_lens.activations import activations_main
 15activations_main(
 16	model_name="gpt2",
 17	save_path="demo/"
 18	prompts_path="data/pile_1k.jsonl",
 19)
 20```
 21
 22"""
 23
 24import argparse
 25import functools
 26import json
 27import re
 28from collections.abc import Callable
 29from dataclasses import asdict
 30from pathlib import Path
 31from typing import Literal, overload
 32
 33import numpy as np
 34import torch
 35import tqdm
 36from jaxtyping import Float
 37from muutils.json_serialize import json_serialize
 38from muutils.misc.numerical import shorten_numerical_to_str
 39
 40# custom utils
 41from muutils.spinner import SpinnerContext
 42from transformer_lens import (  # type: ignore[import-untyped]
 43	ActivationCache,
 44	HookedTransformer,
 45	HookedTransformerConfig,
 46)
 47
 48# pattern_lens
 49from pattern_lens.consts import (
 50	ATTN_PATTERN_REGEX,
 51	DATA_DIR,
 52	DIVIDER_S1,
 53	DIVIDER_S2,
 54	SPINNER_KWARGS,
 55	ActivationCacheNp,
 56	ReturnCache,
 57)
 58from pattern_lens.indexes import (
 59	generate_models_jsonl,
 60	generate_prompts_jsonl,
 61	write_html_index,
 62)
 63from pattern_lens.load_activations import (
 64	ActivationsMissingError,
 65	augment_prompt_with_hash,
 66	load_activations,
 67)
 68from pattern_lens.prompts import load_text_data
 69
 70
 71# return nothing, but `stack_heads` still affects how we save the activations
 72@overload
 73def compute_activations(
 74	prompt: dict,
 75	model: HookedTransformer | None = None,
 76	save_path: Path = Path(DATA_DIR),
 77	names_filter: Callable[[str], bool] | re.Pattern = ATTN_PATTERN_REGEX,
 78	return_cache: Literal[None] = None,
 79	stack_heads: bool = False,
 80) -> tuple[Path, None]: ...
 81# return stacked heads in numpy or torch form
 82@overload
 83def compute_activations(
 84	prompt: dict,
 85	model: HookedTransformer | None = None,
 86	save_path: Path = Path(DATA_DIR),
 87	names_filter: Callable[[str], bool] | re.Pattern = ATTN_PATTERN_REGEX,
 88	return_cache: Literal["torch"] = "torch",
 89	stack_heads: Literal[True] = True,
 90) -> tuple[Path, Float[torch.Tensor, "n_layers n_heads n_ctx n_ctx"]]: ...
 91@overload
 92def compute_activations(
 93	prompt: dict,
 94	model: HookedTransformer | None = None,
 95	save_path: Path = Path(DATA_DIR),
 96	names_filter: Callable[[str], bool] | re.Pattern = ATTN_PATTERN_REGEX,
 97	return_cache: Literal["numpy"] = "numpy",
 98	stack_heads: Literal[True] = True,
 99) -> tuple[Path, Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"]]: ...
100# return dicts in numpy or torch form
101@overload
102def compute_activations(
103	prompt: dict,
104	model: HookedTransformer | None = None,
105	save_path: Path = Path(DATA_DIR),
106	names_filter: Callable[[str], bool] | re.Pattern = ATTN_PATTERN_REGEX,
107	return_cache: Literal["numpy"] = "numpy",
108	stack_heads: Literal[False] = False,
109) -> tuple[Path, ActivationCacheNp]: ...
110@overload
111def compute_activations(
112	prompt: dict,
113	model: HookedTransformer | None = None,
114	save_path: Path = Path(DATA_DIR),
115	names_filter: Callable[[str], bool] | re.Pattern = ATTN_PATTERN_REGEX,
116	return_cache: Literal["torch"] = "torch",
117	stack_heads: Literal[False] = False,
118) -> tuple[Path, ActivationCache]: ...
119# actual function body
120def compute_activations(  # noqa: PLR0915
121	prompt: dict,
122	model: HookedTransformer | None = None,
123	save_path: Path = Path(DATA_DIR),
124	names_filter: Callable[[str], bool] | re.Pattern = ATTN_PATTERN_REGEX,
125	return_cache: ReturnCache = "torch",
126	stack_heads: bool = False,
127) -> tuple[
128	Path,
129	ActivationCacheNp
130	| ActivationCache
131	| Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"]
132	| Float[torch.Tensor, "n_layers n_heads n_ctx n_ctx"]
133	| None,
134]:
135	"""get activations for a given model and prompt, possibly from a cache
136
137	if from a cache, prompt_meta must be passed and contain the prompt hash
138
139	# Parameters:
140	- `prompt : dict | None`
141		(defaults to `None`)
142	- `model : HookedTransformer`
143	- `save_path : Path`
144		(defaults to `Path(DATA_DIR)`)
145	- `names_filter : Callable[[str], bool]|re.Pattern`
146		a filter for the names of the activations to return. if an `re.Pattern`, will use `lambda key: names_filter.match(key) is not None`
147		(defaults to `ATTN_PATTERN_REGEX`)
148	- `return_cache : Literal[None, "numpy", "torch"]`
149		will return `None` as the second element if `None`, otherwise will return the cache in the specified tensor format. `stack_heads` still affects whether it will be a dict (False) or a single tensor (True)
150		(defaults to `None`)
151	- `stack_heads : bool`
152		whether the heads should be stacked in the output. this causes a number of changes:
153	- `npy` file with a single `(n_layers, n_heads, n_ctx, n_ctx)` tensor saved for each prompt instead of `npz` file with dict by layer
154	- `cache` will be a single `(n_layers, n_heads, n_ctx, n_ctx)` tensor instead of a dict by layer if `return_cache` is `True`
155		will assert that everything in the activation cache is only attention patterns, and is all of the attention patterns. raises an exception if not.
156
157	# Returns:
158	```
159	tuple[
160		Path,
161		Union[
162			None,
163			ActivationCacheNp, ActivationCache,
164			Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"], Float[torch.Tensor, "n_layers n_heads n_ctx n_ctx"],
165		]
166	]
167	```
168	"""
169	# check inputs
170	assert model is not None, "model must be passed"
171	assert "text" in prompt, "prompt must contain 'text' key"
172	prompt_str: str = prompt["text"]
173
174	# compute or get prompt metadata
175	prompt_tokenized: list[str] = prompt.get(
176		"tokens",
177		model.tokenizer.tokenize(prompt_str),
178	)
179	prompt.update(
180		dict(
181			n_tokens=len(prompt_tokenized),
182			tokens=prompt_tokenized,
183		),
184	)
185
186	# save metadata
187	prompt_dir: Path = save_path / model.cfg.model_name / "prompts" / prompt["hash"]
188	prompt_dir.mkdir(parents=True, exist_ok=True)
189	with open(prompt_dir / "prompt.json", "w") as f:
190		json.dump(prompt, f)
191
192	# set up names filter
193	names_filter_fn: Callable[[str], bool]
194	if isinstance(names_filter, re.Pattern):
195		names_filter_fn = lambda key: names_filter.match(key) is not None  # noqa: E731
196	else:
197		names_filter_fn = names_filter
198
199	# compute activations
200	cache_torch: ActivationCache
201	with torch.no_grad():
202		model.eval()
203		# TODO: batching?
204		_, cache_torch = model.run_with_cache(
205			prompt_str,
206			names_filter=names_filter_fn,
207			return_type=None,
208		)
209
210	activations_path: Path
211	# saving and returning
212	if stack_heads:
213		n_layers: int = model.cfg.n_layers
214		key_pattern: str = "blocks.{i}.attn.hook_pattern"
215		# NOTE: this only works for stacking heads at the moment
216		# activations_specifier: str = key_pattern.format(i=f'0-{n_layers}')
217		activations_specifier: str = key_pattern.format(i="-")
218		activations_path = prompt_dir / f"activations-{activations_specifier}.npy"
219
220		# check the keys are only attention heads
221		head_keys: list[str] = [key_pattern.format(i=i) for i in range(n_layers)]
222		cache_torch_keys_set: set[str] = set(cache_torch.keys())
223		assert cache_torch_keys_set == set(head_keys), (
224			f"unexpected keys!\n{set(head_keys).symmetric_difference(cache_torch_keys_set) = }\n{cache_torch_keys_set} != {set(head_keys)}"
225		)
226
227		# stack heads
228		patterns_stacked: Float[torch.Tensor, "n_layers n_heads n_ctx n_ctx"] = (
229			torch.stack([cache_torch[k] for k in head_keys], dim=1)
230		)
231		# check shape
232		pattern_shape_no_ctx: tuple[int, ...] = tuple(patterns_stacked.shape[:3])
233		assert pattern_shape_no_ctx == (1, n_layers, model.cfg.n_heads), (
234			f"unexpected shape: {patterns_stacked.shape[:3] = } ({pattern_shape_no_ctx = }), expected {(1, n_layers, model.cfg.n_heads) = }"
235		)
236
237		patterns_stacked_np: Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"] = (
238			patterns_stacked.cpu().numpy()
239		)
240
241		# save
242		np.save(activations_path, patterns_stacked_np)
243
244		# return
245		match return_cache:
246			case "numpy":
247				return activations_path, patterns_stacked_np
248			case "torch":
249				return activations_path, patterns_stacked
250			case None:
251				return activations_path, None
252			case _:
253				msg = f"invalid return_cache: {return_cache = }"
254				raise ValueError(msg)
255	else:
256		activations_path = prompt_dir / "activations.npz"
257
258		# save
259		cache_np: ActivationCacheNp = {
260			k: v.detach().cpu().numpy() for k, v in cache_torch.items()
261		}
262
263		np.savez_compressed(
264			activations_path,
265			**cache_np,
266		)
267
268		# return
269		match return_cache:
270			case "numpy":
271				return activations_path, cache_np
272			case "torch":
273				return activations_path, cache_torch
274			case None:
275				return activations_path, None
276			case _:
277				msg = f"invalid return_cache: {return_cache = }"
278				raise ValueError(msg)
279
280
281@overload
282def get_activations(
283	prompt: dict,
284	model: HookedTransformer | str,
285	save_path: Path = Path(DATA_DIR),
286	allow_disk_cache: bool = True,
287	return_cache: Literal[None] = None,
288) -> tuple[Path, None]: ...
289@overload
290def get_activations(
291	prompt: dict,
292	model: HookedTransformer | str,
293	save_path: Path = Path(DATA_DIR),
294	allow_disk_cache: bool = True,
295	return_cache: Literal["torch"] = "torch",
296) -> tuple[Path, ActivationCache]: ...
297@overload
298def get_activations(
299	prompt: dict,
300	model: HookedTransformer | str,
301	save_path: Path = Path(DATA_DIR),
302	allow_disk_cache: bool = True,
303	return_cache: Literal["numpy"] = "numpy",
304) -> tuple[Path, ActivationCacheNp]: ...
305def get_activations(
306	prompt: dict,
307	model: HookedTransformer | str,
308	save_path: Path = Path(DATA_DIR),
309	allow_disk_cache: bool = True,
310	return_cache: ReturnCache = "numpy",
311) -> tuple[Path, ActivationCacheNp | ActivationCache | None]:
312	"""given a prompt and a model, save or load activations
313
314	# Parameters:
315	- `prompt : dict`
316		expected to contain the 'text' key
317	- `model : HookedTransformer | str`
318		either a `HookedTransformer` or a string model name, to be loaded with `HookedTransformer.from_pretrained`
319	- `save_path : Path`
320		path to save the activations to (and load from)
321		(defaults to `Path(DATA_DIR)`)
322	- `allow_disk_cache : bool`
323		whether to allow loading from disk cache
324		(defaults to `True`)
325	- `return_cache : Literal[None, "numpy", "torch"]`
326		whether to return the cache, and in what format
327		(defaults to `"numpy"`)
328
329	# Returns:
330	- `tuple[Path, ActivationCacheNp | ActivationCache | None]`
331		the path to the activations and the cache if `return_cache is not None`
332
333	"""
334	# add hash to prompt
335	augment_prompt_with_hash(prompt)
336
337	# get the model
338	model_name: str = (
339		model.cfg.model_name if isinstance(model, HookedTransformer) else model
340	)
341
342	# from cache
343	if allow_disk_cache:
344		try:
345			path, cache = load_activations(
346				model_name=model_name,
347				prompt=prompt,
348				save_path=save_path,
349			)
350			if return_cache:
351				return path, cache
352			else:
353				# TODO: this basically does nothing, since we load the activations and then immediately get rid of them.
354				# maybe refactor this so that load_activations can take a parameter to simply assert that the cache exists?
355				# this will let us avoid loading it, which slows things down
356				return path, None
357		except ActivationsMissingError:
358			pass
359
360	# compute them
361	if isinstance(model, str):
362		model = HookedTransformer.from_pretrained(model_name)
363
364	return compute_activations(
365		prompt=prompt,
366		model=model,
367		save_path=save_path,
368		return_cache=return_cache,
369	)
370
371
372DEFAULT_DEVICE: torch.device = torch.device(
373	"cuda" if torch.cuda.is_available() else "cpu",
374)
375
376
377def activations_main(
378	model_name: str,
379	save_path: str,
380	prompts_path: str,
381	raw_prompts: bool,
382	min_chars: int,
383	max_chars: int,
384	force: bool,
385	n_samples: int,
386	no_index_html: bool,
387	shuffle: bool = False,
388	stacked_heads: bool = False,
389	device: str | torch.device = DEFAULT_DEVICE,
390) -> None:
391	"""main function for computing activations
392
393	# Parameters:
394	- `model_name : str`
395		name of a model to load with `HookedTransformer.from_pretrained`
396	- `save_path : str`
397		path to save the activations to
398	- `prompts_path : str`
399		path to the prompts file
400	- `raw_prompts : bool`
401		whether the prompts are raw, not filtered by length. `load_text_data` will be called if `True`, otherwise just load the "text" field from each line in `prompts_path`
402	- `min_chars : int`
403		minimum number of characters for a prompt
404	- `max_chars : int`
405		maximum number of characters for a prompt
406	- `force : bool`
407		whether to overwrite existing files
408	- `n_samples : int`
409		maximum number of samples to process
410	- `no_index_html : bool`
411		whether to write an index.html file
412	- `shuffle : bool`
413		whether to shuffle the prompts
414		(defaults to `False`)
415	- `stacked_heads : bool`
416		whether	to stack the heads in the output tensor. will save as `.npy` instead of `.npz` if `True`
417		(defaults to `False`)
418	- `device : str | torch.device`
419		the device to use. if a string, will be passed to `torch.device`
420	"""
421	# figure out the device to use
422	device_: torch.device
423	if isinstance(device, torch.device):
424		device_ = device
425	elif isinstance(device, str):
426		device_ = torch.device(device)
427	else:
428		msg = f"invalid device: {device}"
429		raise TypeError(msg)
430
431	print(f"using device: {device_}")
432
433	with SpinnerContext(message="loading model", **SPINNER_KWARGS):
434		model: HookedTransformer = HookedTransformer.from_pretrained(
435			model_name,
436			device=device_,
437		)
438		model.model_name = model_name
439		model.cfg.model_name = model_name
440		n_params: int = sum(p.numel() for p in model.parameters())
441	print(
442		f"loaded {model_name} with {shorten_numerical_to_str(n_params)} ({n_params}) parameters",
443	)
444	print(f"\tmodel devices: { {p.device for p in model.parameters()} }")
445
446	save_path_p: Path = Path(save_path)
447	save_path_p.mkdir(parents=True, exist_ok=True)
448	model_path: Path = save_path_p / model_name
449	with SpinnerContext(
450		message=f"saving model info to {model_path.as_posix()}",
451		**SPINNER_KWARGS,
452	):
453		model_cfg: HookedTransformerConfig
454		model_cfg = model.cfg
455		model_path.mkdir(parents=True, exist_ok=True)
456		with open(model_path / "model_cfg.json", "w") as f:
457			json.dump(json_serialize(asdict(model_cfg)), f)
458
459	# load prompts
460	with SpinnerContext(
461		message=f"loading prompts from {prompts_path = }",
462		**SPINNER_KWARGS,
463	):
464		prompts: list[dict]
465		if raw_prompts:
466			prompts = load_text_data(
467				Path(prompts_path),
468				min_chars=min_chars,
469				max_chars=max_chars,
470				shuffle=shuffle,
471			)
472		else:
473			with open(model_path / "prompts.jsonl", "r") as f:
474				prompts = [json.loads(line) for line in f.readlines()]
475		# truncate to n_samples
476		prompts = prompts[:n_samples]
477
478	print(f"{len(prompts)} prompts loaded")
479
480	# write index.html
481	with SpinnerContext(message="writing index.html", **SPINNER_KWARGS):
482		if not no_index_html:
483			write_html_index(save_path_p)
484
485	# TODO: not implemented yet
486	if stacked_heads:
487		raise NotImplementedError("stacked_heads not implemented yet")
488
489	# get activations
490	list(
491		tqdm.tqdm(
492			map(
493				functools.partial(
494					get_activations,
495					model=model,
496					save_path=save_path_p,
497					allow_disk_cache=not force,
498					return_cache=None,
499					# stacked_heads=stacked_heads,
500				),
501				prompts,
502			),
503			total=len(prompts),
504			desc="Computing activations",
505			unit="prompt",
506		),
507	)
508
509	with SpinnerContext(
510		message="updating jsonl metadata for models and prompts",
511		**SPINNER_KWARGS,
512	):
513		generate_models_jsonl(save_path_p)
514		generate_prompts_jsonl(save_path_p / model_name)
515
516
517def main() -> None:
518	"generate attention pattern activations for a model and prompts"
519	print(DIVIDER_S1)
520	with SpinnerContext(message="parsing args", **SPINNER_KWARGS):
521		arg_parser: argparse.ArgumentParser = argparse.ArgumentParser()
522		# input and output
523		arg_parser.add_argument(
524			"--model",
525			"-m",
526			type=str,
527			required=True,
528			help="The model name(s) to use. comma separated with no whitespace if multiple",
529		)
530
531		arg_parser.add_argument(
532			"--prompts",
533			"-p",
534			type=str,
535			required=False,
536			help="The path to the prompts file (jsonl with 'text' key on each line). If `None`, expects that `--figures` is passed and will generate figures for all prompts in the model directory",
537			default=None,
538		)
539
540		arg_parser.add_argument(
541			"--save-path",
542			"-s",
543			type=str,
544			required=False,
545			help="The path to save the attention patterns",
546			default=DATA_DIR,
547		)
548
549		# min and max prompt lengths
550		arg_parser.add_argument(
551			"--min-chars",
552			type=int,
553			required=False,
554			help="The minimum number of characters for a prompt",
555			default=100,
556		)
557		arg_parser.add_argument(
558			"--max-chars",
559			type=int,
560			required=False,
561			help="The maximum number of characters for a prompt",
562			default=1000,
563		)
564
565		# number of samples
566		arg_parser.add_argument(
567			"--n-samples",
568			"-n",
569			type=int,
570			required=False,
571			help="The max number of samples to process, do all in the file if None",
572			default=None,
573		)
574
575		# force overwrite
576		arg_parser.add_argument(
577			"--force",
578			"-f",
579			action="store_true",
580			help="If passed, will overwrite existing files",
581		)
582
583		# no index html
584		arg_parser.add_argument(
585			"--no-index-html",
586			action="store_true",
587			help="If passed, will not write an index.html file for the model",
588		)
589
590		# raw prompts
591		arg_parser.add_argument(
592			"--raw-prompts",
593			"-r",
594			action="store_true",
595			help="pass if the prompts have not been split and tokenized (still needs keys 'text' and 'meta' for each item)",
596		)
597
598		# shuffle
599		arg_parser.add_argument(
600			"--shuffle",
601			action="store_true",
602			help="If passed, will shuffle the prompts",
603		)
604
605		# stack heads
606		arg_parser.add_argument(
607			"--stacked-heads",
608			action="store_true",
609			help="If passed, will stack the heads in the output tensor",
610		)
611
612		# device
613		arg_parser.add_argument(
614			"--device",
615			type=str,
616			required=False,
617			help="The device to use for the model",
618			default="cuda" if torch.cuda.is_available() else "cpu",
619		)
620
621		args: argparse.Namespace = arg_parser.parse_args()
622
623	print(f"args parsed: {args}")
624
625	models: list[str]
626	if "," in args.model:
627		models = args.model.split(",")
628	else:
629		models = [args.model]
630
631	n_models: int = len(models)
632	for idx, model in enumerate(models):
633		print(DIVIDER_S2)
634		print(f"processing model {idx + 1} / {n_models}: {model}")
635		print(DIVIDER_S2)
636
637		activations_main(
638			model_name=model,
639			save_path=args.save_path,
640			prompts_path=args.prompts,
641			raw_prompts=args.raw_prompts,
642			min_chars=args.min_chars,
643			max_chars=args.max_chars,
644			force=args.force,
645			n_samples=args.n_samples,
646			no_index_html=args.no_index_html,
647			shuffle=args.shuffle,
648			stacked_heads=args.stacked_heads,
649			device=args.device,
650		)
651		del model
652
653	print(DIVIDER_S1)
654
655
656if __name__ == "__main__":
657	main()

def compute_activations( prompt: dict, model: transformer_lens.HookedTransformer.HookedTransformer | None = None, save_path: pathlib.Path = PosixPath('attn_data'), names_filter: Callable[[str], bool] | re.Pattern = re.compile('blocks\\.(\\d+)\\.attn\\.hook_pattern'), return_cache: Literal[None, 'numpy', 'torch'] = 'torch', stack_heads: bool = False) -> tuple[pathlib.Path, dict[str, numpy.ndarray] | transformer_lens.ActivationCache.ActivationCache | jaxtyping.Float[ndarray, 'n_layers n_heads n_ctx n_ctx'] | jaxtyping.Float[Tensor, 'n_layers n_heads n_ctx n_ctx'] | None]:
121def compute_activations(  # noqa: PLR0915
122	prompt: dict,
123	model: HookedTransformer | None = None,
124	save_path: Path = Path(DATA_DIR),
125	names_filter: Callable[[str], bool] | re.Pattern = ATTN_PATTERN_REGEX,
126	return_cache: ReturnCache = "torch",
127	stack_heads: bool = False,
128) -> tuple[
129	Path,
130	ActivationCacheNp
131	| ActivationCache
132	| Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"]
133	| Float[torch.Tensor, "n_layers n_heads n_ctx n_ctx"]
134	| None,
135]:
136	"""get activations for a given model and prompt, possibly from a cache
137
138	if from a cache, prompt_meta must be passed and contain the prompt hash
139
140	# Parameters:
141	- `prompt : dict | None`
142		(defaults to `None`)
143	- `model : HookedTransformer`
144	- `save_path : Path`
145		(defaults to `Path(DATA_DIR)`)
146	- `names_filter : Callable[[str], bool]|re.Pattern`
147		a filter for the names of the activations to return. if an `re.Pattern`, will use `lambda key: names_filter.match(key) is not None`
148		(defaults to `ATTN_PATTERN_REGEX`)
149	- `return_cache : Literal[None, "numpy", "torch"]`
150		will return `None` as the second element if `None`, otherwise will return the cache in the specified tensor format. `stack_heads` still affects whether it will be a dict (False) or a single tensor (True)
151		(defaults to `None`)
152	- `stack_heads : bool`
153		whether the heads should be stacked in the output. this causes a number of changes:
154	- `npy` file with a single `(n_layers, n_heads, n_ctx, n_ctx)` tensor saved for each prompt instead of `npz` file with dict by layer
155	- `cache` will be a single `(n_layers, n_heads, n_ctx, n_ctx)` tensor instead of a dict by layer if `return_cache` is `True`
156		will assert that everything in the activation cache is only attention patterns, and is all of the attention patterns. raises an exception if not.
157
158	# Returns:
159	```
160	tuple[
161		Path,
162		Union[
163			None,
164			ActivationCacheNp, ActivationCache,
165			Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"], Float[torch.Tensor, "n_layers n_heads n_ctx n_ctx"],
166		]
167	]
168	```
169	"""
170	# check inputs
171	assert model is not None, "model must be passed"
172	assert "text" in prompt, "prompt must contain 'text' key"
173	prompt_str: str = prompt["text"]
174
175	# compute or get prompt metadata
176	prompt_tokenized: list[str] = prompt.get(
177		"tokens",
178		model.tokenizer.tokenize(prompt_str),
179	)
180	prompt.update(
181		dict(
182			n_tokens=len(prompt_tokenized),
183			tokens=prompt_tokenized,
184		),
185	)
186
187	# save metadata
188	prompt_dir: Path = save_path / model.cfg.model_name / "prompts" / prompt["hash"]
189	prompt_dir.mkdir(parents=True, exist_ok=True)
190	with open(prompt_dir / "prompt.json", "w") as f:
191		json.dump(prompt, f)
192
193	# set up names filter
194	names_filter_fn: Callable[[str], bool]
195	if isinstance(names_filter, re.Pattern):
196		names_filter_fn = lambda key: names_filter.match(key) is not None  # noqa: E731
197	else:
198		names_filter_fn = names_filter
199
200	# compute activations
201	cache_torch: ActivationCache
202	with torch.no_grad():
203		model.eval()
204		# TODO: batching?
205		_, cache_torch = model.run_with_cache(
206			prompt_str,
207			names_filter=names_filter_fn,
208			return_type=None,
209		)
210
211	activations_path: Path
212	# saving and returning
213	if stack_heads:
214		n_layers: int = model.cfg.n_layers
215		key_pattern: str = "blocks.{i}.attn.hook_pattern"
216		# NOTE: this only works for stacking heads at the moment
217		# activations_specifier: str = key_pattern.format(i=f'0-{n_layers}')
218		activations_specifier: str = key_pattern.format(i="-")
219		activations_path = prompt_dir / f"activations-{activations_specifier}.npy"
220
221		# check the keys are only attention heads
222		head_keys: list[str] = [key_pattern.format(i=i) for i in range(n_layers)]
223		cache_torch_keys_set: set[str] = set(cache_torch.keys())
224		assert cache_torch_keys_set == set(head_keys), (
225			f"unexpected keys!\n{set(head_keys).symmetric_difference(cache_torch_keys_set) = }\n{cache_torch_keys_set} != {set(head_keys)}"
226		)
227
228		# stack heads
229		patterns_stacked: Float[torch.Tensor, "n_layers n_heads n_ctx n_ctx"] = (
230			torch.stack([cache_torch[k] for k in head_keys], dim=1)
231		)
232		# check shape
233		pattern_shape_no_ctx: tuple[int, ...] = tuple(patterns_stacked.shape[:3])
234		assert pattern_shape_no_ctx == (1, n_layers, model.cfg.n_heads), (
235			f"unexpected shape: {patterns_stacked.shape[:3] = } ({pattern_shape_no_ctx = }), expected {(1, n_layers, model.cfg.n_heads) = }"
236		)
237
238		patterns_stacked_np: Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"] = (
239			patterns_stacked.cpu().numpy()
240		)
241
242		# save
243		np.save(activations_path, patterns_stacked_np)
244
245		# return
246		match return_cache:
247			case "numpy":
248				return activations_path, patterns_stacked_np
249			case "torch":
250				return activations_path, patterns_stacked
251			case None:
252				return activations_path, None
253			case _:
254				msg = f"invalid return_cache: {return_cache = }"
255				raise ValueError(msg)
256	else:
257		activations_path = prompt_dir / "activations.npz"
258
259		# save
260		cache_np: ActivationCacheNp = {
261			k: v.detach().cpu().numpy() for k, v in cache_torch.items()
262		}
263
264		np.savez_compressed(
265			activations_path,
266			**cache_np,
267		)
268
269		# return
270		match return_cache:
271			case "numpy":
272				return activations_path, cache_np
273			case "torch":
274				return activations_path, cache_torch
275			case None:
276				return activations_path, None
277			case _:
278				msg = f"invalid return_cache: {return_cache = }"
279				raise ValueError(msg)

get activations for a given model and prompt, possibly from a cache

if from a cache, prompt_meta must be passed and contain the prompt hash

Parameters:

  • prompt : dict | None (defaults to None)
  • model : HookedTransformer
  • save_path : Path (defaults to Path(DATA_DIR))
  • names_filter : Callable[[str], bool]|re.Pattern a filter for the names of the activations to return. if an re.Pattern, will use lambda key: names_filter.match(key) is not None (defaults to ATTN_PATTERN_REGEX)
  • return_cache : Literal[None, "numpy", "torch"] will return None as the second element if None, otherwise will return the cache in the specified tensor format. stack_heads still affects whether it will be a dict (False) or a single tensor (True) (defaults to None)
  • stack_heads : bool whether the heads should be stacked in the output. this causes a number of changes:
  • npy file with a single (n_layers, n_heads, n_ctx, n_ctx) tensor saved for each prompt instead of npz file with dict by layer
  • cache will be a single (n_layers, n_heads, n_ctx, n_ctx) tensor instead of a dict by layer if return_cache is True will assert that everything in the activation cache is only attention patterns, and is all of the attention patterns. raises an exception if not.

Returns:

tuple[
        Path,
        Union[
                None,
                ActivationCacheNp, ActivationCache,
                Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"], Float[torch.Tensor, "n_layers n_heads n_ctx n_ctx"],
        ]
]
def get_activations( prompt: dict, model: transformer_lens.HookedTransformer.HookedTransformer | str, save_path: pathlib.Path = PosixPath('attn_data'), allow_disk_cache: bool = True, return_cache: Literal[None, 'numpy', 'torch'] = 'numpy') -> tuple[pathlib.Path, dict[str, numpy.ndarray] | transformer_lens.ActivationCache.ActivationCache | None]:
306def get_activations(
307	prompt: dict,
308	model: HookedTransformer | str,
309	save_path: Path = Path(DATA_DIR),
310	allow_disk_cache: bool = True,
311	return_cache: ReturnCache = "numpy",
312) -> tuple[Path, ActivationCacheNp | ActivationCache | None]:
313	"""given a prompt and a model, save or load activations
314
315	# Parameters:
316	- `prompt : dict`
317		expected to contain the 'text' key
318	- `model : HookedTransformer | str`
319		either a `HookedTransformer` or a string model name, to be loaded with `HookedTransformer.from_pretrained`
320	- `save_path : Path`
321		path to save the activations to (and load from)
322		(defaults to `Path(DATA_DIR)`)
323	- `allow_disk_cache : bool`
324		whether to allow loading from disk cache
325		(defaults to `True`)
326	- `return_cache : Literal[None, "numpy", "torch"]`
327		whether to return the cache, and in what format
328		(defaults to `"numpy"`)
329
330	# Returns:
331	- `tuple[Path, ActivationCacheNp | ActivationCache | None]`
332		the path to the activations and the cache if `return_cache is not None`
333
334	"""
335	# add hash to prompt
336	augment_prompt_with_hash(prompt)
337
338	# get the model
339	model_name: str = (
340		model.cfg.model_name if isinstance(model, HookedTransformer) else model
341	)
342
343	# from cache
344	if allow_disk_cache:
345		try:
346			path, cache = load_activations(
347				model_name=model_name,
348				prompt=prompt,
349				save_path=save_path,
350			)
351			if return_cache:
352				return path, cache
353			else:
354				# TODO: this basically does nothing, since we load the activations and then immediately get rid of them.
355				# maybe refactor this so that load_activations can take a parameter to simply assert that the cache exists?
356				# this will let us avoid loading it, which slows things down
357				return path, None
358		except ActivationsMissingError:
359			pass
360
361	# compute them
362	if isinstance(model, str):
363		model = HookedTransformer.from_pretrained(model_name)
364
365	return compute_activations(
366		prompt=prompt,
367		model=model,
368		save_path=save_path,
369		return_cache=return_cache,
370	)

given a prompt and a model, save or load activations

Parameters:

  • prompt : dict expected to contain the 'text' key
  • model : HookedTransformer | str either a HookedTransformer or a string model name, to be loaded with HookedTransformer.from_pretrained
  • save_path : Path path to save the activations to (and load from) (defaults to Path(DATA_DIR))
  • allow_disk_cache : bool whether to allow loading from disk cache (defaults to True)
  • return_cache : Literal[None, "numpy", "torch"] whether to return the cache, and in what format (defaults to "numpy")

Returns:

  • tuple[Path, ActivationCacheNp | ActivationCache | None] the path to the activations and the cache if return_cache is not None
DEFAULT_DEVICE: torch.device = device(type='cuda')
def activations_main( model_name: str, save_path: str, prompts_path: str, raw_prompts: bool, min_chars: int, max_chars: int, force: bool, n_samples: int, no_index_html: bool, shuffle: bool = False, stacked_heads: bool = False, device: str | torch.device = device(type='cuda')) -> None:
378def activations_main(
379	model_name: str,
380	save_path: str,
381	prompts_path: str,
382	raw_prompts: bool,
383	min_chars: int,
384	max_chars: int,
385	force: bool,
386	n_samples: int,
387	no_index_html: bool,
388	shuffle: bool = False,
389	stacked_heads: bool = False,
390	device: str | torch.device = DEFAULT_DEVICE,
391) -> None:
392	"""main function for computing activations
393
394	# Parameters:
395	- `model_name : str`
396		name of a model to load with `HookedTransformer.from_pretrained`
397	- `save_path : str`
398		path to save the activations to
399	- `prompts_path : str`
400		path to the prompts file
401	- `raw_prompts : bool`
402		whether the prompts are raw, not filtered by length. `load_text_data` will be called if `True`, otherwise just load the "text" field from each line in `prompts_path`
403	- `min_chars : int`
404		minimum number of characters for a prompt
405	- `max_chars : int`
406		maximum number of characters for a prompt
407	- `force : bool`
408		whether to overwrite existing files
409	- `n_samples : int`
410		maximum number of samples to process
411	- `no_index_html : bool`
412		whether to write an index.html file
413	- `shuffle : bool`
414		whether to shuffle the prompts
415		(defaults to `False`)
416	- `stacked_heads : bool`
417		whether	to stack the heads in the output tensor. will save as `.npy` instead of `.npz` if `True`
418		(defaults to `False`)
419	- `device : str | torch.device`
420		the device to use. if a string, will be passed to `torch.device`
421	"""
422	# figure out the device to use
423	device_: torch.device
424	if isinstance(device, torch.device):
425		device_ = device
426	elif isinstance(device, str):
427		device_ = torch.device(device)
428	else:
429		msg = f"invalid device: {device}"
430		raise TypeError(msg)
431
432	print(f"using device: {device_}")
433
434	with SpinnerContext(message="loading model", **SPINNER_KWARGS):
435		model: HookedTransformer = HookedTransformer.from_pretrained(
436			model_name,
437			device=device_,
438		)
439		model.model_name = model_name
440		model.cfg.model_name = model_name
441		n_params: int = sum(p.numel() for p in model.parameters())
442	print(
443		f"loaded {model_name} with {shorten_numerical_to_str(n_params)} ({n_params}) parameters",
444	)
445	print(f"\tmodel devices: { {p.device for p in model.parameters()} }")
446
447	save_path_p: Path = Path(save_path)
448	save_path_p.mkdir(parents=True, exist_ok=True)
449	model_path: Path = save_path_p / model_name
450	with SpinnerContext(
451		message=f"saving model info to {model_path.as_posix()}",
452		**SPINNER_KWARGS,
453	):
454		model_cfg: HookedTransformerConfig
455		model_cfg = model.cfg
456		model_path.mkdir(parents=True, exist_ok=True)
457		with open(model_path / "model_cfg.json", "w") as f:
458			json.dump(json_serialize(asdict(model_cfg)), f)
459
460	# load prompts
461	with SpinnerContext(
462		message=f"loading prompts from {prompts_path = }",
463		**SPINNER_KWARGS,
464	):
465		prompts: list[dict]
466		if raw_prompts:
467			prompts = load_text_data(
468				Path(prompts_path),
469				min_chars=min_chars,
470				max_chars=max_chars,
471				shuffle=shuffle,
472			)
473		else:
474			with open(model_path / "prompts.jsonl", "r") as f:
475				prompts = [json.loads(line) for line in f.readlines()]
476		# truncate to n_samples
477		prompts = prompts[:n_samples]
478
479	print(f"{len(prompts)} prompts loaded")
480
481	# write index.html
482	with SpinnerContext(message="writing index.html", **SPINNER_KWARGS):
483		if not no_index_html:
484			write_html_index(save_path_p)
485
486	# TODO: not implemented yet
487	if stacked_heads:
488		raise NotImplementedError("stacked_heads not implemented yet")
489
490	# get activations
491	list(
492		tqdm.tqdm(
493			map(
494				functools.partial(
495					get_activations,
496					model=model,
497					save_path=save_path_p,
498					allow_disk_cache=not force,
499					return_cache=None,
500					# stacked_heads=stacked_heads,
501				),
502				prompts,
503			),
504			total=len(prompts),
505			desc="Computing activations",
506			unit="prompt",
507		),
508	)
509
510	with SpinnerContext(
511		message="updating jsonl metadata for models and prompts",
512		**SPINNER_KWARGS,
513	):
514		generate_models_jsonl(save_path_p)
515		generate_prompts_jsonl(save_path_p / model_name)

main function for computing activations

Parameters:

  • model_name : str name of a model to load with HookedTransformer.from_pretrained
  • save_path : str path to save the activations to
  • prompts_path : str path to the prompts file
  • raw_prompts : bool whether the prompts are raw, not filtered by length. load_text_data will be called if True, otherwise just load the "text" field from each line in prompts_path
  • min_chars : int minimum number of characters for a prompt
  • max_chars : int maximum number of characters for a prompt
  • force : bool whether to overwrite existing files
  • n_samples : int maximum number of samples to process
  • no_index_html : bool whether to write an index.html file
  • shuffle : bool whether to shuffle the prompts (defaults to False)
  • stacked_heads : bool whether to stack the heads in the output tensor. will save as .npy instead of .npz if True (defaults to False)
  • device : str | torch.device the device to use. if a string, will be passed to torch.device
def main() -> None:
518def main() -> None:
519	"generate attention pattern activations for a model and prompts"
520	print(DIVIDER_S1)
521	with SpinnerContext(message="parsing args", **SPINNER_KWARGS):
522		arg_parser: argparse.ArgumentParser = argparse.ArgumentParser()
523		# input and output
524		arg_parser.add_argument(
525			"--model",
526			"-m",
527			type=str,
528			required=True,
529			help="The model name(s) to use. comma separated with no whitespace if multiple",
530		)
531
532		arg_parser.add_argument(
533			"--prompts",
534			"-p",
535			type=str,
536			required=False,
537			help="The path to the prompts file (jsonl with 'text' key on each line). If `None`, expects that `--figures` is passed and will generate figures for all prompts in the model directory",
538			default=None,
539		)
540
541		arg_parser.add_argument(
542			"--save-path",
543			"-s",
544			type=str,
545			required=False,
546			help="The path to save the attention patterns",
547			default=DATA_DIR,
548		)
549
550		# min and max prompt lengths
551		arg_parser.add_argument(
552			"--min-chars",
553			type=int,
554			required=False,
555			help="The minimum number of characters for a prompt",
556			default=100,
557		)
558		arg_parser.add_argument(
559			"--max-chars",
560			type=int,
561			required=False,
562			help="The maximum number of characters for a prompt",
563			default=1000,
564		)
565
566		# number of samples
567		arg_parser.add_argument(
568			"--n-samples",
569			"-n",
570			type=int,
571			required=False,
572			help="The max number of samples to process, do all in the file if None",
573			default=None,
574		)
575
576		# force overwrite
577		arg_parser.add_argument(
578			"--force",
579			"-f",
580			action="store_true",
581			help="If passed, will overwrite existing files",
582		)
583
584		# no index html
585		arg_parser.add_argument(
586			"--no-index-html",
587			action="store_true",
588			help="If passed, will not write an index.html file for the model",
589		)
590
591		# raw prompts
592		arg_parser.add_argument(
593			"--raw-prompts",
594			"-r",
595			action="store_true",
596			help="pass if the prompts have not been split and tokenized (still needs keys 'text' and 'meta' for each item)",
597		)
598
599		# shuffle
600		arg_parser.add_argument(
601			"--shuffle",
602			action="store_true",
603			help="If passed, will shuffle the prompts",
604		)
605
606		# stack heads
607		arg_parser.add_argument(
608			"--stacked-heads",
609			action="store_true",
610			help="If passed, will stack the heads in the output tensor",
611		)
612
613		# device
614		arg_parser.add_argument(
615			"--device",
616			type=str,
617			required=False,
618			help="The device to use for the model",
619			default="cuda" if torch.cuda.is_available() else "cpu",
620		)
621
622		args: argparse.Namespace = arg_parser.parse_args()
623
624	print(f"args parsed: {args}")
625
626	models: list[str]
627	if "," in args.model:
628		models = args.model.split(",")
629	else:
630		models = [args.model]
631
632	n_models: int = len(models)
633	for idx, model in enumerate(models):
634		print(DIVIDER_S2)
635		print(f"processing model {idx + 1} / {n_models}: {model}")
636		print(DIVIDER_S2)
637
638		activations_main(
639			model_name=model,
640			save_path=args.save_path,
641			prompts_path=args.prompts,
642			raw_prompts=args.raw_prompts,
643			min_chars=args.min_chars,
644			max_chars=args.max_chars,
645			force=args.force,
646			n_samples=args.n_samples,
647			no_index_html=args.no_index_html,
648			shuffle=args.shuffle,
649			stacked_heads=args.stacked_heads,
650			device=args.device,
651		)
652		del model
653
654	print(DIVIDER_S1)

generate attention pattern activations for a model and prompts