Coverage for pattern_lens/figures.py: 86%
133 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-04-06 15:09 -0600
« prev ^ index » next coverage.py v7.6.12, created at 2025-04-06 15:09 -0600
1"""code for generating figures from attention patterns, using the functions decorated with `register_attn_figure_func`"""
3import argparse
4import fnmatch
5import functools
6import itertools
7import json
8import multiprocessing
9import re
10import warnings
11from collections import defaultdict
12from pathlib import Path
14import numpy as np
15from jaxtyping import Float
17# custom utils
18from muutils.json_serialize import json_serialize
19from muutils.parallel import run_maybe_parallel
20from muutils.spinner import SpinnerContext
22# pattern_lens
23from pattern_lens.attn_figure_funcs import ATTENTION_MATRIX_FIGURE_FUNCS
24from pattern_lens.consts import (
25 DATA_DIR,
26 DIVIDER_S1,
27 DIVIDER_S2,
28 SPINNER_KWARGS,
29 ActivationCacheNp,
30 AttentionMatrix,
31)
32from pattern_lens.figure_util import AttentionMatrixFigureFunc
33from pattern_lens.indexes import (
34 generate_functions_jsonl,
35 generate_models_jsonl,
36 generate_prompts_jsonl,
37)
38from pattern_lens.load_activations import load_activations
41class HTConfigMock:
42 """Mock of `transformer_lens.HookedTransformerConfig` for type hinting and loading config json
44 can be initialized with any kwargs, and will update its `__dict__` with them. does, however, require the following attributes:
45 - `n_layers: int`
46 - `n_heads: int`
47 - `model_name: str`
49 we do this to avoid having to import `torch` and `transformer_lens`, since this would have to be done for each process in the parallelization and probably slows things down significantly
50 """
52 def __init__(self, **kwargs: dict[str, str | int]) -> None:
53 "will pass all kwargs to `__dict__`"
54 self.n_layers: int
55 self.n_heads: int
56 self.model_name: str
57 self.__dict__.update(kwargs)
59 def serialize(self) -> dict:
60 """serialize the config to json. values which aren't serializable will be converted via `muutils.json_serialize.json_serialize`"""
61 # its fine, we know its a dict
62 return json_serialize(self.__dict__) # type: ignore[return-value]
64 @classmethod
65 def load(cls, data: dict) -> "HTConfigMock":
66 "try to load a config from a dict, using the `__init__` method"
67 return cls(**data)
70def process_single_head(
71 layer_idx: int,
72 head_idx: int,
73 attn_pattern: AttentionMatrix,
74 save_dir: Path,
75 figure_funcs: list[AttentionMatrixFigureFunc],
76 force_overwrite: bool = False,
77) -> dict[str, bool | Exception]:
78 """process a single head's attention pattern, running all the functions in `figure_funcs` on the attention pattern
80 > [gotcha:] if `force_overwrite` is `False`, and we used a multi-figure function,
81 > it will skip all figures for that function if any are already saved
82 > and it assumes a format of `{func_name}.{figure_name}.{fmt}` for the saved figures
84 # Parameters:
85 - `layer_idx : int`
86 - `head_idx : int`
87 - `attn_pattern : AttentionMatrix`
88 attention pattern for the head
89 - `save_dir : Path`
90 directory to save the figures to
91 - `force_overwrite : bool`
92 whether to overwrite existing figures. if `False`, will skip any functions which have already saved a figure
93 (defaults to `False`)
95 # Returns:
96 - `dict[str, bool | Exception]`
97 a dictionary of the status of each function, with the function name as the key and the status as the value
98 """
99 funcs_status: dict[str, bool | Exception] = dict()
101 for func in figure_funcs:
102 func_name: str = func.__name__
103 fig_path: list[Path] = list(save_dir.glob(f"{func_name}.*"))
105 if not force_overwrite and len(fig_path) > 0:
106 funcs_status[func_name] = True
107 continue
109 try:
110 func(attn_pattern, save_dir)
111 funcs_status[func_name] = True
113 # bling catch any exception
114 except Exception as e: # noqa: BLE001
115 error_file = save_dir / f"{func.__name__}.error.txt"
116 error_file.write_text(str(e))
117 warnings.warn(
118 f"Error in {func.__name__} for L{layer_idx}H{head_idx}: {e!s}",
119 stacklevel=2,
120 )
121 funcs_status[func_name] = e
123 return funcs_status
126def compute_and_save_figures(
127 model_cfg: "HookedTransformerConfig|HTConfigMock", # type: ignore[name-defined] # noqa: F821
128 activations_path: Path,
129 cache: ActivationCacheNp | Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"],
130 figure_funcs: list[AttentionMatrixFigureFunc],
131 save_path: Path = Path(DATA_DIR),
132 force_overwrite: bool = False,
133 track_results: bool = False,
134) -> None:
135 """compute and save figures for all heads in the model, using the functions in `ATTENTION_MATRIX_FIGURE_FUNCS`
137 # Parameters:
138 - `model_cfg : HookedTransformerConfig|HTConfigMock`
139 configuration of the model, used for loading the activations
140 - `cache : ActivationCacheNp | Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"]`
141 activation cache containing actual patterns for the prompt we are processing
142 - `figure_funcs : list[AttentionMatrixFigureFunc]`
143 list of functions to run
144 - `save_path : Path`
145 directory to save the figures to
146 (defaults to `Path(DATA_DIR)`)
147 - `force_overwrite : bool`
148 force overwrite of existing figures. if `False`, will skip any functions which have already saved a figure
149 (defaults to `False`)
150 - `track_results : bool`
151 whether to track the results of each function for each head. Isn't used for anything yet, but this is a TODO
152 (defaults to `False`)
153 """
154 prompt_dir: Path = activations_path.parent
156 if track_results:
157 results: defaultdict[
158 str, # func name
159 dict[
160 tuple[int, int], # layer, head
161 bool | Exception, # success or exception
162 ],
163 ] = defaultdict(dict)
165 for layer_idx, head_idx in itertools.product(
166 range(model_cfg.n_layers),
167 range(model_cfg.n_heads),
168 ):
169 attn_pattern: AttentionMatrix
170 if isinstance(cache, dict):
171 attn_pattern = cache[f"blocks.{layer_idx}.attn.hook_pattern"][0, head_idx]
172 elif isinstance(cache, np.ndarray):
173 attn_pattern = cache[layer_idx, head_idx]
174 else:
175 msg = (
176 f"cache must be a dict or np.ndarray, not {type(cache) = }\n{cache = }"
177 )
178 raise TypeError(
179 msg,
180 )
182 save_dir: Path = prompt_dir / f"L{layer_idx}" / f"H{head_idx}"
183 save_dir.mkdir(parents=True, exist_ok=True)
184 head_res: dict[str, bool | Exception] = process_single_head(
185 layer_idx=layer_idx,
186 head_idx=head_idx,
187 attn_pattern=attn_pattern,
188 save_dir=save_dir,
189 force_overwrite=force_overwrite,
190 figure_funcs=figure_funcs,
191 )
193 if track_results:
194 for func_name, status in head_res.items():
195 results[func_name][(layer_idx, head_idx)] = status
197 # TODO: do something with results
199 generate_prompts_jsonl(save_path / model_cfg.model_name)
202def process_prompt(
203 prompt: dict,
204 model_cfg: "HookedTransformerConfig|HTConfigMock", # type: ignore[name-defined] # noqa: F821
205 save_path: Path,
206 figure_funcs: list[AttentionMatrixFigureFunc],
207 force_overwrite: bool = False,
208) -> None:
209 """process a single prompt, loading the activations and computing and saving the figures
211 basically just calls `load_activations` and then `compute_and_save_figures`
213 # Parameters:
214 - `prompt : dict`
215 prompt to process, should be a dict with the following keys:
216 - `"text"`: the prompt string
217 - `"hash"`: the hash of the prompt
218 - `model_cfg : HookedTransformerConfig|HTConfigMock`
219 configuration of the model, used for figuring out where to save
220 - `save_path : Path`
221 directory to save the figures to
222 - `figure_funcs : list[AttentionMatrixFigureFunc]`
223 list of functions to run
224 - `force_overwrite : bool`
225 (defaults to `False`)
226 """
227 # load the activations
228 activations_path: Path
229 cache: ActivationCacheNp | Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"]
230 activations_path, cache = load_activations(
231 model_name=model_cfg.model_name,
232 prompt=prompt,
233 save_path=save_path,
234 return_fmt="numpy",
235 )
237 # compute and save the figures
238 compute_and_save_figures(
239 model_cfg=model_cfg,
240 activations_path=activations_path,
241 cache=cache,
242 figure_funcs=figure_funcs,
243 save_path=save_path,
244 force_overwrite=force_overwrite,
245 )
248def select_attn_figure_funcs(
249 figure_funcs_select: set[str] | str | None = None,
250) -> list[AttentionMatrixFigureFunc]:
251 """given a selector, figure out which functions from `ATTENTION_MATRIX_FIGURE_FUNCS` to use
253 - if arg is `None`, will use all functions
254 - if a string, will use the function names which match the string (glob/fnmatch syntax)
255 - if a set, will use functions whose names are in the set
257 """
258 # figure out which functions to use
259 figure_funcs: list[AttentionMatrixFigureFunc]
260 if figure_funcs_select is None:
261 # all if nothing specified
262 figure_funcs = ATTENTION_MATRIX_FIGURE_FUNCS
263 elif isinstance(figure_funcs_select, str):
264 # if a string, assume a glob pattern
265 pattern: re.Pattern = re.compile(fnmatch.translate(figure_funcs_select))
266 figure_funcs = [
267 func
268 for func in ATTENTION_MATRIX_FIGURE_FUNCS
269 if pattern.match(func.__name__)
270 ]
271 elif isinstance(figure_funcs_select, set):
272 # if a set, assume a set of function names
273 figure_funcs = [
274 func
275 for func in ATTENTION_MATRIX_FIGURE_FUNCS
276 if func.__name__ in figure_funcs_select
277 ]
278 else:
279 err_msg: str = (
280 f"figure_funcs_select must be None, str, or set, not {type(figure_funcs_select) = }"
281 f"\n{figure_funcs_select = }"
282 )
283 raise TypeError(err_msg)
284 return figure_funcs
287def figures_main(
288 model_name: str,
289 save_path: str,
290 n_samples: int,
291 force: bool,
292 figure_funcs_select: set[str] | str | None = None,
293 parallel: bool | int = True,
294) -> None:
295 """main function for generating figures from attention patterns, using the functions in `ATTENTION_MATRIX_FIGURE_FUNCS`
297 # Parameters:
298 - `model_name : str`
299 model name to use, used for loading the model config, prompts, activations, and saving the figures
300 - `save_path : str`
301 base path to look in
302 - `n_samples : int`
303 max number of samples to process
304 - `force : bool`
305 force overwrite of existing figures. if `False`, will skip any functions which have already saved a figure
306 - `figure_funcs_select : set[str]|str|None`
307 figure functions to use. if `None`, will use all functions. if a string, will use the function names which match the string. if a set, will use the function names in the set
308 (defaults to `None`)
309 - `parallel : bool | int`
310 whether to run in parallel. if `True`, will use all available cores. if `False`, will run in serial. if an int, will try to use that many cores
311 (defaults to `True`)
312 """
313 with SpinnerContext(message="setting up paths", **SPINNER_KWARGS):
314 # save model info or check if it exists
315 save_path_p: Path = Path(save_path)
316 model_path: Path = save_path_p / model_name
317 with open(model_path / "model_cfg.json", "r") as f:
318 model_cfg = HTConfigMock.load(json.load(f))
320 with SpinnerContext(message="loading prompts", **SPINNER_KWARGS):
321 # load prompts
322 with open(model_path / "prompts.jsonl", "r") as f:
323 prompts: list[dict] = [json.loads(line) for line in f.readlines()]
324 # truncate to n_samples
325 prompts = prompts[:n_samples]
327 print(f"{len(prompts)} prompts loaded")
329 figure_funcs: list[AttentionMatrixFigureFunc] = select_attn_figure_funcs(
330 figure_funcs_select=figure_funcs_select,
331 )
332 print(f"{len(figure_funcs)} figure functions loaded")
333 print("\t" + ", ".join([func.__name__ for func in figure_funcs]))
335 chunksize: int = int(
336 max(
337 1,
338 len(prompts) // (5 * multiprocessing.cpu_count()),
339 ),
340 )
341 print(f"chunksize: {chunksize}")
343 list(
344 run_maybe_parallel(
345 func=functools.partial(
346 process_prompt,
347 model_cfg=model_cfg,
348 save_path=save_path_p,
349 figure_funcs=figure_funcs,
350 force_overwrite=force,
351 ),
352 iterable=prompts,
353 parallel=parallel,
354 chunksize=chunksize,
355 pbar="tqdm",
356 pbar_kwargs=dict(
357 desc="Making figures",
358 unit="prompt",
359 ),
360 ),
361 )
363 with SpinnerContext(
364 message="updating jsonl metadata for models and functions",
365 **SPINNER_KWARGS,
366 ):
367 generate_models_jsonl(save_path_p)
368 generate_functions_jsonl(save_path_p)
371def _parse_args() -> tuple[
372 argparse.Namespace,
373 list[str], # models
374 set[str] | str | None, # figure_funcs_select
375]:
376 arg_parser: argparse.ArgumentParser = argparse.ArgumentParser()
377 # input and output
378 arg_parser.add_argument(
379 "--model",
380 "-m",
381 type=str,
382 required=True,
383 help="The model name(s) to use. comma separated with no whitespace if multiple",
384 )
385 arg_parser.add_argument(
386 "--save-path",
387 "-s",
388 type=str,
389 required=False,
390 help="The path to save the attention patterns",
391 default=DATA_DIR,
392 )
393 # number of samples
394 arg_parser.add_argument(
395 "--n-samples",
396 "-n",
397 type=int,
398 required=False,
399 help="The max number of samples to process, do all in the file if None",
400 default=None,
401 )
402 # force overwrite of existing figures
403 arg_parser.add_argument(
404 "--force",
405 "-f",
406 type=bool,
407 required=False,
408 help="Force overwrite of existing figures",
409 default=False,
410 )
411 # figure functions
412 arg_parser.add_argument(
413 "--figure-funcs",
414 type=str,
415 required=False,
416 help="The figure functions to use. if 'None' (default), will use all functions. if a string, will use the function names which match the string. if a comma-separated list of strings, will use the function names in the set",
417 default=None,
418 )
420 args: argparse.Namespace = arg_parser.parse_args()
422 # figure out models
423 models: list[str]
424 if "," in args.model:
425 models = args.model.split(",")
426 else:
427 models = [args.model]
429 # figure out figures
430 figure_funcs_select: set[str] | str | None
431 if (args.figure_funcs is None) or (args.figure_funcs.lower().strip() == "none"):
432 figure_funcs_select = None
433 elif "," in args.figure_funcs:
434 figure_funcs_select = {x.strip() for x in args.figure_funcs.split(",")}
435 else:
436 figure_funcs_select = args.figure_funcs.strip()
438 return args, models, figure_funcs_select
441def main() -> None:
442 "generates figures from the activations using the functions decorated with `register_attn_figure_func`"
443 # parse args
444 print(DIVIDER_S1)
445 args: argparse.Namespace
446 models: list[str]
447 figure_funcs_select: set[str] | str | None
448 with SpinnerContext(message="parsing args", **SPINNER_KWARGS):
449 args, models, figure_funcs_select = _parse_args()
450 print(f"\targs parsed: '{args}'")
451 print(f"\tmodels: '{models}'")
452 print(f"\tfigure_funcs_select: '{figure_funcs_select}'")
454 # compute for each model
455 n_models: int = len(models)
456 for idx, model in enumerate(models):
457 print(DIVIDER_S2)
458 print(f"processing model {idx + 1} / {n_models}: {model}")
459 print(DIVIDER_S2)
460 figures_main(
461 model_name=model,
462 save_path=args.save_path,
463 n_samples=args.n_samples,
464 force=args.force,
465 figure_funcs_select=figure_funcs_select,
466 )
468 print(DIVIDER_S1)
471if __name__ == "__main__":
472 main()