pattern_lens.activations
computing and saving activations given a model and prompts
Usage:
from the command line:
python -m pattern_lens.activations --model <model_name> --prompts <prompts_path> --save-path <save_path> --min-chars <min_chars> --max-chars <max_chars> --n-samples <n_samples>
from a script:
from pattern_lens.activations import activations_main
activations_main(
model_name="gpt2",
save_path="demo/"
prompts_path="data/pile_1k.jsonl",
)
1"""computing and saving activations given a model and prompts 2 3# Usage: 4 5from the command line: 6 7```bash 8python -m pattern_lens.activations --model <model_name> --prompts <prompts_path> --save-path <save_path> --min-chars <min_chars> --max-chars <max_chars> --n-samples <n_samples> 9``` 10 11from a script: 12 13```python 14from pattern_lens.activations import activations_main 15activations_main( 16 model_name="gpt2", 17 save_path="demo/" 18 prompts_path="data/pile_1k.jsonl", 19) 20``` 21 22""" 23 24import argparse 25import functools 26import json 27import re 28from collections.abc import Callable 29from dataclasses import asdict 30from pathlib import Path 31from typing import Literal, overload 32 33import numpy as np 34import torch 35import tqdm 36from jaxtyping import Float 37from muutils.json_serialize import json_serialize 38from muutils.misc.numerical import shorten_numerical_to_str 39 40# custom utils 41from muutils.spinner import SpinnerContext 42from transformer_lens import ( # type: ignore[import-untyped] 43 ActivationCache, 44 HookedTransformer, 45 HookedTransformerConfig, 46) 47 48# pattern_lens 49from pattern_lens.consts import ( 50 ATTN_PATTERN_REGEX, 51 DATA_DIR, 52 DIVIDER_S1, 53 DIVIDER_S2, 54 SPINNER_KWARGS, 55 ActivationCacheNp, 56 ReturnCache, 57) 58from pattern_lens.indexes import ( 59 generate_models_jsonl, 60 generate_prompts_jsonl, 61 write_html_index, 62) 63from pattern_lens.load_activations import ( 64 ActivationsMissingError, 65 augment_prompt_with_hash, 66 load_activations, 67) 68from pattern_lens.prompts import load_text_data 69 70 71# return nothing, but `stack_heads` still affects how we save the activations 72@overload 73def compute_activations( 74 prompt: dict, 75 model: HookedTransformer | None = None, 76 save_path: Path = Path(DATA_DIR), 77 names_filter: Callable[[str], bool] | re.Pattern = ATTN_PATTERN_REGEX, 78 return_cache: Literal[None] = None, 79 stack_heads: bool = False, 80) -> tuple[Path, None]: ... 81# return stacked heads in numpy or torch form 82@overload 83def compute_activations( 84 prompt: dict, 85 model: HookedTransformer | None = None, 86 save_path: Path = Path(DATA_DIR), 87 names_filter: Callable[[str], bool] | re.Pattern = ATTN_PATTERN_REGEX, 88 return_cache: Literal["torch"] = "torch", 89 stack_heads: Literal[True] = True, 90) -> tuple[Path, Float[torch.Tensor, "n_layers n_heads n_ctx n_ctx"]]: ... 91@overload 92def compute_activations( 93 prompt: dict, 94 model: HookedTransformer | None = None, 95 save_path: Path = Path(DATA_DIR), 96 names_filter: Callable[[str], bool] | re.Pattern = ATTN_PATTERN_REGEX, 97 return_cache: Literal["numpy"] = "numpy", 98 stack_heads: Literal[True] = True, 99) -> tuple[Path, Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"]]: ... 100# return dicts in numpy or torch form 101@overload 102def compute_activations( 103 prompt: dict, 104 model: HookedTransformer | None = None, 105 save_path: Path = Path(DATA_DIR), 106 names_filter: Callable[[str], bool] | re.Pattern = ATTN_PATTERN_REGEX, 107 return_cache: Literal["numpy"] = "numpy", 108 stack_heads: Literal[False] = False, 109) -> tuple[Path, ActivationCacheNp]: ... 110@overload 111def compute_activations( 112 prompt: dict, 113 model: HookedTransformer | None = None, 114 save_path: Path = Path(DATA_DIR), 115 names_filter: Callable[[str], bool] | re.Pattern = ATTN_PATTERN_REGEX, 116 return_cache: Literal["torch"] = "torch", 117 stack_heads: Literal[False] = False, 118) -> tuple[Path, ActivationCache]: ... 119# actual function body 120def compute_activations( # noqa: PLR0915 121 prompt: dict, 122 model: HookedTransformer | None = None, 123 save_path: Path = Path(DATA_DIR), 124 names_filter: Callable[[str], bool] | re.Pattern = ATTN_PATTERN_REGEX, 125 return_cache: ReturnCache = "torch", 126 stack_heads: bool = False, 127) -> tuple[ 128 Path, 129 ActivationCacheNp 130 | ActivationCache 131 | Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"] 132 | Float[torch.Tensor, "n_layers n_heads n_ctx n_ctx"] 133 | None, 134]: 135 """get activations for a given model and prompt, possibly from a cache 136 137 if from a cache, prompt_meta must be passed and contain the prompt hash 138 139 # Parameters: 140 - `prompt : dict | None` 141 (defaults to `None`) 142 - `model : HookedTransformer` 143 - `save_path : Path` 144 (defaults to `Path(DATA_DIR)`) 145 - `names_filter : Callable[[str], bool]|re.Pattern` 146 a filter for the names of the activations to return. if an `re.Pattern`, will use `lambda key: names_filter.match(key) is not None` 147 (defaults to `ATTN_PATTERN_REGEX`) 148 - `return_cache : Literal[None, "numpy", "torch"]` 149 will return `None` as the second element if `None`, otherwise will return the cache in the specified tensor format. `stack_heads` still affects whether it will be a dict (False) or a single tensor (True) 150 (defaults to `None`) 151 - `stack_heads : bool` 152 whether the heads should be stacked in the output. this causes a number of changes: 153 - `npy` file with a single `(n_layers, n_heads, n_ctx, n_ctx)` tensor saved for each prompt instead of `npz` file with dict by layer 154 - `cache` will be a single `(n_layers, n_heads, n_ctx, n_ctx)` tensor instead of a dict by layer if `return_cache` is `True` 155 will assert that everything in the activation cache is only attention patterns, and is all of the attention patterns. raises an exception if not. 156 157 # Returns: 158 ``` 159 tuple[ 160 Path, 161 Union[ 162 None, 163 ActivationCacheNp, ActivationCache, 164 Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"], Float[torch.Tensor, "n_layers n_heads n_ctx n_ctx"], 165 ] 166 ] 167 ``` 168 """ 169 # check inputs 170 assert model is not None, "model must be passed" 171 assert "text" in prompt, "prompt must contain 'text' key" 172 prompt_str: str = prompt["text"] 173 174 # compute or get prompt metadata 175 prompt_tokenized: list[str] = prompt.get( 176 "tokens", 177 model.tokenizer.tokenize(prompt_str), 178 ) 179 prompt.update( 180 dict( 181 n_tokens=len(prompt_tokenized), 182 tokens=prompt_tokenized, 183 ), 184 ) 185 186 # save metadata 187 prompt_dir: Path = save_path / model.cfg.model_name / "prompts" / prompt["hash"] 188 prompt_dir.mkdir(parents=True, exist_ok=True) 189 with open(prompt_dir / "prompt.json", "w") as f: 190 json.dump(prompt, f) 191 192 # set up names filter 193 names_filter_fn: Callable[[str], bool] 194 if isinstance(names_filter, re.Pattern): 195 names_filter_fn = lambda key: names_filter.match(key) is not None # noqa: E731 196 else: 197 names_filter_fn = names_filter 198 199 # compute activations 200 cache_torch: ActivationCache 201 with torch.no_grad(): 202 model.eval() 203 # TODO: batching? 204 _, cache_torch = model.run_with_cache( 205 prompt_str, 206 names_filter=names_filter_fn, 207 return_type=None, 208 ) 209 210 activations_path: Path 211 # saving and returning 212 if stack_heads: 213 n_layers: int = model.cfg.n_layers 214 key_pattern: str = "blocks.{i}.attn.hook_pattern" 215 # NOTE: this only works for stacking heads at the moment 216 # activations_specifier: str = key_pattern.format(i=f'0-{n_layers}') 217 activations_specifier: str = key_pattern.format(i="-") 218 activations_path = prompt_dir / f"activations-{activations_specifier}.npy" 219 220 # check the keys are only attention heads 221 head_keys: list[str] = [key_pattern.format(i=i) for i in range(n_layers)] 222 cache_torch_keys_set: set[str] = set(cache_torch.keys()) 223 assert cache_torch_keys_set == set(head_keys), ( 224 f"unexpected keys!\n{set(head_keys).symmetric_difference(cache_torch_keys_set) = }\n{cache_torch_keys_set} != {set(head_keys)}" 225 ) 226 227 # stack heads 228 patterns_stacked: Float[torch.Tensor, "n_layers n_heads n_ctx n_ctx"] = ( 229 torch.stack([cache_torch[k] for k in head_keys], dim=1) 230 ) 231 # check shape 232 pattern_shape_no_ctx: tuple[int, ...] = tuple(patterns_stacked.shape[:3]) 233 assert pattern_shape_no_ctx == (1, n_layers, model.cfg.n_heads), ( 234 f"unexpected shape: {patterns_stacked.shape[:3] = } ({pattern_shape_no_ctx = }), expected {(1, n_layers, model.cfg.n_heads) = }" 235 ) 236 237 patterns_stacked_np: Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"] = ( 238 patterns_stacked.cpu().numpy() 239 ) 240 241 # save 242 np.save(activations_path, patterns_stacked_np) 243 244 # return 245 match return_cache: 246 case "numpy": 247 return activations_path, patterns_stacked_np 248 case "torch": 249 return activations_path, patterns_stacked 250 case None: 251 return activations_path, None 252 case _: 253 msg = f"invalid return_cache: {return_cache = }" 254 raise ValueError(msg) 255 else: 256 activations_path = prompt_dir / "activations.npz" 257 258 # save 259 cache_np: ActivationCacheNp = { 260 k: v.detach().cpu().numpy() for k, v in cache_torch.items() 261 } 262 263 np.savez_compressed( 264 activations_path, 265 **cache_np, 266 ) 267 268 # return 269 match return_cache: 270 case "numpy": 271 return activations_path, cache_np 272 case "torch": 273 return activations_path, cache_torch 274 case None: 275 return activations_path, None 276 case _: 277 msg = f"invalid return_cache: {return_cache = }" 278 raise ValueError(msg) 279 280 281@overload 282def get_activations( 283 prompt: dict, 284 model: HookedTransformer | str, 285 save_path: Path = Path(DATA_DIR), 286 allow_disk_cache: bool = True, 287 return_cache: Literal[None] = None, 288) -> tuple[Path, None]: ... 289@overload 290def get_activations( 291 prompt: dict, 292 model: HookedTransformer | str, 293 save_path: Path = Path(DATA_DIR), 294 allow_disk_cache: bool = True, 295 return_cache: Literal["torch"] = "torch", 296) -> tuple[Path, ActivationCache]: ... 297@overload 298def get_activations( 299 prompt: dict, 300 model: HookedTransformer | str, 301 save_path: Path = Path(DATA_DIR), 302 allow_disk_cache: bool = True, 303 return_cache: Literal["numpy"] = "numpy", 304) -> tuple[Path, ActivationCacheNp]: ... 305def get_activations( 306 prompt: dict, 307 model: HookedTransformer | str, 308 save_path: Path = Path(DATA_DIR), 309 allow_disk_cache: bool = True, 310 return_cache: ReturnCache = "numpy", 311) -> tuple[Path, ActivationCacheNp | ActivationCache | None]: 312 """given a prompt and a model, save or load activations 313 314 # Parameters: 315 - `prompt : dict` 316 expected to contain the 'text' key 317 - `model : HookedTransformer | str` 318 either a `HookedTransformer` or a string model name, to be loaded with `HookedTransformer.from_pretrained` 319 - `save_path : Path` 320 path to save the activations to (and load from) 321 (defaults to `Path(DATA_DIR)`) 322 - `allow_disk_cache : bool` 323 whether to allow loading from disk cache 324 (defaults to `True`) 325 - `return_cache : Literal[None, "numpy", "torch"]` 326 whether to return the cache, and in what format 327 (defaults to `"numpy"`) 328 329 # Returns: 330 - `tuple[Path, ActivationCacheNp | ActivationCache | None]` 331 the path to the activations and the cache if `return_cache is not None` 332 333 """ 334 # add hash to prompt 335 augment_prompt_with_hash(prompt) 336 337 # get the model 338 model_name: str = ( 339 model.cfg.model_name if isinstance(model, HookedTransformer) else model 340 ) 341 342 # from cache 343 if allow_disk_cache: 344 try: 345 path, cache = load_activations( 346 model_name=model_name, 347 prompt=prompt, 348 save_path=save_path, 349 ) 350 if return_cache: 351 return path, cache 352 else: 353 # TODO: this basically does nothing, since we load the activations and then immediately get rid of them. 354 # maybe refactor this so that load_activations can take a parameter to simply assert that the cache exists? 355 # this will let us avoid loading it, which slows things down 356 return path, None 357 except ActivationsMissingError: 358 pass 359 360 # compute them 361 if isinstance(model, str): 362 model = HookedTransformer.from_pretrained(model_name) 363 364 return compute_activations( 365 prompt=prompt, 366 model=model, 367 save_path=save_path, 368 return_cache=return_cache, 369 ) 370 371 372DEFAULT_DEVICE: torch.device = torch.device( 373 "cuda" if torch.cuda.is_available() else "cpu", 374) 375 376 377def activations_main( 378 model_name: str, 379 save_path: str, 380 prompts_path: str, 381 raw_prompts: bool, 382 min_chars: int, 383 max_chars: int, 384 force: bool, 385 n_samples: int, 386 no_index_html: bool, 387 shuffle: bool = False, 388 stacked_heads: bool = False, 389 device: str | torch.device = DEFAULT_DEVICE, 390) -> None: 391 """main function for computing activations 392 393 # Parameters: 394 - `model_name : str` 395 name of a model to load with `HookedTransformer.from_pretrained` 396 - `save_path : str` 397 path to save the activations to 398 - `prompts_path : str` 399 path to the prompts file 400 - `raw_prompts : bool` 401 whether the prompts are raw, not filtered by length. `load_text_data` will be called if `True`, otherwise just load the "text" field from each line in `prompts_path` 402 - `min_chars : int` 403 minimum number of characters for a prompt 404 - `max_chars : int` 405 maximum number of characters for a prompt 406 - `force : bool` 407 whether to overwrite existing files 408 - `n_samples : int` 409 maximum number of samples to process 410 - `no_index_html : bool` 411 whether to write an index.html file 412 - `shuffle : bool` 413 whether to shuffle the prompts 414 (defaults to `False`) 415 - `stacked_heads : bool` 416 whether to stack the heads in the output tensor. will save as `.npy` instead of `.npz` if `True` 417 (defaults to `False`) 418 - `device : str | torch.device` 419 the device to use. if a string, will be passed to `torch.device` 420 """ 421 # figure out the device to use 422 device_: torch.device 423 if isinstance(device, torch.device): 424 device_ = device 425 elif isinstance(device, str): 426 device_ = torch.device(device) 427 else: 428 msg = f"invalid device: {device}" 429 raise TypeError(msg) 430 431 print(f"using device: {device_}") 432 433 with SpinnerContext(message="loading model", **SPINNER_KWARGS): 434 model: HookedTransformer = HookedTransformer.from_pretrained( 435 model_name, 436 device=device_, 437 ) 438 model.model_name = model_name 439 model.cfg.model_name = model_name 440 n_params: int = sum(p.numel() for p in model.parameters()) 441 print( 442 f"loaded {model_name} with {shorten_numerical_to_str(n_params)} ({n_params}) parameters", 443 ) 444 print(f"\tmodel devices: { {p.device for p in model.parameters()} }") 445 446 save_path_p: Path = Path(save_path) 447 save_path_p.mkdir(parents=True, exist_ok=True) 448 model_path: Path = save_path_p / model_name 449 with SpinnerContext( 450 message=f"saving model info to {model_path.as_posix()}", 451 **SPINNER_KWARGS, 452 ): 453 model_cfg: HookedTransformerConfig 454 model_cfg = model.cfg 455 model_path.mkdir(parents=True, exist_ok=True) 456 with open(model_path / "model_cfg.json", "w") as f: 457 json.dump(json_serialize(asdict(model_cfg)), f) 458 459 # load prompts 460 with SpinnerContext( 461 message=f"loading prompts from {prompts_path = }", 462 **SPINNER_KWARGS, 463 ): 464 prompts: list[dict] 465 if raw_prompts: 466 prompts = load_text_data( 467 Path(prompts_path), 468 min_chars=min_chars, 469 max_chars=max_chars, 470 shuffle=shuffle, 471 ) 472 else: 473 with open(model_path / "prompts.jsonl", "r") as f: 474 prompts = [json.loads(line) for line in f.readlines()] 475 # truncate to n_samples 476 prompts = prompts[:n_samples] 477 478 print(f"{len(prompts)} prompts loaded") 479 480 # write index.html 481 with SpinnerContext(message="writing index.html", **SPINNER_KWARGS): 482 if not no_index_html: 483 write_html_index(save_path_p) 484 485 # TODO: not implemented yet 486 if stacked_heads: 487 raise NotImplementedError("stacked_heads not implemented yet") 488 489 # get activations 490 list( 491 tqdm.tqdm( 492 map( 493 functools.partial( 494 get_activations, 495 model=model, 496 save_path=save_path_p, 497 allow_disk_cache=not force, 498 return_cache=None, 499 # stacked_heads=stacked_heads, 500 ), 501 prompts, 502 ), 503 total=len(prompts), 504 desc="Computing activations", 505 unit="prompt", 506 ), 507 ) 508 509 with SpinnerContext( 510 message="updating jsonl metadata for models and prompts", 511 **SPINNER_KWARGS, 512 ): 513 generate_models_jsonl(save_path_p) 514 generate_prompts_jsonl(save_path_p / model_name) 515 516 517def main() -> None: 518 "generate attention pattern activations for a model and prompts" 519 print(DIVIDER_S1) 520 with SpinnerContext(message="parsing args", **SPINNER_KWARGS): 521 arg_parser: argparse.ArgumentParser = argparse.ArgumentParser() 522 # input and output 523 arg_parser.add_argument( 524 "--model", 525 "-m", 526 type=str, 527 required=True, 528 help="The model name(s) to use. comma separated with no whitespace if multiple", 529 ) 530 531 arg_parser.add_argument( 532 "--prompts", 533 "-p", 534 type=str, 535 required=False, 536 help="The path to the prompts file (jsonl with 'text' key on each line). If `None`, expects that `--figures` is passed and will generate figures for all prompts in the model directory", 537 default=None, 538 ) 539 540 arg_parser.add_argument( 541 "--save-path", 542 "-s", 543 type=str, 544 required=False, 545 help="The path to save the attention patterns", 546 default=DATA_DIR, 547 ) 548 549 # min and max prompt lengths 550 arg_parser.add_argument( 551 "--min-chars", 552 type=int, 553 required=False, 554 help="The minimum number of characters for a prompt", 555 default=100, 556 ) 557 arg_parser.add_argument( 558 "--max-chars", 559 type=int, 560 required=False, 561 help="The maximum number of characters for a prompt", 562 default=1000, 563 ) 564 565 # number of samples 566 arg_parser.add_argument( 567 "--n-samples", 568 "-n", 569 type=int, 570 required=False, 571 help="The max number of samples to process, do all in the file if None", 572 default=None, 573 ) 574 575 # force overwrite 576 arg_parser.add_argument( 577 "--force", 578 "-f", 579 action="store_true", 580 help="If passed, will overwrite existing files", 581 ) 582 583 # no index html 584 arg_parser.add_argument( 585 "--no-index-html", 586 action="store_true", 587 help="If passed, will not write an index.html file for the model", 588 ) 589 590 # raw prompts 591 arg_parser.add_argument( 592 "--raw-prompts", 593 "-r", 594 action="store_true", 595 help="pass if the prompts have not been split and tokenized (still needs keys 'text' and 'meta' for each item)", 596 ) 597 598 # shuffle 599 arg_parser.add_argument( 600 "--shuffle", 601 action="store_true", 602 help="If passed, will shuffle the prompts", 603 ) 604 605 # stack heads 606 arg_parser.add_argument( 607 "--stacked-heads", 608 action="store_true", 609 help="If passed, will stack the heads in the output tensor", 610 ) 611 612 # device 613 arg_parser.add_argument( 614 "--device", 615 type=str, 616 required=False, 617 help="The device to use for the model", 618 default="cuda" if torch.cuda.is_available() else "cpu", 619 ) 620 621 args: argparse.Namespace = arg_parser.parse_args() 622 623 print(f"args parsed: {args}") 624 625 models: list[str] 626 if "," in args.model: 627 models = args.model.split(",") 628 else: 629 models = [args.model] 630 631 n_models: int = len(models) 632 for idx, model in enumerate(models): 633 print(DIVIDER_S2) 634 print(f"processing model {idx + 1} / {n_models}: {model}") 635 print(DIVIDER_S2) 636 637 activations_main( 638 model_name=model, 639 save_path=args.save_path, 640 prompts_path=args.prompts, 641 raw_prompts=args.raw_prompts, 642 min_chars=args.min_chars, 643 max_chars=args.max_chars, 644 force=args.force, 645 n_samples=args.n_samples, 646 no_index_html=args.no_index_html, 647 shuffle=args.shuffle, 648 stacked_heads=args.stacked_heads, 649 device=args.device, 650 ) 651 del model 652 653 print(DIVIDER_S1) 654 655 656if __name__ == "__main__": 657 main()
def
compute_activations( prompt: dict, model: transformer_lens.HookedTransformer.HookedTransformer | None = None, save_path: pathlib.Path = PosixPath('attn_data'), names_filter: Callable[[str], bool] | re.Pattern = re.compile('blocks\\.(\\d+)\\.attn\\.hook_pattern'), return_cache: Literal[None, 'numpy', 'torch'] = 'torch', stack_heads: bool = False) -> tuple[pathlib.Path, dict[str, numpy.ndarray] | transformer_lens.ActivationCache.ActivationCache | jaxtyping.Float[ndarray, 'n_layers n_heads n_ctx n_ctx'] | jaxtyping.Float[Tensor, 'n_layers n_heads n_ctx n_ctx'] | None]:
121def compute_activations( # noqa: PLR0915 122 prompt: dict, 123 model: HookedTransformer | None = None, 124 save_path: Path = Path(DATA_DIR), 125 names_filter: Callable[[str], bool] | re.Pattern = ATTN_PATTERN_REGEX, 126 return_cache: ReturnCache = "torch", 127 stack_heads: bool = False, 128) -> tuple[ 129 Path, 130 ActivationCacheNp 131 | ActivationCache 132 | Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"] 133 | Float[torch.Tensor, "n_layers n_heads n_ctx n_ctx"] 134 | None, 135]: 136 """get activations for a given model and prompt, possibly from a cache 137 138 if from a cache, prompt_meta must be passed and contain the prompt hash 139 140 # Parameters: 141 - `prompt : dict | None` 142 (defaults to `None`) 143 - `model : HookedTransformer` 144 - `save_path : Path` 145 (defaults to `Path(DATA_DIR)`) 146 - `names_filter : Callable[[str], bool]|re.Pattern` 147 a filter for the names of the activations to return. if an `re.Pattern`, will use `lambda key: names_filter.match(key) is not None` 148 (defaults to `ATTN_PATTERN_REGEX`) 149 - `return_cache : Literal[None, "numpy", "torch"]` 150 will return `None` as the second element if `None`, otherwise will return the cache in the specified tensor format. `stack_heads` still affects whether it will be a dict (False) or a single tensor (True) 151 (defaults to `None`) 152 - `stack_heads : bool` 153 whether the heads should be stacked in the output. this causes a number of changes: 154 - `npy` file with a single `(n_layers, n_heads, n_ctx, n_ctx)` tensor saved for each prompt instead of `npz` file with dict by layer 155 - `cache` will be a single `(n_layers, n_heads, n_ctx, n_ctx)` tensor instead of a dict by layer if `return_cache` is `True` 156 will assert that everything in the activation cache is only attention patterns, and is all of the attention patterns. raises an exception if not. 157 158 # Returns: 159 ``` 160 tuple[ 161 Path, 162 Union[ 163 None, 164 ActivationCacheNp, ActivationCache, 165 Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"], Float[torch.Tensor, "n_layers n_heads n_ctx n_ctx"], 166 ] 167 ] 168 ``` 169 """ 170 # check inputs 171 assert model is not None, "model must be passed" 172 assert "text" in prompt, "prompt must contain 'text' key" 173 prompt_str: str = prompt["text"] 174 175 # compute or get prompt metadata 176 prompt_tokenized: list[str] = prompt.get( 177 "tokens", 178 model.tokenizer.tokenize(prompt_str), 179 ) 180 prompt.update( 181 dict( 182 n_tokens=len(prompt_tokenized), 183 tokens=prompt_tokenized, 184 ), 185 ) 186 187 # save metadata 188 prompt_dir: Path = save_path / model.cfg.model_name / "prompts" / prompt["hash"] 189 prompt_dir.mkdir(parents=True, exist_ok=True) 190 with open(prompt_dir / "prompt.json", "w") as f: 191 json.dump(prompt, f) 192 193 # set up names filter 194 names_filter_fn: Callable[[str], bool] 195 if isinstance(names_filter, re.Pattern): 196 names_filter_fn = lambda key: names_filter.match(key) is not None # noqa: E731 197 else: 198 names_filter_fn = names_filter 199 200 # compute activations 201 cache_torch: ActivationCache 202 with torch.no_grad(): 203 model.eval() 204 # TODO: batching? 205 _, cache_torch = model.run_with_cache( 206 prompt_str, 207 names_filter=names_filter_fn, 208 return_type=None, 209 ) 210 211 activations_path: Path 212 # saving and returning 213 if stack_heads: 214 n_layers: int = model.cfg.n_layers 215 key_pattern: str = "blocks.{i}.attn.hook_pattern" 216 # NOTE: this only works for stacking heads at the moment 217 # activations_specifier: str = key_pattern.format(i=f'0-{n_layers}') 218 activations_specifier: str = key_pattern.format(i="-") 219 activations_path = prompt_dir / f"activations-{activations_specifier}.npy" 220 221 # check the keys are only attention heads 222 head_keys: list[str] = [key_pattern.format(i=i) for i in range(n_layers)] 223 cache_torch_keys_set: set[str] = set(cache_torch.keys()) 224 assert cache_torch_keys_set == set(head_keys), ( 225 f"unexpected keys!\n{set(head_keys).symmetric_difference(cache_torch_keys_set) = }\n{cache_torch_keys_set} != {set(head_keys)}" 226 ) 227 228 # stack heads 229 patterns_stacked: Float[torch.Tensor, "n_layers n_heads n_ctx n_ctx"] = ( 230 torch.stack([cache_torch[k] for k in head_keys], dim=1) 231 ) 232 # check shape 233 pattern_shape_no_ctx: tuple[int, ...] = tuple(patterns_stacked.shape[:3]) 234 assert pattern_shape_no_ctx == (1, n_layers, model.cfg.n_heads), ( 235 f"unexpected shape: {patterns_stacked.shape[:3] = } ({pattern_shape_no_ctx = }), expected {(1, n_layers, model.cfg.n_heads) = }" 236 ) 237 238 patterns_stacked_np: Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"] = ( 239 patterns_stacked.cpu().numpy() 240 ) 241 242 # save 243 np.save(activations_path, patterns_stacked_np) 244 245 # return 246 match return_cache: 247 case "numpy": 248 return activations_path, patterns_stacked_np 249 case "torch": 250 return activations_path, patterns_stacked 251 case None: 252 return activations_path, None 253 case _: 254 msg = f"invalid return_cache: {return_cache = }" 255 raise ValueError(msg) 256 else: 257 activations_path = prompt_dir / "activations.npz" 258 259 # save 260 cache_np: ActivationCacheNp = { 261 k: v.detach().cpu().numpy() for k, v in cache_torch.items() 262 } 263 264 np.savez_compressed( 265 activations_path, 266 **cache_np, 267 ) 268 269 # return 270 match return_cache: 271 case "numpy": 272 return activations_path, cache_np 273 case "torch": 274 return activations_path, cache_torch 275 case None: 276 return activations_path, None 277 case _: 278 msg = f"invalid return_cache: {return_cache = }" 279 raise ValueError(msg)
get activations for a given model and prompt, possibly from a cache
if from a cache, prompt_meta must be passed and contain the prompt hash
Parameters:
prompt : dict | None(defaults toNone)model : HookedTransformersave_path : Path(defaults toPath(DATA_DIR))names_filter : Callable[[str], bool]|re.Patterna filter for the names of the activations to return. if anre.Pattern, will uselambda key: names_filter.match(key) is not None(defaults toATTN_PATTERN_REGEX)return_cache : Literal[None, "numpy", "torch"]will returnNoneas the second element ifNone, otherwise will return the cache in the specified tensor format.stack_headsstill affects whether it will be a dict (False) or a single tensor (True) (defaults toNone)stack_heads : boolwhether the heads should be stacked in the output. this causes a number of changes:npyfile with a single(n_layers, n_heads, n_ctx, n_ctx)tensor saved for each prompt instead ofnpzfile with dict by layercachewill be a single(n_layers, n_heads, n_ctx, n_ctx)tensor instead of a dict by layer ifreturn_cacheisTruewill assert that everything in the activation cache is only attention patterns, and is all of the attention patterns. raises an exception if not.
Returns:
tuple[
Path,
Union[
None,
ActivationCacheNp, ActivationCache,
Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"], Float[torch.Tensor, "n_layers n_heads n_ctx n_ctx"],
]
]
def
get_activations( prompt: dict, model: transformer_lens.HookedTransformer.HookedTransformer | str, save_path: pathlib.Path = PosixPath('attn_data'), allow_disk_cache: bool = True, return_cache: Literal[None, 'numpy', 'torch'] = 'numpy') -> tuple[pathlib.Path, dict[str, numpy.ndarray] | transformer_lens.ActivationCache.ActivationCache | None]:
306def get_activations( 307 prompt: dict, 308 model: HookedTransformer | str, 309 save_path: Path = Path(DATA_DIR), 310 allow_disk_cache: bool = True, 311 return_cache: ReturnCache = "numpy", 312) -> tuple[Path, ActivationCacheNp | ActivationCache | None]: 313 """given a prompt and a model, save or load activations 314 315 # Parameters: 316 - `prompt : dict` 317 expected to contain the 'text' key 318 - `model : HookedTransformer | str` 319 either a `HookedTransformer` or a string model name, to be loaded with `HookedTransformer.from_pretrained` 320 - `save_path : Path` 321 path to save the activations to (and load from) 322 (defaults to `Path(DATA_DIR)`) 323 - `allow_disk_cache : bool` 324 whether to allow loading from disk cache 325 (defaults to `True`) 326 - `return_cache : Literal[None, "numpy", "torch"]` 327 whether to return the cache, and in what format 328 (defaults to `"numpy"`) 329 330 # Returns: 331 - `tuple[Path, ActivationCacheNp | ActivationCache | None]` 332 the path to the activations and the cache if `return_cache is not None` 333 334 """ 335 # add hash to prompt 336 augment_prompt_with_hash(prompt) 337 338 # get the model 339 model_name: str = ( 340 model.cfg.model_name if isinstance(model, HookedTransformer) else model 341 ) 342 343 # from cache 344 if allow_disk_cache: 345 try: 346 path, cache = load_activations( 347 model_name=model_name, 348 prompt=prompt, 349 save_path=save_path, 350 ) 351 if return_cache: 352 return path, cache 353 else: 354 # TODO: this basically does nothing, since we load the activations and then immediately get rid of them. 355 # maybe refactor this so that load_activations can take a parameter to simply assert that the cache exists? 356 # this will let us avoid loading it, which slows things down 357 return path, None 358 except ActivationsMissingError: 359 pass 360 361 # compute them 362 if isinstance(model, str): 363 model = HookedTransformer.from_pretrained(model_name) 364 365 return compute_activations( 366 prompt=prompt, 367 model=model, 368 save_path=save_path, 369 return_cache=return_cache, 370 )
given a prompt and a model, save or load activations
Parameters:
prompt : dictexpected to contain the 'text' keymodel : HookedTransformer | streither aHookedTransformeror a string model name, to be loaded withHookedTransformer.from_pretrainedsave_path : Pathpath to save the activations to (and load from) (defaults toPath(DATA_DIR))allow_disk_cache : boolwhether to allow loading from disk cache (defaults toTrue)return_cache : Literal[None, "numpy", "torch"]whether to return the cache, and in what format (defaults to"numpy")
Returns:
tuple[Path, ActivationCacheNp | ActivationCache | None]the path to the activations and the cache ifreturn_cache is not None
DEFAULT_DEVICE: torch.device =
device(type='cuda')
def
activations_main( model_name: str, save_path: str, prompts_path: str, raw_prompts: bool, min_chars: int, max_chars: int, force: bool, n_samples: int, no_index_html: bool, shuffle: bool = False, stacked_heads: bool = False, device: str | torch.device = device(type='cuda')) -> None:
378def activations_main( 379 model_name: str, 380 save_path: str, 381 prompts_path: str, 382 raw_prompts: bool, 383 min_chars: int, 384 max_chars: int, 385 force: bool, 386 n_samples: int, 387 no_index_html: bool, 388 shuffle: bool = False, 389 stacked_heads: bool = False, 390 device: str | torch.device = DEFAULT_DEVICE, 391) -> None: 392 """main function for computing activations 393 394 # Parameters: 395 - `model_name : str` 396 name of a model to load with `HookedTransformer.from_pretrained` 397 - `save_path : str` 398 path to save the activations to 399 - `prompts_path : str` 400 path to the prompts file 401 - `raw_prompts : bool` 402 whether the prompts are raw, not filtered by length. `load_text_data` will be called if `True`, otherwise just load the "text" field from each line in `prompts_path` 403 - `min_chars : int` 404 minimum number of characters for a prompt 405 - `max_chars : int` 406 maximum number of characters for a prompt 407 - `force : bool` 408 whether to overwrite existing files 409 - `n_samples : int` 410 maximum number of samples to process 411 - `no_index_html : bool` 412 whether to write an index.html file 413 - `shuffle : bool` 414 whether to shuffle the prompts 415 (defaults to `False`) 416 - `stacked_heads : bool` 417 whether to stack the heads in the output tensor. will save as `.npy` instead of `.npz` if `True` 418 (defaults to `False`) 419 - `device : str | torch.device` 420 the device to use. if a string, will be passed to `torch.device` 421 """ 422 # figure out the device to use 423 device_: torch.device 424 if isinstance(device, torch.device): 425 device_ = device 426 elif isinstance(device, str): 427 device_ = torch.device(device) 428 else: 429 msg = f"invalid device: {device}" 430 raise TypeError(msg) 431 432 print(f"using device: {device_}") 433 434 with SpinnerContext(message="loading model", **SPINNER_KWARGS): 435 model: HookedTransformer = HookedTransformer.from_pretrained( 436 model_name, 437 device=device_, 438 ) 439 model.model_name = model_name 440 model.cfg.model_name = model_name 441 n_params: int = sum(p.numel() for p in model.parameters()) 442 print( 443 f"loaded {model_name} with {shorten_numerical_to_str(n_params)} ({n_params}) parameters", 444 ) 445 print(f"\tmodel devices: { {p.device for p in model.parameters()} }") 446 447 save_path_p: Path = Path(save_path) 448 save_path_p.mkdir(parents=True, exist_ok=True) 449 model_path: Path = save_path_p / model_name 450 with SpinnerContext( 451 message=f"saving model info to {model_path.as_posix()}", 452 **SPINNER_KWARGS, 453 ): 454 model_cfg: HookedTransformerConfig 455 model_cfg = model.cfg 456 model_path.mkdir(parents=True, exist_ok=True) 457 with open(model_path / "model_cfg.json", "w") as f: 458 json.dump(json_serialize(asdict(model_cfg)), f) 459 460 # load prompts 461 with SpinnerContext( 462 message=f"loading prompts from {prompts_path = }", 463 **SPINNER_KWARGS, 464 ): 465 prompts: list[dict] 466 if raw_prompts: 467 prompts = load_text_data( 468 Path(prompts_path), 469 min_chars=min_chars, 470 max_chars=max_chars, 471 shuffle=shuffle, 472 ) 473 else: 474 with open(model_path / "prompts.jsonl", "r") as f: 475 prompts = [json.loads(line) for line in f.readlines()] 476 # truncate to n_samples 477 prompts = prompts[:n_samples] 478 479 print(f"{len(prompts)} prompts loaded") 480 481 # write index.html 482 with SpinnerContext(message="writing index.html", **SPINNER_KWARGS): 483 if not no_index_html: 484 write_html_index(save_path_p) 485 486 # TODO: not implemented yet 487 if stacked_heads: 488 raise NotImplementedError("stacked_heads not implemented yet") 489 490 # get activations 491 list( 492 tqdm.tqdm( 493 map( 494 functools.partial( 495 get_activations, 496 model=model, 497 save_path=save_path_p, 498 allow_disk_cache=not force, 499 return_cache=None, 500 # stacked_heads=stacked_heads, 501 ), 502 prompts, 503 ), 504 total=len(prompts), 505 desc="Computing activations", 506 unit="prompt", 507 ), 508 ) 509 510 with SpinnerContext( 511 message="updating jsonl metadata for models and prompts", 512 **SPINNER_KWARGS, 513 ): 514 generate_models_jsonl(save_path_p) 515 generate_prompts_jsonl(save_path_p / model_name)
main function for computing activations
Parameters:
model_name : strname of a model to load withHookedTransformer.from_pretrainedsave_path : strpath to save the activations toprompts_path : strpath to the prompts fileraw_prompts : boolwhether the prompts are raw, not filtered by length.load_text_datawill be called ifTrue, otherwise just load the "text" field from each line inprompts_pathmin_chars : intminimum number of characters for a promptmax_chars : intmaximum number of characters for a promptforce : boolwhether to overwrite existing filesn_samples : intmaximum number of samples to processno_index_html : boolwhether to write an index.html fileshuffle : boolwhether to shuffle the prompts (defaults toFalse)stacked_heads : boolwhether to stack the heads in the output tensor. will save as.npyinstead of.npzifTrue(defaults toFalse)device : str | torch.devicethe device to use. if a string, will be passed totorch.device
def
main() -> None:
518def main() -> None: 519 "generate attention pattern activations for a model and prompts" 520 print(DIVIDER_S1) 521 with SpinnerContext(message="parsing args", **SPINNER_KWARGS): 522 arg_parser: argparse.ArgumentParser = argparse.ArgumentParser() 523 # input and output 524 arg_parser.add_argument( 525 "--model", 526 "-m", 527 type=str, 528 required=True, 529 help="The model name(s) to use. comma separated with no whitespace if multiple", 530 ) 531 532 arg_parser.add_argument( 533 "--prompts", 534 "-p", 535 type=str, 536 required=False, 537 help="The path to the prompts file (jsonl with 'text' key on each line). If `None`, expects that `--figures` is passed and will generate figures for all prompts in the model directory", 538 default=None, 539 ) 540 541 arg_parser.add_argument( 542 "--save-path", 543 "-s", 544 type=str, 545 required=False, 546 help="The path to save the attention patterns", 547 default=DATA_DIR, 548 ) 549 550 # min and max prompt lengths 551 arg_parser.add_argument( 552 "--min-chars", 553 type=int, 554 required=False, 555 help="The minimum number of characters for a prompt", 556 default=100, 557 ) 558 arg_parser.add_argument( 559 "--max-chars", 560 type=int, 561 required=False, 562 help="The maximum number of characters for a prompt", 563 default=1000, 564 ) 565 566 # number of samples 567 arg_parser.add_argument( 568 "--n-samples", 569 "-n", 570 type=int, 571 required=False, 572 help="The max number of samples to process, do all in the file if None", 573 default=None, 574 ) 575 576 # force overwrite 577 arg_parser.add_argument( 578 "--force", 579 "-f", 580 action="store_true", 581 help="If passed, will overwrite existing files", 582 ) 583 584 # no index html 585 arg_parser.add_argument( 586 "--no-index-html", 587 action="store_true", 588 help="If passed, will not write an index.html file for the model", 589 ) 590 591 # raw prompts 592 arg_parser.add_argument( 593 "--raw-prompts", 594 "-r", 595 action="store_true", 596 help="pass if the prompts have not been split and tokenized (still needs keys 'text' and 'meta' for each item)", 597 ) 598 599 # shuffle 600 arg_parser.add_argument( 601 "--shuffle", 602 action="store_true", 603 help="If passed, will shuffle the prompts", 604 ) 605 606 # stack heads 607 arg_parser.add_argument( 608 "--stacked-heads", 609 action="store_true", 610 help="If passed, will stack the heads in the output tensor", 611 ) 612 613 # device 614 arg_parser.add_argument( 615 "--device", 616 type=str, 617 required=False, 618 help="The device to use for the model", 619 default="cuda" if torch.cuda.is_available() else "cpu", 620 ) 621 622 args: argparse.Namespace = arg_parser.parse_args() 623 624 print(f"args parsed: {args}") 625 626 models: list[str] 627 if "," in args.model: 628 models = args.model.split(",") 629 else: 630 models = [args.model] 631 632 n_models: int = len(models) 633 for idx, model in enumerate(models): 634 print(DIVIDER_S2) 635 print(f"processing model {idx + 1} / {n_models}: {model}") 636 print(DIVIDER_S2) 637 638 activations_main( 639 model_name=model, 640 save_path=args.save_path, 641 prompts_path=args.prompts, 642 raw_prompts=args.raw_prompts, 643 min_chars=args.min_chars, 644 max_chars=args.max_chars, 645 force=args.force, 646 n_samples=args.n_samples, 647 no_index_html=args.no_index_html, 648 shuffle=args.shuffle, 649 stacked_heads=args.stacked_heads, 650 device=args.device, 651 ) 652 del model 653 654 print(DIVIDER_S1)
generate attention pattern activations for a model and prompts