Coverage for pattern_lens\figures.py: 91%
104 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-04 01:50 -0700
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-04 01:50 -0700
1"""code for generating figures from attention patterns, using the functions decorated with `register_attn_figure_func`"""
3import argparse
4import functools
5import itertools
6import json
7import warnings
8from collections import defaultdict
9from pathlib import Path
11import numpy as np
12from jaxtyping import Float
14# custom utils
15from muutils.json_serialize import json_serialize
16from muutils.parallel import run_maybe_parallel
17from muutils.spinner import SpinnerContext
19# pattern_lens
20from pattern_lens.attn_figure_funcs import ATTENTION_MATRIX_FIGURE_FUNCS
21from pattern_lens.consts import (
22 DATA_DIR,
23 DIVIDER_S1,
24 DIVIDER_S2,
25 SPINNER_KWARGS,
26 ActivationCacheNp,
27 AttentionMatrix,
28)
29from pattern_lens.indexes import (
30 generate_functions_jsonl,
31 generate_models_jsonl,
32 generate_prompts_jsonl,
33)
34from pattern_lens.load_activations import load_activations
37class HTConfigMock:
38 """Mock of `transformer_lens.HookedTransformerConfig` for type hinting and loading config json
40 can be initialized with any kwargs, and will update its `__dict__` with them. does, however, require the following attributes:
41 - `n_layers: int`
42 - `n_heads: int`
43 - `model_name: str`
45 we do this to avoid having to import `torch` and `transformer_lens`, since this would have to be done for each process in the parallelization and probably slows things down significantly
46 """
48 def __init__(self, **kwargs: dict[str, str | int]) -> None:
49 "will pass all kwargs to `__dict__`"
50 self.n_layers: int
51 self.n_heads: int
52 self.model_name: str
53 self.__dict__.update(kwargs)
55 def serialize(self) -> dict:
56 """serialize the config to json. values which aren't serializable will be converted via `muutils.json_serialize.json_serialize`"""
57 # its fine, we know its a dict
58 return json_serialize(self.__dict__) # type: ignore[return-value]
60 @classmethod
61 def load(cls, data: dict) -> "HTConfigMock":
62 "try to load a config from a dict, using the `__init__` method"
63 return cls(**data)
66def process_single_head(
67 layer_idx: int,
68 head_idx: int,
69 attn_pattern: AttentionMatrix,
70 save_dir: Path,
71 force_overwrite: bool = False,
72) -> dict[str, bool | Exception]:
73 """process a single head's attention pattern, running all the functions in `ATTENTION_MATRIX_FIGURE_FUNCS` on the attention pattern
75 > [gotcha:] if `force_overwrite` is `False`, and we used a multi-figure function,
76 > it will skip all figures for that function if any are already saved
77 > and it assumes a format of `{func_name}.{figure_name}.{fmt}` for the saved figures
79 # Parameters:
80 - `layer_idx : int`
81 - `head_idx : int`
82 - `attn_pattern : AttentionMatrix`
83 attention pattern for the head
84 - `save_dir : Path`
85 directory to save the figures to
86 - `force_overwrite : bool`
87 whether to overwrite existing figures. if `False`, will skip any functions which have already saved a figure
88 (defaults to `False`)
90 # Returns:
91 - `dict[str, bool | Exception]`
92 a dictionary of the status of each function, with the function name as the key and the status as the value
93 """
94 funcs_status: dict[str, bool | Exception] = dict()
96 for func in ATTENTION_MATRIX_FIGURE_FUNCS:
97 func_name: str = func.__name__
98 fig_path: list[Path] = list(save_dir.glob(f"{func_name}.*"))
100 if not force_overwrite and len(fig_path) > 0:
101 funcs_status[func_name] = True
102 continue
104 try:
105 func(attn_pattern, save_dir)
106 funcs_status[func_name] = True
108 # bling catch any exception
109 except Exception as e: # noqa: BLE001
110 error_file = save_dir / f"{func.__name__}.error.txt"
111 error_file.write_text(str(e))
112 warnings.warn(
113 f"Error in {func.__name__} for L{layer_idx}H{head_idx}: {e!s}",
114 stacklevel=2,
115 )
116 funcs_status[func_name] = e
118 return funcs_status
121def compute_and_save_figures(
122 model_cfg: "HookedTransformerConfig|HTConfigMock", # type: ignore[name-defined] # noqa: F821
123 activations_path: Path,
124 cache: ActivationCacheNp | Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"],
125 save_path: Path = Path(DATA_DIR),
126 force_overwrite: bool = False,
127 track_results: bool = False,
128) -> None:
129 """compute and save figures for all heads in the model, using the functions in `ATTENTION_MATRIX_FIGURE_FUNCS`
131 # Parameters:
132 - `model_cfg : HookedTransformerConfig|HTConfigMock`
133 - `cache : ActivationCacheNp | Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"]`
134 - `save_path : Path`
135 (defaults to `Path(DATA_DIR)`)
136 - `force_overwrite : bool`
137 force overwrite of existing figures. if `False`, will skip any functions which have already saved a figure
138 (defaults to `False`)
139 - `track_results : bool`
140 whether to track the results of each function for each head. Isn't used for anything yet, but this is a TODO
141 (defaults to `False`)
142 """
143 prompt_dir: Path = activations_path.parent
145 if track_results:
146 results: defaultdict[
147 str, # func name
148 dict[
149 tuple[int, int], # layer, head
150 bool | Exception, # success or exception
151 ],
152 ] = defaultdict(dict)
154 for layer_idx, head_idx in itertools.product(
155 range(model_cfg.n_layers),
156 range(model_cfg.n_heads),
157 ):
158 attn_pattern: AttentionMatrix
159 if isinstance(cache, dict):
160 attn_pattern = cache[f"blocks.{layer_idx}.attn.hook_pattern"][0, head_idx]
161 elif isinstance(cache, np.ndarray):
162 attn_pattern = cache[layer_idx, head_idx]
163 else:
164 msg = (
165 f"cache must be a dict or np.ndarray, not {type(cache) = }\n{cache = }"
166 )
167 raise TypeError(
168 msg,
169 )
171 save_dir: Path = prompt_dir / f"L{layer_idx}" / f"H{head_idx}"
172 save_dir.mkdir(parents=True, exist_ok=True)
173 head_res: dict[str, bool | Exception] = process_single_head(
174 layer_idx=layer_idx,
175 head_idx=head_idx,
176 attn_pattern=attn_pattern,
177 save_dir=save_dir,
178 force_overwrite=force_overwrite,
179 )
181 if track_results:
182 for func_name, status in head_res.items():
183 results[func_name][(layer_idx, head_idx)] = status
185 # TODO: do something with results
187 generate_prompts_jsonl(save_path / model_cfg.model_name)
190def process_prompt(
191 prompt: dict,
192 model_cfg: "HookedTransformerConfig|HTConfigMock", # type: ignore[name-defined] # noqa: F821
193 save_path: Path,
194 force_overwrite: bool = False,
195) -> None:
196 """process a single prompt, loading the activations and computing and saving the figures
198 basically just calls `load_activations` and then `compute_and_save_figures`
200 # Parameters:
201 - `prompt : dict`
202 - `model_cfg : HookedTransformerConfig|HTConfigMock`
203 - `force_overwrite : bool`
204 (defaults to `False`)
205 """
206 activations_path: Path
207 cache: ActivationCacheNp | Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"]
208 activations_path, cache = load_activations(
209 model_name=model_cfg.model_name,
210 prompt=prompt,
211 save_path=save_path,
212 return_fmt="numpy",
213 )
215 compute_and_save_figures(
216 model_cfg=model_cfg,
217 activations_path=activations_path,
218 cache=cache,
219 save_path=save_path,
220 force_overwrite=force_overwrite,
221 )
224def figures_main(
225 model_name: str,
226 save_path: str,
227 n_samples: int,
228 force: bool,
229 parallel: bool | int = True,
230) -> None:
231 """main function for generating figures from attention patterns, using the functions in `ATTENTION_MATRIX_FIGURE_FUNCS`
233 # Parameters:
234 - `model_name : str`
235 model name to use, used for loading the model config, prompts, activations, and saving the figures
236 - `save_path : str`
237 base path to look in
238 - `n_samples : int`
239 max number of samples to process
240 - `force : bool`
241 force overwrite of existing figures. if `False`, will skip any functions which have already saved a figure
242 - `parallel : bool | int`
243 whether to run in parallel. if `True`, will use all available cores. if `False`, will run in serial. if an int, will try to use that many cores
244 (defaults to `True`)
245 """
246 with SpinnerContext(message="setting up paths", **SPINNER_KWARGS):
247 # save model info or check if it exists
248 save_path_p: Path = Path(save_path)
249 model_path: Path = save_path_p / model_name
250 with open(model_path / "model_cfg.json", "r") as f:
251 model_cfg = HTConfigMock.load(json.load(f))
253 with SpinnerContext(message="loading prompts", **SPINNER_KWARGS):
254 # load prompts
255 with open(model_path / "prompts.jsonl", "r") as f:
256 prompts: list[dict] = [json.loads(line) for line in f.readlines()]
257 # truncate to n_samples
258 prompts = prompts[:n_samples]
260 print(f"{len(prompts)} prompts loaded")
262 print(f"{len(ATTENTION_MATRIX_FIGURE_FUNCS)} figure functions loaded")
263 print("\t" + ", ".join([func.__name__ for func in ATTENTION_MATRIX_FIGURE_FUNCS]))
265 list(
266 run_maybe_parallel(
267 func=functools.partial(
268 process_prompt,
269 model_cfg=model_cfg,
270 save_path=save_path_p,
271 force_overwrite=force,
272 ),
273 iterable=prompts,
274 parallel=parallel,
275 pbar="tqdm",
276 pbar_kwargs=dict(
277 desc="Making figures",
278 unit="prompt",
279 ),
280 ),
281 )
283 with SpinnerContext(
284 message="updating jsonl metadata for models and functions",
285 **SPINNER_KWARGS,
286 ):
287 generate_models_jsonl(save_path_p)
288 generate_functions_jsonl(save_path_p)
291def main() -> None:
292 "generates figures from the activations using the functions decorated with `register_attn_figure_func`"
293 print(DIVIDER_S1)
294 with SpinnerContext(message="parsing args", **SPINNER_KWARGS):
295 arg_parser: argparse.ArgumentParser = argparse.ArgumentParser()
296 # input and output
297 arg_parser.add_argument(
298 "--model",
299 "-m",
300 type=str,
301 required=True,
302 help="The model name(s) to use. comma separated with no whitespace if multiple",
303 )
304 arg_parser.add_argument(
305 "--save-path",
306 "-s",
307 type=str,
308 required=False,
309 help="The path to save the attention patterns",
310 default=DATA_DIR,
311 )
312 # number of samples
313 arg_parser.add_argument(
314 "--n-samples",
315 "-n",
316 type=int,
317 required=False,
318 help="The max number of samples to process, do all in the file if None",
319 default=None,
320 )
321 # force overwrite of existing figures
322 arg_parser.add_argument(
323 "--force",
324 "-f",
325 type=bool,
326 required=False,
327 help="Force overwrite of existing figures",
328 default=False,
329 )
331 args: argparse.Namespace = arg_parser.parse_args()
333 print(f"args parsed: {args}")
335 models: list[str]
336 if "," in args.model:
337 models = args.model.split(",")
338 else:
339 models = [args.model]
341 n_models: int = len(models)
342 for idx, model in enumerate(models):
343 print(DIVIDER_S2)
344 print(f"processing model {idx + 1} / {n_models}: {model}")
345 print(DIVIDER_S2)
346 figures_main(
347 model_name=model,
348 save_path=args.save_path,
349 n_samples=args.n_samples,
350 force=args.force,
351 )
353 print(DIVIDER_S1)
356if __name__ == "__main__":
357 main()