Coverage for pattern_lens\figures.py: 68%
97 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-16 20:39 -0700
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-16 20:39 -0700
1"""code for generating figures from attention patterns, using the functions decorated with `register_attn_figure_func`"""
3import argparse
4from collections import defaultdict
5import functools
6import itertools
7import json
8import warnings
9from pathlib import Path
11from muutils.json_serialize import json_serialize
12from muutils.spinner import SpinnerContext
13from muutils.parallel import run_maybe_parallel
15from pattern_lens.attn_figure_funcs import ATTENTION_MATRIX_FIGURE_FUNCS
16from pattern_lens.consts import (
17 DATA_DIR,
18 AttentionMatrix,
19 SPINNER_KWARGS,
20 ActivationCacheNp,
21 DIVIDER_S1,
22 DIVIDER_S2,
23)
24from pattern_lens.indexes import (
25 generate_functions_jsonl,
26 generate_models_jsonl,
27 generate_prompts_jsonl,
28)
29from pattern_lens.load_activations import load_activations
32class HTConfigMock:
33 """Mock of `transformer_lens.HookedTransformerConfig` for type hinting and loading config json
35 can be initialized with any kwargs, and will update its `__dict__` with them. does, however, require the following attributes:
36 - `n_layers: int`
37 - `n_heads: int`
38 - `model_name: str`
39 """
41 def __init__(self, **kwargs):
42 self.n_layers: int
43 self.n_heads: int
44 self.model_name: str
45 self.__dict__.update(kwargs)
47 def serialize(self):
48 """serialize the config to json. values which aren't serializable will be converted via `muutils.json_serialize.json_serialize`"""
49 return json_serialize(self.__dict__)
51 @classmethod
52 def load(cls, data: dict):
53 "try to load a config from a dict, using the `__init__` method"
54 return cls(**data)
57def process_single_head(
58 layer_idx: int,
59 head_idx: int,
60 attn_pattern: AttentionMatrix,
61 save_dir: Path,
62 force_overwrite: bool = False,
63) -> dict[str, bool | Exception]:
64 """process a single head's attention pattern, running all the functions in `ATTENTION_MATRIX_FIGURE_FUNCS` on the attention pattern
66 # Parameters:
67 - `layer_idx : int`
68 - `head_idx : int`
69 - `attn_pattern : AttentionMatrix`
70 attention pattern for the head
71 - `save_dir : Path`
72 directory to save the figures to
73 - `force_overwrite : bool`
74 whether to overwrite existing figures. if `False`, will skip any functions which have already saved a figure
75 (defaults to `False`)
77 # Returns:
78 - `dict[str, bool | Exception]`
79 a dictionary of the status of each function, with the function name as the key and the status as the value
80 """
81 funcs_status: dict[str, bool | Exception] = dict()
83 for func in ATTENTION_MATRIX_FIGURE_FUNCS:
84 func_name: str = func.__name__
85 fig_path: list[Path] = list(save_dir.glob(f"{func_name}.*"))
87 if not force_overwrite and len(fig_path) > 0:
88 funcs_status[func_name] = True
89 continue
91 try:
92 func(attn_pattern, save_dir)
93 funcs_status[func_name] = True
95 except Exception as e:
96 error_file = save_dir / f"{func.__name__}.error.txt"
97 error_file.write_text(str(e))
98 warnings.warn(
99 f"Error in {func.__name__} for L{layer_idx}H{head_idx}: {str(e)}"
100 )
101 funcs_status[func_name] = e
103 return funcs_status
106def compute_and_save_figures(
107 model_cfg: "HookedTransformerConfig|HTConfigMock", # type: ignore[name-defined] # noqa: F821
108 activations_path: Path,
109 cache: ActivationCacheNp,
110 save_path: Path = Path(DATA_DIR),
111 force_overwrite: bool = False,
112 track_results: bool = False,
113) -> None:
114 """compute and save figures for all heads in the model, using the functions in `ATTENTION_MATRIX_FIGURE_FUNCS`
116 # Parameters:
117 - `model_cfg : HookedTransformerConfig|HTConfigMock`
118 - `cache : ActivationCacheNp`
119 - `save_path : Path`
120 (defaults to `Path(DATA_DIR)`)
121 - `force_overwrite : bool`
122 force overwrite of existing figures. if `False`, will skip any functions which have already saved a figure
123 (defaults to `False`)
124 - `track_results : bool`
125 whether to track the results of each function for each head. Isn't used for anything yet, but this is a TODO
126 (defaults to `False`)
127 """
128 prompt_dir: Path = activations_path.parent
130 if track_results:
131 results: defaultdict[
132 str, # func name
133 dict[
134 tuple[int, int], # layer, head
135 bool | Exception, # success or exception
136 ],
137 ] = defaultdict(dict)
139 for layer_idx, head_idx in itertools.product(
140 range(model_cfg.n_layers),
141 range(model_cfg.n_heads),
142 ):
143 attn_pattern: AttentionMatrix = cache[f"blocks.{layer_idx}.attn.hook_pattern"][
144 0, head_idx
145 ]
146 save_dir: Path = prompt_dir / f"L{layer_idx}" / f"H{head_idx}"
147 save_dir.mkdir(parents=True, exist_ok=True)
148 head_res: dict[str, bool | Exception] = process_single_head(
149 layer_idx=layer_idx,
150 head_idx=head_idx,
151 attn_pattern=attn_pattern,
152 save_dir=save_dir,
153 force_overwrite=force_overwrite,
154 )
156 if track_results:
157 for func_name, status in head_res.items():
158 results[func_name][(layer_idx, head_idx)] = status
160 # TODO: do something with results
162 generate_prompts_jsonl(save_path / model_cfg.model_name)
165def process_prompt(
166 prompt: dict,
167 model_cfg: "HookedTransformerConfig|HTConfigMock", # type: ignore[name-defined] # noqa: F821
168 save_path: Path,
169 force_overwrite: bool = False,
170) -> None:
171 """process a single prompt, loading the activations and computing and saving the figures
173 basically just calls `load_activations` and then `compute_and_save_figures`
175 # Parameters:
176 - `prompt : dict`
177 - `model_cfg : HookedTransformerConfig|HTConfigMock`
178 - `force_overwrite : bool`
179 (defaults to `False`)
180 """
181 activations_path: Path
182 cache: ActivationCacheNp
183 activations_path, cache = load_activations(
184 model_name=model_cfg.model_name,
185 prompt=prompt,
186 save_path=save_path,
187 return_fmt="numpy",
188 )
190 compute_and_save_figures(
191 model_cfg=model_cfg,
192 activations_path=activations_path,
193 cache=cache,
194 save_path=save_path,
195 force_overwrite=force_overwrite,
196 )
199def figures_main(
200 model_name: str,
201 save_path: str,
202 n_samples: int,
203 force: bool,
204 parallel: bool | int = True,
205) -> None:
206 """main function for generating figures from attention patterns, using the functions in `ATTENTION_MATRIX_FIGURE_FUNCS`
208 # Parameters:
209 - `model_name : str`
210 model name to use, used for loading the model config, prompts, activations, and saving the figures
211 - `save_path : str`
212 base path to look in
213 - `n_samples : int`
214 max number of samples to process
215 - `force : bool`
216 force overwrite of existing figures. if `False`, will skip any functions which have already saved a figure
217 - `parallel : bool | int`
218 whether to run in parallel. if `True`, will use all available cores. if `False`, will run in serial. if an int, will try to use that many cores
219 (defaults to `True`)
220 """
221 with SpinnerContext(message="setting up paths", **SPINNER_KWARGS):
222 # save model info or check if it exists
223 save_path_p: Path = Path(save_path)
224 model_path: Path = save_path_p / model_name
225 with open(model_path / "model_cfg.json", "r") as f:
226 model_cfg = HTConfigMock.load(json.load(f))
228 with SpinnerContext(message="loading prompts", **SPINNER_KWARGS):
229 # load prompts
230 with open(model_path / "prompts.jsonl", "r") as f:
231 prompts: list[dict] = [json.loads(line) for line in f.readlines()]
232 # truncate to n_samples
233 prompts = prompts[:n_samples]
235 print(f"{len(prompts)} prompts loaded")
237 print(f"{len(ATTENTION_MATRIX_FIGURE_FUNCS)} figure functions loaded")
238 print("\t" + ", ".join([func.__name__ for func in ATTENTION_MATRIX_FIGURE_FUNCS]))
240 list(
241 run_maybe_parallel(
242 func=functools.partial(
243 process_prompt,
244 model_cfg=model_cfg,
245 save_path=save_path_p,
246 force_overwrite=force,
247 ),
248 iterable=prompts,
249 parallel=parallel,
250 pbar="tqdm",
251 pbar_kwargs=dict(
252 desc="Making figures",
253 unit="prompt",
254 ),
255 )
256 )
258 with SpinnerContext(
259 message="updating jsonl metadata for models and functions", **SPINNER_KWARGS
260 ):
261 generate_models_jsonl(save_path_p)
262 generate_functions_jsonl(save_path_p)
265def main():
266 print(DIVIDER_S1)
267 with SpinnerContext(message="parsing args", **SPINNER_KWARGS):
268 arg_parser: argparse.ArgumentParser = argparse.ArgumentParser()
269 # input and output
270 arg_parser.add_argument(
271 "--model",
272 "-m",
273 type=str,
274 required=True,
275 help="The model name(s) to use. comma separated with no whitespace if multiple",
276 )
277 arg_parser.add_argument(
278 "--save-path",
279 "-s",
280 type=str,
281 required=False,
282 help="The path to save the attention patterns",
283 default=DATA_DIR,
284 )
285 # number of samples
286 arg_parser.add_argument(
287 "--n-samples",
288 "-n",
289 type=int,
290 required=False,
291 help="The max number of samples to process, do all in the file if None",
292 default=None,
293 )
294 # force overwrite of existing figures
295 arg_parser.add_argument(
296 "--force",
297 "-f",
298 type=bool,
299 required=False,
300 help="Force overwrite of existing figures",
301 default=False,
302 )
304 args: argparse.Namespace = arg_parser.parse_args()
306 print(f"args parsed: {args}")
308 models: list[str]
309 if "," in args.model:
310 models = args.model.split(",")
311 else:
312 models = [args.model]
314 n_models: int = len(models)
315 for idx, model in enumerate(models):
316 print(DIVIDER_S2)
317 print(f"processing model {idx+1} / {n_models}: {model}")
318 print(DIVIDER_S2)
319 figures_main(
320 model_name=model,
321 save_path=args.save_path,
322 n_samples=args.n_samples,
323 force=args.force,
324 )
326 print(DIVIDER_S1)
329if __name__ == "__main__":
330 main()