Coverage for pattern_lens/activations.py: 92%

173 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-04-06 15:09 -0600

1"""computing and saving activations given a model and prompts 

2 

3# Usage: 

4 

5from the command line: 

6 

7```bash 

8python -m pattern_lens.activations --model <model_name> --prompts <prompts_path> --save-path <save_path> --min-chars <min_chars> --max-chars <max_chars> --n-samples <n_samples> 

9``` 

10 

11from a script: 

12 

13```python 

14from pattern_lens.activations import activations_main 

15activations_main( 

16 model_name="gpt2", 

17 save_path="demo/" 

18 prompts_path="data/pile_1k.jsonl", 

19) 

20``` 

21 

22""" 

23 

24import argparse 

25import functools 

26import json 

27import re 

28from collections.abc import Callable 

29from dataclasses import asdict 

30from pathlib import Path 

31from typing import Literal, overload 

32 

33import numpy as np 

34import torch 

35import tqdm 

36from jaxtyping import Float 

37from muutils.json_serialize import json_serialize 

38from muutils.misc.numerical import shorten_numerical_to_str 

39 

40# custom utils 

41from muutils.spinner import SpinnerContext 

42from transformer_lens import ( # type: ignore[import-untyped] 

43 ActivationCache, 

44 HookedTransformer, 

45 HookedTransformerConfig, 

46) 

47 

48# pattern_lens 

49from pattern_lens.consts import ( 

50 ATTN_PATTERN_REGEX, 

51 DATA_DIR, 

52 DIVIDER_S1, 

53 DIVIDER_S2, 

54 SPINNER_KWARGS, 

55 ActivationCacheNp, 

56 ReturnCache, 

57) 

58from pattern_lens.indexes import ( 

59 generate_models_jsonl, 

60 generate_prompts_jsonl, 

61 write_html_index, 

62) 

63from pattern_lens.load_activations import ( 

64 ActivationsMissingError, 

65 augment_prompt_with_hash, 

66 load_activations, 

67) 

68from pattern_lens.prompts import load_text_data 

69 

70 

71# return nothing, but `stack_heads` still affects how we save the activations 

72@overload 

73def compute_activations( 

74 prompt: dict, 

75 model: HookedTransformer | None = None, 

76 save_path: Path = Path(DATA_DIR), 

77 names_filter: Callable[[str], bool] | re.Pattern = ATTN_PATTERN_REGEX, 

78 return_cache: Literal[None] = None, 

79 stack_heads: bool = False, 

80) -> tuple[Path, None]: ... 

81# return stacked heads in numpy or torch form 

82@overload 

83def compute_activations( 

84 prompt: dict, 

85 model: HookedTransformer | None = None, 

86 save_path: Path = Path(DATA_DIR), 

87 names_filter: Callable[[str], bool] | re.Pattern = ATTN_PATTERN_REGEX, 

88 return_cache: Literal["torch"] = "torch", 

89 stack_heads: Literal[True] = True, 

90) -> tuple[Path, Float[torch.Tensor, "n_layers n_heads n_ctx n_ctx"]]: ... 

91@overload 

92def compute_activations( 

93 prompt: dict, 

94 model: HookedTransformer | None = None, 

95 save_path: Path = Path(DATA_DIR), 

96 names_filter: Callable[[str], bool] | re.Pattern = ATTN_PATTERN_REGEX, 

97 return_cache: Literal["numpy"] = "numpy", 

98 stack_heads: Literal[True] = True, 

99) -> tuple[Path, Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"]]: ... 

100# return dicts in numpy or torch form 

101@overload 

102def compute_activations( 

103 prompt: dict, 

104 model: HookedTransformer | None = None, 

105 save_path: Path = Path(DATA_DIR), 

106 names_filter: Callable[[str], bool] | re.Pattern = ATTN_PATTERN_REGEX, 

107 return_cache: Literal["numpy"] = "numpy", 

108 stack_heads: Literal[False] = False, 

109) -> tuple[Path, ActivationCacheNp]: ... 

110@overload 

111def compute_activations( 

112 prompt: dict, 

113 model: HookedTransformer | None = None, 

114 save_path: Path = Path(DATA_DIR), 

115 names_filter: Callable[[str], bool] | re.Pattern = ATTN_PATTERN_REGEX, 

116 return_cache: Literal["torch"] = "torch", 

117 stack_heads: Literal[False] = False, 

118) -> tuple[Path, ActivationCache]: ... 

119# actual function body 

120def compute_activations( # noqa: PLR0915 

121 prompt: dict, 

122 model: HookedTransformer | None = None, 

123 save_path: Path = Path(DATA_DIR), 

124 names_filter: Callable[[str], bool] | re.Pattern = ATTN_PATTERN_REGEX, 

125 return_cache: ReturnCache = "torch", 

126 stack_heads: bool = False, 

127) -> tuple[ 

128 Path, 

129 ActivationCacheNp 

130 | ActivationCache 

131 | Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"] 

132 | Float[torch.Tensor, "n_layers n_heads n_ctx n_ctx"] 

133 | None, 

134]: 

135 """get activations for a given model and prompt, possibly from a cache 

136 

137 if from a cache, prompt_meta must be passed and contain the prompt hash 

138 

139 # Parameters: 

140 - `prompt : dict | None` 

141 (defaults to `None`) 

142 - `model : HookedTransformer` 

143 - `save_path : Path` 

144 (defaults to `Path(DATA_DIR)`) 

145 - `names_filter : Callable[[str], bool]|re.Pattern` 

146 a filter for the names of the activations to return. if an `re.Pattern`, will use `lambda key: names_filter.match(key) is not None` 

147 (defaults to `ATTN_PATTERN_REGEX`) 

148 - `return_cache : Literal[None, "numpy", "torch"]` 

149 will return `None` as the second element if `None`, otherwise will return the cache in the specified tensor format. `stack_heads` still affects whether it will be a dict (False) or a single tensor (True) 

150 (defaults to `None`) 

151 - `stack_heads : bool` 

152 whether the heads should be stacked in the output. this causes a number of changes: 

153 - `npy` file with a single `(n_layers, n_heads, n_ctx, n_ctx)` tensor saved for each prompt instead of `npz` file with dict by layer 

154 - `cache` will be a single `(n_layers, n_heads, n_ctx, n_ctx)` tensor instead of a dict by layer if `return_cache` is `True` 

155 will assert that everything in the activation cache is only attention patterns, and is all of the attention patterns. raises an exception if not. 

156 

157 # Returns: 

158 ``` 

159 tuple[ 

160 Path, 

161 Union[ 

162 None, 

163 ActivationCacheNp, ActivationCache, 

164 Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"], Float[torch.Tensor, "n_layers n_heads n_ctx n_ctx"], 

165 ] 

166 ] 

167 ``` 

168 """ 

169 # check inputs 

170 assert model is not None, "model must be passed" 

171 assert "text" in prompt, "prompt must contain 'text' key" 

172 prompt_str: str = prompt["text"] 

173 

174 # compute or get prompt metadata 

175 prompt_tokenized: list[str] = prompt.get( 

176 "tokens", 

177 model.tokenizer.tokenize(prompt_str), 

178 ) 

179 prompt.update( 

180 dict( 

181 n_tokens=len(prompt_tokenized), 

182 tokens=prompt_tokenized, 

183 ), 

184 ) 

185 

186 # save metadata 

187 prompt_dir: Path = save_path / model.cfg.model_name / "prompts" / prompt["hash"] 

188 prompt_dir.mkdir(parents=True, exist_ok=True) 

189 with open(prompt_dir / "prompt.json", "w") as f: 

190 json.dump(prompt, f) 

191 

192 # set up names filter 

193 names_filter_fn: Callable[[str], bool] 

194 if isinstance(names_filter, re.Pattern): 

195 names_filter_fn = lambda key: names_filter.match(key) is not None # noqa: E731 

196 else: 

197 names_filter_fn = names_filter 

198 

199 # compute activations 

200 cache_torch: ActivationCache 

201 with torch.no_grad(): 

202 model.eval() 

203 # TODO: batching? 

204 _, cache_torch = model.run_with_cache( 

205 prompt_str, 

206 names_filter=names_filter_fn, 

207 return_type=None, 

208 ) 

209 

210 activations_path: Path 

211 # saving and returning 

212 if stack_heads: 

213 n_layers: int = model.cfg.n_layers 

214 key_pattern: str = "blocks.{i}.attn.hook_pattern" 

215 # NOTE: this only works for stacking heads at the moment 

216 # activations_specifier: str = key_pattern.format(i=f'0-{n_layers}') 

217 activations_specifier: str = key_pattern.format(i="-") 

218 activations_path = prompt_dir / f"activations-{activations_specifier}.npy" 

219 

220 # check the keys are only attention heads 

221 head_keys: list[str] = [key_pattern.format(i=i) for i in range(n_layers)] 

222 cache_torch_keys_set: set[str] = set(cache_torch.keys()) 

223 assert cache_torch_keys_set == set(head_keys), ( 

224 f"unexpected keys!\n{set(head_keys).symmetric_difference(cache_torch_keys_set) = }\n{cache_torch_keys_set} != {set(head_keys)}" 

225 ) 

226 

227 # stack heads 

228 patterns_stacked: Float[torch.Tensor, "n_layers n_heads n_ctx n_ctx"] = ( 

229 torch.stack([cache_torch[k] for k in head_keys], dim=1) 

230 ) 

231 # check shape 

232 pattern_shape_no_ctx: tuple[int, ...] = tuple(patterns_stacked.shape[:3]) 

233 assert pattern_shape_no_ctx == (1, n_layers, model.cfg.n_heads), ( 

234 f"unexpected shape: {patterns_stacked.shape[:3] = } ({pattern_shape_no_ctx = }), expected {(1, n_layers, model.cfg.n_heads) = }" 

235 ) 

236 

237 patterns_stacked_np: Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"] = ( 

238 patterns_stacked.cpu().numpy() 

239 ) 

240 

241 # save 

242 np.save(activations_path, patterns_stacked_np) 

243 

244 # return 

245 match return_cache: 

246 case "numpy": 

247 return activations_path, patterns_stacked_np 

248 case "torch": 

249 return activations_path, patterns_stacked 

250 case None: 

251 return activations_path, None 

252 case _: 

253 msg = f"invalid return_cache: {return_cache = }" 

254 raise ValueError(msg) 

255 else: 

256 activations_path = prompt_dir / "activations.npz" 

257 

258 # save 

259 cache_np: ActivationCacheNp = { 

260 k: v.detach().cpu().numpy() for k, v in cache_torch.items() 

261 } 

262 

263 np.savez_compressed( 

264 activations_path, 

265 **cache_np, 

266 ) 

267 

268 # return 

269 match return_cache: 

270 case "numpy": 

271 return activations_path, cache_np 

272 case "torch": 

273 return activations_path, cache_torch 

274 case None: 

275 return activations_path, None 

276 case _: 

277 msg = f"invalid return_cache: {return_cache = }" 

278 raise ValueError(msg) 

279 

280 

281@overload 

282def get_activations( 

283 prompt: dict, 

284 model: HookedTransformer | str, 

285 save_path: Path = Path(DATA_DIR), 

286 allow_disk_cache: bool = True, 

287 return_cache: Literal[None] = None, 

288) -> tuple[Path, None]: ... 

289@overload 

290def get_activations( 

291 prompt: dict, 

292 model: HookedTransformer | str, 

293 save_path: Path = Path(DATA_DIR), 

294 allow_disk_cache: bool = True, 

295 return_cache: Literal["torch"] = "torch", 

296) -> tuple[Path, ActivationCache]: ... 

297@overload 

298def get_activations( 

299 prompt: dict, 

300 model: HookedTransformer | str, 

301 save_path: Path = Path(DATA_DIR), 

302 allow_disk_cache: bool = True, 

303 return_cache: Literal["numpy"] = "numpy", 

304) -> tuple[Path, ActivationCacheNp]: ... 

305def get_activations( 

306 prompt: dict, 

307 model: HookedTransformer | str, 

308 save_path: Path = Path(DATA_DIR), 

309 allow_disk_cache: bool = True, 

310 return_cache: ReturnCache = "numpy", 

311) -> tuple[Path, ActivationCacheNp | ActivationCache | None]: 

312 """given a prompt and a model, save or load activations 

313 

314 # Parameters: 

315 - `prompt : dict` 

316 expected to contain the 'text' key 

317 - `model : HookedTransformer | str` 

318 either a `HookedTransformer` or a string model name, to be loaded with `HookedTransformer.from_pretrained` 

319 - `save_path : Path` 

320 path to save the activations to (and load from) 

321 (defaults to `Path(DATA_DIR)`) 

322 - `allow_disk_cache : bool` 

323 whether to allow loading from disk cache 

324 (defaults to `True`) 

325 - `return_cache : Literal[None, "numpy", "torch"]` 

326 whether to return the cache, and in what format 

327 (defaults to `"numpy"`) 

328 

329 # Returns: 

330 - `tuple[Path, ActivationCacheNp | ActivationCache | None]` 

331 the path to the activations and the cache if `return_cache is not None` 

332 

333 """ 

334 # add hash to prompt 

335 augment_prompt_with_hash(prompt) 

336 

337 # get the model 

338 model_name: str = ( 

339 model.cfg.model_name if isinstance(model, HookedTransformer) else model 

340 ) 

341 

342 # from cache 

343 if allow_disk_cache: 

344 try: 

345 path, cache = load_activations( 

346 model_name=model_name, 

347 prompt=prompt, 

348 save_path=save_path, 

349 ) 

350 if return_cache: 

351 return path, cache 

352 else: 

353 # TODO: this basically does nothing, since we load the activations and then immediately get rid of them. 

354 # maybe refactor this so that load_activations can take a parameter to simply assert that the cache exists? 

355 # this will let us avoid loading it, which slows things down 

356 return path, None 

357 except ActivationsMissingError: 

358 pass 

359 

360 # compute them 

361 if isinstance(model, str): 

362 model = HookedTransformer.from_pretrained(model_name) 

363 

364 return compute_activations( 

365 prompt=prompt, 

366 model=model, 

367 save_path=save_path, 

368 return_cache=return_cache, 

369 ) 

370 

371 

372DEFAULT_DEVICE: torch.device = torch.device( 

373 "cuda" if torch.cuda.is_available() else "cpu", 

374) 

375 

376 

377def activations_main( 

378 model_name: str, 

379 save_path: str, 

380 prompts_path: str, 

381 raw_prompts: bool, 

382 min_chars: int, 

383 max_chars: int, 

384 force: bool, 

385 n_samples: int, 

386 no_index_html: bool, 

387 shuffle: bool = False, 

388 stacked_heads: bool = False, 

389 device: str | torch.device = DEFAULT_DEVICE, 

390) -> None: 

391 """main function for computing activations 

392 

393 # Parameters: 

394 - `model_name : str` 

395 name of a model to load with `HookedTransformer.from_pretrained` 

396 - `save_path : str` 

397 path to save the activations to 

398 - `prompts_path : str` 

399 path to the prompts file 

400 - `raw_prompts : bool` 

401 whether the prompts are raw, not filtered by length. `load_text_data` will be called if `True`, otherwise just load the "text" field from each line in `prompts_path` 

402 - `min_chars : int` 

403 minimum number of characters for a prompt 

404 - `max_chars : int` 

405 maximum number of characters for a prompt 

406 - `force : bool` 

407 whether to overwrite existing files 

408 - `n_samples : int` 

409 maximum number of samples to process 

410 - `no_index_html : bool` 

411 whether to write an index.html file 

412 - `shuffle : bool` 

413 whether to shuffle the prompts 

414 (defaults to `False`) 

415 - `stacked_heads : bool` 

416 whether to stack the heads in the output tensor. will save as `.npy` instead of `.npz` if `True` 

417 (defaults to `False`) 

418 - `device : str | torch.device` 

419 the device to use. if a string, will be passed to `torch.device` 

420 """ 

421 # figure out the device to use 

422 device_: torch.device 

423 if isinstance(device, torch.device): 

424 device_ = device 

425 elif isinstance(device, str): 

426 device_ = torch.device(device) 

427 else: 

428 msg = f"invalid device: {device}" 

429 raise TypeError(msg) 

430 

431 print(f"using device: {device_}") 

432 

433 with SpinnerContext(message="loading model", **SPINNER_KWARGS): 

434 model: HookedTransformer = HookedTransformer.from_pretrained( 

435 model_name, 

436 device=device_, 

437 ) 

438 model.model_name = model_name 

439 model.cfg.model_name = model_name 

440 n_params: int = sum(p.numel() for p in model.parameters()) 

441 print( 

442 f"loaded {model_name} with {shorten_numerical_to_str(n_params)} ({n_params}) parameters", 

443 ) 

444 print(f"\tmodel devices: { {p.device for p in model.parameters()} }") 

445 

446 save_path_p: Path = Path(save_path) 

447 save_path_p.mkdir(parents=True, exist_ok=True) 

448 model_path: Path = save_path_p / model_name 

449 with SpinnerContext( 

450 message=f"saving model info to {model_path.as_posix()}", 

451 **SPINNER_KWARGS, 

452 ): 

453 model_cfg: HookedTransformerConfig 

454 model_cfg = model.cfg 

455 model_path.mkdir(parents=True, exist_ok=True) 

456 with open(model_path / "model_cfg.json", "w") as f: 

457 json.dump(json_serialize(asdict(model_cfg)), f) 

458 

459 # load prompts 

460 with SpinnerContext( 

461 message=f"loading prompts from {prompts_path = }", 

462 **SPINNER_KWARGS, 

463 ): 

464 prompts: list[dict] 

465 if raw_prompts: 

466 prompts = load_text_data( 

467 Path(prompts_path), 

468 min_chars=min_chars, 

469 max_chars=max_chars, 

470 shuffle=shuffle, 

471 ) 

472 else: 

473 with open(model_path / "prompts.jsonl", "r") as f: 

474 prompts = [json.loads(line) for line in f.readlines()] 

475 # truncate to n_samples 

476 prompts = prompts[:n_samples] 

477 

478 print(f"{len(prompts)} prompts loaded") 

479 

480 # write index.html 

481 with SpinnerContext(message="writing index.html", **SPINNER_KWARGS): 

482 if not no_index_html: 

483 write_html_index(save_path_p) 

484 

485 # TODO: not implemented yet 

486 if stacked_heads: 

487 raise NotImplementedError("stacked_heads not implemented yet") 

488 

489 # get activations 

490 list( 

491 tqdm.tqdm( 

492 map( 

493 functools.partial( 

494 get_activations, 

495 model=model, 

496 save_path=save_path_p, 

497 allow_disk_cache=not force, 

498 return_cache=None, 

499 # stacked_heads=stacked_heads, 

500 ), 

501 prompts, 

502 ), 

503 total=len(prompts), 

504 desc="Computing activations", 

505 unit="prompt", 

506 ), 

507 ) 

508 

509 with SpinnerContext( 

510 message="updating jsonl metadata for models and prompts", 

511 **SPINNER_KWARGS, 

512 ): 

513 generate_models_jsonl(save_path_p) 

514 generate_prompts_jsonl(save_path_p / model_name) 

515 

516 

517def main() -> None: 

518 "generate attention pattern activations for a model and prompts" 

519 print(DIVIDER_S1) 

520 with SpinnerContext(message="parsing args", **SPINNER_KWARGS): 

521 arg_parser: argparse.ArgumentParser = argparse.ArgumentParser() 

522 # input and output 

523 arg_parser.add_argument( 

524 "--model", 

525 "-m", 

526 type=str, 

527 required=True, 

528 help="The model name(s) to use. comma separated with no whitespace if multiple", 

529 ) 

530 

531 arg_parser.add_argument( 

532 "--prompts", 

533 "-p", 

534 type=str, 

535 required=False, 

536 help="The path to the prompts file (jsonl with 'text' key on each line). If `None`, expects that `--figures` is passed and will generate figures for all prompts in the model directory", 

537 default=None, 

538 ) 

539 

540 arg_parser.add_argument( 

541 "--save-path", 

542 "-s", 

543 type=str, 

544 required=False, 

545 help="The path to save the attention patterns", 

546 default=DATA_DIR, 

547 ) 

548 

549 # min and max prompt lengths 

550 arg_parser.add_argument( 

551 "--min-chars", 

552 type=int, 

553 required=False, 

554 help="The minimum number of characters for a prompt", 

555 default=100, 

556 ) 

557 arg_parser.add_argument( 

558 "--max-chars", 

559 type=int, 

560 required=False, 

561 help="The maximum number of characters for a prompt", 

562 default=1000, 

563 ) 

564 

565 # number of samples 

566 arg_parser.add_argument( 

567 "--n-samples", 

568 "-n", 

569 type=int, 

570 required=False, 

571 help="The max number of samples to process, do all in the file if None", 

572 default=None, 

573 ) 

574 

575 # force overwrite 

576 arg_parser.add_argument( 

577 "--force", 

578 "-f", 

579 action="store_true", 

580 help="If passed, will overwrite existing files", 

581 ) 

582 

583 # no index html 

584 arg_parser.add_argument( 

585 "--no-index-html", 

586 action="store_true", 

587 help="If passed, will not write an index.html file for the model", 

588 ) 

589 

590 # raw prompts 

591 arg_parser.add_argument( 

592 "--raw-prompts", 

593 "-r", 

594 action="store_true", 

595 help="pass if the prompts have not been split and tokenized (still needs keys 'text' and 'meta' for each item)", 

596 ) 

597 

598 # shuffle 

599 arg_parser.add_argument( 

600 "--shuffle", 

601 action="store_true", 

602 help="If passed, will shuffle the prompts", 

603 ) 

604 

605 # stack heads 

606 arg_parser.add_argument( 

607 "--stacked-heads", 

608 action="store_true", 

609 help="If passed, will stack the heads in the output tensor", 

610 ) 

611 

612 # device 

613 arg_parser.add_argument( 

614 "--device", 

615 type=str, 

616 required=False, 

617 help="The device to use for the model", 

618 default="cuda" if torch.cuda.is_available() else "cpu", 

619 ) 

620 

621 args: argparse.Namespace = arg_parser.parse_args() 

622 

623 print(f"args parsed: {args}") 

624 

625 models: list[str] 

626 if "," in args.model: 

627 models = args.model.split(",") 

628 else: 

629 models = [args.model] 

630 

631 n_models: int = len(models) 

632 for idx, model in enumerate(models): 

633 print(DIVIDER_S2) 

634 print(f"processing model {idx + 1} / {n_models}: {model}") 

635 print(DIVIDER_S2) 

636 

637 activations_main( 

638 model_name=model, 

639 save_path=args.save_path, 

640 prompts_path=args.prompts, 

641 raw_prompts=args.raw_prompts, 

642 min_chars=args.min_chars, 

643 max_chars=args.max_chars, 

644 force=args.force, 

645 n_samples=args.n_samples, 

646 no_index_html=args.no_index_html, 

647 shuffle=args.shuffle, 

648 stacked_heads=args.stacked_heads, 

649 device=args.device, 

650 ) 

651 del model 

652 

653 print(DIVIDER_S1) 

654 

655 

656if __name__ == "__main__": 

657 main()