docs for pattern_lens v0.6.0
View Source on GitHub

pattern_lens.figures

code for generating figures from attention patterns, using the functions decorated with register_attn_figure_func


  1"""code for generating figures from attention patterns, using the functions decorated with `register_attn_figure_func`"""
  2
  3import argparse
  4import fnmatch
  5import functools
  6import itertools
  7import json
  8import multiprocessing
  9import re
 10import warnings
 11from collections import defaultdict
 12from pathlib import Path
 13
 14import numpy as np
 15from jaxtyping import Float
 16
 17# custom utils
 18from muutils.json_serialize import json_serialize
 19from muutils.parallel import run_maybe_parallel
 20from muutils.spinner import SpinnerContext
 21
 22# pattern_lens
 23from pattern_lens.attn_figure_funcs import ATTENTION_MATRIX_FIGURE_FUNCS
 24from pattern_lens.consts import (
 25	DATA_DIR,
 26	DIVIDER_S1,
 27	DIVIDER_S2,
 28	SPINNER_KWARGS,
 29	ActivationCacheNp,
 30	AttentionMatrix,
 31)
 32from pattern_lens.figure_util import AttentionMatrixFigureFunc
 33from pattern_lens.indexes import (
 34	generate_functions_jsonl,
 35	generate_models_jsonl,
 36	generate_prompts_jsonl,
 37)
 38from pattern_lens.load_activations import load_activations
 39
 40
 41class HTConfigMock:
 42	"""Mock of `transformer_lens.HookedTransformerConfig` for type hinting and loading config json
 43
 44	can be initialized with any kwargs, and will update its `__dict__` with them. does, however, require the following attributes:
 45	- `n_layers: int`
 46	- `n_heads: int`
 47	- `model_name: str`
 48
 49	we do this to avoid having to import `torch` and `transformer_lens`, since this would have to be done for each process in the parallelization and probably slows things down significantly
 50	"""
 51
 52	def __init__(self, **kwargs: dict[str, str | int]) -> None:
 53		"will pass all kwargs to `__dict__`"
 54		self.n_layers: int
 55		self.n_heads: int
 56		self.model_name: str
 57		self.__dict__.update(kwargs)
 58
 59	def serialize(self) -> dict:
 60		"""serialize the config to json. values which aren't serializable will be converted via `muutils.json_serialize.json_serialize`"""
 61		# its fine, we know its a dict
 62		return json_serialize(self.__dict__)  # type: ignore[return-value]
 63
 64	@classmethod
 65	def load(cls, data: dict) -> "HTConfigMock":
 66		"try to load a config from a dict, using the `__init__` method"
 67		return cls(**data)
 68
 69
 70def process_single_head(
 71	layer_idx: int,
 72	head_idx: int,
 73	attn_pattern: AttentionMatrix,
 74	save_dir: Path,
 75	figure_funcs: list[AttentionMatrixFigureFunc],
 76	force_overwrite: bool = False,
 77) -> dict[str, bool | Exception]:
 78	"""process a single head's attention pattern, running all the functions in `figure_funcs` on the attention pattern
 79
 80	> [gotcha:] if `force_overwrite` is `False`, and we used a multi-figure function,
 81	> it will skip all figures for that function if any are already saved
 82	> and it assumes a format of `{func_name}.{figure_name}.{fmt}` for the saved figures
 83
 84	# Parameters:
 85	- `layer_idx : int`
 86	- `head_idx : int`
 87	- `attn_pattern : AttentionMatrix`
 88		attention pattern for the head
 89	- `save_dir : Path`
 90		directory to save the figures to
 91	- `force_overwrite : bool`
 92		whether to overwrite existing figures. if `False`, will skip any functions which have already saved a figure
 93		(defaults to `False`)
 94
 95	# Returns:
 96	- `dict[str, bool | Exception]`
 97		a dictionary of the status of each function, with the function name as the key and the status as the value
 98	"""
 99	funcs_status: dict[str, bool | Exception] = dict()
100
101	for func in figure_funcs:
102		func_name: str = getattr(func, "__name__", "<unknown>")
103		fig_path: list[Path] = list(save_dir.glob(f"{func_name}.*"))
104
105		if not force_overwrite and len(fig_path) > 0:
106			funcs_status[func_name] = True
107			continue
108
109		try:
110			func(attn_pattern, save_dir)
111			funcs_status[func_name] = True
112
113		# bling catch any exception
114		except Exception as e:  # noqa: BLE001
115			error_file = save_dir / f"{func_name}.error.txt"
116			error_file.write_text(str(e))
117			warnings.warn(
118				f"Error in {func_name} for L{layer_idx}H{head_idx}: {e!s}",
119				stacklevel=2,
120			)
121			funcs_status[func_name] = e
122
123	return funcs_status
124
125
126def compute_and_save_figures(
127	model_cfg: "HookedTransformerConfig|HTConfigMock",  # type: ignore[name-defined] # noqa: F821
128	activations_path: Path,
129	cache: ActivationCacheNp | Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"],
130	figure_funcs: list[AttentionMatrixFigureFunc],
131	save_path: Path = Path(DATA_DIR),
132	force_overwrite: bool = False,
133	track_results: bool = False,
134) -> None:
135	"""compute and save figures for all heads in the model, using the functions in `ATTENTION_MATRIX_FIGURE_FUNCS`
136
137	# Parameters:
138	- `model_cfg : HookedTransformerConfig|HTConfigMock`
139		configuration of the model, used for loading the activations
140	- `cache : ActivationCacheNp | Float[np.ndarray, &quot;n_layers n_heads n_ctx n_ctx&quot;]`
141		activation cache containing actual patterns for the prompt we are processing
142	- `figure_funcs : list[AttentionMatrixFigureFunc]`
143		list of functions to run
144	- `save_path : Path`
145		directory to save the figures to
146		(defaults to `Path(DATA_DIR)`)
147	- `force_overwrite : bool`
148		force overwrite of existing figures. if `False`, will skip any functions which have already saved a figure
149		(defaults to `False`)
150	- `track_results : bool`
151		whether to track the results of each function for each head. Isn't used for anything yet, but this is a TODO
152		(defaults to `False`)
153	"""
154	prompt_dir: Path = activations_path.parent
155
156	if track_results:
157		results: defaultdict[
158			str,  # func name
159			dict[
160				tuple[int, int],  # layer, head
161				bool | Exception,  # success or exception
162			],
163		] = defaultdict(dict)
164
165	for layer_idx, head_idx in itertools.product(
166		range(model_cfg.n_layers),
167		range(model_cfg.n_heads),
168	):
169		attn_pattern: AttentionMatrix
170		if isinstance(cache, dict):
171			attn_pattern = cache[f"blocks.{layer_idx}.attn.hook_pattern"][0, head_idx]
172		elif isinstance(cache, np.ndarray):
173			attn_pattern = cache[layer_idx, head_idx]
174		else:
175			msg = (
176				f"cache must be a dict or np.ndarray, not {type(cache) = }\n{cache = }"
177			)
178			raise TypeError(
179				msg,
180			)
181
182		save_dir: Path = prompt_dir / f"L{layer_idx}" / f"H{head_idx}"
183		save_dir.mkdir(parents=True, exist_ok=True)
184		head_res: dict[str, bool | Exception] = process_single_head(
185			layer_idx=layer_idx,
186			head_idx=head_idx,
187			attn_pattern=attn_pattern,
188			save_dir=save_dir,
189			force_overwrite=force_overwrite,
190			figure_funcs=figure_funcs,
191		)
192
193		if track_results:
194			for func_name, status in head_res.items():
195				results[func_name][(layer_idx, head_idx)] = status
196
197	# TODO: do something with results
198
199	generate_prompts_jsonl(save_path / model_cfg.model_name)
200
201
202def process_prompt(
203	prompt: dict,
204	model_cfg: "HookedTransformerConfig|HTConfigMock",  # type: ignore[name-defined] # noqa: F821
205	save_path: Path,
206	figure_funcs: list[AttentionMatrixFigureFunc],
207	force_overwrite: bool = False,
208) -> None:
209	"""process a single prompt, loading the activations and computing and saving the figures
210
211	basically just calls `load_activations` and then `compute_and_save_figures`
212
213	# Parameters:
214	- `prompt : dict`
215		prompt to process, should be a dict with the following keys:
216		- `"text"`: the prompt string
217		- `"hash"`: the hash of the prompt
218	- `model_cfg : HookedTransformerConfig|HTConfigMock`
219		configuration of the model, used for figuring out where to save
220	- `save_path : Path`
221		directory to save the figures to
222	- `figure_funcs : list[AttentionMatrixFigureFunc]`
223		list of functions to run
224	- `force_overwrite : bool`
225		(defaults to `False`)
226	"""
227	# load the activations
228	activations_path: Path
229	cache: ActivationCacheNp | Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"]
230	activations_path, cache = load_activations(
231		model_name=model_cfg.model_name,
232		prompt=prompt,
233		save_path=save_path,
234		return_fmt="numpy",
235	)
236
237	# compute and save the figures
238	compute_and_save_figures(
239		model_cfg=model_cfg,
240		activations_path=activations_path,
241		cache=cache,
242		figure_funcs=figure_funcs,
243		save_path=save_path,
244		force_overwrite=force_overwrite,
245	)
246
247
248def select_attn_figure_funcs(
249	figure_funcs_select: set[str] | str | None = None,
250) -> list[AttentionMatrixFigureFunc]:
251	"""given a selector, figure out which functions from `ATTENTION_MATRIX_FIGURE_FUNCS` to use
252
253	- if arg is `None`, will use all functions
254	- if a string, will use the function names which match the string (glob/fnmatch syntax)
255	- if a set, will use functions whose names are in the set
256
257	"""
258	# figure out which functions to use
259	figure_funcs: list[AttentionMatrixFigureFunc]
260	if figure_funcs_select is None:
261		# all if nothing specified
262		figure_funcs = ATTENTION_MATRIX_FIGURE_FUNCS
263	elif isinstance(figure_funcs_select, str):
264		# if a string, assume a glob pattern
265		pattern: re.Pattern = re.compile(fnmatch.translate(figure_funcs_select))
266		figure_funcs = [
267			func
268			for func in ATTENTION_MATRIX_FIGURE_FUNCS
269			if pattern.match(getattr(func, "__name__", "<unknown>"))
270		]
271	elif isinstance(figure_funcs_select, set):
272		# if a set, assume a set of function names
273		figure_funcs = [
274			func
275			for func in ATTENTION_MATRIX_FIGURE_FUNCS
276			if getattr(func, "__name__", "<unknown>") in figure_funcs_select
277		]
278	else:
279		err_msg: str = (
280			f"figure_funcs_select must be None, str, or set, not {type(figure_funcs_select) = }"
281			f"\n{figure_funcs_select = }"
282		)
283		raise TypeError(err_msg)
284	return figure_funcs
285
286
287def figures_main(
288	model_name: str,
289	save_path: str | Path,
290	n_samples: int,
291	force: bool,
292	figure_funcs_select: set[str] | str | None = None,
293	parallel: bool | int = True,
294) -> None:
295	"""main function for generating figures from attention patterns, using the functions in `ATTENTION_MATRIX_FIGURE_FUNCS`
296
297	# Parameters:
298	- `model_name : str`
299		model name to use, used for loading the model config, prompts, activations, and saving the figures
300	- `save_path : str | Path`
301		base path to look in
302	- `n_samples : int`
303		max number of samples to process
304	- `force : bool`
305		force overwrite of existing figures. if `False`, will skip any functions which have already saved a figure
306	- `figure_funcs_select : set[str]|str|None`
307		figure functions to use. if `None`, will use all functions. if a string, will use the function names which match the string. if a set, will use the function names in the set
308		(defaults to `None`)
309	- `parallel : bool | int`
310		whether to run in parallel. if `True`, will use all available cores. if `False`, will run in serial. if an int, will try to use that many cores
311		(defaults to `True`)
312	"""
313	with SpinnerContext(message="setting up paths", **SPINNER_KWARGS):
314		# save model info or check if it exists
315		save_path_p: Path = Path(save_path)
316		model_path: Path = save_path_p / model_name
317		with open(model_path / "model_cfg.json", "r") as f:
318			model_cfg = HTConfigMock.load(json.load(f))
319
320	with SpinnerContext(message="loading prompts", **SPINNER_KWARGS):
321		# load prompts
322		with open(model_path / "prompts.jsonl", "r") as f:
323			prompts: list[dict] = [json.loads(line) for line in f.readlines()]
324		# truncate to n_samples
325		prompts = prompts[:n_samples]
326
327	print(f"{len(prompts)} prompts loaded")
328
329	figure_funcs: list[AttentionMatrixFigureFunc] = select_attn_figure_funcs(
330		figure_funcs_select=figure_funcs_select,
331	)
332	print(f"{len(figure_funcs)} figure functions loaded")
333	print(
334		"\t"
335		+ ", ".join([getattr(func, "__name__", "<unknown>") for func in figure_funcs]),
336	)
337
338	chunksize: int = int(
339		max(
340			1,
341			len(prompts) // (5 * multiprocessing.cpu_count()),
342		),
343	)
344	print(f"chunksize: {chunksize}")
345
346	list(
347		run_maybe_parallel(
348			func=functools.partial(
349				process_prompt,
350				model_cfg=model_cfg,
351				save_path=save_path_p,
352				figure_funcs=figure_funcs,
353				force_overwrite=force,
354			),
355			iterable=prompts,
356			parallel=parallel,
357			chunksize=chunksize,
358			pbar="tqdm",
359			pbar_kwargs=dict(
360				desc="Making figures",
361				unit="prompt",
362			),
363		),
364	)
365
366	with SpinnerContext(
367		message="updating jsonl metadata for models and functions",
368		**SPINNER_KWARGS,
369	):
370		generate_models_jsonl(save_path_p)
371		generate_functions_jsonl(save_path_p)
372
373
374def _parse_args() -> tuple[
375	argparse.Namespace,
376	list[str],  # models
377	set[str] | str | None,  # figure_funcs_select
378]:
379	arg_parser: argparse.ArgumentParser = argparse.ArgumentParser()
380	# input and output
381	arg_parser.add_argument(
382		"--model",
383		"-m",
384		type=str,
385		required=True,
386		help="The model name(s) to use. comma separated with no whitespace if multiple",
387	)
388	arg_parser.add_argument(
389		"--save-path",
390		"-s",
391		type=str,
392		required=False,
393		help="The path to save the attention patterns",
394		default=DATA_DIR,
395	)
396	# number of samples
397	arg_parser.add_argument(
398		"--n-samples",
399		"-n",
400		type=int,
401		required=False,
402		help="The max number of samples to process, do all in the file if None",
403		default=None,
404	)
405	# force overwrite of existing figures
406	arg_parser.add_argument(
407		"--force",
408		"-f",
409		type=bool,
410		required=False,
411		help="Force overwrite of existing figures",
412		default=False,
413	)
414	# figure functions
415	arg_parser.add_argument(
416		"--figure-funcs",
417		type=str,
418		required=False,
419		help="The figure functions to use. if 'None' (default), will use all functions. if a string, will use the function names which match the string. if a comma-separated list of strings, will use the function names in the set",
420		default=None,
421	)
422
423	args: argparse.Namespace = arg_parser.parse_args()
424
425	# figure out models
426	models: list[str]
427	if "," in args.model:
428		models = args.model.split(",")
429	else:
430		models = [args.model]
431
432	# figure out figures
433	figure_funcs_select: set[str] | str | None
434	if (args.figure_funcs is None) or (args.figure_funcs.lower().strip() == "none"):
435		figure_funcs_select = None
436	elif "," in args.figure_funcs:
437		figure_funcs_select = {x.strip() for x in args.figure_funcs.split(",")}
438	else:
439		figure_funcs_select = args.figure_funcs.strip()
440
441	return args, models, figure_funcs_select
442
443
444def main() -> None:
445	"generates figures from the activations using the functions decorated with `register_attn_figure_func`"
446	# parse args
447	print(DIVIDER_S1)
448	args: argparse.Namespace
449	models: list[str]
450	figure_funcs_select: set[str] | str | None
451	with SpinnerContext(message="parsing args", **SPINNER_KWARGS):
452		args, models, figure_funcs_select = _parse_args()
453	print(f"\targs parsed: '{args}'")
454	print(f"\tmodels: '{models}'")
455	print(f"\tfigure_funcs_select: '{figure_funcs_select}'")
456
457	# compute for each model
458	n_models: int = len(models)
459	for idx, model in enumerate(models):
460		print(DIVIDER_S2)
461		print(f"processing model {idx + 1} / {n_models}: {model}")
462		print(DIVIDER_S2)
463		figures_main(
464			model_name=model,
465			save_path=args.save_path,
466			n_samples=args.n_samples,
467			force=args.force,
468			figure_funcs_select=figure_funcs_select,
469		)
470
471	print(DIVIDER_S1)
472
473
474if __name__ == "__main__":
475	main()

class HTConfigMock:
42class HTConfigMock:
43	"""Mock of `transformer_lens.HookedTransformerConfig` for type hinting and loading config json
44
45	can be initialized with any kwargs, and will update its `__dict__` with them. does, however, require the following attributes:
46	- `n_layers: int`
47	- `n_heads: int`
48	- `model_name: str`
49
50	we do this to avoid having to import `torch` and `transformer_lens`, since this would have to be done for each process in the parallelization and probably slows things down significantly
51	"""
52
53	def __init__(self, **kwargs: dict[str, str | int]) -> None:
54		"will pass all kwargs to `__dict__`"
55		self.n_layers: int
56		self.n_heads: int
57		self.model_name: str
58		self.__dict__.update(kwargs)
59
60	def serialize(self) -> dict:
61		"""serialize the config to json. values which aren't serializable will be converted via `muutils.json_serialize.json_serialize`"""
62		# its fine, we know its a dict
63		return json_serialize(self.__dict__)  # type: ignore[return-value]
64
65	@classmethod
66	def load(cls, data: dict) -> "HTConfigMock":
67		"try to load a config from a dict, using the `__init__` method"
68		return cls(**data)

Mock of transformer_lens.HookedTransformerConfig for type hinting and loading config json

can be initialized with any kwargs, and will update its __dict__ with them. does, however, require the following attributes:

  • n_layers: int
  • n_heads: int
  • model_name: str

we do this to avoid having to import torch and transformer_lens, since this would have to be done for each process in the parallelization and probably slows things down significantly

HTConfigMock(**kwargs: dict[str, str | int])
53	def __init__(self, **kwargs: dict[str, str | int]) -> None:
54		"will pass all kwargs to `__dict__`"
55		self.n_layers: int
56		self.n_heads: int
57		self.model_name: str
58		self.__dict__.update(kwargs)

will pass all kwargs to __dict__

n_layers: int
n_heads: int
model_name: str
def serialize(self) -> dict:
60	def serialize(self) -> dict:
61		"""serialize the config to json. values which aren't serializable will be converted via `muutils.json_serialize.json_serialize`"""
62		# its fine, we know its a dict
63		return json_serialize(self.__dict__)  # type: ignore[return-value]

serialize the config to json. values which aren't serializable will be converted via muutils.json_serialize.json_serialize

@classmethod
def load(cls, data: dict) -> HTConfigMock:
65	@classmethod
66	def load(cls, data: dict) -> "HTConfigMock":
67		"try to load a config from a dict, using the `__init__` method"
68		return cls(**data)

try to load a config from a dict, using the __init__ method

def process_single_head( layer_idx: int, head_idx: int, attn_pattern: jaxtyping.Float[ndarray, 'n_ctx n_ctx'], save_dir: pathlib._local.Path, figure_funcs: list[Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], pathlib._local.Path], None]], force_overwrite: bool = False) -> dict[str, bool | Exception]:
 71def process_single_head(
 72	layer_idx: int,
 73	head_idx: int,
 74	attn_pattern: AttentionMatrix,
 75	save_dir: Path,
 76	figure_funcs: list[AttentionMatrixFigureFunc],
 77	force_overwrite: bool = False,
 78) -> dict[str, bool | Exception]:
 79	"""process a single head's attention pattern, running all the functions in `figure_funcs` on the attention pattern
 80
 81	> [gotcha:] if `force_overwrite` is `False`, and we used a multi-figure function,
 82	> it will skip all figures for that function if any are already saved
 83	> and it assumes a format of `{func_name}.{figure_name}.{fmt}` for the saved figures
 84
 85	# Parameters:
 86	- `layer_idx : int`
 87	- `head_idx : int`
 88	- `attn_pattern : AttentionMatrix`
 89		attention pattern for the head
 90	- `save_dir : Path`
 91		directory to save the figures to
 92	- `force_overwrite : bool`
 93		whether to overwrite existing figures. if `False`, will skip any functions which have already saved a figure
 94		(defaults to `False`)
 95
 96	# Returns:
 97	- `dict[str, bool | Exception]`
 98		a dictionary of the status of each function, with the function name as the key and the status as the value
 99	"""
100	funcs_status: dict[str, bool | Exception] = dict()
101
102	for func in figure_funcs:
103		func_name: str = getattr(func, "__name__", "<unknown>")
104		fig_path: list[Path] = list(save_dir.glob(f"{func_name}.*"))
105
106		if not force_overwrite and len(fig_path) > 0:
107			funcs_status[func_name] = True
108			continue
109
110		try:
111			func(attn_pattern, save_dir)
112			funcs_status[func_name] = True
113
114		# bling catch any exception
115		except Exception as e:  # noqa: BLE001
116			error_file = save_dir / f"{func_name}.error.txt"
117			error_file.write_text(str(e))
118			warnings.warn(
119				f"Error in {func_name} for L{layer_idx}H{head_idx}: {e!s}",
120				stacklevel=2,
121			)
122			funcs_status[func_name] = e
123
124	return funcs_status

process a single head's attention pattern, running all the functions in figure_funcs on the attention pattern

[gotcha:] if force_overwrite is False, and we used a multi-figure function, it will skip all figures for that function if any are already saved and it assumes a format of {func_name}.{figure_name}.{fmt} for the saved figures

Parameters:

  • layer_idx : int
  • head_idx : int
  • attn_pattern : AttentionMatrix attention pattern for the head
  • save_dir : Path directory to save the figures to
  • force_overwrite : bool whether to overwrite existing figures. if False, will skip any functions which have already saved a figure (defaults to False)

Returns:

  • dict[str, bool | Exception] a dictionary of the status of each function, with the function name as the key and the status as the value
def compute_and_save_figures( model_cfg: 'HookedTransformerConfig|HTConfigMock', activations_path: pathlib._local.Path, cache: dict[str, numpy.ndarray] | jaxtyping.Float[ndarray, 'n_layers n_heads n_ctx n_ctx'], figure_funcs: list[Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], pathlib._local.Path], None]], save_path: pathlib._local.Path = PosixPath('attn_data'), force_overwrite: bool = False, track_results: bool = False) -> None:
127def compute_and_save_figures(
128	model_cfg: "HookedTransformerConfig|HTConfigMock",  # type: ignore[name-defined] # noqa: F821
129	activations_path: Path,
130	cache: ActivationCacheNp | Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"],
131	figure_funcs: list[AttentionMatrixFigureFunc],
132	save_path: Path = Path(DATA_DIR),
133	force_overwrite: bool = False,
134	track_results: bool = False,
135) -> None:
136	"""compute and save figures for all heads in the model, using the functions in `ATTENTION_MATRIX_FIGURE_FUNCS`
137
138	# Parameters:
139	- `model_cfg : HookedTransformerConfig|HTConfigMock`
140		configuration of the model, used for loading the activations
141	- `cache : ActivationCacheNp | Float[np.ndarray, &quot;n_layers n_heads n_ctx n_ctx&quot;]`
142		activation cache containing actual patterns for the prompt we are processing
143	- `figure_funcs : list[AttentionMatrixFigureFunc]`
144		list of functions to run
145	- `save_path : Path`
146		directory to save the figures to
147		(defaults to `Path(DATA_DIR)`)
148	- `force_overwrite : bool`
149		force overwrite of existing figures. if `False`, will skip any functions which have already saved a figure
150		(defaults to `False`)
151	- `track_results : bool`
152		whether to track the results of each function for each head. Isn't used for anything yet, but this is a TODO
153		(defaults to `False`)
154	"""
155	prompt_dir: Path = activations_path.parent
156
157	if track_results:
158		results: defaultdict[
159			str,  # func name
160			dict[
161				tuple[int, int],  # layer, head
162				bool | Exception,  # success or exception
163			],
164		] = defaultdict(dict)
165
166	for layer_idx, head_idx in itertools.product(
167		range(model_cfg.n_layers),
168		range(model_cfg.n_heads),
169	):
170		attn_pattern: AttentionMatrix
171		if isinstance(cache, dict):
172			attn_pattern = cache[f"blocks.{layer_idx}.attn.hook_pattern"][0, head_idx]
173		elif isinstance(cache, np.ndarray):
174			attn_pattern = cache[layer_idx, head_idx]
175		else:
176			msg = (
177				f"cache must be a dict or np.ndarray, not {type(cache) = }\n{cache = }"
178			)
179			raise TypeError(
180				msg,
181			)
182
183		save_dir: Path = prompt_dir / f"L{layer_idx}" / f"H{head_idx}"
184		save_dir.mkdir(parents=True, exist_ok=True)
185		head_res: dict[str, bool | Exception] = process_single_head(
186			layer_idx=layer_idx,
187			head_idx=head_idx,
188			attn_pattern=attn_pattern,
189			save_dir=save_dir,
190			force_overwrite=force_overwrite,
191			figure_funcs=figure_funcs,
192		)
193
194		if track_results:
195			for func_name, status in head_res.items():
196				results[func_name][(layer_idx, head_idx)] = status
197
198	# TODO: do something with results
199
200	generate_prompts_jsonl(save_path / model_cfg.model_name)

compute and save figures for all heads in the model, using the functions in ATTENTION_MATRIX_FIGURE_FUNCS

Parameters:

  • model_cfg : HookedTransformerConfig|HTConfigMock configuration of the model, used for loading the activations
  • cache : ActivationCacheNp | Float[np.ndarray, &quot;n_layers n_heads n_ctx n_ctx&quot;] activation cache containing actual patterns for the prompt we are processing
  • figure_funcs : list[AttentionMatrixFigureFunc] list of functions to run
  • save_path : Path directory to save the figures to (defaults to Path(DATA_DIR))
  • force_overwrite : bool force overwrite of existing figures. if False, will skip any functions which have already saved a figure (defaults to False)
  • track_results : bool whether to track the results of each function for each head. Isn't used for anything yet, but this is a TODO (defaults to False)
def process_prompt( prompt: dict, model_cfg: 'HookedTransformerConfig|HTConfigMock', save_path: pathlib._local.Path, figure_funcs: list[Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], pathlib._local.Path], None]], force_overwrite: bool = False) -> None:
203def process_prompt(
204	prompt: dict,
205	model_cfg: "HookedTransformerConfig|HTConfigMock",  # type: ignore[name-defined] # noqa: F821
206	save_path: Path,
207	figure_funcs: list[AttentionMatrixFigureFunc],
208	force_overwrite: bool = False,
209) -> None:
210	"""process a single prompt, loading the activations and computing and saving the figures
211
212	basically just calls `load_activations` and then `compute_and_save_figures`
213
214	# Parameters:
215	- `prompt : dict`
216		prompt to process, should be a dict with the following keys:
217		- `"text"`: the prompt string
218		- `"hash"`: the hash of the prompt
219	- `model_cfg : HookedTransformerConfig|HTConfigMock`
220		configuration of the model, used for figuring out where to save
221	- `save_path : Path`
222		directory to save the figures to
223	- `figure_funcs : list[AttentionMatrixFigureFunc]`
224		list of functions to run
225	- `force_overwrite : bool`
226		(defaults to `False`)
227	"""
228	# load the activations
229	activations_path: Path
230	cache: ActivationCacheNp | Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"]
231	activations_path, cache = load_activations(
232		model_name=model_cfg.model_name,
233		prompt=prompt,
234		save_path=save_path,
235		return_fmt="numpy",
236	)
237
238	# compute and save the figures
239	compute_and_save_figures(
240		model_cfg=model_cfg,
241		activations_path=activations_path,
242		cache=cache,
243		figure_funcs=figure_funcs,
244		save_path=save_path,
245		force_overwrite=force_overwrite,
246	)

process a single prompt, loading the activations and computing and saving the figures

basically just calls load_activations and then compute_and_save_figures

Parameters:

  • prompt : dict prompt to process, should be a dict with the following keys:
    • "text": the prompt string
    • "hash": the hash of the prompt
  • model_cfg : HookedTransformerConfig|HTConfigMock configuration of the model, used for figuring out where to save
  • save_path : Path directory to save the figures to
  • figure_funcs : list[AttentionMatrixFigureFunc] list of functions to run
  • force_overwrite : bool (defaults to False)
def select_attn_figure_funcs( figure_funcs_select: set[str] | str | None = None) -> list[Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], pathlib._local.Path], None]]:
249def select_attn_figure_funcs(
250	figure_funcs_select: set[str] | str | None = None,
251) -> list[AttentionMatrixFigureFunc]:
252	"""given a selector, figure out which functions from `ATTENTION_MATRIX_FIGURE_FUNCS` to use
253
254	- if arg is `None`, will use all functions
255	- if a string, will use the function names which match the string (glob/fnmatch syntax)
256	- if a set, will use functions whose names are in the set
257
258	"""
259	# figure out which functions to use
260	figure_funcs: list[AttentionMatrixFigureFunc]
261	if figure_funcs_select is None:
262		# all if nothing specified
263		figure_funcs = ATTENTION_MATRIX_FIGURE_FUNCS
264	elif isinstance(figure_funcs_select, str):
265		# if a string, assume a glob pattern
266		pattern: re.Pattern = re.compile(fnmatch.translate(figure_funcs_select))
267		figure_funcs = [
268			func
269			for func in ATTENTION_MATRIX_FIGURE_FUNCS
270			if pattern.match(getattr(func, "__name__", "<unknown>"))
271		]
272	elif isinstance(figure_funcs_select, set):
273		# if a set, assume a set of function names
274		figure_funcs = [
275			func
276			for func in ATTENTION_MATRIX_FIGURE_FUNCS
277			if getattr(func, "__name__", "<unknown>") in figure_funcs_select
278		]
279	else:
280		err_msg: str = (
281			f"figure_funcs_select must be None, str, or set, not {type(figure_funcs_select) = }"
282			f"\n{figure_funcs_select = }"
283		)
284		raise TypeError(err_msg)
285	return figure_funcs

given a selector, figure out which functions from ATTENTION_MATRIX_FIGURE_FUNCS to use

  • if arg is None, will use all functions
  • if a string, will use the function names which match the string (glob/fnmatch syntax)
  • if a set, will use functions whose names are in the set
def figures_main( model_name: str, save_path: str | pathlib._local.Path, n_samples: int, force: bool, figure_funcs_select: set[str] | str | None = None, parallel: bool | int = True) -> None:
288def figures_main(
289	model_name: str,
290	save_path: str | Path,
291	n_samples: int,
292	force: bool,
293	figure_funcs_select: set[str] | str | None = None,
294	parallel: bool | int = True,
295) -> None:
296	"""main function for generating figures from attention patterns, using the functions in `ATTENTION_MATRIX_FIGURE_FUNCS`
297
298	# Parameters:
299	- `model_name : str`
300		model name to use, used for loading the model config, prompts, activations, and saving the figures
301	- `save_path : str | Path`
302		base path to look in
303	- `n_samples : int`
304		max number of samples to process
305	- `force : bool`
306		force overwrite of existing figures. if `False`, will skip any functions which have already saved a figure
307	- `figure_funcs_select : set[str]|str|None`
308		figure functions to use. if `None`, will use all functions. if a string, will use the function names which match the string. if a set, will use the function names in the set
309		(defaults to `None`)
310	- `parallel : bool | int`
311		whether to run in parallel. if `True`, will use all available cores. if `False`, will run in serial. if an int, will try to use that many cores
312		(defaults to `True`)
313	"""
314	with SpinnerContext(message="setting up paths", **SPINNER_KWARGS):
315		# save model info or check if it exists
316		save_path_p: Path = Path(save_path)
317		model_path: Path = save_path_p / model_name
318		with open(model_path / "model_cfg.json", "r") as f:
319			model_cfg = HTConfigMock.load(json.load(f))
320
321	with SpinnerContext(message="loading prompts", **SPINNER_KWARGS):
322		# load prompts
323		with open(model_path / "prompts.jsonl", "r") as f:
324			prompts: list[dict] = [json.loads(line) for line in f.readlines()]
325		# truncate to n_samples
326		prompts = prompts[:n_samples]
327
328	print(f"{len(prompts)} prompts loaded")
329
330	figure_funcs: list[AttentionMatrixFigureFunc] = select_attn_figure_funcs(
331		figure_funcs_select=figure_funcs_select,
332	)
333	print(f"{len(figure_funcs)} figure functions loaded")
334	print(
335		"\t"
336		+ ", ".join([getattr(func, "__name__", "<unknown>") for func in figure_funcs]),
337	)
338
339	chunksize: int = int(
340		max(
341			1,
342			len(prompts) // (5 * multiprocessing.cpu_count()),
343		),
344	)
345	print(f"chunksize: {chunksize}")
346
347	list(
348		run_maybe_parallel(
349			func=functools.partial(
350				process_prompt,
351				model_cfg=model_cfg,
352				save_path=save_path_p,
353				figure_funcs=figure_funcs,
354				force_overwrite=force,
355			),
356			iterable=prompts,
357			parallel=parallel,
358			chunksize=chunksize,
359			pbar="tqdm",
360			pbar_kwargs=dict(
361				desc="Making figures",
362				unit="prompt",
363			),
364		),
365	)
366
367	with SpinnerContext(
368		message="updating jsonl metadata for models and functions",
369		**SPINNER_KWARGS,
370	):
371		generate_models_jsonl(save_path_p)
372		generate_functions_jsonl(save_path_p)

main function for generating figures from attention patterns, using the functions in ATTENTION_MATRIX_FIGURE_FUNCS

Parameters:

  • model_name : str model name to use, used for loading the model config, prompts, activations, and saving the figures
  • save_path : str | Path base path to look in
  • n_samples : int max number of samples to process
  • force : bool force overwrite of existing figures. if False, will skip any functions which have already saved a figure
  • figure_funcs_select : set[str]|str|None figure functions to use. if None, will use all functions. if a string, will use the function names which match the string. if a set, will use the function names in the set (defaults to None)
  • parallel : bool | int whether to run in parallel. if True, will use all available cores. if False, will run in serial. if an int, will try to use that many cores (defaults to True)
def main() -> None:
445def main() -> None:
446	"generates figures from the activations using the functions decorated with `register_attn_figure_func`"
447	# parse args
448	print(DIVIDER_S1)
449	args: argparse.Namespace
450	models: list[str]
451	figure_funcs_select: set[str] | str | None
452	with SpinnerContext(message="parsing args", **SPINNER_KWARGS):
453		args, models, figure_funcs_select = _parse_args()
454	print(f"\targs parsed: '{args}'")
455	print(f"\tmodels: '{models}'")
456	print(f"\tfigure_funcs_select: '{figure_funcs_select}'")
457
458	# compute for each model
459	n_models: int = len(models)
460	for idx, model in enumerate(models):
461		print(DIVIDER_S2)
462		print(f"processing model {idx + 1} / {n_models}: {model}")
463		print(DIVIDER_S2)
464		figures_main(
465			model_name=model,
466			save_path=args.save_path,
467			n_samples=args.n_samples,
468			force=args.force,
469			figure_funcs_select=figure_funcs_select,
470		)
471
472	print(DIVIDER_S1)

generates figures from the activations using the functions decorated with register_attn_figure_func