Coverage for pattern_lens/activations.py: 92%
173 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-04-06 15:09 -0600
« prev ^ index » next coverage.py v7.6.12, created at 2025-04-06 15:09 -0600
1"""computing and saving activations given a model and prompts
3# Usage:
5from the command line:
7```bash
8python -m pattern_lens.activations --model <model_name> --prompts <prompts_path> --save-path <save_path> --min-chars <min_chars> --max-chars <max_chars> --n-samples <n_samples>
9```
11from a script:
13```python
14from pattern_lens.activations import activations_main
15activations_main(
16 model_name="gpt2",
17 save_path="demo/"
18 prompts_path="data/pile_1k.jsonl",
19)
20```
22"""
24import argparse
25import functools
26import json
27import re
28from collections.abc import Callable
29from dataclasses import asdict
30from pathlib import Path
31from typing import Literal, overload
33import numpy as np
34import torch
35import tqdm
36from jaxtyping import Float
37from muutils.json_serialize import json_serialize
38from muutils.misc.numerical import shorten_numerical_to_str
40# custom utils
41from muutils.spinner import SpinnerContext
42from transformer_lens import ( # type: ignore[import-untyped]
43 ActivationCache,
44 HookedTransformer,
45 HookedTransformerConfig,
46)
48# pattern_lens
49from pattern_lens.consts import (
50 ATTN_PATTERN_REGEX,
51 DATA_DIR,
52 DIVIDER_S1,
53 DIVIDER_S2,
54 SPINNER_KWARGS,
55 ActivationCacheNp,
56 ReturnCache,
57)
58from pattern_lens.indexes import (
59 generate_models_jsonl,
60 generate_prompts_jsonl,
61 write_html_index,
62)
63from pattern_lens.load_activations import (
64 ActivationsMissingError,
65 augment_prompt_with_hash,
66 load_activations,
67)
68from pattern_lens.prompts import load_text_data
71# return nothing, but `stack_heads` still affects how we save the activations
72@overload
73def compute_activations(
74 prompt: dict,
75 model: HookedTransformer | None = None,
76 save_path: Path = Path(DATA_DIR),
77 names_filter: Callable[[str], bool] | re.Pattern = ATTN_PATTERN_REGEX,
78 return_cache: Literal[None] = None,
79 stack_heads: bool = False,
80) -> tuple[Path, None]: ...
81# return stacked heads in numpy or torch form
82@overload
83def compute_activations(
84 prompt: dict,
85 model: HookedTransformer | None = None,
86 save_path: Path = Path(DATA_DIR),
87 names_filter: Callable[[str], bool] | re.Pattern = ATTN_PATTERN_REGEX,
88 return_cache: Literal["torch"] = "torch",
89 stack_heads: Literal[True] = True,
90) -> tuple[Path, Float[torch.Tensor, "n_layers n_heads n_ctx n_ctx"]]: ...
91@overload
92def compute_activations(
93 prompt: dict,
94 model: HookedTransformer | None = None,
95 save_path: Path = Path(DATA_DIR),
96 names_filter: Callable[[str], bool] | re.Pattern = ATTN_PATTERN_REGEX,
97 return_cache: Literal["numpy"] = "numpy",
98 stack_heads: Literal[True] = True,
99) -> tuple[Path, Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"]]: ...
100# return dicts in numpy or torch form
101@overload
102def compute_activations(
103 prompt: dict,
104 model: HookedTransformer | None = None,
105 save_path: Path = Path(DATA_DIR),
106 names_filter: Callable[[str], bool] | re.Pattern = ATTN_PATTERN_REGEX,
107 return_cache: Literal["numpy"] = "numpy",
108 stack_heads: Literal[False] = False,
109) -> tuple[Path, ActivationCacheNp]: ...
110@overload
111def compute_activations(
112 prompt: dict,
113 model: HookedTransformer | None = None,
114 save_path: Path = Path(DATA_DIR),
115 names_filter: Callable[[str], bool] | re.Pattern = ATTN_PATTERN_REGEX,
116 return_cache: Literal["torch"] = "torch",
117 stack_heads: Literal[False] = False,
118) -> tuple[Path, ActivationCache]: ...
119# actual function body
120def compute_activations( # noqa: PLR0915
121 prompt: dict,
122 model: HookedTransformer | None = None,
123 save_path: Path = Path(DATA_DIR),
124 names_filter: Callable[[str], bool] | re.Pattern = ATTN_PATTERN_REGEX,
125 return_cache: ReturnCache = "torch",
126 stack_heads: bool = False,
127) -> tuple[
128 Path,
129 ActivationCacheNp
130 | ActivationCache
131 | Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"]
132 | Float[torch.Tensor, "n_layers n_heads n_ctx n_ctx"]
133 | None,
134]:
135 """get activations for a given model and prompt, possibly from a cache
137 if from a cache, prompt_meta must be passed and contain the prompt hash
139 # Parameters:
140 - `prompt : dict | None`
141 (defaults to `None`)
142 - `model : HookedTransformer`
143 - `save_path : Path`
144 (defaults to `Path(DATA_DIR)`)
145 - `names_filter : Callable[[str], bool]|re.Pattern`
146 a filter for the names of the activations to return. if an `re.Pattern`, will use `lambda key: names_filter.match(key) is not None`
147 (defaults to `ATTN_PATTERN_REGEX`)
148 - `return_cache : Literal[None, "numpy", "torch"]`
149 will return `None` as the second element if `None`, otherwise will return the cache in the specified tensor format. `stack_heads` still affects whether it will be a dict (False) or a single tensor (True)
150 (defaults to `None`)
151 - `stack_heads : bool`
152 whether the heads should be stacked in the output. this causes a number of changes:
153 - `npy` file with a single `(n_layers, n_heads, n_ctx, n_ctx)` tensor saved for each prompt instead of `npz` file with dict by layer
154 - `cache` will be a single `(n_layers, n_heads, n_ctx, n_ctx)` tensor instead of a dict by layer if `return_cache` is `True`
155 will assert that everything in the activation cache is only attention patterns, and is all of the attention patterns. raises an exception if not.
157 # Returns:
158 ```
159 tuple[
160 Path,
161 Union[
162 None,
163 ActivationCacheNp, ActivationCache,
164 Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"], Float[torch.Tensor, "n_layers n_heads n_ctx n_ctx"],
165 ]
166 ]
167 ```
168 """
169 # check inputs
170 assert model is not None, "model must be passed"
171 assert "text" in prompt, "prompt must contain 'text' key"
172 prompt_str: str = prompt["text"]
174 # compute or get prompt metadata
175 prompt_tokenized: list[str] = prompt.get(
176 "tokens",
177 model.tokenizer.tokenize(prompt_str),
178 )
179 prompt.update(
180 dict(
181 n_tokens=len(prompt_tokenized),
182 tokens=prompt_tokenized,
183 ),
184 )
186 # save metadata
187 prompt_dir: Path = save_path / model.cfg.model_name / "prompts" / prompt["hash"]
188 prompt_dir.mkdir(parents=True, exist_ok=True)
189 with open(prompt_dir / "prompt.json", "w") as f:
190 json.dump(prompt, f)
192 # set up names filter
193 names_filter_fn: Callable[[str], bool]
194 if isinstance(names_filter, re.Pattern):
195 names_filter_fn = lambda key: names_filter.match(key) is not None # noqa: E731
196 else:
197 names_filter_fn = names_filter
199 # compute activations
200 cache_torch: ActivationCache
201 with torch.no_grad():
202 model.eval()
203 # TODO: batching?
204 _, cache_torch = model.run_with_cache(
205 prompt_str,
206 names_filter=names_filter_fn,
207 return_type=None,
208 )
210 activations_path: Path
211 # saving and returning
212 if stack_heads:
213 n_layers: int = model.cfg.n_layers
214 key_pattern: str = "blocks.{i}.attn.hook_pattern"
215 # NOTE: this only works for stacking heads at the moment
216 # activations_specifier: str = key_pattern.format(i=f'0-{n_layers}')
217 activations_specifier: str = key_pattern.format(i="-")
218 activations_path = prompt_dir / f"activations-{activations_specifier}.npy"
220 # check the keys are only attention heads
221 head_keys: list[str] = [key_pattern.format(i=i) for i in range(n_layers)]
222 cache_torch_keys_set: set[str] = set(cache_torch.keys())
223 assert cache_torch_keys_set == set(head_keys), (
224 f"unexpected keys!\n{set(head_keys).symmetric_difference(cache_torch_keys_set) = }\n{cache_torch_keys_set} != {set(head_keys)}"
225 )
227 # stack heads
228 patterns_stacked: Float[torch.Tensor, "n_layers n_heads n_ctx n_ctx"] = (
229 torch.stack([cache_torch[k] for k in head_keys], dim=1)
230 )
231 # check shape
232 pattern_shape_no_ctx: tuple[int, ...] = tuple(patterns_stacked.shape[:3])
233 assert pattern_shape_no_ctx == (1, n_layers, model.cfg.n_heads), (
234 f"unexpected shape: {patterns_stacked.shape[:3] = } ({pattern_shape_no_ctx = }), expected {(1, n_layers, model.cfg.n_heads) = }"
235 )
237 patterns_stacked_np: Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"] = (
238 patterns_stacked.cpu().numpy()
239 )
241 # save
242 np.save(activations_path, patterns_stacked_np)
244 # return
245 match return_cache:
246 case "numpy":
247 return activations_path, patterns_stacked_np
248 case "torch":
249 return activations_path, patterns_stacked
250 case None:
251 return activations_path, None
252 case _:
253 msg = f"invalid return_cache: {return_cache = }"
254 raise ValueError(msg)
255 else:
256 activations_path = prompt_dir / "activations.npz"
258 # save
259 cache_np: ActivationCacheNp = {
260 k: v.detach().cpu().numpy() for k, v in cache_torch.items()
261 }
263 np.savez_compressed(
264 activations_path,
265 **cache_np,
266 )
268 # return
269 match return_cache:
270 case "numpy":
271 return activations_path, cache_np
272 case "torch":
273 return activations_path, cache_torch
274 case None:
275 return activations_path, None
276 case _:
277 msg = f"invalid return_cache: {return_cache = }"
278 raise ValueError(msg)
281@overload
282def get_activations(
283 prompt: dict,
284 model: HookedTransformer | str,
285 save_path: Path = Path(DATA_DIR),
286 allow_disk_cache: bool = True,
287 return_cache: Literal[None] = None,
288) -> tuple[Path, None]: ...
289@overload
290def get_activations(
291 prompt: dict,
292 model: HookedTransformer | str,
293 save_path: Path = Path(DATA_DIR),
294 allow_disk_cache: bool = True,
295 return_cache: Literal["torch"] = "torch",
296) -> tuple[Path, ActivationCache]: ...
297@overload
298def get_activations(
299 prompt: dict,
300 model: HookedTransformer | str,
301 save_path: Path = Path(DATA_DIR),
302 allow_disk_cache: bool = True,
303 return_cache: Literal["numpy"] = "numpy",
304) -> tuple[Path, ActivationCacheNp]: ...
305def get_activations(
306 prompt: dict,
307 model: HookedTransformer | str,
308 save_path: Path = Path(DATA_DIR),
309 allow_disk_cache: bool = True,
310 return_cache: ReturnCache = "numpy",
311) -> tuple[Path, ActivationCacheNp | ActivationCache | None]:
312 """given a prompt and a model, save or load activations
314 # Parameters:
315 - `prompt : dict`
316 expected to contain the 'text' key
317 - `model : HookedTransformer | str`
318 either a `HookedTransformer` or a string model name, to be loaded with `HookedTransformer.from_pretrained`
319 - `save_path : Path`
320 path to save the activations to (and load from)
321 (defaults to `Path(DATA_DIR)`)
322 - `allow_disk_cache : bool`
323 whether to allow loading from disk cache
324 (defaults to `True`)
325 - `return_cache : Literal[None, "numpy", "torch"]`
326 whether to return the cache, and in what format
327 (defaults to `"numpy"`)
329 # Returns:
330 - `tuple[Path, ActivationCacheNp | ActivationCache | None]`
331 the path to the activations and the cache if `return_cache is not None`
333 """
334 # add hash to prompt
335 augment_prompt_with_hash(prompt)
337 # get the model
338 model_name: str = (
339 model.cfg.model_name if isinstance(model, HookedTransformer) else model
340 )
342 # from cache
343 if allow_disk_cache:
344 try:
345 path, cache = load_activations(
346 model_name=model_name,
347 prompt=prompt,
348 save_path=save_path,
349 )
350 if return_cache:
351 return path, cache
352 else:
353 # TODO: this basically does nothing, since we load the activations and then immediately get rid of them.
354 # maybe refactor this so that load_activations can take a parameter to simply assert that the cache exists?
355 # this will let us avoid loading it, which slows things down
356 return path, None
357 except ActivationsMissingError:
358 pass
360 # compute them
361 if isinstance(model, str):
362 model = HookedTransformer.from_pretrained(model_name)
364 return compute_activations(
365 prompt=prompt,
366 model=model,
367 save_path=save_path,
368 return_cache=return_cache,
369 )
372DEFAULT_DEVICE: torch.device = torch.device(
373 "cuda" if torch.cuda.is_available() else "cpu",
374)
377def activations_main(
378 model_name: str,
379 save_path: str,
380 prompts_path: str,
381 raw_prompts: bool,
382 min_chars: int,
383 max_chars: int,
384 force: bool,
385 n_samples: int,
386 no_index_html: bool,
387 shuffle: bool = False,
388 stacked_heads: bool = False,
389 device: str | torch.device = DEFAULT_DEVICE,
390) -> None:
391 """main function for computing activations
393 # Parameters:
394 - `model_name : str`
395 name of a model to load with `HookedTransformer.from_pretrained`
396 - `save_path : str`
397 path to save the activations to
398 - `prompts_path : str`
399 path to the prompts file
400 - `raw_prompts : bool`
401 whether the prompts are raw, not filtered by length. `load_text_data` will be called if `True`, otherwise just load the "text" field from each line in `prompts_path`
402 - `min_chars : int`
403 minimum number of characters for a prompt
404 - `max_chars : int`
405 maximum number of characters for a prompt
406 - `force : bool`
407 whether to overwrite existing files
408 - `n_samples : int`
409 maximum number of samples to process
410 - `no_index_html : bool`
411 whether to write an index.html file
412 - `shuffle : bool`
413 whether to shuffle the prompts
414 (defaults to `False`)
415 - `stacked_heads : bool`
416 whether to stack the heads in the output tensor. will save as `.npy` instead of `.npz` if `True`
417 (defaults to `False`)
418 - `device : str | torch.device`
419 the device to use. if a string, will be passed to `torch.device`
420 """
421 # figure out the device to use
422 device_: torch.device
423 if isinstance(device, torch.device):
424 device_ = device
425 elif isinstance(device, str):
426 device_ = torch.device(device)
427 else:
428 msg = f"invalid device: {device}"
429 raise TypeError(msg)
431 print(f"using device: {device_}")
433 with SpinnerContext(message="loading model", **SPINNER_KWARGS):
434 model: HookedTransformer = HookedTransformer.from_pretrained(
435 model_name,
436 device=device_,
437 )
438 model.model_name = model_name
439 model.cfg.model_name = model_name
440 n_params: int = sum(p.numel() for p in model.parameters())
441 print(
442 f"loaded {model_name} with {shorten_numerical_to_str(n_params)} ({n_params}) parameters",
443 )
444 print(f"\tmodel devices: { {p.device for p in model.parameters()} }")
446 save_path_p: Path = Path(save_path)
447 save_path_p.mkdir(parents=True, exist_ok=True)
448 model_path: Path = save_path_p / model_name
449 with SpinnerContext(
450 message=f"saving model info to {model_path.as_posix()}",
451 **SPINNER_KWARGS,
452 ):
453 model_cfg: HookedTransformerConfig
454 model_cfg = model.cfg
455 model_path.mkdir(parents=True, exist_ok=True)
456 with open(model_path / "model_cfg.json", "w") as f:
457 json.dump(json_serialize(asdict(model_cfg)), f)
459 # load prompts
460 with SpinnerContext(
461 message=f"loading prompts from {prompts_path = }",
462 **SPINNER_KWARGS,
463 ):
464 prompts: list[dict]
465 if raw_prompts:
466 prompts = load_text_data(
467 Path(prompts_path),
468 min_chars=min_chars,
469 max_chars=max_chars,
470 shuffle=shuffle,
471 )
472 else:
473 with open(model_path / "prompts.jsonl", "r") as f:
474 prompts = [json.loads(line) for line in f.readlines()]
475 # truncate to n_samples
476 prompts = prompts[:n_samples]
478 print(f"{len(prompts)} prompts loaded")
480 # write index.html
481 with SpinnerContext(message="writing index.html", **SPINNER_KWARGS):
482 if not no_index_html:
483 write_html_index(save_path_p)
485 # TODO: not implemented yet
486 if stacked_heads:
487 raise NotImplementedError("stacked_heads not implemented yet")
489 # get activations
490 list(
491 tqdm.tqdm(
492 map(
493 functools.partial(
494 get_activations,
495 model=model,
496 save_path=save_path_p,
497 allow_disk_cache=not force,
498 return_cache=None,
499 # stacked_heads=stacked_heads,
500 ),
501 prompts,
502 ),
503 total=len(prompts),
504 desc="Computing activations",
505 unit="prompt",
506 ),
507 )
509 with SpinnerContext(
510 message="updating jsonl metadata for models and prompts",
511 **SPINNER_KWARGS,
512 ):
513 generate_models_jsonl(save_path_p)
514 generate_prompts_jsonl(save_path_p / model_name)
517def main() -> None:
518 "generate attention pattern activations for a model and prompts"
519 print(DIVIDER_S1)
520 with SpinnerContext(message="parsing args", **SPINNER_KWARGS):
521 arg_parser: argparse.ArgumentParser = argparse.ArgumentParser()
522 # input and output
523 arg_parser.add_argument(
524 "--model",
525 "-m",
526 type=str,
527 required=True,
528 help="The model name(s) to use. comma separated with no whitespace if multiple",
529 )
531 arg_parser.add_argument(
532 "--prompts",
533 "-p",
534 type=str,
535 required=False,
536 help="The path to the prompts file (jsonl with 'text' key on each line). If `None`, expects that `--figures` is passed and will generate figures for all prompts in the model directory",
537 default=None,
538 )
540 arg_parser.add_argument(
541 "--save-path",
542 "-s",
543 type=str,
544 required=False,
545 help="The path to save the attention patterns",
546 default=DATA_DIR,
547 )
549 # min and max prompt lengths
550 arg_parser.add_argument(
551 "--min-chars",
552 type=int,
553 required=False,
554 help="The minimum number of characters for a prompt",
555 default=100,
556 )
557 arg_parser.add_argument(
558 "--max-chars",
559 type=int,
560 required=False,
561 help="The maximum number of characters for a prompt",
562 default=1000,
563 )
565 # number of samples
566 arg_parser.add_argument(
567 "--n-samples",
568 "-n",
569 type=int,
570 required=False,
571 help="The max number of samples to process, do all in the file if None",
572 default=None,
573 )
575 # force overwrite
576 arg_parser.add_argument(
577 "--force",
578 "-f",
579 action="store_true",
580 help="If passed, will overwrite existing files",
581 )
583 # no index html
584 arg_parser.add_argument(
585 "--no-index-html",
586 action="store_true",
587 help="If passed, will not write an index.html file for the model",
588 )
590 # raw prompts
591 arg_parser.add_argument(
592 "--raw-prompts",
593 "-r",
594 action="store_true",
595 help="pass if the prompts have not been split and tokenized (still needs keys 'text' and 'meta' for each item)",
596 )
598 # shuffle
599 arg_parser.add_argument(
600 "--shuffle",
601 action="store_true",
602 help="If passed, will shuffle the prompts",
603 )
605 # stack heads
606 arg_parser.add_argument(
607 "--stacked-heads",
608 action="store_true",
609 help="If passed, will stack the heads in the output tensor",
610 )
612 # device
613 arg_parser.add_argument(
614 "--device",
615 type=str,
616 required=False,
617 help="The device to use for the model",
618 default="cuda" if torch.cuda.is_available() else "cpu",
619 )
621 args: argparse.Namespace = arg_parser.parse_args()
623 print(f"args parsed: {args}")
625 models: list[str]
626 if "," in args.model:
627 models = args.model.split(",")
628 else:
629 models = [args.model]
631 n_models: int = len(models)
632 for idx, model in enumerate(models):
633 print(DIVIDER_S2)
634 print(f"processing model {idx + 1} / {n_models}: {model}")
635 print(DIVIDER_S2)
637 activations_main(
638 model_name=model,
639 save_path=args.save_path,
640 prompts_path=args.prompts,
641 raw_prompts=args.raw_prompts,
642 min_chars=args.min_chars,
643 max_chars=args.max_chars,
644 force=args.force,
645 n_samples=args.n_samples,
646 no_index_html=args.no_index_html,
647 shuffle=args.shuffle,
648 stacked_heads=args.stacked_heads,
649 device=args.device,
650 )
651 del model
653 print(DIVIDER_S1)
656if __name__ == "__main__":
657 main()