pattern_lens.figures
code for generating figures from attention patterns, using the functions decorated with register_attn_figure_func
1"""code for generating figures from attention patterns, using the functions decorated with `register_attn_figure_func`""" 2 3import argparse 4import fnmatch 5import functools 6import itertools 7import json 8import multiprocessing 9import re 10import warnings 11from collections import defaultdict 12from pathlib import Path 13 14import numpy as np 15from jaxtyping import Float 16 17# custom utils 18from muutils.json_serialize import json_serialize 19from muutils.parallel import run_maybe_parallel 20from muutils.spinner import SpinnerContext 21 22# pattern_lens 23from pattern_lens.attn_figure_funcs import ATTENTION_MATRIX_FIGURE_FUNCS 24from pattern_lens.consts import ( 25 DATA_DIR, 26 DIVIDER_S1, 27 DIVIDER_S2, 28 SPINNER_KWARGS, 29 ActivationCacheNp, 30 AttentionMatrix, 31) 32from pattern_lens.figure_util import AttentionMatrixFigureFunc 33from pattern_lens.indexes import ( 34 generate_functions_jsonl, 35 generate_models_jsonl, 36 generate_prompts_jsonl, 37) 38from pattern_lens.load_activations import load_activations 39 40 41class HTConfigMock: 42 """Mock of `transformer_lens.HookedTransformerConfig` for type hinting and loading config json 43 44 can be initialized with any kwargs, and will update its `__dict__` with them. does, however, require the following attributes: 45 - `n_layers: int` 46 - `n_heads: int` 47 - `model_name: str` 48 49 we do this to avoid having to import `torch` and `transformer_lens`, since this would have to be done for each process in the parallelization and probably slows things down significantly 50 """ 51 52 def __init__(self, **kwargs: dict[str, str | int]) -> None: 53 "will pass all kwargs to `__dict__`" 54 self.n_layers: int 55 self.n_heads: int 56 self.model_name: str 57 self.__dict__.update(kwargs) 58 59 def serialize(self) -> dict: 60 """serialize the config to json. values which aren't serializable will be converted via `muutils.json_serialize.json_serialize`""" 61 # its fine, we know its a dict 62 return json_serialize(self.__dict__) # type: ignore[return-value] 63 64 @classmethod 65 def load(cls, data: dict) -> "HTConfigMock": 66 "try to load a config from a dict, using the `__init__` method" 67 return cls(**data) 68 69 70def process_single_head( 71 layer_idx: int, 72 head_idx: int, 73 attn_pattern: AttentionMatrix, 74 save_dir: Path, 75 figure_funcs: list[AttentionMatrixFigureFunc], 76 force_overwrite: bool = False, 77) -> dict[str, bool | Exception]: 78 """process a single head's attention pattern, running all the functions in `figure_funcs` on the attention pattern 79 80 > [gotcha:] if `force_overwrite` is `False`, and we used a multi-figure function, 81 > it will skip all figures for that function if any are already saved 82 > and it assumes a format of `{func_name}.{figure_name}.{fmt}` for the saved figures 83 84 # Parameters: 85 - `layer_idx : int` 86 - `head_idx : int` 87 - `attn_pattern : AttentionMatrix` 88 attention pattern for the head 89 - `save_dir : Path` 90 directory to save the figures to 91 - `force_overwrite : bool` 92 whether to overwrite existing figures. if `False`, will skip any functions which have already saved a figure 93 (defaults to `False`) 94 95 # Returns: 96 - `dict[str, bool | Exception]` 97 a dictionary of the status of each function, with the function name as the key and the status as the value 98 """ 99 funcs_status: dict[str, bool | Exception] = dict() 100 101 for func in figure_funcs: 102 func_name: str = func.__name__ 103 fig_path: list[Path] = list(save_dir.glob(f"{func_name}.*")) 104 105 if not force_overwrite and len(fig_path) > 0: 106 funcs_status[func_name] = True 107 continue 108 109 try: 110 func(attn_pattern, save_dir) 111 funcs_status[func_name] = True 112 113 # bling catch any exception 114 except Exception as e: # noqa: BLE001 115 error_file = save_dir / f"{func.__name__}.error.txt" 116 error_file.write_text(str(e)) 117 warnings.warn( 118 f"Error in {func.__name__} for L{layer_idx}H{head_idx}: {e!s}", 119 stacklevel=2, 120 ) 121 funcs_status[func_name] = e 122 123 return funcs_status 124 125 126def compute_and_save_figures( 127 model_cfg: "HookedTransformerConfig|HTConfigMock", # type: ignore[name-defined] # noqa: F821 128 activations_path: Path, 129 cache: ActivationCacheNp | Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"], 130 figure_funcs: list[AttentionMatrixFigureFunc], 131 save_path: Path = Path(DATA_DIR), 132 force_overwrite: bool = False, 133 track_results: bool = False, 134) -> None: 135 """compute and save figures for all heads in the model, using the functions in `ATTENTION_MATRIX_FIGURE_FUNCS` 136 137 # Parameters: 138 - `model_cfg : HookedTransformerConfig|HTConfigMock` 139 configuration of the model, used for loading the activations 140 - `cache : ActivationCacheNp | Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"]` 141 activation cache containing actual patterns for the prompt we are processing 142 - `figure_funcs : list[AttentionMatrixFigureFunc]` 143 list of functions to run 144 - `save_path : Path` 145 directory to save the figures to 146 (defaults to `Path(DATA_DIR)`) 147 - `force_overwrite : bool` 148 force overwrite of existing figures. if `False`, will skip any functions which have already saved a figure 149 (defaults to `False`) 150 - `track_results : bool` 151 whether to track the results of each function for each head. Isn't used for anything yet, but this is a TODO 152 (defaults to `False`) 153 """ 154 prompt_dir: Path = activations_path.parent 155 156 if track_results: 157 results: defaultdict[ 158 str, # func name 159 dict[ 160 tuple[int, int], # layer, head 161 bool | Exception, # success or exception 162 ], 163 ] = defaultdict(dict) 164 165 for layer_idx, head_idx in itertools.product( 166 range(model_cfg.n_layers), 167 range(model_cfg.n_heads), 168 ): 169 attn_pattern: AttentionMatrix 170 if isinstance(cache, dict): 171 attn_pattern = cache[f"blocks.{layer_idx}.attn.hook_pattern"][0, head_idx] 172 elif isinstance(cache, np.ndarray): 173 attn_pattern = cache[layer_idx, head_idx] 174 else: 175 msg = ( 176 f"cache must be a dict or np.ndarray, not {type(cache) = }\n{cache = }" 177 ) 178 raise TypeError( 179 msg, 180 ) 181 182 save_dir: Path = prompt_dir / f"L{layer_idx}" / f"H{head_idx}" 183 save_dir.mkdir(parents=True, exist_ok=True) 184 head_res: dict[str, bool | Exception] = process_single_head( 185 layer_idx=layer_idx, 186 head_idx=head_idx, 187 attn_pattern=attn_pattern, 188 save_dir=save_dir, 189 force_overwrite=force_overwrite, 190 figure_funcs=figure_funcs, 191 ) 192 193 if track_results: 194 for func_name, status in head_res.items(): 195 results[func_name][(layer_idx, head_idx)] = status 196 197 # TODO: do something with results 198 199 generate_prompts_jsonl(save_path / model_cfg.model_name) 200 201 202def process_prompt( 203 prompt: dict, 204 model_cfg: "HookedTransformerConfig|HTConfigMock", # type: ignore[name-defined] # noqa: F821 205 save_path: Path, 206 figure_funcs: list[AttentionMatrixFigureFunc], 207 force_overwrite: bool = False, 208) -> None: 209 """process a single prompt, loading the activations and computing and saving the figures 210 211 basically just calls `load_activations` and then `compute_and_save_figures` 212 213 # Parameters: 214 - `prompt : dict` 215 prompt to process, should be a dict with the following keys: 216 - `"text"`: the prompt string 217 - `"hash"`: the hash of the prompt 218 - `model_cfg : HookedTransformerConfig|HTConfigMock` 219 configuration of the model, used for figuring out where to save 220 - `save_path : Path` 221 directory to save the figures to 222 - `figure_funcs : list[AttentionMatrixFigureFunc]` 223 list of functions to run 224 - `force_overwrite : bool` 225 (defaults to `False`) 226 """ 227 # load the activations 228 activations_path: Path 229 cache: ActivationCacheNp | Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"] 230 activations_path, cache = load_activations( 231 model_name=model_cfg.model_name, 232 prompt=prompt, 233 save_path=save_path, 234 return_fmt="numpy", 235 ) 236 237 # compute and save the figures 238 compute_and_save_figures( 239 model_cfg=model_cfg, 240 activations_path=activations_path, 241 cache=cache, 242 figure_funcs=figure_funcs, 243 save_path=save_path, 244 force_overwrite=force_overwrite, 245 ) 246 247 248def select_attn_figure_funcs( 249 figure_funcs_select: set[str] | str | None = None, 250) -> list[AttentionMatrixFigureFunc]: 251 """given a selector, figure out which functions from `ATTENTION_MATRIX_FIGURE_FUNCS` to use 252 253 - if arg is `None`, will use all functions 254 - if a string, will use the function names which match the string (glob/fnmatch syntax) 255 - if a set, will use functions whose names are in the set 256 257 """ 258 # figure out which functions to use 259 figure_funcs: list[AttentionMatrixFigureFunc] 260 if figure_funcs_select is None: 261 # all if nothing specified 262 figure_funcs = ATTENTION_MATRIX_FIGURE_FUNCS 263 elif isinstance(figure_funcs_select, str): 264 # if a string, assume a glob pattern 265 pattern: re.Pattern = re.compile(fnmatch.translate(figure_funcs_select)) 266 figure_funcs = [ 267 func 268 for func in ATTENTION_MATRIX_FIGURE_FUNCS 269 if pattern.match(func.__name__) 270 ] 271 elif isinstance(figure_funcs_select, set): 272 # if a set, assume a set of function names 273 figure_funcs = [ 274 func 275 for func in ATTENTION_MATRIX_FIGURE_FUNCS 276 if func.__name__ in figure_funcs_select 277 ] 278 else: 279 err_msg: str = ( 280 f"figure_funcs_select must be None, str, or set, not {type(figure_funcs_select) = }" 281 f"\n{figure_funcs_select = }" 282 ) 283 raise TypeError(err_msg) 284 return figure_funcs 285 286 287def figures_main( 288 model_name: str, 289 save_path: str, 290 n_samples: int, 291 force: bool, 292 figure_funcs_select: set[str] | str | None = None, 293 parallel: bool | int = True, 294) -> None: 295 """main function for generating figures from attention patterns, using the functions in `ATTENTION_MATRIX_FIGURE_FUNCS` 296 297 # Parameters: 298 - `model_name : str` 299 model name to use, used for loading the model config, prompts, activations, and saving the figures 300 - `save_path : str` 301 base path to look in 302 - `n_samples : int` 303 max number of samples to process 304 - `force : bool` 305 force overwrite of existing figures. if `False`, will skip any functions which have already saved a figure 306 - `figure_funcs_select : set[str]|str|None` 307 figure functions to use. if `None`, will use all functions. if a string, will use the function names which match the string. if a set, will use the function names in the set 308 (defaults to `None`) 309 - `parallel : bool | int` 310 whether to run in parallel. if `True`, will use all available cores. if `False`, will run in serial. if an int, will try to use that many cores 311 (defaults to `True`) 312 """ 313 with SpinnerContext(message="setting up paths", **SPINNER_KWARGS): 314 # save model info or check if it exists 315 save_path_p: Path = Path(save_path) 316 model_path: Path = save_path_p / model_name 317 with open(model_path / "model_cfg.json", "r") as f: 318 model_cfg = HTConfigMock.load(json.load(f)) 319 320 with SpinnerContext(message="loading prompts", **SPINNER_KWARGS): 321 # load prompts 322 with open(model_path / "prompts.jsonl", "r") as f: 323 prompts: list[dict] = [json.loads(line) for line in f.readlines()] 324 # truncate to n_samples 325 prompts = prompts[:n_samples] 326 327 print(f"{len(prompts)} prompts loaded") 328 329 figure_funcs: list[AttentionMatrixFigureFunc] = select_attn_figure_funcs( 330 figure_funcs_select=figure_funcs_select, 331 ) 332 print(f"{len(figure_funcs)} figure functions loaded") 333 print("\t" + ", ".join([func.__name__ for func in figure_funcs])) 334 335 chunksize: int = int( 336 max( 337 1, 338 len(prompts) // (5 * multiprocessing.cpu_count()), 339 ), 340 ) 341 print(f"chunksize: {chunksize}") 342 343 list( 344 run_maybe_parallel( 345 func=functools.partial( 346 process_prompt, 347 model_cfg=model_cfg, 348 save_path=save_path_p, 349 figure_funcs=figure_funcs, 350 force_overwrite=force, 351 ), 352 iterable=prompts, 353 parallel=parallel, 354 chunksize=chunksize, 355 pbar="tqdm", 356 pbar_kwargs=dict( 357 desc="Making figures", 358 unit="prompt", 359 ), 360 ), 361 ) 362 363 with SpinnerContext( 364 message="updating jsonl metadata for models and functions", 365 **SPINNER_KWARGS, 366 ): 367 generate_models_jsonl(save_path_p) 368 generate_functions_jsonl(save_path_p) 369 370 371def _parse_args() -> tuple[ 372 argparse.Namespace, 373 list[str], # models 374 set[str] | str | None, # figure_funcs_select 375]: 376 arg_parser: argparse.ArgumentParser = argparse.ArgumentParser() 377 # input and output 378 arg_parser.add_argument( 379 "--model", 380 "-m", 381 type=str, 382 required=True, 383 help="The model name(s) to use. comma separated with no whitespace if multiple", 384 ) 385 arg_parser.add_argument( 386 "--save-path", 387 "-s", 388 type=str, 389 required=False, 390 help="The path to save the attention patterns", 391 default=DATA_DIR, 392 ) 393 # number of samples 394 arg_parser.add_argument( 395 "--n-samples", 396 "-n", 397 type=int, 398 required=False, 399 help="The max number of samples to process, do all in the file if None", 400 default=None, 401 ) 402 # force overwrite of existing figures 403 arg_parser.add_argument( 404 "--force", 405 "-f", 406 type=bool, 407 required=False, 408 help="Force overwrite of existing figures", 409 default=False, 410 ) 411 # figure functions 412 arg_parser.add_argument( 413 "--figure-funcs", 414 type=str, 415 required=False, 416 help="The figure functions to use. if 'None' (default), will use all functions. if a string, will use the function names which match the string. if a comma-separated list of strings, will use the function names in the set", 417 default=None, 418 ) 419 420 args: argparse.Namespace = arg_parser.parse_args() 421 422 # figure out models 423 models: list[str] 424 if "," in args.model: 425 models = args.model.split(",") 426 else: 427 models = [args.model] 428 429 # figure out figures 430 figure_funcs_select: set[str] | str | None 431 if (args.figure_funcs is None) or (args.figure_funcs.lower().strip() == "none"): 432 figure_funcs_select = None 433 elif "," in args.figure_funcs: 434 figure_funcs_select = {x.strip() for x in args.figure_funcs.split(",")} 435 else: 436 figure_funcs_select = args.figure_funcs.strip() 437 438 return args, models, figure_funcs_select 439 440 441def main() -> None: 442 "generates figures from the activations using the functions decorated with `register_attn_figure_func`" 443 # parse args 444 print(DIVIDER_S1) 445 args: argparse.Namespace 446 models: list[str] 447 figure_funcs_select: set[str] | str | None 448 with SpinnerContext(message="parsing args", **SPINNER_KWARGS): 449 args, models, figure_funcs_select = _parse_args() 450 print(f"\targs parsed: '{args}'") 451 print(f"\tmodels: '{models}'") 452 print(f"\tfigure_funcs_select: '{figure_funcs_select}'") 453 454 # compute for each model 455 n_models: int = len(models) 456 for idx, model in enumerate(models): 457 print(DIVIDER_S2) 458 print(f"processing model {idx + 1} / {n_models}: {model}") 459 print(DIVIDER_S2) 460 figures_main( 461 model_name=model, 462 save_path=args.save_path, 463 n_samples=args.n_samples, 464 force=args.force, 465 figure_funcs_select=figure_funcs_select, 466 ) 467 468 print(DIVIDER_S1) 469 470 471if __name__ == "__main__": 472 main()
42class HTConfigMock: 43 """Mock of `transformer_lens.HookedTransformerConfig` for type hinting and loading config json 44 45 can be initialized with any kwargs, and will update its `__dict__` with them. does, however, require the following attributes: 46 - `n_layers: int` 47 - `n_heads: int` 48 - `model_name: str` 49 50 we do this to avoid having to import `torch` and `transformer_lens`, since this would have to be done for each process in the parallelization and probably slows things down significantly 51 """ 52 53 def __init__(self, **kwargs: dict[str, str | int]) -> None: 54 "will pass all kwargs to `__dict__`" 55 self.n_layers: int 56 self.n_heads: int 57 self.model_name: str 58 self.__dict__.update(kwargs) 59 60 def serialize(self) -> dict: 61 """serialize the config to json. values which aren't serializable will be converted via `muutils.json_serialize.json_serialize`""" 62 # its fine, we know its a dict 63 return json_serialize(self.__dict__) # type: ignore[return-value] 64 65 @classmethod 66 def load(cls, data: dict) -> "HTConfigMock": 67 "try to load a config from a dict, using the `__init__` method" 68 return cls(**data)
Mock of transformer_lens.HookedTransformerConfig for type hinting and loading config json
can be initialized with any kwargs, and will update its __dict__ with them. does, however, require the following attributes:
n_layers: intn_heads: intmodel_name: str
we do this to avoid having to import torch and transformer_lens, since this would have to be done for each process in the parallelization and probably slows things down significantly
53 def __init__(self, **kwargs: dict[str, str | int]) -> None: 54 "will pass all kwargs to `__dict__`" 55 self.n_layers: int 56 self.n_heads: int 57 self.model_name: str 58 self.__dict__.update(kwargs)
will pass all kwargs to __dict__
60 def serialize(self) -> dict: 61 """serialize the config to json. values which aren't serializable will be converted via `muutils.json_serialize.json_serialize`""" 62 # its fine, we know its a dict 63 return json_serialize(self.__dict__) # type: ignore[return-value]
serialize the config to json. values which aren't serializable will be converted via muutils.json_serialize.json_serialize
71def process_single_head( 72 layer_idx: int, 73 head_idx: int, 74 attn_pattern: AttentionMatrix, 75 save_dir: Path, 76 figure_funcs: list[AttentionMatrixFigureFunc], 77 force_overwrite: bool = False, 78) -> dict[str, bool | Exception]: 79 """process a single head's attention pattern, running all the functions in `figure_funcs` on the attention pattern 80 81 > [gotcha:] if `force_overwrite` is `False`, and we used a multi-figure function, 82 > it will skip all figures for that function if any are already saved 83 > and it assumes a format of `{func_name}.{figure_name}.{fmt}` for the saved figures 84 85 # Parameters: 86 - `layer_idx : int` 87 - `head_idx : int` 88 - `attn_pattern : AttentionMatrix` 89 attention pattern for the head 90 - `save_dir : Path` 91 directory to save the figures to 92 - `force_overwrite : bool` 93 whether to overwrite existing figures. if `False`, will skip any functions which have already saved a figure 94 (defaults to `False`) 95 96 # Returns: 97 - `dict[str, bool | Exception]` 98 a dictionary of the status of each function, with the function name as the key and the status as the value 99 """ 100 funcs_status: dict[str, bool | Exception] = dict() 101 102 for func in figure_funcs: 103 func_name: str = func.__name__ 104 fig_path: list[Path] = list(save_dir.glob(f"{func_name}.*")) 105 106 if not force_overwrite and len(fig_path) > 0: 107 funcs_status[func_name] = True 108 continue 109 110 try: 111 func(attn_pattern, save_dir) 112 funcs_status[func_name] = True 113 114 # bling catch any exception 115 except Exception as e: # noqa: BLE001 116 error_file = save_dir / f"{func.__name__}.error.txt" 117 error_file.write_text(str(e)) 118 warnings.warn( 119 f"Error in {func.__name__} for L{layer_idx}H{head_idx}: {e!s}", 120 stacklevel=2, 121 ) 122 funcs_status[func_name] = e 123 124 return funcs_status
process a single head's attention pattern, running all the functions in figure_funcs on the attention pattern
[gotcha:] if
force_overwriteisFalse, and we used a multi-figure function, it will skip all figures for that function if any are already saved and it assumes a format of{func_name}.{figure_name}.{fmt}for the saved figures
Parameters:
layer_idx : inthead_idx : intattn_pattern : AttentionMatrixattention pattern for the headsave_dir : Pathdirectory to save the figures toforce_overwrite : boolwhether to overwrite existing figures. ifFalse, will skip any functions which have already saved a figure (defaults toFalse)
Returns:
dict[str, bool | Exception]a dictionary of the status of each function, with the function name as the key and the status as the value
127def compute_and_save_figures( 128 model_cfg: "HookedTransformerConfig|HTConfigMock", # type: ignore[name-defined] # noqa: F821 129 activations_path: Path, 130 cache: ActivationCacheNp | Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"], 131 figure_funcs: list[AttentionMatrixFigureFunc], 132 save_path: Path = Path(DATA_DIR), 133 force_overwrite: bool = False, 134 track_results: bool = False, 135) -> None: 136 """compute and save figures for all heads in the model, using the functions in `ATTENTION_MATRIX_FIGURE_FUNCS` 137 138 # Parameters: 139 - `model_cfg : HookedTransformerConfig|HTConfigMock` 140 configuration of the model, used for loading the activations 141 - `cache : ActivationCacheNp | Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"]` 142 activation cache containing actual patterns for the prompt we are processing 143 - `figure_funcs : list[AttentionMatrixFigureFunc]` 144 list of functions to run 145 - `save_path : Path` 146 directory to save the figures to 147 (defaults to `Path(DATA_DIR)`) 148 - `force_overwrite : bool` 149 force overwrite of existing figures. if `False`, will skip any functions which have already saved a figure 150 (defaults to `False`) 151 - `track_results : bool` 152 whether to track the results of each function for each head. Isn't used for anything yet, but this is a TODO 153 (defaults to `False`) 154 """ 155 prompt_dir: Path = activations_path.parent 156 157 if track_results: 158 results: defaultdict[ 159 str, # func name 160 dict[ 161 tuple[int, int], # layer, head 162 bool | Exception, # success or exception 163 ], 164 ] = defaultdict(dict) 165 166 for layer_idx, head_idx in itertools.product( 167 range(model_cfg.n_layers), 168 range(model_cfg.n_heads), 169 ): 170 attn_pattern: AttentionMatrix 171 if isinstance(cache, dict): 172 attn_pattern = cache[f"blocks.{layer_idx}.attn.hook_pattern"][0, head_idx] 173 elif isinstance(cache, np.ndarray): 174 attn_pattern = cache[layer_idx, head_idx] 175 else: 176 msg = ( 177 f"cache must be a dict or np.ndarray, not {type(cache) = }\n{cache = }" 178 ) 179 raise TypeError( 180 msg, 181 ) 182 183 save_dir: Path = prompt_dir / f"L{layer_idx}" / f"H{head_idx}" 184 save_dir.mkdir(parents=True, exist_ok=True) 185 head_res: dict[str, bool | Exception] = process_single_head( 186 layer_idx=layer_idx, 187 head_idx=head_idx, 188 attn_pattern=attn_pattern, 189 save_dir=save_dir, 190 force_overwrite=force_overwrite, 191 figure_funcs=figure_funcs, 192 ) 193 194 if track_results: 195 for func_name, status in head_res.items(): 196 results[func_name][(layer_idx, head_idx)] = status 197 198 # TODO: do something with results 199 200 generate_prompts_jsonl(save_path / model_cfg.model_name)
compute and save figures for all heads in the model, using the functions in ATTENTION_MATRIX_FIGURE_FUNCS
Parameters:
model_cfg : HookedTransformerConfig|HTConfigMockconfiguration of the model, used for loading the activationscache : ActivationCacheNp | Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"]activation cache containing actual patterns for the prompt we are processingfigure_funcs : list[AttentionMatrixFigureFunc]list of functions to runsave_path : Pathdirectory to save the figures to (defaults toPath(DATA_DIR))force_overwrite : boolforce overwrite of existing figures. ifFalse, will skip any functions which have already saved a figure (defaults toFalse)track_results : boolwhether to track the results of each function for each head. Isn't used for anything yet, but this is a TODO (defaults toFalse)
203def process_prompt( 204 prompt: dict, 205 model_cfg: "HookedTransformerConfig|HTConfigMock", # type: ignore[name-defined] # noqa: F821 206 save_path: Path, 207 figure_funcs: list[AttentionMatrixFigureFunc], 208 force_overwrite: bool = False, 209) -> None: 210 """process a single prompt, loading the activations and computing and saving the figures 211 212 basically just calls `load_activations` and then `compute_and_save_figures` 213 214 # Parameters: 215 - `prompt : dict` 216 prompt to process, should be a dict with the following keys: 217 - `"text"`: the prompt string 218 - `"hash"`: the hash of the prompt 219 - `model_cfg : HookedTransformerConfig|HTConfigMock` 220 configuration of the model, used for figuring out where to save 221 - `save_path : Path` 222 directory to save the figures to 223 - `figure_funcs : list[AttentionMatrixFigureFunc]` 224 list of functions to run 225 - `force_overwrite : bool` 226 (defaults to `False`) 227 """ 228 # load the activations 229 activations_path: Path 230 cache: ActivationCacheNp | Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"] 231 activations_path, cache = load_activations( 232 model_name=model_cfg.model_name, 233 prompt=prompt, 234 save_path=save_path, 235 return_fmt="numpy", 236 ) 237 238 # compute and save the figures 239 compute_and_save_figures( 240 model_cfg=model_cfg, 241 activations_path=activations_path, 242 cache=cache, 243 figure_funcs=figure_funcs, 244 save_path=save_path, 245 force_overwrite=force_overwrite, 246 )
process a single prompt, loading the activations and computing and saving the figures
basically just calls load_activations and then compute_and_save_figures
Parameters:
prompt : dictprompt to process, should be a dict with the following keys:"text": the prompt string"hash": the hash of the prompt
model_cfg : HookedTransformerConfig|HTConfigMockconfiguration of the model, used for figuring out where to savesave_path : Pathdirectory to save the figures tofigure_funcs : list[AttentionMatrixFigureFunc]list of functions to runforce_overwrite : bool(defaults toFalse)
249def select_attn_figure_funcs( 250 figure_funcs_select: set[str] | str | None = None, 251) -> list[AttentionMatrixFigureFunc]: 252 """given a selector, figure out which functions from `ATTENTION_MATRIX_FIGURE_FUNCS` to use 253 254 - if arg is `None`, will use all functions 255 - if a string, will use the function names which match the string (glob/fnmatch syntax) 256 - if a set, will use functions whose names are in the set 257 258 """ 259 # figure out which functions to use 260 figure_funcs: list[AttentionMatrixFigureFunc] 261 if figure_funcs_select is None: 262 # all if nothing specified 263 figure_funcs = ATTENTION_MATRIX_FIGURE_FUNCS 264 elif isinstance(figure_funcs_select, str): 265 # if a string, assume a glob pattern 266 pattern: re.Pattern = re.compile(fnmatch.translate(figure_funcs_select)) 267 figure_funcs = [ 268 func 269 for func in ATTENTION_MATRIX_FIGURE_FUNCS 270 if pattern.match(func.__name__) 271 ] 272 elif isinstance(figure_funcs_select, set): 273 # if a set, assume a set of function names 274 figure_funcs = [ 275 func 276 for func in ATTENTION_MATRIX_FIGURE_FUNCS 277 if func.__name__ in figure_funcs_select 278 ] 279 else: 280 err_msg: str = ( 281 f"figure_funcs_select must be None, str, or set, not {type(figure_funcs_select) = }" 282 f"\n{figure_funcs_select = }" 283 ) 284 raise TypeError(err_msg) 285 return figure_funcs
given a selector, figure out which functions from ATTENTION_MATRIX_FIGURE_FUNCS to use
- if arg is
None, will use all functions - if a string, will use the function names which match the string (glob/fnmatch syntax)
- if a set, will use functions whose names are in the set
288def figures_main( 289 model_name: str, 290 save_path: str, 291 n_samples: int, 292 force: bool, 293 figure_funcs_select: set[str] | str | None = None, 294 parallel: bool | int = True, 295) -> None: 296 """main function for generating figures from attention patterns, using the functions in `ATTENTION_MATRIX_FIGURE_FUNCS` 297 298 # Parameters: 299 - `model_name : str` 300 model name to use, used for loading the model config, prompts, activations, and saving the figures 301 - `save_path : str` 302 base path to look in 303 - `n_samples : int` 304 max number of samples to process 305 - `force : bool` 306 force overwrite of existing figures. if `False`, will skip any functions which have already saved a figure 307 - `figure_funcs_select : set[str]|str|None` 308 figure functions to use. if `None`, will use all functions. if a string, will use the function names which match the string. if a set, will use the function names in the set 309 (defaults to `None`) 310 - `parallel : bool | int` 311 whether to run in parallel. if `True`, will use all available cores. if `False`, will run in serial. if an int, will try to use that many cores 312 (defaults to `True`) 313 """ 314 with SpinnerContext(message="setting up paths", **SPINNER_KWARGS): 315 # save model info or check if it exists 316 save_path_p: Path = Path(save_path) 317 model_path: Path = save_path_p / model_name 318 with open(model_path / "model_cfg.json", "r") as f: 319 model_cfg = HTConfigMock.load(json.load(f)) 320 321 with SpinnerContext(message="loading prompts", **SPINNER_KWARGS): 322 # load prompts 323 with open(model_path / "prompts.jsonl", "r") as f: 324 prompts: list[dict] = [json.loads(line) for line in f.readlines()] 325 # truncate to n_samples 326 prompts = prompts[:n_samples] 327 328 print(f"{len(prompts)} prompts loaded") 329 330 figure_funcs: list[AttentionMatrixFigureFunc] = select_attn_figure_funcs( 331 figure_funcs_select=figure_funcs_select, 332 ) 333 print(f"{len(figure_funcs)} figure functions loaded") 334 print("\t" + ", ".join([func.__name__ for func in figure_funcs])) 335 336 chunksize: int = int( 337 max( 338 1, 339 len(prompts) // (5 * multiprocessing.cpu_count()), 340 ), 341 ) 342 print(f"chunksize: {chunksize}") 343 344 list( 345 run_maybe_parallel( 346 func=functools.partial( 347 process_prompt, 348 model_cfg=model_cfg, 349 save_path=save_path_p, 350 figure_funcs=figure_funcs, 351 force_overwrite=force, 352 ), 353 iterable=prompts, 354 parallel=parallel, 355 chunksize=chunksize, 356 pbar="tqdm", 357 pbar_kwargs=dict( 358 desc="Making figures", 359 unit="prompt", 360 ), 361 ), 362 ) 363 364 with SpinnerContext( 365 message="updating jsonl metadata for models and functions", 366 **SPINNER_KWARGS, 367 ): 368 generate_models_jsonl(save_path_p) 369 generate_functions_jsonl(save_path_p)
main function for generating figures from attention patterns, using the functions in ATTENTION_MATRIX_FIGURE_FUNCS
Parameters:
model_name : strmodel name to use, used for loading the model config, prompts, activations, and saving the figuressave_path : strbase path to look inn_samples : intmax number of samples to processforce : boolforce overwrite of existing figures. ifFalse, will skip any functions which have already saved a figurefigure_funcs_select : set[str]|str|Nonefigure functions to use. ifNone, will use all functions. if a string, will use the function names which match the string. if a set, will use the function names in the set (defaults toNone)parallel : bool | intwhether to run in parallel. ifTrue, will use all available cores. ifFalse, will run in serial. if an int, will try to use that many cores (defaults toTrue)
442def main() -> None: 443 "generates figures from the activations using the functions decorated with `register_attn_figure_func`" 444 # parse args 445 print(DIVIDER_S1) 446 args: argparse.Namespace 447 models: list[str] 448 figure_funcs_select: set[str] | str | None 449 with SpinnerContext(message="parsing args", **SPINNER_KWARGS): 450 args, models, figure_funcs_select = _parse_args() 451 print(f"\targs parsed: '{args}'") 452 print(f"\tmodels: '{models}'") 453 print(f"\tfigure_funcs_select: '{figure_funcs_select}'") 454 455 # compute for each model 456 n_models: int = len(models) 457 for idx, model in enumerate(models): 458 print(DIVIDER_S2) 459 print(f"processing model {idx + 1} / {n_models}: {model}") 460 print(DIVIDER_S2) 461 figures_main( 462 model_name=model, 463 save_path=args.save_path, 464 n_samples=args.n_samples, 465 force=args.force, 466 figure_funcs_select=figure_funcs_select, 467 ) 468 469 print(DIVIDER_S1)
generates figures from the activations using the functions decorated with register_attn_figure_func