pattern_lens.figures
code for generating figures from attention patterns, using the functions decorated with register_attn_figure_func
1"""code for generating figures from attention patterns, using the functions decorated with `register_attn_figure_func`""" 2 3import argparse 4import fnmatch 5import functools 6import itertools 7import json 8import multiprocessing 9import re 10import warnings 11from collections import defaultdict 12from pathlib import Path 13 14import numpy as np 15from jaxtyping import Float 16 17# custom utils 18from muutils.json_serialize import json_serialize 19from muutils.parallel import run_maybe_parallel 20from muutils.spinner import SpinnerContext 21 22# pattern_lens 23from pattern_lens.attn_figure_funcs import ATTENTION_MATRIX_FIGURE_FUNCS 24from pattern_lens.consts import ( 25 DATA_DIR, 26 DIVIDER_S1, 27 DIVIDER_S2, 28 SPINNER_KWARGS, 29 ActivationCacheNp, 30 AttentionMatrix, 31) 32from pattern_lens.figure_util import AttentionMatrixFigureFunc 33from pattern_lens.indexes import ( 34 generate_functions_jsonl, 35 generate_models_jsonl, 36 generate_prompts_jsonl, 37) 38from pattern_lens.load_activations import load_activations 39 40 41class HTConfigMock: 42 """Mock of `transformer_lens.HookedTransformerConfig` for type hinting and loading config json 43 44 can be initialized with any kwargs, and will update its `__dict__` with them. does, however, require the following attributes: 45 - `n_layers: int` 46 - `n_heads: int` 47 - `model_name: str` 48 49 we do this to avoid having to import `torch` and `transformer_lens`, since this would have to be done for each process in the parallelization and probably slows things down significantly 50 """ 51 52 def __init__(self, **kwargs: dict[str, str | int]) -> None: 53 "will pass all kwargs to `__dict__`" 54 self.n_layers: int 55 self.n_heads: int 56 self.model_name: str 57 self.__dict__.update(kwargs) 58 59 def serialize(self) -> dict: 60 """serialize the config to json. values which aren't serializable will be converted via `muutils.json_serialize.json_serialize`""" 61 # its fine, we know its a dict 62 return json_serialize(self.__dict__) # type: ignore[return-value] 63 64 @classmethod 65 def load(cls, data: dict) -> "HTConfigMock": 66 "try to load a config from a dict, using the `__init__` method" 67 return cls(**data) 68 69 70def process_single_head( 71 layer_idx: int, 72 head_idx: int, 73 attn_pattern: AttentionMatrix, 74 save_dir: Path, 75 figure_funcs: list[AttentionMatrixFigureFunc], 76 force_overwrite: bool = False, 77) -> dict[str, bool | Exception]: 78 """process a single head's attention pattern, running all the functions in `figure_funcs` on the attention pattern 79 80 > [gotcha:] if `force_overwrite` is `False`, and we used a multi-figure function, 81 > it will skip all figures for that function if any are already saved 82 > and it assumes a format of `{func_name}.{figure_name}.{fmt}` for the saved figures 83 84 # Parameters: 85 - `layer_idx : int` 86 - `head_idx : int` 87 - `attn_pattern : AttentionMatrix` 88 attention pattern for the head 89 - `save_dir : Path` 90 directory to save the figures to 91 - `force_overwrite : bool` 92 whether to overwrite existing figures. if `False`, will skip any functions which have already saved a figure 93 (defaults to `False`) 94 95 # Returns: 96 - `dict[str, bool | Exception]` 97 a dictionary of the status of each function, with the function name as the key and the status as the value 98 """ 99 funcs_status: dict[str, bool | Exception] = dict() 100 101 for func in figure_funcs: 102 func_name: str = getattr(func, "__name__", "<unknown>") 103 fig_path: list[Path] = list(save_dir.glob(f"{func_name}.*")) 104 105 if not force_overwrite and len(fig_path) > 0: 106 funcs_status[func_name] = True 107 continue 108 109 try: 110 func(attn_pattern, save_dir) 111 funcs_status[func_name] = True 112 113 # bling catch any exception 114 except Exception as e: # noqa: BLE001 115 error_file = save_dir / f"{func_name}.error.txt" 116 error_file.write_text(str(e)) 117 warnings.warn( 118 f"Error in {func_name} for L{layer_idx}H{head_idx}: {e!s}", 119 stacklevel=2, 120 ) 121 funcs_status[func_name] = e 122 123 return funcs_status 124 125 126def compute_and_save_figures( 127 model_cfg: "HookedTransformerConfig|HTConfigMock", # type: ignore[name-defined] # noqa: F821 128 activations_path: Path, 129 cache: ActivationCacheNp | Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"], 130 figure_funcs: list[AttentionMatrixFigureFunc], 131 save_path: Path = Path(DATA_DIR), 132 force_overwrite: bool = False, 133 track_results: bool = False, 134) -> None: 135 """compute and save figures for all heads in the model, using the functions in `ATTENTION_MATRIX_FIGURE_FUNCS` 136 137 # Parameters: 138 - `model_cfg : HookedTransformerConfig|HTConfigMock` 139 configuration of the model, used for loading the activations 140 - `cache : ActivationCacheNp | Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"]` 141 activation cache containing actual patterns for the prompt we are processing 142 - `figure_funcs : list[AttentionMatrixFigureFunc]` 143 list of functions to run 144 - `save_path : Path` 145 directory to save the figures to 146 (defaults to `Path(DATA_DIR)`) 147 - `force_overwrite : bool` 148 force overwrite of existing figures. if `False`, will skip any functions which have already saved a figure 149 (defaults to `False`) 150 - `track_results : bool` 151 whether to track the results of each function for each head. Isn't used for anything yet, but this is a TODO 152 (defaults to `False`) 153 """ 154 prompt_dir: Path = activations_path.parent 155 156 if track_results: 157 results: defaultdict[ 158 str, # func name 159 dict[ 160 tuple[int, int], # layer, head 161 bool | Exception, # success or exception 162 ], 163 ] = defaultdict(dict) 164 165 for layer_idx, head_idx in itertools.product( 166 range(model_cfg.n_layers), 167 range(model_cfg.n_heads), 168 ): 169 attn_pattern: AttentionMatrix 170 if isinstance(cache, dict): 171 attn_pattern = cache[f"blocks.{layer_idx}.attn.hook_pattern"][0, head_idx] 172 elif isinstance(cache, np.ndarray): 173 attn_pattern = cache[layer_idx, head_idx] 174 else: 175 msg = ( 176 f"cache must be a dict or np.ndarray, not {type(cache) = }\n{cache = }" 177 ) 178 raise TypeError( 179 msg, 180 ) 181 182 save_dir: Path = prompt_dir / f"L{layer_idx}" / f"H{head_idx}" 183 save_dir.mkdir(parents=True, exist_ok=True) 184 head_res: dict[str, bool | Exception] = process_single_head( 185 layer_idx=layer_idx, 186 head_idx=head_idx, 187 attn_pattern=attn_pattern, 188 save_dir=save_dir, 189 force_overwrite=force_overwrite, 190 figure_funcs=figure_funcs, 191 ) 192 193 if track_results: 194 for func_name, status in head_res.items(): 195 results[func_name][(layer_idx, head_idx)] = status 196 197 # TODO: do something with results 198 199 generate_prompts_jsonl(save_path / model_cfg.model_name) 200 201 202def process_prompt( 203 prompt: dict, 204 model_cfg: "HookedTransformerConfig|HTConfigMock", # type: ignore[name-defined] # noqa: F821 205 save_path: Path, 206 figure_funcs: list[AttentionMatrixFigureFunc], 207 force_overwrite: bool = False, 208) -> None: 209 """process a single prompt, loading the activations and computing and saving the figures 210 211 basically just calls `load_activations` and then `compute_and_save_figures` 212 213 # Parameters: 214 - `prompt : dict` 215 prompt to process, should be a dict with the following keys: 216 - `"text"`: the prompt string 217 - `"hash"`: the hash of the prompt 218 - `model_cfg : HookedTransformerConfig|HTConfigMock` 219 configuration of the model, used for figuring out where to save 220 - `save_path : Path` 221 directory to save the figures to 222 - `figure_funcs : list[AttentionMatrixFigureFunc]` 223 list of functions to run 224 - `force_overwrite : bool` 225 (defaults to `False`) 226 """ 227 # load the activations 228 activations_path: Path 229 cache: ActivationCacheNp | Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"] 230 activations_path, cache = load_activations( 231 model_name=model_cfg.model_name, 232 prompt=prompt, 233 save_path=save_path, 234 return_fmt="numpy", 235 ) 236 237 # compute and save the figures 238 compute_and_save_figures( 239 model_cfg=model_cfg, 240 activations_path=activations_path, 241 cache=cache, 242 figure_funcs=figure_funcs, 243 save_path=save_path, 244 force_overwrite=force_overwrite, 245 ) 246 247 248def select_attn_figure_funcs( 249 figure_funcs_select: set[str] | str | None = None, 250) -> list[AttentionMatrixFigureFunc]: 251 """given a selector, figure out which functions from `ATTENTION_MATRIX_FIGURE_FUNCS` to use 252 253 - if arg is `None`, will use all functions 254 - if a string, will use the function names which match the string (glob/fnmatch syntax) 255 - if a set, will use functions whose names are in the set 256 257 """ 258 # figure out which functions to use 259 figure_funcs: list[AttentionMatrixFigureFunc] 260 if figure_funcs_select is None: 261 # all if nothing specified 262 figure_funcs = ATTENTION_MATRIX_FIGURE_FUNCS 263 elif isinstance(figure_funcs_select, str): 264 # if a string, assume a glob pattern 265 pattern: re.Pattern = re.compile(fnmatch.translate(figure_funcs_select)) 266 figure_funcs = [ 267 func 268 for func in ATTENTION_MATRIX_FIGURE_FUNCS 269 if pattern.match(getattr(func, "__name__", "<unknown>")) 270 ] 271 elif isinstance(figure_funcs_select, set): 272 # if a set, assume a set of function names 273 figure_funcs = [ 274 func 275 for func in ATTENTION_MATRIX_FIGURE_FUNCS 276 if getattr(func, "__name__", "<unknown>") in figure_funcs_select 277 ] 278 else: 279 err_msg: str = ( 280 f"figure_funcs_select must be None, str, or set, not {type(figure_funcs_select) = }" 281 f"\n{figure_funcs_select = }" 282 ) 283 raise TypeError(err_msg) 284 return figure_funcs 285 286 287def figures_main( 288 model_name: str, 289 save_path: str | Path, 290 n_samples: int, 291 force: bool, 292 figure_funcs_select: set[str] | str | None = None, 293 parallel: bool | int = True, 294) -> None: 295 """main function for generating figures from attention patterns, using the functions in `ATTENTION_MATRIX_FIGURE_FUNCS` 296 297 # Parameters: 298 - `model_name : str` 299 model name to use, used for loading the model config, prompts, activations, and saving the figures 300 - `save_path : str | Path` 301 base path to look in 302 - `n_samples : int` 303 max number of samples to process 304 - `force : bool` 305 force overwrite of existing figures. if `False`, will skip any functions which have already saved a figure 306 - `figure_funcs_select : set[str]|str|None` 307 figure functions to use. if `None`, will use all functions. if a string, will use the function names which match the string. if a set, will use the function names in the set 308 (defaults to `None`) 309 - `parallel : bool | int` 310 whether to run in parallel. if `True`, will use all available cores. if `False`, will run in serial. if an int, will try to use that many cores 311 (defaults to `True`) 312 """ 313 with SpinnerContext(message="setting up paths", **SPINNER_KWARGS): 314 # save model info or check if it exists 315 save_path_p: Path = Path(save_path) 316 model_path: Path = save_path_p / model_name 317 with open(model_path / "model_cfg.json", "r") as f: 318 model_cfg = HTConfigMock.load(json.load(f)) 319 320 with SpinnerContext(message="loading prompts", **SPINNER_KWARGS): 321 # load prompts 322 with open(model_path / "prompts.jsonl", "r") as f: 323 prompts: list[dict] = [json.loads(line) for line in f.readlines()] 324 # truncate to n_samples 325 prompts = prompts[:n_samples] 326 327 print(f"{len(prompts)} prompts loaded") 328 329 figure_funcs: list[AttentionMatrixFigureFunc] = select_attn_figure_funcs( 330 figure_funcs_select=figure_funcs_select, 331 ) 332 print(f"{len(figure_funcs)} figure functions loaded") 333 print( 334 "\t" 335 + ", ".join([getattr(func, "__name__", "<unknown>") for func in figure_funcs]), 336 ) 337 338 chunksize: int = int( 339 max( 340 1, 341 len(prompts) // (5 * multiprocessing.cpu_count()), 342 ), 343 ) 344 print(f"chunksize: {chunksize}") 345 346 list( 347 run_maybe_parallel( 348 func=functools.partial( 349 process_prompt, 350 model_cfg=model_cfg, 351 save_path=save_path_p, 352 figure_funcs=figure_funcs, 353 force_overwrite=force, 354 ), 355 iterable=prompts, 356 parallel=parallel, 357 chunksize=chunksize, 358 pbar="tqdm", 359 pbar_kwargs=dict( 360 desc="Making figures", 361 unit="prompt", 362 ), 363 ), 364 ) 365 366 with SpinnerContext( 367 message="updating jsonl metadata for models and functions", 368 **SPINNER_KWARGS, 369 ): 370 generate_models_jsonl(save_path_p) 371 generate_functions_jsonl(save_path_p) 372 373 374def _parse_args() -> tuple[ 375 argparse.Namespace, 376 list[str], # models 377 set[str] | str | None, # figure_funcs_select 378]: 379 arg_parser: argparse.ArgumentParser = argparse.ArgumentParser() 380 # input and output 381 arg_parser.add_argument( 382 "--model", 383 "-m", 384 type=str, 385 required=True, 386 help="The model name(s) to use. comma separated with no whitespace if multiple", 387 ) 388 arg_parser.add_argument( 389 "--save-path", 390 "-s", 391 type=str, 392 required=False, 393 help="The path to save the attention patterns", 394 default=DATA_DIR, 395 ) 396 # number of samples 397 arg_parser.add_argument( 398 "--n-samples", 399 "-n", 400 type=int, 401 required=False, 402 help="The max number of samples to process, do all in the file if None", 403 default=None, 404 ) 405 # force overwrite of existing figures 406 arg_parser.add_argument( 407 "--force", 408 "-f", 409 type=bool, 410 required=False, 411 help="Force overwrite of existing figures", 412 default=False, 413 ) 414 # figure functions 415 arg_parser.add_argument( 416 "--figure-funcs", 417 type=str, 418 required=False, 419 help="The figure functions to use. if 'None' (default), will use all functions. if a string, will use the function names which match the string. if a comma-separated list of strings, will use the function names in the set", 420 default=None, 421 ) 422 423 args: argparse.Namespace = arg_parser.parse_args() 424 425 # figure out models 426 models: list[str] 427 if "," in args.model: 428 models = args.model.split(",") 429 else: 430 models = [args.model] 431 432 # figure out figures 433 figure_funcs_select: set[str] | str | None 434 if (args.figure_funcs is None) or (args.figure_funcs.lower().strip() == "none"): 435 figure_funcs_select = None 436 elif "," in args.figure_funcs: 437 figure_funcs_select = {x.strip() for x in args.figure_funcs.split(",")} 438 else: 439 figure_funcs_select = args.figure_funcs.strip() 440 441 return args, models, figure_funcs_select 442 443 444def main() -> None: 445 "generates figures from the activations using the functions decorated with `register_attn_figure_func`" 446 # parse args 447 print(DIVIDER_S1) 448 args: argparse.Namespace 449 models: list[str] 450 figure_funcs_select: set[str] | str | None 451 with SpinnerContext(message="parsing args", **SPINNER_KWARGS): 452 args, models, figure_funcs_select = _parse_args() 453 print(f"\targs parsed: '{args}'") 454 print(f"\tmodels: '{models}'") 455 print(f"\tfigure_funcs_select: '{figure_funcs_select}'") 456 457 # compute for each model 458 n_models: int = len(models) 459 for idx, model in enumerate(models): 460 print(DIVIDER_S2) 461 print(f"processing model {idx + 1} / {n_models}: {model}") 462 print(DIVIDER_S2) 463 figures_main( 464 model_name=model, 465 save_path=args.save_path, 466 n_samples=args.n_samples, 467 force=args.force, 468 figure_funcs_select=figure_funcs_select, 469 ) 470 471 print(DIVIDER_S1) 472 473 474if __name__ == "__main__": 475 main()
42class HTConfigMock: 43 """Mock of `transformer_lens.HookedTransformerConfig` for type hinting and loading config json 44 45 can be initialized with any kwargs, and will update its `__dict__` with them. does, however, require the following attributes: 46 - `n_layers: int` 47 - `n_heads: int` 48 - `model_name: str` 49 50 we do this to avoid having to import `torch` and `transformer_lens`, since this would have to be done for each process in the parallelization and probably slows things down significantly 51 """ 52 53 def __init__(self, **kwargs: dict[str, str | int]) -> None: 54 "will pass all kwargs to `__dict__`" 55 self.n_layers: int 56 self.n_heads: int 57 self.model_name: str 58 self.__dict__.update(kwargs) 59 60 def serialize(self) -> dict: 61 """serialize the config to json. values which aren't serializable will be converted via `muutils.json_serialize.json_serialize`""" 62 # its fine, we know its a dict 63 return json_serialize(self.__dict__) # type: ignore[return-value] 64 65 @classmethod 66 def load(cls, data: dict) -> "HTConfigMock": 67 "try to load a config from a dict, using the `__init__` method" 68 return cls(**data)
Mock of transformer_lens.HookedTransformerConfig for type hinting and loading config json
can be initialized with any kwargs, and will update its __dict__ with them. does, however, require the following attributes:
n_layers: intn_heads: intmodel_name: str
we do this to avoid having to import torch and transformer_lens, since this would have to be done for each process in the parallelization and probably slows things down significantly
53 def __init__(self, **kwargs: dict[str, str | int]) -> None: 54 "will pass all kwargs to `__dict__`" 55 self.n_layers: int 56 self.n_heads: int 57 self.model_name: str 58 self.__dict__.update(kwargs)
will pass all kwargs to __dict__
60 def serialize(self) -> dict: 61 """serialize the config to json. values which aren't serializable will be converted via `muutils.json_serialize.json_serialize`""" 62 # its fine, we know its a dict 63 return json_serialize(self.__dict__) # type: ignore[return-value]
serialize the config to json. values which aren't serializable will be converted via muutils.json_serialize.json_serialize
71def process_single_head( 72 layer_idx: int, 73 head_idx: int, 74 attn_pattern: AttentionMatrix, 75 save_dir: Path, 76 figure_funcs: list[AttentionMatrixFigureFunc], 77 force_overwrite: bool = False, 78) -> dict[str, bool | Exception]: 79 """process a single head's attention pattern, running all the functions in `figure_funcs` on the attention pattern 80 81 > [gotcha:] if `force_overwrite` is `False`, and we used a multi-figure function, 82 > it will skip all figures for that function if any are already saved 83 > and it assumes a format of `{func_name}.{figure_name}.{fmt}` for the saved figures 84 85 # Parameters: 86 - `layer_idx : int` 87 - `head_idx : int` 88 - `attn_pattern : AttentionMatrix` 89 attention pattern for the head 90 - `save_dir : Path` 91 directory to save the figures to 92 - `force_overwrite : bool` 93 whether to overwrite existing figures. if `False`, will skip any functions which have already saved a figure 94 (defaults to `False`) 95 96 # Returns: 97 - `dict[str, bool | Exception]` 98 a dictionary of the status of each function, with the function name as the key and the status as the value 99 """ 100 funcs_status: dict[str, bool | Exception] = dict() 101 102 for func in figure_funcs: 103 func_name: str = getattr(func, "__name__", "<unknown>") 104 fig_path: list[Path] = list(save_dir.glob(f"{func_name}.*")) 105 106 if not force_overwrite and len(fig_path) > 0: 107 funcs_status[func_name] = True 108 continue 109 110 try: 111 func(attn_pattern, save_dir) 112 funcs_status[func_name] = True 113 114 # bling catch any exception 115 except Exception as e: # noqa: BLE001 116 error_file = save_dir / f"{func_name}.error.txt" 117 error_file.write_text(str(e)) 118 warnings.warn( 119 f"Error in {func_name} for L{layer_idx}H{head_idx}: {e!s}", 120 stacklevel=2, 121 ) 122 funcs_status[func_name] = e 123 124 return funcs_status
process a single head's attention pattern, running all the functions in figure_funcs on the attention pattern
[gotcha:] if
force_overwriteisFalse, and we used a multi-figure function, it will skip all figures for that function if any are already saved and it assumes a format of{func_name}.{figure_name}.{fmt}for the saved figures
Parameters:
layer_idx : inthead_idx : intattn_pattern : AttentionMatrixattention pattern for the headsave_dir : Pathdirectory to save the figures toforce_overwrite : boolwhether to overwrite existing figures. ifFalse, will skip any functions which have already saved a figure (defaults toFalse)
Returns:
dict[str, bool | Exception]a dictionary of the status of each function, with the function name as the key and the status as the value
127def compute_and_save_figures( 128 model_cfg: "HookedTransformerConfig|HTConfigMock", # type: ignore[name-defined] # noqa: F821 129 activations_path: Path, 130 cache: ActivationCacheNp | Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"], 131 figure_funcs: list[AttentionMatrixFigureFunc], 132 save_path: Path = Path(DATA_DIR), 133 force_overwrite: bool = False, 134 track_results: bool = False, 135) -> None: 136 """compute and save figures for all heads in the model, using the functions in `ATTENTION_MATRIX_FIGURE_FUNCS` 137 138 # Parameters: 139 - `model_cfg : HookedTransformerConfig|HTConfigMock` 140 configuration of the model, used for loading the activations 141 - `cache : ActivationCacheNp | Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"]` 142 activation cache containing actual patterns for the prompt we are processing 143 - `figure_funcs : list[AttentionMatrixFigureFunc]` 144 list of functions to run 145 - `save_path : Path` 146 directory to save the figures to 147 (defaults to `Path(DATA_DIR)`) 148 - `force_overwrite : bool` 149 force overwrite of existing figures. if `False`, will skip any functions which have already saved a figure 150 (defaults to `False`) 151 - `track_results : bool` 152 whether to track the results of each function for each head. Isn't used for anything yet, but this is a TODO 153 (defaults to `False`) 154 """ 155 prompt_dir: Path = activations_path.parent 156 157 if track_results: 158 results: defaultdict[ 159 str, # func name 160 dict[ 161 tuple[int, int], # layer, head 162 bool | Exception, # success or exception 163 ], 164 ] = defaultdict(dict) 165 166 for layer_idx, head_idx in itertools.product( 167 range(model_cfg.n_layers), 168 range(model_cfg.n_heads), 169 ): 170 attn_pattern: AttentionMatrix 171 if isinstance(cache, dict): 172 attn_pattern = cache[f"blocks.{layer_idx}.attn.hook_pattern"][0, head_idx] 173 elif isinstance(cache, np.ndarray): 174 attn_pattern = cache[layer_idx, head_idx] 175 else: 176 msg = ( 177 f"cache must be a dict or np.ndarray, not {type(cache) = }\n{cache = }" 178 ) 179 raise TypeError( 180 msg, 181 ) 182 183 save_dir: Path = prompt_dir / f"L{layer_idx}" / f"H{head_idx}" 184 save_dir.mkdir(parents=True, exist_ok=True) 185 head_res: dict[str, bool | Exception] = process_single_head( 186 layer_idx=layer_idx, 187 head_idx=head_idx, 188 attn_pattern=attn_pattern, 189 save_dir=save_dir, 190 force_overwrite=force_overwrite, 191 figure_funcs=figure_funcs, 192 ) 193 194 if track_results: 195 for func_name, status in head_res.items(): 196 results[func_name][(layer_idx, head_idx)] = status 197 198 # TODO: do something with results 199 200 generate_prompts_jsonl(save_path / model_cfg.model_name)
compute and save figures for all heads in the model, using the functions in ATTENTION_MATRIX_FIGURE_FUNCS
Parameters:
model_cfg : HookedTransformerConfig|HTConfigMockconfiguration of the model, used for loading the activationscache : ActivationCacheNp | Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"]activation cache containing actual patterns for the prompt we are processingfigure_funcs : list[AttentionMatrixFigureFunc]list of functions to runsave_path : Pathdirectory to save the figures to (defaults toPath(DATA_DIR))force_overwrite : boolforce overwrite of existing figures. ifFalse, will skip any functions which have already saved a figure (defaults toFalse)track_results : boolwhether to track the results of each function for each head. Isn't used for anything yet, but this is a TODO (defaults toFalse)
203def process_prompt( 204 prompt: dict, 205 model_cfg: "HookedTransformerConfig|HTConfigMock", # type: ignore[name-defined] # noqa: F821 206 save_path: Path, 207 figure_funcs: list[AttentionMatrixFigureFunc], 208 force_overwrite: bool = False, 209) -> None: 210 """process a single prompt, loading the activations and computing and saving the figures 211 212 basically just calls `load_activations` and then `compute_and_save_figures` 213 214 # Parameters: 215 - `prompt : dict` 216 prompt to process, should be a dict with the following keys: 217 - `"text"`: the prompt string 218 - `"hash"`: the hash of the prompt 219 - `model_cfg : HookedTransformerConfig|HTConfigMock` 220 configuration of the model, used for figuring out where to save 221 - `save_path : Path` 222 directory to save the figures to 223 - `figure_funcs : list[AttentionMatrixFigureFunc]` 224 list of functions to run 225 - `force_overwrite : bool` 226 (defaults to `False`) 227 """ 228 # load the activations 229 activations_path: Path 230 cache: ActivationCacheNp | Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"] 231 activations_path, cache = load_activations( 232 model_name=model_cfg.model_name, 233 prompt=prompt, 234 save_path=save_path, 235 return_fmt="numpy", 236 ) 237 238 # compute and save the figures 239 compute_and_save_figures( 240 model_cfg=model_cfg, 241 activations_path=activations_path, 242 cache=cache, 243 figure_funcs=figure_funcs, 244 save_path=save_path, 245 force_overwrite=force_overwrite, 246 )
process a single prompt, loading the activations and computing and saving the figures
basically just calls load_activations and then compute_and_save_figures
Parameters:
prompt : dictprompt to process, should be a dict with the following keys:"text": the prompt string"hash": the hash of the prompt
model_cfg : HookedTransformerConfig|HTConfigMockconfiguration of the model, used for figuring out where to savesave_path : Pathdirectory to save the figures tofigure_funcs : list[AttentionMatrixFigureFunc]list of functions to runforce_overwrite : bool(defaults toFalse)
249def select_attn_figure_funcs( 250 figure_funcs_select: set[str] | str | None = None, 251) -> list[AttentionMatrixFigureFunc]: 252 """given a selector, figure out which functions from `ATTENTION_MATRIX_FIGURE_FUNCS` to use 253 254 - if arg is `None`, will use all functions 255 - if a string, will use the function names which match the string (glob/fnmatch syntax) 256 - if a set, will use functions whose names are in the set 257 258 """ 259 # figure out which functions to use 260 figure_funcs: list[AttentionMatrixFigureFunc] 261 if figure_funcs_select is None: 262 # all if nothing specified 263 figure_funcs = ATTENTION_MATRIX_FIGURE_FUNCS 264 elif isinstance(figure_funcs_select, str): 265 # if a string, assume a glob pattern 266 pattern: re.Pattern = re.compile(fnmatch.translate(figure_funcs_select)) 267 figure_funcs = [ 268 func 269 for func in ATTENTION_MATRIX_FIGURE_FUNCS 270 if pattern.match(getattr(func, "__name__", "<unknown>")) 271 ] 272 elif isinstance(figure_funcs_select, set): 273 # if a set, assume a set of function names 274 figure_funcs = [ 275 func 276 for func in ATTENTION_MATRIX_FIGURE_FUNCS 277 if getattr(func, "__name__", "<unknown>") in figure_funcs_select 278 ] 279 else: 280 err_msg: str = ( 281 f"figure_funcs_select must be None, str, or set, not {type(figure_funcs_select) = }" 282 f"\n{figure_funcs_select = }" 283 ) 284 raise TypeError(err_msg) 285 return figure_funcs
given a selector, figure out which functions from ATTENTION_MATRIX_FIGURE_FUNCS to use
- if arg is
None, will use all functions - if a string, will use the function names which match the string (glob/fnmatch syntax)
- if a set, will use functions whose names are in the set
288def figures_main( 289 model_name: str, 290 save_path: str | Path, 291 n_samples: int, 292 force: bool, 293 figure_funcs_select: set[str] | str | None = None, 294 parallel: bool | int = True, 295) -> None: 296 """main function for generating figures from attention patterns, using the functions in `ATTENTION_MATRIX_FIGURE_FUNCS` 297 298 # Parameters: 299 - `model_name : str` 300 model name to use, used for loading the model config, prompts, activations, and saving the figures 301 - `save_path : str | Path` 302 base path to look in 303 - `n_samples : int` 304 max number of samples to process 305 - `force : bool` 306 force overwrite of existing figures. if `False`, will skip any functions which have already saved a figure 307 - `figure_funcs_select : set[str]|str|None` 308 figure functions to use. if `None`, will use all functions. if a string, will use the function names which match the string. if a set, will use the function names in the set 309 (defaults to `None`) 310 - `parallel : bool | int` 311 whether to run in parallel. if `True`, will use all available cores. if `False`, will run in serial. if an int, will try to use that many cores 312 (defaults to `True`) 313 """ 314 with SpinnerContext(message="setting up paths", **SPINNER_KWARGS): 315 # save model info or check if it exists 316 save_path_p: Path = Path(save_path) 317 model_path: Path = save_path_p / model_name 318 with open(model_path / "model_cfg.json", "r") as f: 319 model_cfg = HTConfigMock.load(json.load(f)) 320 321 with SpinnerContext(message="loading prompts", **SPINNER_KWARGS): 322 # load prompts 323 with open(model_path / "prompts.jsonl", "r") as f: 324 prompts: list[dict] = [json.loads(line) for line in f.readlines()] 325 # truncate to n_samples 326 prompts = prompts[:n_samples] 327 328 print(f"{len(prompts)} prompts loaded") 329 330 figure_funcs: list[AttentionMatrixFigureFunc] = select_attn_figure_funcs( 331 figure_funcs_select=figure_funcs_select, 332 ) 333 print(f"{len(figure_funcs)} figure functions loaded") 334 print( 335 "\t" 336 + ", ".join([getattr(func, "__name__", "<unknown>") for func in figure_funcs]), 337 ) 338 339 chunksize: int = int( 340 max( 341 1, 342 len(prompts) // (5 * multiprocessing.cpu_count()), 343 ), 344 ) 345 print(f"chunksize: {chunksize}") 346 347 list( 348 run_maybe_parallel( 349 func=functools.partial( 350 process_prompt, 351 model_cfg=model_cfg, 352 save_path=save_path_p, 353 figure_funcs=figure_funcs, 354 force_overwrite=force, 355 ), 356 iterable=prompts, 357 parallel=parallel, 358 chunksize=chunksize, 359 pbar="tqdm", 360 pbar_kwargs=dict( 361 desc="Making figures", 362 unit="prompt", 363 ), 364 ), 365 ) 366 367 with SpinnerContext( 368 message="updating jsonl metadata for models and functions", 369 **SPINNER_KWARGS, 370 ): 371 generate_models_jsonl(save_path_p) 372 generate_functions_jsonl(save_path_p)
main function for generating figures from attention patterns, using the functions in ATTENTION_MATRIX_FIGURE_FUNCS
Parameters:
model_name : strmodel name to use, used for loading the model config, prompts, activations, and saving the figuressave_path : str | Pathbase path to look inn_samples : intmax number of samples to processforce : boolforce overwrite of existing figures. ifFalse, will skip any functions which have already saved a figurefigure_funcs_select : set[str]|str|Nonefigure functions to use. ifNone, will use all functions. if a string, will use the function names which match the string. if a set, will use the function names in the set (defaults toNone)parallel : bool | intwhether to run in parallel. ifTrue, will use all available cores. ifFalse, will run in serial. if an int, will try to use that many cores (defaults toTrue)
445def main() -> None: 446 "generates figures from the activations using the functions decorated with `register_attn_figure_func`" 447 # parse args 448 print(DIVIDER_S1) 449 args: argparse.Namespace 450 models: list[str] 451 figure_funcs_select: set[str] | str | None 452 with SpinnerContext(message="parsing args", **SPINNER_KWARGS): 453 args, models, figure_funcs_select = _parse_args() 454 print(f"\targs parsed: '{args}'") 455 print(f"\tmodels: '{models}'") 456 print(f"\tfigure_funcs_select: '{figure_funcs_select}'") 457 458 # compute for each model 459 n_models: int = len(models) 460 for idx, model in enumerate(models): 461 print(DIVIDER_S2) 462 print(f"processing model {idx + 1} / {n_models}: {model}") 463 print(DIVIDER_S2) 464 figures_main( 465 model_name=model, 466 save_path=args.save_path, 467 n_samples=args.n_samples, 468 force=args.force, 469 figure_funcs_select=figure_funcs_select, 470 ) 471 472 print(DIVIDER_S1)
generates figures from the activations using the functions decorated with register_attn_figure_func