Coverage for pattern_lens / figures.py: 86%
133 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-22 18:15 -0700
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-22 18:15 -0700
1"""code for generating figures from attention patterns, using the functions decorated with `register_attn_figure_func`"""
3import argparse
4import fnmatch
5import functools
6import itertools
7import json
8import multiprocessing
9import re
10import warnings
11from collections import defaultdict
12from pathlib import Path
14import numpy as np
15from jaxtyping import Float
17# custom utils
18from muutils.json_serialize import json_serialize
19from muutils.parallel import run_maybe_parallel
20from muutils.spinner import SpinnerContext
22# pattern_lens
23from pattern_lens.attn_figure_funcs import ATTENTION_MATRIX_FIGURE_FUNCS
24from pattern_lens.consts import (
25 DATA_DIR,
26 DIVIDER_S1,
27 DIVIDER_S2,
28 SPINNER_KWARGS,
29 ActivationCacheNp,
30 AttentionMatrix,
31)
32from pattern_lens.figure_util import AttentionMatrixFigureFunc
33from pattern_lens.indexes import (
34 generate_functions_jsonl,
35 generate_models_jsonl,
36 generate_prompts_jsonl,
37)
38from pattern_lens.load_activations import load_activations
41class HTConfigMock:
42 """Mock of `transformer_lens.HookedTransformerConfig` for type hinting and loading config json
44 can be initialized with any kwargs, and will update its `__dict__` with them. does, however, require the following attributes:
45 - `n_layers: int`
46 - `n_heads: int`
47 - `model_name: str`
49 we do this to avoid having to import `torch` and `transformer_lens`, since this would have to be done for each process in the parallelization and probably slows things down significantly
50 """
52 def __init__(self, **kwargs: dict[str, str | int]) -> None:
53 "will pass all kwargs to `__dict__`"
54 self.n_layers: int
55 self.n_heads: int
56 self.model_name: str
57 self.__dict__.update(kwargs)
59 def serialize(self) -> dict:
60 """serialize the config to json. values which aren't serializable will be converted via `muutils.json_serialize.json_serialize`"""
61 # its fine, we know its a dict
62 return json_serialize(self.__dict__) # type: ignore[return-value]
64 @classmethod
65 def load(cls, data: dict) -> "HTConfigMock":
66 "try to load a config from a dict, using the `__init__` method"
67 return cls(**data)
70def process_single_head(
71 layer_idx: int,
72 head_idx: int,
73 attn_pattern: AttentionMatrix,
74 save_dir: Path,
75 figure_funcs: list[AttentionMatrixFigureFunc],
76 force_overwrite: bool = False,
77) -> dict[str, bool | Exception]:
78 """process a single head's attention pattern, running all the functions in `figure_funcs` on the attention pattern
80 > [gotcha:] if `force_overwrite` is `False`, and we used a multi-figure function,
81 > it will skip all figures for that function if any are already saved
82 > and it assumes a format of `{func_name}.{figure_name}.{fmt}` for the saved figures
84 # Parameters:
85 - `layer_idx : int`
86 - `head_idx : int`
87 - `attn_pattern : AttentionMatrix`
88 attention pattern for the head
89 - `save_dir : Path`
90 directory to save the figures to
91 - `force_overwrite : bool`
92 whether to overwrite existing figures. if `False`, will skip any functions which have already saved a figure
93 (defaults to `False`)
95 # Returns:
96 - `dict[str, bool | Exception]`
97 a dictionary of the status of each function, with the function name as the key and the status as the value
98 """
99 funcs_status: dict[str, bool | Exception] = dict()
101 for func in figure_funcs:
102 func_name: str = getattr(func, "__name__", "<unknown>")
103 fig_path: list[Path] = list(save_dir.glob(f"{func_name}.*"))
105 if not force_overwrite and len(fig_path) > 0:
106 funcs_status[func_name] = True
107 continue
109 try:
110 func(attn_pattern, save_dir)
111 funcs_status[func_name] = True
113 # bling catch any exception
114 except Exception as e: # noqa: BLE001
115 error_file = save_dir / f"{func_name}.error.txt"
116 error_file.write_text(str(e))
117 warnings.warn(
118 f"Error in {func_name} for L{layer_idx}H{head_idx}: {e!s}",
119 stacklevel=2,
120 )
121 funcs_status[func_name] = e
123 return funcs_status
126def compute_and_save_figures(
127 model_cfg: "HookedTransformerConfig|HTConfigMock", # type: ignore[name-defined] # noqa: F821
128 activations_path: Path,
129 cache: ActivationCacheNp | Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"],
130 figure_funcs: list[AttentionMatrixFigureFunc],
131 save_path: Path = Path(DATA_DIR),
132 force_overwrite: bool = False,
133 track_results: bool = False,
134) -> None:
135 """compute and save figures for all heads in the model, using the functions in `ATTENTION_MATRIX_FIGURE_FUNCS`
137 # Parameters:
138 - `model_cfg : HookedTransformerConfig|HTConfigMock`
139 configuration of the model, used for loading the activations
140 - `cache : ActivationCacheNp | Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"]`
141 activation cache containing actual patterns for the prompt we are processing
142 - `figure_funcs : list[AttentionMatrixFigureFunc]`
143 list of functions to run
144 - `save_path : Path`
145 directory to save the figures to
146 (defaults to `Path(DATA_DIR)`)
147 - `force_overwrite : bool`
148 force overwrite of existing figures. if `False`, will skip any functions which have already saved a figure
149 (defaults to `False`)
150 - `track_results : bool`
151 whether to track the results of each function for each head. Isn't used for anything yet, but this is a TODO
152 (defaults to `False`)
153 """
154 prompt_dir: Path = activations_path.parent
156 if track_results:
157 results: defaultdict[
158 str, # func name
159 dict[
160 tuple[int, int], # layer, head
161 bool | Exception, # success or exception
162 ],
163 ] = defaultdict(dict)
165 for layer_idx, head_idx in itertools.product(
166 range(model_cfg.n_layers),
167 range(model_cfg.n_heads),
168 ):
169 attn_pattern: AttentionMatrix
170 if isinstance(cache, dict):
171 attn_pattern = cache[f"blocks.{layer_idx}.attn.hook_pattern"][0, head_idx]
172 elif isinstance(cache, np.ndarray):
173 attn_pattern = cache[layer_idx, head_idx]
174 else:
175 msg = (
176 f"cache must be a dict or np.ndarray, not {type(cache) = }\n{cache = }"
177 )
178 raise TypeError(
179 msg,
180 )
182 save_dir: Path = prompt_dir / f"L{layer_idx}" / f"H{head_idx}"
183 save_dir.mkdir(parents=True, exist_ok=True)
184 head_res: dict[str, bool | Exception] = process_single_head(
185 layer_idx=layer_idx,
186 head_idx=head_idx,
187 attn_pattern=attn_pattern,
188 save_dir=save_dir,
189 force_overwrite=force_overwrite,
190 figure_funcs=figure_funcs,
191 )
193 if track_results:
194 for func_name, status in head_res.items():
195 results[func_name][(layer_idx, head_idx)] = status
197 # TODO: do something with results
199 generate_prompts_jsonl(save_path / model_cfg.model_name)
202def process_prompt(
203 prompt: dict,
204 model_cfg: "HookedTransformerConfig|HTConfigMock", # type: ignore[name-defined] # noqa: F821
205 save_path: Path,
206 figure_funcs: list[AttentionMatrixFigureFunc],
207 force_overwrite: bool = False,
208) -> None:
209 """process a single prompt, loading the activations and computing and saving the figures
211 basically just calls `load_activations` and then `compute_and_save_figures`
213 # Parameters:
214 - `prompt : dict`
215 prompt to process, should be a dict with the following keys:
216 - `"text"`: the prompt string
217 - `"hash"`: the hash of the prompt
218 - `model_cfg : HookedTransformerConfig|HTConfigMock`
219 configuration of the model, used for figuring out where to save
220 - `save_path : Path`
221 directory to save the figures to
222 - `figure_funcs : list[AttentionMatrixFigureFunc]`
223 list of functions to run
224 - `force_overwrite : bool`
225 (defaults to `False`)
226 """
227 # load the activations
228 activations_path: Path
229 cache: ActivationCacheNp | Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"]
230 activations_path, cache = load_activations(
231 model_name=model_cfg.model_name,
232 prompt=prompt,
233 save_path=save_path,
234 return_fmt="numpy",
235 )
237 # compute and save the figures
238 compute_and_save_figures(
239 model_cfg=model_cfg,
240 activations_path=activations_path,
241 cache=cache,
242 figure_funcs=figure_funcs,
243 save_path=save_path,
244 force_overwrite=force_overwrite,
245 )
248def select_attn_figure_funcs(
249 figure_funcs_select: set[str] | str | None = None,
250) -> list[AttentionMatrixFigureFunc]:
251 """given a selector, figure out which functions from `ATTENTION_MATRIX_FIGURE_FUNCS` to use
253 - if arg is `None`, will use all functions
254 - if a string, will use the function names which match the string (glob/fnmatch syntax)
255 - if a set, will use functions whose names are in the set
257 """
258 # figure out which functions to use
259 figure_funcs: list[AttentionMatrixFigureFunc]
260 if figure_funcs_select is None:
261 # all if nothing specified
262 figure_funcs = ATTENTION_MATRIX_FIGURE_FUNCS
263 elif isinstance(figure_funcs_select, str):
264 # if a string, assume a glob pattern
265 pattern: re.Pattern = re.compile(fnmatch.translate(figure_funcs_select))
266 figure_funcs = [
267 func
268 for func in ATTENTION_MATRIX_FIGURE_FUNCS
269 if pattern.match(getattr(func, "__name__", "<unknown>"))
270 ]
271 elif isinstance(figure_funcs_select, set):
272 # if a set, assume a set of function names
273 figure_funcs = [
274 func
275 for func in ATTENTION_MATRIX_FIGURE_FUNCS
276 if getattr(func, "__name__", "<unknown>") in figure_funcs_select
277 ]
278 else:
279 err_msg: str = (
280 f"figure_funcs_select must be None, str, or set, not {type(figure_funcs_select) = }"
281 f"\n{figure_funcs_select = }"
282 )
283 raise TypeError(err_msg)
284 return figure_funcs
287def figures_main(
288 model_name: str,
289 save_path: str | Path,
290 n_samples: int,
291 force: bool,
292 figure_funcs_select: set[str] | str | None = None,
293 parallel: bool | int = True,
294) -> None:
295 """main function for generating figures from attention patterns, using the functions in `ATTENTION_MATRIX_FIGURE_FUNCS`
297 # Parameters:
298 - `model_name : str`
299 model name to use, used for loading the model config, prompts, activations, and saving the figures
300 - `save_path : str | Path`
301 base path to look in
302 - `n_samples : int`
303 max number of samples to process
304 - `force : bool`
305 force overwrite of existing figures. if `False`, will skip any functions which have already saved a figure
306 - `figure_funcs_select : set[str]|str|None`
307 figure functions to use. if `None`, will use all functions. if a string, will use the function names which match the string. if a set, will use the function names in the set
308 (defaults to `None`)
309 - `parallel : bool | int`
310 whether to run in parallel. if `True`, will use all available cores. if `False`, will run in serial. if an int, will try to use that many cores
311 (defaults to `True`)
312 """
313 with SpinnerContext(message="setting up paths", **SPINNER_KWARGS):
314 # save model info or check if it exists
315 save_path_p: Path = Path(save_path)
316 model_path: Path = save_path_p / model_name
317 with open(model_path / "model_cfg.json", "r") as f:
318 model_cfg = HTConfigMock.load(json.load(f))
320 with SpinnerContext(message="loading prompts", **SPINNER_KWARGS):
321 # load prompts
322 with open(model_path / "prompts.jsonl", "r") as f:
323 prompts: list[dict] = [json.loads(line) for line in f.readlines()]
324 # truncate to n_samples
325 prompts = prompts[:n_samples]
327 print(f"{len(prompts)} prompts loaded")
329 figure_funcs: list[AttentionMatrixFigureFunc] = select_attn_figure_funcs(
330 figure_funcs_select=figure_funcs_select,
331 )
332 print(f"{len(figure_funcs)} figure functions loaded")
333 print(
334 "\t"
335 + ", ".join([getattr(func, "__name__", "<unknown>") for func in figure_funcs]),
336 )
338 chunksize: int = int(
339 max(
340 1,
341 len(prompts) // (5 * multiprocessing.cpu_count()),
342 ),
343 )
344 print(f"chunksize: {chunksize}")
346 list(
347 run_maybe_parallel(
348 func=functools.partial(
349 process_prompt,
350 model_cfg=model_cfg,
351 save_path=save_path_p,
352 figure_funcs=figure_funcs,
353 force_overwrite=force,
354 ),
355 iterable=prompts,
356 parallel=parallel,
357 chunksize=chunksize,
358 pbar="tqdm",
359 pbar_kwargs=dict(
360 desc="Making figures",
361 unit="prompt",
362 ),
363 ),
364 )
366 with SpinnerContext(
367 message="updating jsonl metadata for models and functions",
368 **SPINNER_KWARGS,
369 ):
370 generate_models_jsonl(save_path_p)
371 generate_functions_jsonl(save_path_p)
374def _parse_args() -> tuple[
375 argparse.Namespace,
376 list[str], # models
377 set[str] | str | None, # figure_funcs_select
378]:
379 arg_parser: argparse.ArgumentParser = argparse.ArgumentParser()
380 # input and output
381 arg_parser.add_argument(
382 "--model",
383 "-m",
384 type=str,
385 required=True,
386 help="The model name(s) to use. comma separated with no whitespace if multiple",
387 )
388 arg_parser.add_argument(
389 "--save-path",
390 "-s",
391 type=str,
392 required=False,
393 help="The path to save the attention patterns",
394 default=DATA_DIR,
395 )
396 # number of samples
397 arg_parser.add_argument(
398 "--n-samples",
399 "-n",
400 type=int,
401 required=False,
402 help="The max number of samples to process, do all in the file if None",
403 default=None,
404 )
405 # force overwrite of existing figures
406 arg_parser.add_argument(
407 "--force",
408 "-f",
409 type=bool,
410 required=False,
411 help="Force overwrite of existing figures",
412 default=False,
413 )
414 # figure functions
415 arg_parser.add_argument(
416 "--figure-funcs",
417 type=str,
418 required=False,
419 help="The figure functions to use. if 'None' (default), will use all functions. if a string, will use the function names which match the string. if a comma-separated list of strings, will use the function names in the set",
420 default=None,
421 )
423 args: argparse.Namespace = arg_parser.parse_args()
425 # figure out models
426 models: list[str]
427 if "," in args.model:
428 models = args.model.split(",")
429 else:
430 models = [args.model]
432 # figure out figures
433 figure_funcs_select: set[str] | str | None
434 if (args.figure_funcs is None) or (args.figure_funcs.lower().strip() == "none"):
435 figure_funcs_select = None
436 elif "," in args.figure_funcs:
437 figure_funcs_select = {x.strip() for x in args.figure_funcs.split(",")}
438 else:
439 figure_funcs_select = args.figure_funcs.strip()
441 return args, models, figure_funcs_select
444def main() -> None:
445 "generates figures from the activations using the functions decorated with `register_attn_figure_func`"
446 # parse args
447 print(DIVIDER_S1)
448 args: argparse.Namespace
449 models: list[str]
450 figure_funcs_select: set[str] | str | None
451 with SpinnerContext(message="parsing args", **SPINNER_KWARGS):
452 args, models, figure_funcs_select = _parse_args()
453 print(f"\targs parsed: '{args}'")
454 print(f"\tmodels: '{models}'")
455 print(f"\tfigure_funcs_select: '{figure_funcs_select}'")
457 # compute for each model
458 n_models: int = len(models)
459 for idx, model in enumerate(models):
460 print(DIVIDER_S2)
461 print(f"processing model {idx + 1} / {n_models}: {model}")
462 print(DIVIDER_S2)
463 figures_main(
464 model_name=model,
465 save_path=args.save_path,
466 n_samples=args.n_samples,
467 force=args.force,
468 figure_funcs_select=figure_funcs_select,
469 )
471 print(DIVIDER_S1)
474if __name__ == "__main__":
475 main()