pattern_lens.figures
code for generating figures from attention patterns, using the functions decorated with register_attn_figure_func
1"""code for generating figures from attention patterns, using the functions decorated with `register_attn_figure_func`""" 2 3import argparse 4import functools 5import itertools 6import json 7import warnings 8from collections import defaultdict 9from pathlib import Path 10 11import numpy as np 12from jaxtyping import Float 13 14# custom utils 15from muutils.json_serialize import json_serialize 16from muutils.parallel import run_maybe_parallel 17from muutils.spinner import SpinnerContext 18 19# pattern_lens 20from pattern_lens.attn_figure_funcs import ATTENTION_MATRIX_FIGURE_FUNCS 21from pattern_lens.consts import ( 22 DATA_DIR, 23 DIVIDER_S1, 24 DIVIDER_S2, 25 SPINNER_KWARGS, 26 ActivationCacheNp, 27 AttentionMatrix, 28) 29from pattern_lens.indexes import ( 30 generate_functions_jsonl, 31 generate_models_jsonl, 32 generate_prompts_jsonl, 33) 34from pattern_lens.load_activations import load_activations 35 36 37class HTConfigMock: 38 """Mock of `transformer_lens.HookedTransformerConfig` for type hinting and loading config json 39 40 can be initialized with any kwargs, and will update its `__dict__` with them. does, however, require the following attributes: 41 - `n_layers: int` 42 - `n_heads: int` 43 - `model_name: str` 44 45 we do this to avoid having to import `torch` and `transformer_lens`, since this would have to be done for each process in the parallelization and probably slows things down significantly 46 """ 47 48 def __init__(self, **kwargs: dict[str, str | int]) -> None: 49 "will pass all kwargs to `__dict__`" 50 self.n_layers: int 51 self.n_heads: int 52 self.model_name: str 53 self.__dict__.update(kwargs) 54 55 def serialize(self) -> dict: 56 """serialize the config to json. values which aren't serializable will be converted via `muutils.json_serialize.json_serialize`""" 57 # its fine, we know its a dict 58 return json_serialize(self.__dict__) # type: ignore[return-value] 59 60 @classmethod 61 def load(cls, data: dict) -> "HTConfigMock": 62 "try to load a config from a dict, using the `__init__` method" 63 return cls(**data) 64 65 66def process_single_head( 67 layer_idx: int, 68 head_idx: int, 69 attn_pattern: AttentionMatrix, 70 save_dir: Path, 71 force_overwrite: bool = False, 72) -> dict[str, bool | Exception]: 73 """process a single head's attention pattern, running all the functions in `ATTENTION_MATRIX_FIGURE_FUNCS` on the attention pattern 74 75 > [gotcha:] if `force_overwrite` is `False`, and we used a multi-figure function, 76 > it will skip all figures for that function if any are already saved 77 > and it assumes a format of `{func_name}.{figure_name}.{fmt}` for the saved figures 78 79 # Parameters: 80 - `layer_idx : int` 81 - `head_idx : int` 82 - `attn_pattern : AttentionMatrix` 83 attention pattern for the head 84 - `save_dir : Path` 85 directory to save the figures to 86 - `force_overwrite : bool` 87 whether to overwrite existing figures. if `False`, will skip any functions which have already saved a figure 88 (defaults to `False`) 89 90 # Returns: 91 - `dict[str, bool | Exception]` 92 a dictionary of the status of each function, with the function name as the key and the status as the value 93 """ 94 funcs_status: dict[str, bool | Exception] = dict() 95 96 for func in ATTENTION_MATRIX_FIGURE_FUNCS: 97 func_name: str = func.__name__ 98 fig_path: list[Path] = list(save_dir.glob(f"{func_name}.*")) 99 100 if not force_overwrite and len(fig_path) > 0: 101 funcs_status[func_name] = True 102 continue 103 104 try: 105 func(attn_pattern, save_dir) 106 funcs_status[func_name] = True 107 108 # bling catch any exception 109 except Exception as e: # noqa: BLE001 110 error_file = save_dir / f"{func.__name__}.error.txt" 111 error_file.write_text(str(e)) 112 warnings.warn( 113 f"Error in {func.__name__} for L{layer_idx}H{head_idx}: {e!s}", 114 stacklevel=2, 115 ) 116 funcs_status[func_name] = e 117 118 return funcs_status 119 120 121def compute_and_save_figures( 122 model_cfg: "HookedTransformerConfig|HTConfigMock", # type: ignore[name-defined] # noqa: F821 123 activations_path: Path, 124 cache: ActivationCacheNp | Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"], 125 save_path: Path = Path(DATA_DIR), 126 force_overwrite: bool = False, 127 track_results: bool = False, 128) -> None: 129 """compute and save figures for all heads in the model, using the functions in `ATTENTION_MATRIX_FIGURE_FUNCS` 130 131 # Parameters: 132 - `model_cfg : HookedTransformerConfig|HTConfigMock` 133 - `cache : ActivationCacheNp | Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"]` 134 - `save_path : Path` 135 (defaults to `Path(DATA_DIR)`) 136 - `force_overwrite : bool` 137 force overwrite of existing figures. if `False`, will skip any functions which have already saved a figure 138 (defaults to `False`) 139 - `track_results : bool` 140 whether to track the results of each function for each head. Isn't used for anything yet, but this is a TODO 141 (defaults to `False`) 142 """ 143 prompt_dir: Path = activations_path.parent 144 145 if track_results: 146 results: defaultdict[ 147 str, # func name 148 dict[ 149 tuple[int, int], # layer, head 150 bool | Exception, # success or exception 151 ], 152 ] = defaultdict(dict) 153 154 for layer_idx, head_idx in itertools.product( 155 range(model_cfg.n_layers), 156 range(model_cfg.n_heads), 157 ): 158 attn_pattern: AttentionMatrix 159 if isinstance(cache, dict): 160 attn_pattern = cache[f"blocks.{layer_idx}.attn.hook_pattern"][0, head_idx] 161 elif isinstance(cache, np.ndarray): 162 attn_pattern = cache[layer_idx, head_idx] 163 else: 164 msg = ( 165 f"cache must be a dict or np.ndarray, not {type(cache) = }\n{cache = }" 166 ) 167 raise TypeError( 168 msg, 169 ) 170 171 save_dir: Path = prompt_dir / f"L{layer_idx}" / f"H{head_idx}" 172 save_dir.mkdir(parents=True, exist_ok=True) 173 head_res: dict[str, bool | Exception] = process_single_head( 174 layer_idx=layer_idx, 175 head_idx=head_idx, 176 attn_pattern=attn_pattern, 177 save_dir=save_dir, 178 force_overwrite=force_overwrite, 179 ) 180 181 if track_results: 182 for func_name, status in head_res.items(): 183 results[func_name][(layer_idx, head_idx)] = status 184 185 # TODO: do something with results 186 187 generate_prompts_jsonl(save_path / model_cfg.model_name) 188 189 190def process_prompt( 191 prompt: dict, 192 model_cfg: "HookedTransformerConfig|HTConfigMock", # type: ignore[name-defined] # noqa: F821 193 save_path: Path, 194 force_overwrite: bool = False, 195) -> None: 196 """process a single prompt, loading the activations and computing and saving the figures 197 198 basically just calls `load_activations` and then `compute_and_save_figures` 199 200 # Parameters: 201 - `prompt : dict` 202 - `model_cfg : HookedTransformerConfig|HTConfigMock` 203 - `force_overwrite : bool` 204 (defaults to `False`) 205 """ 206 activations_path: Path 207 cache: ActivationCacheNp | Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"] 208 activations_path, cache = load_activations( 209 model_name=model_cfg.model_name, 210 prompt=prompt, 211 save_path=save_path, 212 return_fmt="numpy", 213 ) 214 215 compute_and_save_figures( 216 model_cfg=model_cfg, 217 activations_path=activations_path, 218 cache=cache, 219 save_path=save_path, 220 force_overwrite=force_overwrite, 221 ) 222 223 224def figures_main( 225 model_name: str, 226 save_path: str, 227 n_samples: int, 228 force: bool, 229 parallel: bool | int = True, 230) -> None: 231 """main function for generating figures from attention patterns, using the functions in `ATTENTION_MATRIX_FIGURE_FUNCS` 232 233 # Parameters: 234 - `model_name : str` 235 model name to use, used for loading the model config, prompts, activations, and saving the figures 236 - `save_path : str` 237 base path to look in 238 - `n_samples : int` 239 max number of samples to process 240 - `force : bool` 241 force overwrite of existing figures. if `False`, will skip any functions which have already saved a figure 242 - `parallel : bool | int` 243 whether to run in parallel. if `True`, will use all available cores. if `False`, will run in serial. if an int, will try to use that many cores 244 (defaults to `True`) 245 """ 246 with SpinnerContext(message="setting up paths", **SPINNER_KWARGS): 247 # save model info or check if it exists 248 save_path_p: Path = Path(save_path) 249 model_path: Path = save_path_p / model_name 250 with open(model_path / "model_cfg.json", "r") as f: 251 model_cfg = HTConfigMock.load(json.load(f)) 252 253 with SpinnerContext(message="loading prompts", **SPINNER_KWARGS): 254 # load prompts 255 with open(model_path / "prompts.jsonl", "r") as f: 256 prompts: list[dict] = [json.loads(line) for line in f.readlines()] 257 # truncate to n_samples 258 prompts = prompts[:n_samples] 259 260 print(f"{len(prompts)} prompts loaded") 261 262 print(f"{len(ATTENTION_MATRIX_FIGURE_FUNCS)} figure functions loaded") 263 print("\t" + ", ".join([func.__name__ for func in ATTENTION_MATRIX_FIGURE_FUNCS])) 264 265 list( 266 run_maybe_parallel( 267 func=functools.partial( 268 process_prompt, 269 model_cfg=model_cfg, 270 save_path=save_path_p, 271 force_overwrite=force, 272 ), 273 iterable=prompts, 274 parallel=parallel, 275 pbar="tqdm", 276 pbar_kwargs=dict( 277 desc="Making figures", 278 unit="prompt", 279 ), 280 ), 281 ) 282 283 with SpinnerContext( 284 message="updating jsonl metadata for models and functions", 285 **SPINNER_KWARGS, 286 ): 287 generate_models_jsonl(save_path_p) 288 generate_functions_jsonl(save_path_p) 289 290 291def main() -> None: 292 "generates figures from the activations using the functions decorated with `register_attn_figure_func`" 293 print(DIVIDER_S1) 294 with SpinnerContext(message="parsing args", **SPINNER_KWARGS): 295 arg_parser: argparse.ArgumentParser = argparse.ArgumentParser() 296 # input and output 297 arg_parser.add_argument( 298 "--model", 299 "-m", 300 type=str, 301 required=True, 302 help="The model name(s) to use. comma separated with no whitespace if multiple", 303 ) 304 arg_parser.add_argument( 305 "--save-path", 306 "-s", 307 type=str, 308 required=False, 309 help="The path to save the attention patterns", 310 default=DATA_DIR, 311 ) 312 # number of samples 313 arg_parser.add_argument( 314 "--n-samples", 315 "-n", 316 type=int, 317 required=False, 318 help="The max number of samples to process, do all in the file if None", 319 default=None, 320 ) 321 # force overwrite of existing figures 322 arg_parser.add_argument( 323 "--force", 324 "-f", 325 type=bool, 326 required=False, 327 help="Force overwrite of existing figures", 328 default=False, 329 ) 330 331 args: argparse.Namespace = arg_parser.parse_args() 332 333 print(f"args parsed: {args}") 334 335 models: list[str] 336 if "," in args.model: 337 models = args.model.split(",") 338 else: 339 models = [args.model] 340 341 n_models: int = len(models) 342 for idx, model in enumerate(models): 343 print(DIVIDER_S2) 344 print(f"processing model {idx + 1} / {n_models}: {model}") 345 print(DIVIDER_S2) 346 figures_main( 347 model_name=model, 348 save_path=args.save_path, 349 n_samples=args.n_samples, 350 force=args.force, 351 ) 352 353 print(DIVIDER_S1) 354 355 356if __name__ == "__main__": 357 main()
38class HTConfigMock: 39 """Mock of `transformer_lens.HookedTransformerConfig` for type hinting and loading config json 40 41 can be initialized with any kwargs, and will update its `__dict__` with them. does, however, require the following attributes: 42 - `n_layers: int` 43 - `n_heads: int` 44 - `model_name: str` 45 46 we do this to avoid having to import `torch` and `transformer_lens`, since this would have to be done for each process in the parallelization and probably slows things down significantly 47 """ 48 49 def __init__(self, **kwargs: dict[str, str | int]) -> None: 50 "will pass all kwargs to `__dict__`" 51 self.n_layers: int 52 self.n_heads: int 53 self.model_name: str 54 self.__dict__.update(kwargs) 55 56 def serialize(self) -> dict: 57 """serialize the config to json. values which aren't serializable will be converted via `muutils.json_serialize.json_serialize`""" 58 # its fine, we know its a dict 59 return json_serialize(self.__dict__) # type: ignore[return-value] 60 61 @classmethod 62 def load(cls, data: dict) -> "HTConfigMock": 63 "try to load a config from a dict, using the `__init__` method" 64 return cls(**data)
Mock of transformer_lens.HookedTransformerConfig for type hinting and loading config json
can be initialized with any kwargs, and will update its __dict__ with them. does, however, require the following attributes:
n_layers: intn_heads: intmodel_name: str
we do this to avoid having to import torch and transformer_lens, since this would have to be done for each process in the parallelization and probably slows things down significantly
49 def __init__(self, **kwargs: dict[str, str | int]) -> None: 50 "will pass all kwargs to `__dict__`" 51 self.n_layers: int 52 self.n_heads: int 53 self.model_name: str 54 self.__dict__.update(kwargs)
will pass all kwargs to __dict__
56 def serialize(self) -> dict: 57 """serialize the config to json. values which aren't serializable will be converted via `muutils.json_serialize.json_serialize`""" 58 # its fine, we know its a dict 59 return json_serialize(self.__dict__) # type: ignore[return-value]
serialize the config to json. values which aren't serializable will be converted via muutils.json_serialize.json_serialize
67def process_single_head( 68 layer_idx: int, 69 head_idx: int, 70 attn_pattern: AttentionMatrix, 71 save_dir: Path, 72 force_overwrite: bool = False, 73) -> dict[str, bool | Exception]: 74 """process a single head's attention pattern, running all the functions in `ATTENTION_MATRIX_FIGURE_FUNCS` on the attention pattern 75 76 > [gotcha:] if `force_overwrite` is `False`, and we used a multi-figure function, 77 > it will skip all figures for that function if any are already saved 78 > and it assumes a format of `{func_name}.{figure_name}.{fmt}` for the saved figures 79 80 # Parameters: 81 - `layer_idx : int` 82 - `head_idx : int` 83 - `attn_pattern : AttentionMatrix` 84 attention pattern for the head 85 - `save_dir : Path` 86 directory to save the figures to 87 - `force_overwrite : bool` 88 whether to overwrite existing figures. if `False`, will skip any functions which have already saved a figure 89 (defaults to `False`) 90 91 # Returns: 92 - `dict[str, bool | Exception]` 93 a dictionary of the status of each function, with the function name as the key and the status as the value 94 """ 95 funcs_status: dict[str, bool | Exception] = dict() 96 97 for func in ATTENTION_MATRIX_FIGURE_FUNCS: 98 func_name: str = func.__name__ 99 fig_path: list[Path] = list(save_dir.glob(f"{func_name}.*")) 100 101 if not force_overwrite and len(fig_path) > 0: 102 funcs_status[func_name] = True 103 continue 104 105 try: 106 func(attn_pattern, save_dir) 107 funcs_status[func_name] = True 108 109 # bling catch any exception 110 except Exception as e: # noqa: BLE001 111 error_file = save_dir / f"{func.__name__}.error.txt" 112 error_file.write_text(str(e)) 113 warnings.warn( 114 f"Error in {func.__name__} for L{layer_idx}H{head_idx}: {e!s}", 115 stacklevel=2, 116 ) 117 funcs_status[func_name] = e 118 119 return funcs_status
process a single head's attention pattern, running all the functions in ATTENTION_MATRIX_FIGURE_FUNCS on the attention pattern
[gotcha:] if
force_overwriteisFalse, and we used a multi-figure function, it will skip all figures for that function if any are already saved and it assumes a format of{func_name}.{figure_name}.{fmt}for the saved figures
Parameters:
layer_idx : inthead_idx : intattn_pattern : AttentionMatrixattention pattern for the headsave_dir : Pathdirectory to save the figures toforce_overwrite : boolwhether to overwrite existing figures. ifFalse, will skip any functions which have already saved a figure (defaults toFalse)
Returns:
dict[str, bool | Exception]a dictionary of the status of each function, with the function name as the key and the status as the value
122def compute_and_save_figures( 123 model_cfg: "HookedTransformerConfig|HTConfigMock", # type: ignore[name-defined] # noqa: F821 124 activations_path: Path, 125 cache: ActivationCacheNp | Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"], 126 save_path: Path = Path(DATA_DIR), 127 force_overwrite: bool = False, 128 track_results: bool = False, 129) -> None: 130 """compute and save figures for all heads in the model, using the functions in `ATTENTION_MATRIX_FIGURE_FUNCS` 131 132 # Parameters: 133 - `model_cfg : HookedTransformerConfig|HTConfigMock` 134 - `cache : ActivationCacheNp | Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"]` 135 - `save_path : Path` 136 (defaults to `Path(DATA_DIR)`) 137 - `force_overwrite : bool` 138 force overwrite of existing figures. if `False`, will skip any functions which have already saved a figure 139 (defaults to `False`) 140 - `track_results : bool` 141 whether to track the results of each function for each head. Isn't used for anything yet, but this is a TODO 142 (defaults to `False`) 143 """ 144 prompt_dir: Path = activations_path.parent 145 146 if track_results: 147 results: defaultdict[ 148 str, # func name 149 dict[ 150 tuple[int, int], # layer, head 151 bool | Exception, # success or exception 152 ], 153 ] = defaultdict(dict) 154 155 for layer_idx, head_idx in itertools.product( 156 range(model_cfg.n_layers), 157 range(model_cfg.n_heads), 158 ): 159 attn_pattern: AttentionMatrix 160 if isinstance(cache, dict): 161 attn_pattern = cache[f"blocks.{layer_idx}.attn.hook_pattern"][0, head_idx] 162 elif isinstance(cache, np.ndarray): 163 attn_pattern = cache[layer_idx, head_idx] 164 else: 165 msg = ( 166 f"cache must be a dict or np.ndarray, not {type(cache) = }\n{cache = }" 167 ) 168 raise TypeError( 169 msg, 170 ) 171 172 save_dir: Path = prompt_dir / f"L{layer_idx}" / f"H{head_idx}" 173 save_dir.mkdir(parents=True, exist_ok=True) 174 head_res: dict[str, bool | Exception] = process_single_head( 175 layer_idx=layer_idx, 176 head_idx=head_idx, 177 attn_pattern=attn_pattern, 178 save_dir=save_dir, 179 force_overwrite=force_overwrite, 180 ) 181 182 if track_results: 183 for func_name, status in head_res.items(): 184 results[func_name][(layer_idx, head_idx)] = status 185 186 # TODO: do something with results 187 188 generate_prompts_jsonl(save_path / model_cfg.model_name)
compute and save figures for all heads in the model, using the functions in ATTENTION_MATRIX_FIGURE_FUNCS
Parameters:
model_cfg : HookedTransformerConfig|HTConfigMockcache : ActivationCacheNp | Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"]save_path : Path(defaults toPath(DATA_DIR))force_overwrite : boolforce overwrite of existing figures. ifFalse, will skip any functions which have already saved a figure (defaults toFalse)track_results : boolwhether to track the results of each function for each head. Isn't used for anything yet, but this is a TODO (defaults toFalse)
191def process_prompt( 192 prompt: dict, 193 model_cfg: "HookedTransformerConfig|HTConfigMock", # type: ignore[name-defined] # noqa: F821 194 save_path: Path, 195 force_overwrite: bool = False, 196) -> None: 197 """process a single prompt, loading the activations and computing and saving the figures 198 199 basically just calls `load_activations` and then `compute_and_save_figures` 200 201 # Parameters: 202 - `prompt : dict` 203 - `model_cfg : HookedTransformerConfig|HTConfigMock` 204 - `force_overwrite : bool` 205 (defaults to `False`) 206 """ 207 activations_path: Path 208 cache: ActivationCacheNp | Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"] 209 activations_path, cache = load_activations( 210 model_name=model_cfg.model_name, 211 prompt=prompt, 212 save_path=save_path, 213 return_fmt="numpy", 214 ) 215 216 compute_and_save_figures( 217 model_cfg=model_cfg, 218 activations_path=activations_path, 219 cache=cache, 220 save_path=save_path, 221 force_overwrite=force_overwrite, 222 )
process a single prompt, loading the activations and computing and saving the figures
basically just calls load_activations and then compute_and_save_figures
Parameters:
prompt : dictmodel_cfg : HookedTransformerConfig|HTConfigMockforce_overwrite : bool(defaults toFalse)
225def figures_main( 226 model_name: str, 227 save_path: str, 228 n_samples: int, 229 force: bool, 230 parallel: bool | int = True, 231) -> None: 232 """main function for generating figures from attention patterns, using the functions in `ATTENTION_MATRIX_FIGURE_FUNCS` 233 234 # Parameters: 235 - `model_name : str` 236 model name to use, used for loading the model config, prompts, activations, and saving the figures 237 - `save_path : str` 238 base path to look in 239 - `n_samples : int` 240 max number of samples to process 241 - `force : bool` 242 force overwrite of existing figures. if `False`, will skip any functions which have already saved a figure 243 - `parallel : bool | int` 244 whether to run in parallel. if `True`, will use all available cores. if `False`, will run in serial. if an int, will try to use that many cores 245 (defaults to `True`) 246 """ 247 with SpinnerContext(message="setting up paths", **SPINNER_KWARGS): 248 # save model info or check if it exists 249 save_path_p: Path = Path(save_path) 250 model_path: Path = save_path_p / model_name 251 with open(model_path / "model_cfg.json", "r") as f: 252 model_cfg = HTConfigMock.load(json.load(f)) 253 254 with SpinnerContext(message="loading prompts", **SPINNER_KWARGS): 255 # load prompts 256 with open(model_path / "prompts.jsonl", "r") as f: 257 prompts: list[dict] = [json.loads(line) for line in f.readlines()] 258 # truncate to n_samples 259 prompts = prompts[:n_samples] 260 261 print(f"{len(prompts)} prompts loaded") 262 263 print(f"{len(ATTENTION_MATRIX_FIGURE_FUNCS)} figure functions loaded") 264 print("\t" + ", ".join([func.__name__ for func in ATTENTION_MATRIX_FIGURE_FUNCS])) 265 266 list( 267 run_maybe_parallel( 268 func=functools.partial( 269 process_prompt, 270 model_cfg=model_cfg, 271 save_path=save_path_p, 272 force_overwrite=force, 273 ), 274 iterable=prompts, 275 parallel=parallel, 276 pbar="tqdm", 277 pbar_kwargs=dict( 278 desc="Making figures", 279 unit="prompt", 280 ), 281 ), 282 ) 283 284 with SpinnerContext( 285 message="updating jsonl metadata for models and functions", 286 **SPINNER_KWARGS, 287 ): 288 generate_models_jsonl(save_path_p) 289 generate_functions_jsonl(save_path_p)
main function for generating figures from attention patterns, using the functions in ATTENTION_MATRIX_FIGURE_FUNCS
Parameters:
model_name : strmodel name to use, used for loading the model config, prompts, activations, and saving the figuressave_path : strbase path to look inn_samples : intmax number of samples to processforce : boolforce overwrite of existing figures. ifFalse, will skip any functions which have already saved a figureparallel : bool | intwhether to run in parallel. ifTrue, will use all available cores. ifFalse, will run in serial. if an int, will try to use that many cores (defaults toTrue)
292def main() -> None: 293 "generates figures from the activations using the functions decorated with `register_attn_figure_func`" 294 print(DIVIDER_S1) 295 with SpinnerContext(message="parsing args", **SPINNER_KWARGS): 296 arg_parser: argparse.ArgumentParser = argparse.ArgumentParser() 297 # input and output 298 arg_parser.add_argument( 299 "--model", 300 "-m", 301 type=str, 302 required=True, 303 help="The model name(s) to use. comma separated with no whitespace if multiple", 304 ) 305 arg_parser.add_argument( 306 "--save-path", 307 "-s", 308 type=str, 309 required=False, 310 help="The path to save the attention patterns", 311 default=DATA_DIR, 312 ) 313 # number of samples 314 arg_parser.add_argument( 315 "--n-samples", 316 "-n", 317 type=int, 318 required=False, 319 help="The max number of samples to process, do all in the file if None", 320 default=None, 321 ) 322 # force overwrite of existing figures 323 arg_parser.add_argument( 324 "--force", 325 "-f", 326 type=bool, 327 required=False, 328 help="Force overwrite of existing figures", 329 default=False, 330 ) 331 332 args: argparse.Namespace = arg_parser.parse_args() 333 334 print(f"args parsed: {args}") 335 336 models: list[str] 337 if "," in args.model: 338 models = args.model.split(",") 339 else: 340 models = [args.model] 341 342 n_models: int = len(models) 343 for idx, model in enumerate(models): 344 print(DIVIDER_S2) 345 print(f"processing model {idx + 1} / {n_models}: {model}") 346 print(DIVIDER_S2) 347 figures_main( 348 model_name=model, 349 save_path=args.save_path, 350 n_samples=args.n_samples, 351 force=args.force, 352 ) 353 354 print(DIVIDER_S1)
generates figures from the activations using the functions decorated with register_attn_figure_func