Coverage for pattern_lens\activations.py: 67%
123 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-16 20:39 -0700
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-16 20:39 -0700
1"""computing and saving activations given a model and prompts
4# Usage:
6from the command line:
8```bash
9python -m pattern_lens.activations --model <model_name> --prompts <prompts_path> --save-path <save_path> --min-chars <min_chars> --max-chars <max_chars> --n-samples <n_samples>
10```
12from a script:
14```python
15from pattern_lens.activations import activations_main
16activations_main(
17 model_name="gpt2",
18 save_path="demo/"
19 prompts_path="data/pile_1k.jsonl",
20)
21```
23"""
25import argparse
26import functools
27import json
28from dataclasses import asdict
29from pathlib import Path
30import re
31from typing import Callable, Literal, overload
33import numpy as np
34import torch
35import tqdm
36from muutils.spinner import SpinnerContext
37from muutils.misc.numerical import shorten_numerical_to_str
38from muutils.json_serialize import json_serialize
39from transformer_lens import HookedTransformer, HookedTransformerConfig # type: ignore[import-untyped]
41from pattern_lens.consts import (
42 ATTN_PATTERN_REGEX,
43 DATA_DIR,
44 ActivationCacheNp,
45 SPINNER_KWARGS,
46 DIVIDER_S1,
47 DIVIDER_S2,
48)
49from pattern_lens.indexes import (
50 generate_models_jsonl,
51 generate_prompts_jsonl,
52 write_html_index,
53)
54from pattern_lens.load_activations import (
55 ActivationsMissingError,
56 augment_prompt_with_hash,
57 load_activations,
58)
59from pattern_lens.prompts import load_text_data
62def compute_activations(
63 prompt: dict,
64 model: HookedTransformer | None = None,
65 save_path: Path = Path(DATA_DIR),
66 return_cache: bool = True,
67 names_filter: Callable[[str], bool] | re.Pattern = ATTN_PATTERN_REGEX,
68) -> tuple[Path, ActivationCacheNp | None]:
69 """get activations for a given model and prompt, possibly from a cache
71 if from a cache, prompt_meta must be passed and contain the prompt hash
73 # Parameters:
74 - `prompt : dict | None`
75 (defaults to `None`)
76 - `model : HookedTransformer`
77 - `save_path : Path`
78 (defaults to `Path(DATA_DIR)`)
79 - `return_cache : bool`
80 will return `None` as the second element if `False`
81 (defaults to `True`)
82 - `names_filter : Callable[[str], bool]|re.Pattern`
83 a filter for the names of the activations to return. if an `re.Pattern`, will use `lambda key: names_filter.match(key) is not None`
84 (defaults to `ATTN_PATTERN_REGEX`)
86 # Returns:
87 - `tuple[Path, ActivationCacheNp|None]`
88 """
89 assert model is not None, "model must be passed"
90 assert "text" in prompt, "prompt must contain 'text' key"
91 prompt_str: str = prompt["text"]
93 # compute or get prompt metadata
94 prompt_tokenized: list[str] = prompt.get(
95 "tokens",
96 model.tokenizer.tokenize(prompt_str),
97 )
98 prompt.update(
99 dict(
100 n_tokens=len(prompt_tokenized),
101 tokens=prompt_tokenized,
102 )
103 )
105 # save metadata
106 prompt_dir: Path = save_path / model.model_name / "prompts" / prompt["hash"]
107 prompt_dir.mkdir(parents=True, exist_ok=True)
108 with open(prompt_dir / "prompt.json", "w") as f:
109 json.dump(prompt, f)
111 # set up names filter
112 names_filter_fn: Callable[[str], bool]
113 if isinstance(names_filter, re.Pattern):
114 names_filter_fn = lambda key: names_filter.match(key) is not None # noqa: E731
115 else:
116 names_filter_fn = names_filter
118 # compute activations
119 with torch.no_grad():
120 model.eval()
121 # TODO: batching?
122 _, cache = model.run_with_cache(
123 prompt_str,
124 names_filter=names_filter_fn,
125 return_type=None,
126 )
128 cache_np: ActivationCacheNp = {
129 k: v.detach().cpu().numpy() for k, v in cache.items()
130 }
132 # save activations
133 activations_path: Path = prompt_dir / "activations.npz"
134 np.savez_compressed(
135 activations_path,
136 **cache_np,
137 )
139 # return path and cache
140 if return_cache:
141 return activations_path, cache_np
142 else:
143 return activations_path, None
146@overload
147def get_activations(
148 prompt: dict,
149 model: HookedTransformer | str,
150 save_path: Path = Path(DATA_DIR),
151 allow_disk_cache: bool = True,
152 return_cache: Literal[False] = False,
153) -> tuple[Path, None]: ...
154@overload
155def get_activations(
156 prompt: dict,
157 model: HookedTransformer | str,
158 save_path: Path = Path(DATA_DIR),
159 allow_disk_cache: bool = True,
160 return_cache: Literal[True] = True,
161) -> tuple[Path, ActivationCacheNp]: ...
162def get_activations(
163 prompt: dict,
164 model: HookedTransformer | str,
165 save_path: Path = Path(DATA_DIR),
166 allow_disk_cache: bool = True,
167 return_cache: bool = True,
168) -> tuple[Path, ActivationCacheNp | None]:
169 """given a prompt and a model, save or load activations
171 # Parameters:
172 - `prompt : dict`
173 expected to contain the 'text' key
174 - `model : HookedTransformer | str`
175 either a `HookedTransformer` or a string model name, to be loaded with `HookedTransformer.from_pretrained`
176 - `save_path : Path`
177 path to save the activations to (and load from)
178 (defaults to `Path(DATA_DIR)`)
179 - `allow_disk_cache : bool`
180 whether to allow loading from disk cache
181 (defaults to `True`)
182 - `return_cache : bool`
183 whether to return the cache. if `False`, will return `None` as the second element
184 (defaults to `True`)
186 # Returns:
187 - `tuple[Path, ActivationCacheNp | None]`
188 the path to the activations and the cache if `return_cache` is `True`
190 """
191 # add hash to prompt
192 augment_prompt_with_hash(prompt)
194 # get the model
195 model_name: str = (
196 model.model_name if isinstance(model, HookedTransformer) else model
197 )
199 # from cache
200 if allow_disk_cache:
201 try:
202 path, cache = load_activations(
203 model_name=model_name,
204 prompt=prompt,
205 save_path=save_path,
206 )
207 if return_cache:
208 return path, cache
209 else:
210 return path, None
211 except ActivationsMissingError:
212 pass
214 # compute them
215 if isinstance(model, str):
216 model = HookedTransformer.from_pretrained(model_name)
218 return compute_activations(
219 prompt=prompt,
220 model=model,
221 save_path=save_path,
222 return_cache=True,
223 )
226def activations_main(
227 model_name: str,
228 save_path: str,
229 prompts_path: str,
230 raw_prompts: bool,
231 min_chars: int,
232 max_chars: int,
233 force: bool,
234 n_samples: int,
235 no_index_html: bool,
236 shuffle: bool = False,
237 device: str | torch.device = "cuda" if torch.cuda.is_available() else "cpu",
238) -> None:
239 """main function for computing activations
241 # Parameters:
242 - `model_name : str`
243 name of a model to load with `HookedTransformer.from_pretrained`
244 - `save_path : str`
245 path to save the activations to
246 - `prompts_path : str`
247 path to the prompts file
248 - `raw_prompts : bool`
249 whether the prompts are raw, not filtered by length. `load_text_data` will be called if `True`, otherwise just load the "text" field from each line in `prompts_path`
250 - `min_chars : int`
251 minimum number of characters for a prompt
252 - `max_chars : int`
253 maximum number of characters for a prompt
254 - `force : bool`
255 whether to overwrite existing files
256 - `n_samples : int`
257 maximum number of samples to process
258 - `no_index_html : bool`
259 whether to write an index.html file
260 - `shuffle : bool`
261 whether to shuffle the prompts
262 (defaults to `False`)
263 - `device : str | torch.device`
264 the device to use. if a string, will be passed to `torch.device`
265 """
267 # figure out the device to use
268 device_: torch.device
269 if isinstance(device, torch.device):
270 device_ = device
271 elif isinstance(device, str):
272 device_ = torch.device(device)
273 else:
274 raise ValueError(f"invalid device: {device}")
276 print(f"using device: {device_}")
278 with SpinnerContext(message="loading model", **SPINNER_KWARGS):
279 model: HookedTransformer = HookedTransformer.from_pretrained(
280 model_name, device=device_
281 )
282 model.model_name = model_name
283 model.cfg.model_name = model_name
284 n_params: int = sum(p.numel() for p in model.parameters())
285 print(
286 f"loaded {model_name} with {shorten_numerical_to_str(n_params)} ({n_params}) parameters"
287 )
288 print(f"\tmodel devices: {set(p.device for p in model.parameters())}")
290 save_path_p: Path = Path(save_path)
291 save_path_p.mkdir(parents=True, exist_ok=True)
292 model_path: Path = save_path_p / model_name
293 with SpinnerContext(
294 message=f"saving model info to {model_path.as_posix()}", **SPINNER_KWARGS
295 ):
296 model_cfg: HookedTransformerConfig
297 model_cfg = model.cfg
298 model_path.mkdir(parents=True, exist_ok=True)
299 with open(model_path / "model_cfg.json", "w") as f:
300 json.dump(json_serialize(asdict(model_cfg)), f)
302 # load prompts
303 with SpinnerContext(
304 message=f"loading prompts from {prompts_path = }", **SPINNER_KWARGS
305 ):
306 prompts: list[dict]
307 if raw_prompts:
308 prompts = load_text_data(
309 Path(prompts_path),
310 min_chars=min_chars,
311 max_chars=max_chars,
312 shuffle=shuffle,
313 )
314 else:
315 with open(model_path / "prompts.jsonl", "r") as f:
316 prompts = [json.loads(line) for line in f.readlines()]
317 # truncate to n_samples
318 prompts = prompts[:n_samples]
320 print(f"{len(prompts)} prompts loaded")
322 # write index.html
323 with SpinnerContext(message="writing index.html", **SPINNER_KWARGS):
324 if not no_index_html:
325 write_html_index(save_path_p)
327 # get activations
328 list(
329 tqdm.tqdm(
330 map(
331 functools.partial(
332 get_activations,
333 model=model,
334 save_path=save_path_p,
335 allow_disk_cache=not force,
336 return_cache=False,
337 ),
338 prompts,
339 ),
340 total=len(prompts),
341 desc="Computing activations",
342 unit="prompt",
343 )
344 )
346 with SpinnerContext(
347 message="updating jsonl metadata for models and prompts", **SPINNER_KWARGS
348 ):
349 generate_models_jsonl(save_path_p)
350 generate_prompts_jsonl(save_path_p / model_name)
353def main():
354 print(DIVIDER_S1)
355 with SpinnerContext(message="parsing args", **SPINNER_KWARGS):
356 arg_parser: argparse.ArgumentParser = argparse.ArgumentParser()
357 # input and output
358 arg_parser.add_argument(
359 "--model",
360 "-m",
361 type=str,
362 required=True,
363 help="The model name(s) to use. comma separated with no whitespace if multiple",
364 )
366 arg_parser.add_argument(
367 "--prompts",
368 "-p",
369 type=str,
370 required=False,
371 help="The path to the prompts file (jsonl with 'text' key on each line). If `None`, expects that `--figures` is passed and will generate figures for all prompts in the model directory",
372 default=None,
373 )
375 arg_parser.add_argument(
376 "--save-path",
377 "-s",
378 type=str,
379 required=False,
380 help="The path to save the attention patterns",
381 default=DATA_DIR,
382 )
384 # min and max prompt lengths
385 arg_parser.add_argument(
386 "--min-chars",
387 type=int,
388 required=False,
389 help="The minimum number of characters for a prompt",
390 default=100,
391 )
392 arg_parser.add_argument(
393 "--max-chars",
394 type=int,
395 required=False,
396 help="The maximum number of characters for a prompt",
397 default=1000,
398 )
400 # number of samples
401 arg_parser.add_argument(
402 "--n-samples",
403 "-n",
404 type=int,
405 required=False,
406 help="The max number of samples to process, do all in the file if None",
407 default=None,
408 )
410 # force overwrite
411 arg_parser.add_argument(
412 "--force",
413 "-f",
414 action="store_true",
415 help="If passed, will overwrite existing files",
416 )
418 # no index html
419 arg_parser.add_argument(
420 "--no-index-html",
421 action="store_true",
422 help="If passed, will not write an index.html file for the model",
423 )
425 # raw prompts
426 arg_parser.add_argument(
427 "--raw-prompts",
428 "-r",
429 action="store_true",
430 help="pass if the prompts have not been split and tokenized (still needs keys 'text' and 'meta' for each item)",
431 )
433 # shuffle
434 arg_parser.add_argument(
435 "--shuffle",
436 action="store_true",
437 help="If passed, will shuffle the prompts",
438 )
440 # device
441 arg_parser.add_argument(
442 "--device",
443 type=str,
444 required=False,
445 help="The device to use for the model",
446 default="cuda" if torch.cuda.is_available() else "cpu",
447 )
449 args: argparse.Namespace = arg_parser.parse_args()
451 print(f"args parsed: {args}")
453 models: list[str]
454 if "," in args.model:
455 models = args.model.split(",")
456 else:
457 models = [args.model]
459 n_models: int = len(models)
460 for idx, model in enumerate(models):
461 print(DIVIDER_S2)
462 print(f"processing model {idx+1} / {n_models}: {model}")
463 print(DIVIDER_S2)
465 activations_main(
466 model_name=model,
467 save_path=args.save_path,
468 prompts_path=args.prompts,
469 raw_prompts=args.raw_prompts,
470 min_chars=args.min_chars,
471 max_chars=args.max_chars,
472 force=args.force,
473 n_samples=args.n_samples,
474 no_index_html=args.no_index_html,
475 shuffle=args.shuffle,
476 device=args.device,
477 )
479 print(DIVIDER_S1)
482if __name__ == "__main__":
483 main()