docs for pattern_lens v0.4.0
View Source on GitHub

pattern_lens.figures

code for generating figures from attention patterns, using the functions decorated with register_attn_figure_func


  1"""code for generating figures from attention patterns, using the functions decorated with `register_attn_figure_func`"""
  2
  3import argparse
  4import fnmatch
  5import functools
  6import itertools
  7import json
  8import multiprocessing
  9import re
 10import warnings
 11from collections import defaultdict
 12from pathlib import Path
 13
 14import numpy as np
 15from jaxtyping import Float
 16
 17# custom utils
 18from muutils.json_serialize import json_serialize
 19from muutils.parallel import run_maybe_parallel
 20from muutils.spinner import SpinnerContext
 21
 22# pattern_lens
 23from pattern_lens.attn_figure_funcs import ATTENTION_MATRIX_FIGURE_FUNCS
 24from pattern_lens.consts import (
 25	DATA_DIR,
 26	DIVIDER_S1,
 27	DIVIDER_S2,
 28	SPINNER_KWARGS,
 29	ActivationCacheNp,
 30	AttentionMatrix,
 31)
 32from pattern_lens.figure_util import AttentionMatrixFigureFunc
 33from pattern_lens.indexes import (
 34	generate_functions_jsonl,
 35	generate_models_jsonl,
 36	generate_prompts_jsonl,
 37)
 38from pattern_lens.load_activations import load_activations
 39
 40
 41class HTConfigMock:
 42	"""Mock of `transformer_lens.HookedTransformerConfig` for type hinting and loading config json
 43
 44	can be initialized with any kwargs, and will update its `__dict__` with them. does, however, require the following attributes:
 45	- `n_layers: int`
 46	- `n_heads: int`
 47	- `model_name: str`
 48
 49	we do this to avoid having to import `torch` and `transformer_lens`, since this would have to be done for each process in the parallelization and probably slows things down significantly
 50	"""
 51
 52	def __init__(self, **kwargs: dict[str, str | int]) -> None:
 53		"will pass all kwargs to `__dict__`"
 54		self.n_layers: int
 55		self.n_heads: int
 56		self.model_name: str
 57		self.__dict__.update(kwargs)
 58
 59	def serialize(self) -> dict:
 60		"""serialize the config to json. values which aren't serializable will be converted via `muutils.json_serialize.json_serialize`"""
 61		# its fine, we know its a dict
 62		return json_serialize(self.__dict__)  # type: ignore[return-value]
 63
 64	@classmethod
 65	def load(cls, data: dict) -> "HTConfigMock":
 66		"try to load a config from a dict, using the `__init__` method"
 67		return cls(**data)
 68
 69
 70def process_single_head(
 71	layer_idx: int,
 72	head_idx: int,
 73	attn_pattern: AttentionMatrix,
 74	save_dir: Path,
 75	figure_funcs: list[AttentionMatrixFigureFunc],
 76	force_overwrite: bool = False,
 77) -> dict[str, bool | Exception]:
 78	"""process a single head's attention pattern, running all the functions in `figure_funcs` on the attention pattern
 79
 80	> [gotcha:] if `force_overwrite` is `False`, and we used a multi-figure function,
 81	> it will skip all figures for that function if any are already saved
 82	> and it assumes a format of `{func_name}.{figure_name}.{fmt}` for the saved figures
 83
 84	# Parameters:
 85	- `layer_idx : int`
 86	- `head_idx : int`
 87	- `attn_pattern : AttentionMatrix`
 88		attention pattern for the head
 89	- `save_dir : Path`
 90		directory to save the figures to
 91	- `force_overwrite : bool`
 92		whether to overwrite existing figures. if `False`, will skip any functions which have already saved a figure
 93		(defaults to `False`)
 94
 95	# Returns:
 96	- `dict[str, bool | Exception]`
 97		a dictionary of the status of each function, with the function name as the key and the status as the value
 98	"""
 99	funcs_status: dict[str, bool | Exception] = dict()
100
101	for func in figure_funcs:
102		func_name: str = func.__name__
103		fig_path: list[Path] = list(save_dir.glob(f"{func_name}.*"))
104
105		if not force_overwrite and len(fig_path) > 0:
106			funcs_status[func_name] = True
107			continue
108
109		try:
110			func(attn_pattern, save_dir)
111			funcs_status[func_name] = True
112
113		# bling catch any exception
114		except Exception as e:  # noqa: BLE001
115			error_file = save_dir / f"{func.__name__}.error.txt"
116			error_file.write_text(str(e))
117			warnings.warn(
118				f"Error in {func.__name__} for L{layer_idx}H{head_idx}: {e!s}",
119				stacklevel=2,
120			)
121			funcs_status[func_name] = e
122
123	return funcs_status
124
125
126def compute_and_save_figures(
127	model_cfg: "HookedTransformerConfig|HTConfigMock",  # type: ignore[name-defined] # noqa: F821
128	activations_path: Path,
129	cache: ActivationCacheNp | Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"],
130	figure_funcs: list[AttentionMatrixFigureFunc],
131	save_path: Path = Path(DATA_DIR),
132	force_overwrite: bool = False,
133	track_results: bool = False,
134) -> None:
135	"""compute and save figures for all heads in the model, using the functions in `ATTENTION_MATRIX_FIGURE_FUNCS`
136
137	# Parameters:
138	- `model_cfg : HookedTransformerConfig|HTConfigMock`
139		configuration of the model, used for loading the activations
140	- `cache : ActivationCacheNp | Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"]`
141		activation cache containing actual patterns for the prompt we are processing
142	- `figure_funcs : list[AttentionMatrixFigureFunc]`
143		list of functions to run
144	- `save_path : Path`
145		directory to save the figures to
146		(defaults to `Path(DATA_DIR)`)
147	- `force_overwrite : bool`
148		force overwrite of existing figures. if `False`, will skip any functions which have already saved a figure
149		(defaults to `False`)
150	- `track_results : bool`
151		whether to track the results of each function for each head. Isn't used for anything yet, but this is a TODO
152		(defaults to `False`)
153	"""
154	prompt_dir: Path = activations_path.parent
155
156	if track_results:
157		results: defaultdict[
158			str,  # func name
159			dict[
160				tuple[int, int],  # layer, head
161				bool | Exception,  # success or exception
162			],
163		] = defaultdict(dict)
164
165	for layer_idx, head_idx in itertools.product(
166		range(model_cfg.n_layers),
167		range(model_cfg.n_heads),
168	):
169		attn_pattern: AttentionMatrix
170		if isinstance(cache, dict):
171			attn_pattern = cache[f"blocks.{layer_idx}.attn.hook_pattern"][0, head_idx]
172		elif isinstance(cache, np.ndarray):
173			attn_pattern = cache[layer_idx, head_idx]
174		else:
175			msg = (
176				f"cache must be a dict or np.ndarray, not {type(cache) = }\n{cache = }"
177			)
178			raise TypeError(
179				msg,
180			)
181
182		save_dir: Path = prompt_dir / f"L{layer_idx}" / f"H{head_idx}"
183		save_dir.mkdir(parents=True, exist_ok=True)
184		head_res: dict[str, bool | Exception] = process_single_head(
185			layer_idx=layer_idx,
186			head_idx=head_idx,
187			attn_pattern=attn_pattern,
188			save_dir=save_dir,
189			force_overwrite=force_overwrite,
190			figure_funcs=figure_funcs,
191		)
192
193		if track_results:
194			for func_name, status in head_res.items():
195				results[func_name][(layer_idx, head_idx)] = status
196
197	# TODO: do something with results
198
199	generate_prompts_jsonl(save_path / model_cfg.model_name)
200
201
202def process_prompt(
203	prompt: dict,
204	model_cfg: "HookedTransformerConfig|HTConfigMock",  # type: ignore[name-defined] # noqa: F821
205	save_path: Path,
206	figure_funcs: list[AttentionMatrixFigureFunc],
207	force_overwrite: bool = False,
208) -> None:
209	"""process a single prompt, loading the activations and computing and saving the figures
210
211	basically just calls `load_activations` and then `compute_and_save_figures`
212
213	# Parameters:
214	- `prompt : dict`
215		prompt to process, should be a dict with the following keys:
216		- `"text"`: the prompt string
217		- `"hash"`: the hash of the prompt
218	- `model_cfg : HookedTransformerConfig|HTConfigMock`
219		configuration of the model, used for figuring out where to save
220	- `save_path : Path`
221		directory to save the figures to
222	- `figure_funcs : list[AttentionMatrixFigureFunc]`
223		list of functions to run
224	- `force_overwrite : bool`
225		(defaults to `False`)
226	"""
227	# load the activations
228	activations_path: Path
229	cache: ActivationCacheNp | Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"]
230	activations_path, cache = load_activations(
231		model_name=model_cfg.model_name,
232		prompt=prompt,
233		save_path=save_path,
234		return_fmt="numpy",
235	)
236
237	# compute and save the figures
238	compute_and_save_figures(
239		model_cfg=model_cfg,
240		activations_path=activations_path,
241		cache=cache,
242		figure_funcs=figure_funcs,
243		save_path=save_path,
244		force_overwrite=force_overwrite,
245	)
246
247
248def select_attn_figure_funcs(
249	figure_funcs_select: set[str] | str | None = None,
250) -> list[AttentionMatrixFigureFunc]:
251	"""given a selector, figure out which functions from `ATTENTION_MATRIX_FIGURE_FUNCS` to use
252
253	- if arg is `None`, will use all functions
254	- if a string, will use the function names which match the string (glob/fnmatch syntax)
255	- if a set, will use functions whose names are in the set
256
257	"""
258	# figure out which functions to use
259	figure_funcs: list[AttentionMatrixFigureFunc]
260	if figure_funcs_select is None:
261		# all if nothing specified
262		figure_funcs = ATTENTION_MATRIX_FIGURE_FUNCS
263	elif isinstance(figure_funcs_select, str):
264		# if a string, assume a glob pattern
265		pattern: re.Pattern = re.compile(fnmatch.translate(figure_funcs_select))
266		figure_funcs = [
267			func
268			for func in ATTENTION_MATRIX_FIGURE_FUNCS
269			if pattern.match(func.__name__)
270		]
271	elif isinstance(figure_funcs_select, set):
272		# if a set, assume a set of function names
273		figure_funcs = [
274			func
275			for func in ATTENTION_MATRIX_FIGURE_FUNCS
276			if func.__name__ in figure_funcs_select
277		]
278	else:
279		err_msg: str = (
280			f"figure_funcs_select must be None, str, or set, not {type(figure_funcs_select) = }"
281			f"\n{figure_funcs_select = }"
282		)
283		raise TypeError(err_msg)
284	return figure_funcs
285
286
287def figures_main(
288	model_name: str,
289	save_path: str,
290	n_samples: int,
291	force: bool,
292	figure_funcs_select: set[str] | str | None = None,
293	parallel: bool | int = True,
294) -> None:
295	"""main function for generating figures from attention patterns, using the functions in `ATTENTION_MATRIX_FIGURE_FUNCS`
296
297	# Parameters:
298	- `model_name : str`
299		model name to use, used for loading the model config, prompts, activations, and saving the figures
300	- `save_path : str`
301		base path to look in
302	- `n_samples : int`
303		max number of samples to process
304	- `force : bool`
305		force overwrite of existing figures. if `False`, will skip any functions which have already saved a figure
306	- `figure_funcs_select : set[str]|str|None`
307		figure functions to use. if `None`, will use all functions. if a string, will use the function names which match the string. if a set, will use the function names in the set
308		(defaults to `None`)
309	- `parallel : bool | int`
310		whether to run in parallel. if `True`, will use all available cores. if `False`, will run in serial. if an int, will try to use that many cores
311		(defaults to `True`)
312	"""
313	with SpinnerContext(message="setting up paths", **SPINNER_KWARGS):
314		# save model info or check if it exists
315		save_path_p: Path = Path(save_path)
316		model_path: Path = save_path_p / model_name
317		with open(model_path / "model_cfg.json", "r") as f:
318			model_cfg = HTConfigMock.load(json.load(f))
319
320	with SpinnerContext(message="loading prompts", **SPINNER_KWARGS):
321		# load prompts
322		with open(model_path / "prompts.jsonl", "r") as f:
323			prompts: list[dict] = [json.loads(line) for line in f.readlines()]
324		# truncate to n_samples
325		prompts = prompts[:n_samples]
326
327	print(f"{len(prompts)} prompts loaded")
328
329	figure_funcs: list[AttentionMatrixFigureFunc] = select_attn_figure_funcs(
330		figure_funcs_select=figure_funcs_select,
331	)
332	print(f"{len(figure_funcs)} figure functions loaded")
333	print("\t" + ", ".join([func.__name__ for func in figure_funcs]))
334
335	chunksize: int = int(
336		max(
337			1,
338			len(prompts) // (5 * multiprocessing.cpu_count()),
339		),
340	)
341	print(f"chunksize: {chunksize}")
342
343	list(
344		run_maybe_parallel(
345			func=functools.partial(
346				process_prompt,
347				model_cfg=model_cfg,
348				save_path=save_path_p,
349				figure_funcs=figure_funcs,
350				force_overwrite=force,
351			),
352			iterable=prompts,
353			parallel=parallel,
354			chunksize=chunksize,
355			pbar="tqdm",
356			pbar_kwargs=dict(
357				desc="Making figures",
358				unit="prompt",
359			),
360		),
361	)
362
363	with SpinnerContext(
364		message="updating jsonl metadata for models and functions",
365		**SPINNER_KWARGS,
366	):
367		generate_models_jsonl(save_path_p)
368		generate_functions_jsonl(save_path_p)
369
370
371def _parse_args() -> tuple[
372	argparse.Namespace,
373	list[str],  # models
374	set[str] | str | None,  # figure_funcs_select
375]:
376	arg_parser: argparse.ArgumentParser = argparse.ArgumentParser()
377	# input and output
378	arg_parser.add_argument(
379		"--model",
380		"-m",
381		type=str,
382		required=True,
383		help="The model name(s) to use. comma separated with no whitespace if multiple",
384	)
385	arg_parser.add_argument(
386		"--save-path",
387		"-s",
388		type=str,
389		required=False,
390		help="The path to save the attention patterns",
391		default=DATA_DIR,
392	)
393	# number of samples
394	arg_parser.add_argument(
395		"--n-samples",
396		"-n",
397		type=int,
398		required=False,
399		help="The max number of samples to process, do all in the file if None",
400		default=None,
401	)
402	# force overwrite of existing figures
403	arg_parser.add_argument(
404		"--force",
405		"-f",
406		type=bool,
407		required=False,
408		help="Force overwrite of existing figures",
409		default=False,
410	)
411	# figure functions
412	arg_parser.add_argument(
413		"--figure-funcs",
414		type=str,
415		required=False,
416		help="The figure functions to use. if 'None' (default), will use all functions. if a string, will use the function names which match the string. if a comma-separated list of strings, will use the function names in the set",
417		default=None,
418	)
419
420	args: argparse.Namespace = arg_parser.parse_args()
421
422	# figure out models
423	models: list[str]
424	if "," in args.model:
425		models = args.model.split(",")
426	else:
427		models = [args.model]
428
429	# figure out figures
430	figure_funcs_select: set[str] | str | None
431	if (args.figure_funcs is None) or (args.figure_funcs.lower().strip() == "none"):
432		figure_funcs_select = None
433	elif "," in args.figure_funcs:
434		figure_funcs_select = {x.strip() for x in args.figure_funcs.split(",")}
435	else:
436		figure_funcs_select = args.figure_funcs.strip()
437
438	return args, models, figure_funcs_select
439
440
441def main() -> None:
442	"generates figures from the activations using the functions decorated with `register_attn_figure_func`"
443	# parse args
444	print(DIVIDER_S1)
445	args: argparse.Namespace
446	models: list[str]
447	figure_funcs_select: set[str] | str | None
448	with SpinnerContext(message="parsing args", **SPINNER_KWARGS):
449		args, models, figure_funcs_select = _parse_args()
450	print(f"\targs parsed: '{args}'")
451	print(f"\tmodels: '{models}'")
452	print(f"\tfigure_funcs_select: '{figure_funcs_select}'")
453
454	# compute for each model
455	n_models: int = len(models)
456	for idx, model in enumerate(models):
457		print(DIVIDER_S2)
458		print(f"processing model {idx + 1} / {n_models}: {model}")
459		print(DIVIDER_S2)
460		figures_main(
461			model_name=model,
462			save_path=args.save_path,
463			n_samples=args.n_samples,
464			force=args.force,
465			figure_funcs_select=figure_funcs_select,
466		)
467
468	print(DIVIDER_S1)
469
470
471if __name__ == "__main__":
472	main()

class HTConfigMock:
42class HTConfigMock:
43	"""Mock of `transformer_lens.HookedTransformerConfig` for type hinting and loading config json
44
45	can be initialized with any kwargs, and will update its `__dict__` with them. does, however, require the following attributes:
46	- `n_layers: int`
47	- `n_heads: int`
48	- `model_name: str`
49
50	we do this to avoid having to import `torch` and `transformer_lens`, since this would have to be done for each process in the parallelization and probably slows things down significantly
51	"""
52
53	def __init__(self, **kwargs: dict[str, str | int]) -> None:
54		"will pass all kwargs to `__dict__`"
55		self.n_layers: int
56		self.n_heads: int
57		self.model_name: str
58		self.__dict__.update(kwargs)
59
60	def serialize(self) -> dict:
61		"""serialize the config to json. values which aren't serializable will be converted via `muutils.json_serialize.json_serialize`"""
62		# its fine, we know its a dict
63		return json_serialize(self.__dict__)  # type: ignore[return-value]
64
65	@classmethod
66	def load(cls, data: dict) -> "HTConfigMock":
67		"try to load a config from a dict, using the `__init__` method"
68		return cls(**data)

Mock of transformer_lens.HookedTransformerConfig for type hinting and loading config json

can be initialized with any kwargs, and will update its __dict__ with them. does, however, require the following attributes:

  • n_layers: int
  • n_heads: int
  • model_name: str

we do this to avoid having to import torch and transformer_lens, since this would have to be done for each process in the parallelization and probably slows things down significantly

HTConfigMock(**kwargs: dict[str, str | int])
53	def __init__(self, **kwargs: dict[str, str | int]) -> None:
54		"will pass all kwargs to `__dict__`"
55		self.n_layers: int
56		self.n_heads: int
57		self.model_name: str
58		self.__dict__.update(kwargs)

will pass all kwargs to __dict__

n_layers: int
n_heads: int
model_name: str
def serialize(self) -> dict:
60	def serialize(self) -> dict:
61		"""serialize the config to json. values which aren't serializable will be converted via `muutils.json_serialize.json_serialize`"""
62		# its fine, we know its a dict
63		return json_serialize(self.__dict__)  # type: ignore[return-value]

serialize the config to json. values which aren't serializable will be converted via muutils.json_serialize.json_serialize

@classmethod
def load(cls, data: dict) -> HTConfigMock:
65	@classmethod
66	def load(cls, data: dict) -> "HTConfigMock":
67		"try to load a config from a dict, using the `__init__` method"
68		return cls(**data)

try to load a config from a dict, using the __init__ method

def process_single_head( layer_idx: int, head_idx: int, attn_pattern: jaxtyping.Float[ndarray, 'n_ctx n_ctx'], save_dir: pathlib.Path, figure_funcs: list[Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], pathlib.Path], None]], force_overwrite: bool = False) -> dict[str, bool | Exception]:
 71def process_single_head(
 72	layer_idx: int,
 73	head_idx: int,
 74	attn_pattern: AttentionMatrix,
 75	save_dir: Path,
 76	figure_funcs: list[AttentionMatrixFigureFunc],
 77	force_overwrite: bool = False,
 78) -> dict[str, bool | Exception]:
 79	"""process a single head's attention pattern, running all the functions in `figure_funcs` on the attention pattern
 80
 81	> [gotcha:] if `force_overwrite` is `False`, and we used a multi-figure function,
 82	> it will skip all figures for that function if any are already saved
 83	> and it assumes a format of `{func_name}.{figure_name}.{fmt}` for the saved figures
 84
 85	# Parameters:
 86	- `layer_idx : int`
 87	- `head_idx : int`
 88	- `attn_pattern : AttentionMatrix`
 89		attention pattern for the head
 90	- `save_dir : Path`
 91		directory to save the figures to
 92	- `force_overwrite : bool`
 93		whether to overwrite existing figures. if `False`, will skip any functions which have already saved a figure
 94		(defaults to `False`)
 95
 96	# Returns:
 97	- `dict[str, bool | Exception]`
 98		a dictionary of the status of each function, with the function name as the key and the status as the value
 99	"""
100	funcs_status: dict[str, bool | Exception] = dict()
101
102	for func in figure_funcs:
103		func_name: str = func.__name__
104		fig_path: list[Path] = list(save_dir.glob(f"{func_name}.*"))
105
106		if not force_overwrite and len(fig_path) > 0:
107			funcs_status[func_name] = True
108			continue
109
110		try:
111			func(attn_pattern, save_dir)
112			funcs_status[func_name] = True
113
114		# bling catch any exception
115		except Exception as e:  # noqa: BLE001
116			error_file = save_dir / f"{func.__name__}.error.txt"
117			error_file.write_text(str(e))
118			warnings.warn(
119				f"Error in {func.__name__} for L{layer_idx}H{head_idx}: {e!s}",
120				stacklevel=2,
121			)
122			funcs_status[func_name] = e
123
124	return funcs_status

process a single head's attention pattern, running all the functions in figure_funcs on the attention pattern

[gotcha:] if force_overwrite is False, and we used a multi-figure function, it will skip all figures for that function if any are already saved and it assumes a format of {func_name}.{figure_name}.{fmt} for the saved figures

Parameters:

  • layer_idx : int
  • head_idx : int
  • attn_pattern : AttentionMatrix attention pattern for the head
  • save_dir : Path directory to save the figures to
  • force_overwrite : bool whether to overwrite existing figures. if False, will skip any functions which have already saved a figure (defaults to False)

Returns:

  • dict[str, bool | Exception] a dictionary of the status of each function, with the function name as the key and the status as the value
def compute_and_save_figures( model_cfg: 'HookedTransformerConfig|HTConfigMock', activations_path: pathlib.Path, cache: dict[str, numpy.ndarray] | jaxtyping.Float[ndarray, 'n_layers n_heads n_ctx n_ctx'], figure_funcs: list[Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], pathlib.Path], None]], save_path: pathlib.Path = PosixPath('attn_data'), force_overwrite: bool = False, track_results: bool = False) -> None:
127def compute_and_save_figures(
128	model_cfg: "HookedTransformerConfig|HTConfigMock",  # type: ignore[name-defined] # noqa: F821
129	activations_path: Path,
130	cache: ActivationCacheNp | Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"],
131	figure_funcs: list[AttentionMatrixFigureFunc],
132	save_path: Path = Path(DATA_DIR),
133	force_overwrite: bool = False,
134	track_results: bool = False,
135) -> None:
136	"""compute and save figures for all heads in the model, using the functions in `ATTENTION_MATRIX_FIGURE_FUNCS`
137
138	# Parameters:
139	- `model_cfg : HookedTransformerConfig|HTConfigMock`
140		configuration of the model, used for loading the activations
141	- `cache : ActivationCacheNp | Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"]`
142		activation cache containing actual patterns for the prompt we are processing
143	- `figure_funcs : list[AttentionMatrixFigureFunc]`
144		list of functions to run
145	- `save_path : Path`
146		directory to save the figures to
147		(defaults to `Path(DATA_DIR)`)
148	- `force_overwrite : bool`
149		force overwrite of existing figures. if `False`, will skip any functions which have already saved a figure
150		(defaults to `False`)
151	- `track_results : bool`
152		whether to track the results of each function for each head. Isn't used for anything yet, but this is a TODO
153		(defaults to `False`)
154	"""
155	prompt_dir: Path = activations_path.parent
156
157	if track_results:
158		results: defaultdict[
159			str,  # func name
160			dict[
161				tuple[int, int],  # layer, head
162				bool | Exception,  # success or exception
163			],
164		] = defaultdict(dict)
165
166	for layer_idx, head_idx in itertools.product(
167		range(model_cfg.n_layers),
168		range(model_cfg.n_heads),
169	):
170		attn_pattern: AttentionMatrix
171		if isinstance(cache, dict):
172			attn_pattern = cache[f"blocks.{layer_idx}.attn.hook_pattern"][0, head_idx]
173		elif isinstance(cache, np.ndarray):
174			attn_pattern = cache[layer_idx, head_idx]
175		else:
176			msg = (
177				f"cache must be a dict or np.ndarray, not {type(cache) = }\n{cache = }"
178			)
179			raise TypeError(
180				msg,
181			)
182
183		save_dir: Path = prompt_dir / f"L{layer_idx}" / f"H{head_idx}"
184		save_dir.mkdir(parents=True, exist_ok=True)
185		head_res: dict[str, bool | Exception] = process_single_head(
186			layer_idx=layer_idx,
187			head_idx=head_idx,
188			attn_pattern=attn_pattern,
189			save_dir=save_dir,
190			force_overwrite=force_overwrite,
191			figure_funcs=figure_funcs,
192		)
193
194		if track_results:
195			for func_name, status in head_res.items():
196				results[func_name][(layer_idx, head_idx)] = status
197
198	# TODO: do something with results
199
200	generate_prompts_jsonl(save_path / model_cfg.model_name)

compute and save figures for all heads in the model, using the functions in ATTENTION_MATRIX_FIGURE_FUNCS

Parameters:

  • model_cfg : HookedTransformerConfig|HTConfigMock configuration of the model, used for loading the activations
  • cache : ActivationCacheNp | Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"] activation cache containing actual patterns for the prompt we are processing
  • figure_funcs : list[AttentionMatrixFigureFunc] list of functions to run
  • save_path : Path directory to save the figures to (defaults to Path(DATA_DIR))
  • force_overwrite : bool force overwrite of existing figures. if False, will skip any functions which have already saved a figure (defaults to False)
  • track_results : bool whether to track the results of each function for each head. Isn't used for anything yet, but this is a TODO (defaults to False)
def process_prompt( prompt: dict, model_cfg: 'HookedTransformerConfig|HTConfigMock', save_path: pathlib.Path, figure_funcs: list[Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], pathlib.Path], None]], force_overwrite: bool = False) -> None:
203def process_prompt(
204	prompt: dict,
205	model_cfg: "HookedTransformerConfig|HTConfigMock",  # type: ignore[name-defined] # noqa: F821
206	save_path: Path,
207	figure_funcs: list[AttentionMatrixFigureFunc],
208	force_overwrite: bool = False,
209) -> None:
210	"""process a single prompt, loading the activations and computing and saving the figures
211
212	basically just calls `load_activations` and then `compute_and_save_figures`
213
214	# Parameters:
215	- `prompt : dict`
216		prompt to process, should be a dict with the following keys:
217		- `"text"`: the prompt string
218		- `"hash"`: the hash of the prompt
219	- `model_cfg : HookedTransformerConfig|HTConfigMock`
220		configuration of the model, used for figuring out where to save
221	- `save_path : Path`
222		directory to save the figures to
223	- `figure_funcs : list[AttentionMatrixFigureFunc]`
224		list of functions to run
225	- `force_overwrite : bool`
226		(defaults to `False`)
227	"""
228	# load the activations
229	activations_path: Path
230	cache: ActivationCacheNp | Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"]
231	activations_path, cache = load_activations(
232		model_name=model_cfg.model_name,
233		prompt=prompt,
234		save_path=save_path,
235		return_fmt="numpy",
236	)
237
238	# compute and save the figures
239	compute_and_save_figures(
240		model_cfg=model_cfg,
241		activations_path=activations_path,
242		cache=cache,
243		figure_funcs=figure_funcs,
244		save_path=save_path,
245		force_overwrite=force_overwrite,
246	)

process a single prompt, loading the activations and computing and saving the figures

basically just calls load_activations and then compute_and_save_figures

Parameters:

  • prompt : dict prompt to process, should be a dict with the following keys:
    • "text": the prompt string
    • "hash": the hash of the prompt
  • model_cfg : HookedTransformerConfig|HTConfigMock configuration of the model, used for figuring out where to save
  • save_path : Path directory to save the figures to
  • figure_funcs : list[AttentionMatrixFigureFunc] list of functions to run
  • force_overwrite : bool (defaults to False)
def select_attn_figure_funcs( figure_funcs_select: set[str] | str | None = None) -> list[Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], pathlib.Path], None]]:
249def select_attn_figure_funcs(
250	figure_funcs_select: set[str] | str | None = None,
251) -> list[AttentionMatrixFigureFunc]:
252	"""given a selector, figure out which functions from `ATTENTION_MATRIX_FIGURE_FUNCS` to use
253
254	- if arg is `None`, will use all functions
255	- if a string, will use the function names which match the string (glob/fnmatch syntax)
256	- if a set, will use functions whose names are in the set
257
258	"""
259	# figure out which functions to use
260	figure_funcs: list[AttentionMatrixFigureFunc]
261	if figure_funcs_select is None:
262		# all if nothing specified
263		figure_funcs = ATTENTION_MATRIX_FIGURE_FUNCS
264	elif isinstance(figure_funcs_select, str):
265		# if a string, assume a glob pattern
266		pattern: re.Pattern = re.compile(fnmatch.translate(figure_funcs_select))
267		figure_funcs = [
268			func
269			for func in ATTENTION_MATRIX_FIGURE_FUNCS
270			if pattern.match(func.__name__)
271		]
272	elif isinstance(figure_funcs_select, set):
273		# if a set, assume a set of function names
274		figure_funcs = [
275			func
276			for func in ATTENTION_MATRIX_FIGURE_FUNCS
277			if func.__name__ in figure_funcs_select
278		]
279	else:
280		err_msg: str = (
281			f"figure_funcs_select must be None, str, or set, not {type(figure_funcs_select) = }"
282			f"\n{figure_funcs_select = }"
283		)
284		raise TypeError(err_msg)
285	return figure_funcs

given a selector, figure out which functions from ATTENTION_MATRIX_FIGURE_FUNCS to use

  • if arg is None, will use all functions
  • if a string, will use the function names which match the string (glob/fnmatch syntax)
  • if a set, will use functions whose names are in the set
def figures_main( model_name: str, save_path: str, n_samples: int, force: bool, figure_funcs_select: set[str] | str | None = None, parallel: bool | int = True) -> None:
288def figures_main(
289	model_name: str,
290	save_path: str,
291	n_samples: int,
292	force: bool,
293	figure_funcs_select: set[str] | str | None = None,
294	parallel: bool | int = True,
295) -> None:
296	"""main function for generating figures from attention patterns, using the functions in `ATTENTION_MATRIX_FIGURE_FUNCS`
297
298	# Parameters:
299	- `model_name : str`
300		model name to use, used for loading the model config, prompts, activations, and saving the figures
301	- `save_path : str`
302		base path to look in
303	- `n_samples : int`
304		max number of samples to process
305	- `force : bool`
306		force overwrite of existing figures. if `False`, will skip any functions which have already saved a figure
307	- `figure_funcs_select : set[str]|str|None`
308		figure functions to use. if `None`, will use all functions. if a string, will use the function names which match the string. if a set, will use the function names in the set
309		(defaults to `None`)
310	- `parallel : bool | int`
311		whether to run in parallel. if `True`, will use all available cores. if `False`, will run in serial. if an int, will try to use that many cores
312		(defaults to `True`)
313	"""
314	with SpinnerContext(message="setting up paths", **SPINNER_KWARGS):
315		# save model info or check if it exists
316		save_path_p: Path = Path(save_path)
317		model_path: Path = save_path_p / model_name
318		with open(model_path / "model_cfg.json", "r") as f:
319			model_cfg = HTConfigMock.load(json.load(f))
320
321	with SpinnerContext(message="loading prompts", **SPINNER_KWARGS):
322		# load prompts
323		with open(model_path / "prompts.jsonl", "r") as f:
324			prompts: list[dict] = [json.loads(line) for line in f.readlines()]
325		# truncate to n_samples
326		prompts = prompts[:n_samples]
327
328	print(f"{len(prompts)} prompts loaded")
329
330	figure_funcs: list[AttentionMatrixFigureFunc] = select_attn_figure_funcs(
331		figure_funcs_select=figure_funcs_select,
332	)
333	print(f"{len(figure_funcs)} figure functions loaded")
334	print("\t" + ", ".join([func.__name__ for func in figure_funcs]))
335
336	chunksize: int = int(
337		max(
338			1,
339			len(prompts) // (5 * multiprocessing.cpu_count()),
340		),
341	)
342	print(f"chunksize: {chunksize}")
343
344	list(
345		run_maybe_parallel(
346			func=functools.partial(
347				process_prompt,
348				model_cfg=model_cfg,
349				save_path=save_path_p,
350				figure_funcs=figure_funcs,
351				force_overwrite=force,
352			),
353			iterable=prompts,
354			parallel=parallel,
355			chunksize=chunksize,
356			pbar="tqdm",
357			pbar_kwargs=dict(
358				desc="Making figures",
359				unit="prompt",
360			),
361		),
362	)
363
364	with SpinnerContext(
365		message="updating jsonl metadata for models and functions",
366		**SPINNER_KWARGS,
367	):
368		generate_models_jsonl(save_path_p)
369		generate_functions_jsonl(save_path_p)

main function for generating figures from attention patterns, using the functions in ATTENTION_MATRIX_FIGURE_FUNCS

Parameters:

  • model_name : str model name to use, used for loading the model config, prompts, activations, and saving the figures
  • save_path : str base path to look in
  • n_samples : int max number of samples to process
  • force : bool force overwrite of existing figures. if False, will skip any functions which have already saved a figure
  • figure_funcs_select : set[str]|str|None figure functions to use. if None, will use all functions. if a string, will use the function names which match the string. if a set, will use the function names in the set (defaults to None)
  • parallel : bool | int whether to run in parallel. if True, will use all available cores. if False, will run in serial. if an int, will try to use that many cores (defaults to True)
def main() -> None:
442def main() -> None:
443	"generates figures from the activations using the functions decorated with `register_attn_figure_func`"
444	# parse args
445	print(DIVIDER_S1)
446	args: argparse.Namespace
447	models: list[str]
448	figure_funcs_select: set[str] | str | None
449	with SpinnerContext(message="parsing args", **SPINNER_KWARGS):
450		args, models, figure_funcs_select = _parse_args()
451	print(f"\targs parsed: '{args}'")
452	print(f"\tmodels: '{models}'")
453	print(f"\tfigure_funcs_select: '{figure_funcs_select}'")
454
455	# compute for each model
456	n_models: int = len(models)
457	for idx, model in enumerate(models):
458		print(DIVIDER_S2)
459		print(f"processing model {idx + 1} / {n_models}: {model}")
460		print(DIVIDER_S2)
461		figures_main(
462			model_name=model,
463			save_path=args.save_path,
464			n_samples=args.n_samples,
465			force=args.force,
466			figure_funcs_select=figure_funcs_select,
467		)
468
469	print(DIVIDER_S1)

generates figures from the activations using the functions decorated with register_attn_figure_func