Coverage for pattern_lens / activations.py: 94%

223 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-22 18:15 -0700

1"""computing and saving activations given a model and prompts 

2 

3# Usage: 

4 

5from the command line: 

6 

7```bash 

8python -m pattern_lens.activations --model <model_name> --prompts <prompts_path> --save-path <save_path> --min-chars <min_chars> --max-chars <max_chars> --n-samples <n_samples> 

9``` 

10 

11from a script: 

12 

13```python 

14from pattern_lens.activations import activations_main 

15activations_main( 

16 model_name="gpt2", 

17 save_path="demo/" 

18 prompts_path="data/pile_1k.jsonl", 

19) 

20``` 

21 

22""" 

23 

24import argparse 

25import json 

26import re 

27from collections.abc import Callable 

28from dataclasses import asdict 

29from pathlib import Path 

30from typing import Literal, overload 

31 

32import numpy as np 

33import torch 

34import tqdm 

35from jaxtyping import Float 

36from muutils.json_serialize import json_serialize 

37from muutils.misc.numerical import shorten_numerical_to_str 

38 

39# custom utils 

40from muutils.spinner import SpinnerContext 

41from transformer_lens import ( # type: ignore[import-untyped] 

42 ActivationCache, 

43 HookedTransformer, 

44 HookedTransformerConfig, 

45) 

46 

47# pattern_lens 

48from pattern_lens.consts import ( 

49 ATTN_PATTERN_REGEX, 

50 DATA_DIR, 

51 DIVIDER_S1, 

52 DIVIDER_S2, 

53 SPINNER_KWARGS, 

54 ActivationCacheNp, 

55 ReturnCache, 

56) 

57from pattern_lens.indexes import ( 

58 generate_models_jsonl, 

59 generate_prompts_jsonl, 

60 write_html_index, 

61) 

62from pattern_lens.load_activations import ( 

63 ActivationsMissingError, 

64 activations_exist, 

65 augment_prompt_with_hash, 

66 load_activations, 

67) 

68from pattern_lens.prompts import load_text_data 

69 

70 

71def _rel_path(p: Path) -> str: 

72 """Return path relative to cwd if possible, otherwise absolute.""" 

73 try: 

74 return p.relative_to(Path.cwd()).as_posix() 

75 except ValueError: 

76 return p.as_posix() 

77 

78 

79# return nothing, but `stack_heads` still affects how we save the activations 

80@overload 

81def compute_activations( 

82 prompt: dict, 

83 model: HookedTransformer | None = None, 

84 save_path: Path = Path(DATA_DIR), 

85 names_filter: Callable[[str], bool] | re.Pattern = ATTN_PATTERN_REGEX, 

86 return_cache: None = None, 

87 stack_heads: bool = False, 

88) -> tuple[Path, None]: ... 

89# return stacked heads in numpy or torch form 

90@overload 

91def compute_activations( 

92 prompt: dict, 

93 model: HookedTransformer | None = None, 

94 save_path: Path = Path(DATA_DIR), 

95 names_filter: Callable[[str], bool] | re.Pattern = ATTN_PATTERN_REGEX, 

96 return_cache: Literal["torch"] = "torch", 

97 stack_heads: Literal[True] = True, 

98) -> tuple[Path, Float[torch.Tensor, "n_layers n_heads n_ctx n_ctx"]]: ... 

99@overload 

100def compute_activations( 

101 prompt: dict, 

102 model: HookedTransformer | None = None, 

103 save_path: Path = Path(DATA_DIR), 

104 names_filter: Callable[[str], bool] | re.Pattern = ATTN_PATTERN_REGEX, 

105 return_cache: Literal["numpy"] = "numpy", 

106 stack_heads: Literal[True] = True, 

107) -> tuple[Path, Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"]]: ... 

108# return dicts in numpy or torch form 

109@overload 

110def compute_activations( 

111 prompt: dict, 

112 model: HookedTransformer | None = None, 

113 save_path: Path = Path(DATA_DIR), 

114 names_filter: Callable[[str], bool] | re.Pattern = ATTN_PATTERN_REGEX, 

115 return_cache: Literal["numpy"] = "numpy", 

116 stack_heads: Literal[False] = False, 

117) -> tuple[Path, ActivationCacheNp]: ... 

118@overload 

119def compute_activations( 

120 prompt: dict, 

121 model: HookedTransformer | None = None, 

122 save_path: Path = Path(DATA_DIR), 

123 names_filter: Callable[[str], bool] | re.Pattern = ATTN_PATTERN_REGEX, 

124 return_cache: Literal["torch"] = "torch", 

125 stack_heads: Literal[False] = False, 

126) -> tuple[Path, ActivationCache]: ... 

127# actual function body 

128def compute_activations( # noqa: PLR0915 

129 prompt: dict, 

130 model: HookedTransformer | None = None, 

131 save_path: Path = Path(DATA_DIR), 

132 names_filter: Callable[[str], bool] | re.Pattern = ATTN_PATTERN_REGEX, 

133 return_cache: ReturnCache = "torch", 

134 stack_heads: bool = False, 

135) -> tuple[ 

136 Path, 

137 ActivationCacheNp 

138 | ActivationCache 

139 | Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"] 

140 | Float[torch.Tensor, "n_layers n_heads n_ctx n_ctx"] 

141 | None, 

142]: 

143 """compute activations for a single prompt and save to disk 

144 

145 always runs a forward pass -- does NOT load from disk cache. 

146 for cache-aware loading, use `get_activations` which tries disk first. 

147 

148 # Parameters: 

149 - `prompt : dict | None` 

150 (defaults to `None`) 

151 - `model : HookedTransformer` 

152 - `save_path : Path` 

153 (defaults to `Path(DATA_DIR)`) 

154 - `names_filter : Callable[[str], bool]|re.Pattern` 

155 a filter for the names of the activations to return. if an `re.Pattern`, will use `lambda key: names_filter.match(key) is not None` 

156 (defaults to `ATTN_PATTERN_REGEX`) 

157 - `return_cache : Literal[None, "numpy", "torch"]` 

158 will return `None` as the second element if `None`, otherwise will return the cache in the specified tensor format. `stack_heads` still affects whether it will be a dict (False) or a single tensor (True) 

159 (defaults to `None`) 

160 - `stack_heads : bool` 

161 whether the heads should be stacked in the output. this causes a number of changes: 

162 - `npy` file with a single `(n_layers, n_heads, n_ctx, n_ctx)` tensor saved for each prompt instead of `npz` file with dict by layer 

163 - `cache` will be a single `(n_layers, n_heads, n_ctx, n_ctx)` tensor instead of a dict by layer if `return_cache` is `True` 

164 will assert that everything in the activation cache is only attention patterns, and is all of the attention patterns. raises an exception if not. 

165 

166 # Returns: 

167 ``` 

168 tuple[ 

169 Path, 

170 Union[ 

171 None, 

172 ActivationCacheNp, ActivationCache, 

173 Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"], Float[torch.Tensor, "n_layers n_heads n_ctx n_ctx"], 

174 ] 

175 ] 

176 ``` 

177 """ 

178 # check inputs 

179 assert model is not None, "model must be passed" 

180 assert "text" in prompt, "prompt must contain 'text' key" 

181 prompt_str: str = prompt["text"] 

182 

183 # compute or get prompt metadata 

184 assert model.tokenizer is not None 

185 prompt_tokenized: list[str] = prompt.get( 

186 "tokens", 

187 model.tokenizer.tokenize(prompt_str), 

188 ) 

189 # n_tokens counts subword tokens (no BOS); attention patterns include BOS 

190 # so have dim n_tokens+1. see also compute_activations_batched Phase B. 

191 prompt.update( 

192 dict( 

193 n_tokens=len(prompt_tokenized), 

194 tokens=prompt_tokenized, 

195 ), 

196 ) 

197 

198 # save metadata 

199 prompt_dir: Path = save_path / model.cfg.model_name / "prompts" / prompt["hash"] 

200 prompt_dir.mkdir(parents=True, exist_ok=True) 

201 with open(prompt_dir / "prompt.json", "w") as f: 

202 json.dump(prompt, f) 

203 

204 # set up names filter 

205 names_filter_fn: Callable[[str], bool] 

206 if isinstance(names_filter, re.Pattern): 

207 names_filter_fn = lambda key: names_filter.match(key) is not None # noqa: E731 

208 else: 

209 names_filter_fn = names_filter 

210 

211 # compute activations 

212 # NOTE: no padding_side kwarg here -- it's only meaningful for multi-sequence 

213 # batches where padding is needed. single-string input has no padding. 

214 # see compute_activations_batched for the batched path that passes padding_side="right". 

215 cache_torch: ActivationCache 

216 with torch.no_grad(): 

217 model.eval() 

218 _, cache_torch = model.run_with_cache( 

219 prompt_str, 

220 names_filter=names_filter_fn, 

221 return_type=None, 

222 ) 

223 

224 activations_path: Path 

225 # saving and returning 

226 if stack_heads: 

227 n_layers: int = model.cfg.n_layers 

228 key_pattern: str = "blocks.{i}.attn.hook_pattern" 

229 # NOTE: this only works for stacking heads at the moment 

230 # activations_specifier: str = key_pattern.format(i=f'0-{n_layers}') 

231 activations_specifier: str = key_pattern.format(i="-") 

232 activations_path = prompt_dir / f"activations-{activations_specifier}.npy" 

233 

234 # check the keys are only attention heads 

235 head_keys: list[str] = [key_pattern.format(i=i) for i in range(n_layers)] 

236 cache_torch_keys_set: set[str] = set(cache_torch.keys()) 

237 assert cache_torch_keys_set == set(head_keys), ( 

238 f"unexpected keys!\n{set(head_keys).symmetric_difference(cache_torch_keys_set) = }\n{cache_torch_keys_set} != {set(head_keys)}" 

239 ) 

240 

241 # stack heads 

242 patterns_stacked: Float[torch.Tensor, "n_layers n_heads n_ctx n_ctx"] = ( 

243 torch.stack([cache_torch[k] for k in head_keys], dim=1) 

244 ) 

245 # check shape 

246 pattern_shape_no_ctx: tuple[int, ...] = tuple(patterns_stacked.shape[:3]) 

247 assert pattern_shape_no_ctx == (1, n_layers, model.cfg.n_heads), ( 

248 f"unexpected shape: {patterns_stacked.shape[:3] = } ({pattern_shape_no_ctx = }), expected {(1, n_layers, model.cfg.n_heads) = }" 

249 ) 

250 

251 patterns_stacked_np: Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"] = ( 

252 patterns_stacked.cpu().numpy() 

253 ) 

254 

255 # save 

256 np.save(activations_path, patterns_stacked_np) 

257 

258 # return 

259 match return_cache: 

260 case "numpy": 

261 return activations_path, patterns_stacked_np 

262 case "torch": 

263 return activations_path, patterns_stacked 

264 case None: 

265 return activations_path, None 

266 case _: 

267 msg = f"invalid return_cache: {return_cache = }" 

268 raise ValueError(msg) 

269 else: 

270 activations_path = prompt_dir / "activations.npz" 

271 

272 # save 

273 cache_np: ActivationCacheNp = { 

274 k: v.detach().cpu().numpy() for k, v in cache_torch.items() 

275 } 

276 

277 np.savez_compressed( 

278 activations_path, 

279 **cache_np, # type: ignore[arg-type] 

280 ) 

281 

282 # return 

283 match return_cache: 

284 case "numpy": 

285 return activations_path, cache_np 

286 case "torch": 

287 return activations_path, cache_torch 

288 case None: 

289 return activations_path, None 

290 case _: 

291 msg = f"invalid return_cache: {return_cache = }" 

292 raise ValueError(msg) 

293 

294 

295def compute_activations_batched( 

296 prompts: list[dict], 

297 model: HookedTransformer, 

298 save_path: Path = Path(DATA_DIR), 

299 names_filter: Callable[[str], bool] | re.Pattern = ATTN_PATTERN_REGEX, 

300 seq_lens: list[int] | None = None, 

301) -> list[Path]: 

302 """compute and save activations for a batch of prompts in a single forward pass 

303 

304 Batched companion to `compute_activations` -- instead of one forward pass per 

305 prompt, this runs a single `model.run_with_cache(list_of_strings)` call for the 

306 whole batch. TransformerLens tokenizes and right-pads automatically. Each prompt's 

307 attention patterns are then trimmed to their actual (unpadded) size and saved 

308 individually, producing files identical to the single-prompt path. 

309 

310 Does not support `stack_heads` or `return_cache` -- this function is intended for 

311 the bulk processing path in `activations_main`, not for interactive use. Use 

312 `compute_activations` directly for single-prompt use cases that need those features. 

313 

314 ## Why right-padding makes trimming correct without an explicit attention mask 

315 

316 With right-padding, pad tokens sit at positions seq_len, seq_len+1, ..., 

317 max_seq_len-1 (higher than any real token). The causal attention mask prevents 

318 position i from attending to any j > i. So for real tokens at positions 

319 0..seq_len-1, they can only attend to 0..i -- all real tokens. The softmax is computed over the same set of positions 

320 as in single-prompt inference, producing identical attention patterns. 

321 

322 We explicitly pass `padding_side="right"` to `run_with_cache` to guarantee this 

323 regardless of the model's default padding side. 

324 

325 # Parameters: 

326 - `prompts : list[dict]` 

327 each prompt must contain 'text' and 'hash' keys. call 

328 `augment_prompt_with_hash` on each prompt before passing them here. 

329 - `model : HookedTransformer` 

330 the model to compute activations with 

331 - `save_path : Path` 

332 path to save the activations to 

333 (defaults to `Path(DATA_DIR)`) 

334 - `names_filter : Callable[[str], bool] | re.Pattern` 

335 filter for which activations to save. must only match activations with 

336 4D shape `[batch, n_heads, seq, seq]` (e.g. attention patterns). 

337 non-attention activations will cause incorrect trimming. 

338 (defaults to `ATTN_PATTERN_REGEX`) 

339 - `seq_lens : list[int] | None` 

340 pre-computed model sequence lengths per prompt (from `model.to_tokens`). 

341 if `None`, will be computed internally. pass this to avoid redundant 

342 tokenization when lengths are already known (e.g. from length-sorting). 

343 **important**: these must be from `model.to_tokens()` (includes BOS), 

344 NOT from `model.tokenizer.tokenize()` (excludes BOS). 

345 (defaults to `None`) 

346 

347 # Returns: 

348 - `list[Path]` 

349 paths to the saved activations files, one per prompt 

350 

351 # Modifies: 

352 each prompt dict in `prompts` -- adds/overwrites `n_tokens` and `tokens` keys 

353 with tokenization metadata (same mutation as `compute_activations`). 

354 """ 

355 assert model is not None, "model must be passed" 

356 assert len(prompts) > 0, "prompts must not be empty" 

357 assert "text" in prompts[0], f"prompt must contain 'text' key: {prompts[0].keys()}" 

358 assert "hash" in prompts[0], ( 

359 f"prompt must contain 'hash' key (call augment_prompt_with_hash first): {prompts[0].keys()}" 

360 ) 

361 

362 # --- Phase A: get actual model sequence lengths --- 

363 # model.to_tokens() includes BOS if applicable, matching the attention pattern dims 

364 # model.tokenizer.tokenize() gives subword strings WITHOUT BOS, used for metadata 

365 # these differ by 1 when BOS is prepended -- using the wrong one for trimming 

366 # would silently truncate or include garbage 

367 if seq_lens is None: 

368 seq_lens = [model.to_tokens(p["text"]).shape[1] for p in prompts] 

369 assert len(seq_lens) == len(prompts), ( 

370 f"seq_lens length mismatch: {len(seq_lens)} != {len(prompts)}" 

371 ) 

372 

373 # --- Phase B: save prompt metadata (mirrors compute_activations's metadata logic) --- 

374 assert model.tokenizer is not None 

375 for p in prompts: 

376 prompt_str: str = p["text"] 

377 prompt_tokenized: list[str] = p.get( 

378 "tokens", 

379 model.tokenizer.tokenize(prompt_str), 

380 ) 

381 # n_tokens counts subword tokens (no BOS); attention patterns include BOS so have dim n_tokens+1 

382 p.update( 

383 dict( 

384 n_tokens=len(prompt_tokenized), 

385 tokens=prompt_tokenized, 

386 ), 

387 ) 

388 prompt_dir: Path = save_path / model.cfg.model_name / "prompts" / p["hash"] 

389 prompt_dir.mkdir(parents=True, exist_ok=True) 

390 with open(prompt_dir / "prompt.json", "w") as f: 

391 json.dump(p, f) 

392 

393 # --- Phase C: batched forward pass --- 

394 names_filter_fn: Callable[[str], bool] 

395 if isinstance(names_filter, re.Pattern): 

396 names_filter_fn = lambda key: names_filter.match(key) is not None # noqa: E731 

397 else: 

398 names_filter_fn = names_filter 

399 

400 texts: list[str] = [p["text"] for p in prompts] 

401 cache_torch: ActivationCache 

402 with torch.no_grad(): 

403 model.eval() 

404 _, cache_torch = model.run_with_cache( 

405 texts, 

406 names_filter=names_filter_fn, 

407 return_type=None, 

408 padding_side="right", 

409 ) 

410 

411 # --- Phase D: split, trim padding, and save per-prompt --- 

412 # For each prompt i with actual sequence length seq_len_i: 

413 # v[i : i+1, :, :seq_len_i, :seq_len_i] 

414 # ^^^^^^^ i:i+1 not i -- keeps batch dim [1,...] for 

415 # format compatibility with compute_activations 

416 # ^^ all attention heads 

417 # ^^^^^^^^^^ ^^^^^^^^^^ trim both query and key dims to actual length, 

418 # discarding meaningless padding positions 

419 paths: list[Path] = [] 

420 for i, (prompt, seq_len) in enumerate(zip(prompts, seq_lens, strict=True)): 

421 prompt_dir = save_path / model.cfg.model_name / "prompts" / prompt["hash"] 

422 activations_path: Path = prompt_dir / "activations.npz" 

423 cache_np: ActivationCacheNp = {} 

424 for k, v in cache_torch.items(): 

425 assert v.ndim == 4, ( # noqa: PLR2004 

426 f"expected 4D attention pattern tensor for {k!r}, " 

427 f"got shape {v.shape}. names_filter must only match " 

428 f"attention pattern activations [batch, n_heads, seq, seq]" 

429 ) 

430 cache_np[k] = v[i : i + 1, :, :seq_len, :seq_len].detach().cpu().numpy() 

431 

432 np.savez_compressed( 

433 activations_path, 

434 **cache_np, # type: ignore[arg-type] 

435 ) 

436 paths.append(activations_path) 

437 

438 return paths 

439 

440 

441@overload 

442def get_activations( 

443 prompt: dict, 

444 model: HookedTransformer | str, 

445 save_path: Path = Path(DATA_DIR), 

446 allow_disk_cache: bool = True, 

447 return_cache: None = None, 

448) -> tuple[Path, None]: ... 

449@overload 

450def get_activations( 

451 prompt: dict, 

452 model: HookedTransformer | str, 

453 save_path: Path = Path(DATA_DIR), 

454 allow_disk_cache: bool = True, 

455 return_cache: Literal["torch"] = "torch", 

456) -> tuple[Path, ActivationCache]: ... 

457@overload 

458def get_activations( 

459 prompt: dict, 

460 model: HookedTransformer | str, 

461 save_path: Path = Path(DATA_DIR), 

462 allow_disk_cache: bool = True, 

463 return_cache: Literal["numpy"] = "numpy", 

464) -> tuple[Path, ActivationCacheNp]: ... 

465def get_activations( 

466 prompt: dict, 

467 model: HookedTransformer | str, 

468 save_path: Path = Path(DATA_DIR), 

469 allow_disk_cache: bool = True, 

470 return_cache: ReturnCache = "numpy", 

471) -> tuple[Path, ActivationCacheNp | ActivationCache | None]: 

472 """given a prompt and a model, save or load activations 

473 

474 # Parameters: 

475 - `prompt : dict` 

476 expected to contain the 'text' key 

477 - `model : HookedTransformer | str` 

478 either a `HookedTransformer` or a string model name, to be loaded with `HookedTransformer.from_pretrained` 

479 - `save_path : Path` 

480 path to save the activations to (and load from) 

481 (defaults to `Path(DATA_DIR)`) 

482 - `allow_disk_cache : bool` 

483 whether to allow loading from disk cache 

484 (defaults to `True`) 

485 - `return_cache : Literal[None, "numpy", "torch"]` 

486 whether to return the cache, and in what format 

487 (defaults to `"numpy"`) 

488 

489 # Returns: 

490 - `tuple[Path, ActivationCacheNp | ActivationCache | None]` 

491 the path to the activations and the cache if `return_cache is not None` 

492 

493 """ 

494 # add hash to prompt 

495 augment_prompt_with_hash(prompt) 

496 

497 # get the model 

498 model_name: str = ( 

499 model.cfg.model_name if isinstance(model, HookedTransformer) else model 

500 ) 

501 

502 # from cache 

503 if allow_disk_cache: 

504 if return_cache is None: 

505 # fast path: check file existence without loading data into memory. 

506 # activations_exist just calls .exists() on two paths, whereas 

507 # load_activations would decompress the full .npz into numpy arrays 

508 # only for us to discard them immediately. 

509 if activations_exist(model_name, prompt, save_path): 

510 prompt_dir: Path = save_path / model_name / "prompts" / prompt["hash"] 

511 return prompt_dir / "activations.npz", None 

512 else: 

513 try: 

514 path, cache = load_activations( 

515 model_name=model_name, 

516 prompt=prompt, 

517 save_path=save_path, 

518 ) 

519 except ActivationsMissingError: 

520 pass 

521 else: 

522 return path, cache 

523 

524 # compute them 

525 if isinstance(model, str): 

526 model = HookedTransformer.from_pretrained(model_name) 

527 

528 return compute_activations( # type: ignore[return-value] 

529 prompt=prompt, 

530 model=model, 

531 save_path=save_path, 

532 return_cache=return_cache, 

533 ) 

534 

535 

536DEFAULT_DEVICE: torch.device = torch.device( 

537 "cuda" if torch.cuda.is_available() else "cpu", 

538) 

539 

540 

541def activations_main( # noqa: C901, PLR0912, PLR0915 

542 model_name: str, 

543 save_path: str | Path, 

544 prompts_path: str, 

545 raw_prompts: bool, 

546 min_chars: int, 

547 max_chars: int, 

548 force: bool, 

549 n_samples: int, 

550 no_index_html: bool, 

551 shuffle: bool = False, 

552 stacked_heads: bool = False, 

553 device: str | torch.device = DEFAULT_DEVICE, 

554 batch_size: int = 32, 

555) -> None: 

556 """main function for computing activations 

557 

558 # Parameters: 

559 - `model_name : str` 

560 name of a model to load with `HookedTransformer.from_pretrained` 

561 - `save_path : str | Path` 

562 path to save the activations to 

563 - `prompts_path : str` 

564 path to the prompts file 

565 - `raw_prompts : bool` 

566 whether the prompts are raw, not filtered by length. `load_text_data` will be called if `True`, otherwise just load the "text" field from each line in `prompts_path` 

567 - `min_chars : int` 

568 minimum number of characters for a prompt 

569 - `max_chars : int` 

570 maximum number of characters for a prompt 

571 - `force : bool` 

572 whether to overwrite existing files 

573 - `n_samples : int` 

574 maximum number of samples to process 

575 - `no_index_html : bool` 

576 whether to write an index.html file 

577 - `shuffle : bool` 

578 whether to shuffle the prompts 

579 (defaults to `False`) 

580 - `stacked_heads : bool` 

581 whether to stack the heads in the output tensor. will save as `.npy` instead of `.npz` if `True` 

582 (defaults to `False`) 

583 - `device : str | torch.device` 

584 the device to use. if a string, will be passed to `torch.device` 

585 - `batch_size : int` 

586 number of prompts per forward pass. prompts are sorted by token length 

587 (longest first) and grouped so that similar-length prompts share a batch, 

588 minimizing padding waste. use `batch_size=1` for one prompt per forward 

589 pass (largely equivalent to the old sequential behavior, but note: prompts 

590 are still sorted by length and cache checking uses file-existence only, 

591 unlike the old path which processed prompts in order and validated cache 

592 contents via `load_activations`). 

593 the single-prompt functions `compute_activations` and `get_activations` 

594 are still available for programmatic use outside of `activations_main`. 

595 (defaults to `32`) 

596 """ 

597 # figure out the device to use 

598 device_: torch.device 

599 if isinstance(device, torch.device): 

600 device_ = device 

601 elif isinstance(device, str): 

602 device_ = torch.device(device) 

603 else: 

604 msg = f"invalid device: {device}" 

605 raise TypeError(msg) 

606 

607 print(f"using device: {device_}") 

608 

609 with SpinnerContext(message="loading model", **SPINNER_KWARGS): 

610 model: HookedTransformer = HookedTransformer.from_pretrained( 

611 model_name, 

612 device=device_, 

613 ) 

614 model.model_name = model_name # type: ignore[unresolved-attribute] 

615 model.cfg.model_name = model_name 

616 n_params: int = sum(p.numel() for p in model.parameters()) 

617 print( 

618 f"loaded {model_name} with {shorten_numerical_to_str(n_params)} ({n_params}) parameters", 

619 ) 

620 print(f"\tmodel devices: { {p.device for p in model.parameters()} }") 

621 

622 save_path_p: Path = Path(save_path) 

623 save_path_p.mkdir(parents=True, exist_ok=True) 

624 model_path: Path = save_path_p / model_name 

625 with SpinnerContext( 

626 message=f"saving model info to {_rel_path(model_path)}", 

627 **SPINNER_KWARGS, 

628 ): 

629 model_cfg: HookedTransformerConfig 

630 model_cfg = model.cfg 

631 model_path.mkdir(parents=True, exist_ok=True) 

632 with open(model_path / "model_cfg.json", "w") as f: 

633 json.dump(json_serialize(asdict(model_cfg)), f) 

634 

635 # load prompts 

636 with SpinnerContext( 

637 message=f"loading prompts from {Path(prompts_path).as_posix()}", 

638 **SPINNER_KWARGS, 

639 ): 

640 prompts: list[dict] 

641 if raw_prompts: 

642 prompts = load_text_data( 

643 Path(prompts_path), 

644 min_chars=min_chars, 

645 max_chars=max_chars, 

646 shuffle=shuffle, 

647 ) 

648 else: 

649 with open(model_path / "prompts.jsonl", "r") as f: 

650 prompts = [json.loads(line) for line in f.readlines()] 

651 # truncate to n_samples 

652 prompts = prompts[:n_samples] 

653 

654 print(f" {len(prompts)} prompts loaded") 

655 

656 # write index.html 

657 with SpinnerContext( 

658 message=f"writing {_rel_path(save_path_p / 'index.html')}", 

659 **SPINNER_KWARGS, 

660 ): 

661 if not no_index_html: 

662 write_html_index(save_path_p) 

663 

664 # TODO: not implemented yet 

665 if stacked_heads: 

666 raise NotImplementedError("stacked_heads not implemented yet") 

667 

668 # augment all prompts with hashes 

669 for prompt in prompts: 

670 augment_prompt_with_hash(prompt) 

671 

672 # filter out cached prompts 

673 if not force: 

674 uncached: list[dict] = [ 

675 p for p in prompts if not activations_exist(model_name, p, save_path_p) 

676 ] 

677 n_cached: int = len(prompts) - len(uncached) 

678 if n_cached > 0: 

679 print(f" {n_cached} prompts already cached, {len(uncached)} to compute") 

680 else: 

681 uncached = list(prompts) 

682 

683 if uncached: 

684 # sort by token length descending so that: 

685 # 1. the longest (slowest, most memory-hungry) batches run first -- 

686 # OOM errors surface immediately rather than after all the cheap work, 

687 # and tqdm's ETA stabilizes early for better progress estimation 

688 # 2. similar-length prompts are grouped together, minimizing padding waste 

689 # 

690 # pre-tokenization is a separate step from compute_activations_batched because 

691 # we need token lengths *before* batching to sort and group. the resulting 

692 # seq_lens are then passed through so compute_activations_batched can skip 

693 # re-tokenizing each prompt internally. 

694 with SpinnerContext( 

695 message="pre-tokenizing prompts for length sorting", 

696 **SPINNER_KWARGS, 

697 ): 

698 uncached_with_lens: list[tuple[dict, int]] = [ 

699 (p, model.to_tokens(p["text"]).shape[1]) for p in uncached 

700 ] 

701 uncached_with_lens.sort(key=lambda x: x[1], reverse=True) 

702 sorted_uncached: list[dict] = [p for p, _ in uncached_with_lens] 

703 sorted_seq_lens: list[int] = [sl for _, sl in uncached_with_lens] 

704 

705 # process in batches 

706 n_prompts: int = len(sorted_uncached) 

707 with tqdm.tqdm( 

708 total=n_prompts, 

709 desc="Computing activations", 

710 unit="prompt", 

711 ) as pbar: 

712 for batch_start in range(0, n_prompts, batch_size): 

713 batch_end: int = min(batch_start + batch_size, n_prompts) 

714 batch: list[dict] = sorted_uncached[batch_start:batch_end] 

715 batch_seq_lens: list[int] = sorted_seq_lens[batch_start:batch_end] 

716 pbar.set_postfix( 

717 n_ctx=batch_seq_lens[0], 

718 ) # longest in batch (sorted descending) 

719 compute_activations_batched( 

720 prompts=batch, 

721 model=model, 

722 save_path=save_path_p, 

723 seq_lens=batch_seq_lens, 

724 ) 

725 pbar.update(len(batch)) 

726 else: 

727 print(" all prompts cached, nothing to compute") 

728 

729 with SpinnerContext( 

730 message="updating jsonl metadata for models and prompts", 

731 **SPINNER_KWARGS, 

732 ): 

733 generate_models_jsonl(save_path_p) 

734 generate_prompts_jsonl(save_path_p / model_name) 

735 

736 

737def main() -> None: 

738 "generate attention pattern activations for a model and prompts" 

739 print(DIVIDER_S1) 

740 with SpinnerContext(message="parsing args", **SPINNER_KWARGS): 

741 arg_parser: argparse.ArgumentParser = argparse.ArgumentParser() 

742 # input and output 

743 arg_parser.add_argument( 

744 "--model", 

745 "-m", 

746 type=str, 

747 required=True, 

748 help="The model name(s) to use. comma separated with no whitespace if multiple", 

749 ) 

750 

751 arg_parser.add_argument( 

752 "--prompts", 

753 "-p", 

754 type=str, 

755 required=False, 

756 help="The path to the prompts file (jsonl with 'text' key on each line). If `None`, expects that `--figures` is passed and will generate figures for all prompts in the model directory", 

757 default=None, 

758 ) 

759 

760 arg_parser.add_argument( 

761 "--save-path", 

762 "-s", 

763 type=str, 

764 required=False, 

765 help="The path to save the attention patterns", 

766 default=DATA_DIR, 

767 ) 

768 

769 # min and max prompt lengths 

770 arg_parser.add_argument( 

771 "--min-chars", 

772 type=int, 

773 required=False, 

774 help="The minimum number of characters for a prompt", 

775 default=100, 

776 ) 

777 arg_parser.add_argument( 

778 "--max-chars", 

779 type=int, 

780 required=False, 

781 help="The maximum number of characters for a prompt", 

782 default=1000, 

783 ) 

784 

785 # number of samples 

786 arg_parser.add_argument( 

787 "--n-samples", 

788 "-n", 

789 type=int, 

790 required=False, 

791 help="The max number of samples to process, do all in the file if None", 

792 default=None, 

793 ) 

794 

795 # batch size 

796 arg_parser.add_argument( 

797 "--batch-size", 

798 "-b", 

799 type=int, 

800 required=False, 

801 help="Batch size for computing activations (number of prompts per forward pass)", 

802 default=32, 

803 ) 

804 

805 # force overwrite 

806 arg_parser.add_argument( 

807 "--force", 

808 "-f", 

809 action="store_true", 

810 help="If passed, will overwrite existing files", 

811 ) 

812 

813 # no index html 

814 arg_parser.add_argument( 

815 "--no-index-html", 

816 action="store_true", 

817 help="If passed, will not write an index.html file for the model", 

818 ) 

819 

820 # raw prompts 

821 arg_parser.add_argument( 

822 "--raw-prompts", 

823 "-r", 

824 action="store_true", 

825 help="pass if the prompts have not been split and tokenized (still needs keys 'text' and 'meta' for each item)", 

826 ) 

827 

828 # shuffle 

829 arg_parser.add_argument( 

830 "--shuffle", 

831 action="store_true", 

832 help="If passed, will shuffle the prompts", 

833 ) 

834 

835 # stack heads 

836 arg_parser.add_argument( 

837 "--stacked-heads", 

838 action="store_true", 

839 help="If passed, will stack the heads in the output tensor", 

840 ) 

841 

842 # device 

843 arg_parser.add_argument( 

844 "--device", 

845 type=str, 

846 required=False, 

847 help="The device to use for the model", 

848 default="cuda" if torch.cuda.is_available() else "cpu", 

849 ) 

850 

851 args: argparse.Namespace = arg_parser.parse_args() 

852 

853 print(f"args parsed: {args}") 

854 

855 models: list[str] 

856 if "," in args.model: 

857 models = args.model.split(",") 

858 else: 

859 models = [args.model] 

860 

861 n_models: int = len(models) 

862 for idx, model in enumerate(models): 

863 print(DIVIDER_S2) 

864 print(f"processing model {idx + 1} / {n_models}: {model}") 

865 print(DIVIDER_S2) 

866 

867 activations_main( 

868 model_name=model, 

869 save_path=args.save_path, 

870 prompts_path=args.prompts, 

871 raw_prompts=args.raw_prompts, 

872 min_chars=args.min_chars, 

873 max_chars=args.max_chars, 

874 force=args.force, 

875 n_samples=args.n_samples, 

876 no_index_html=args.no_index_html, 

877 shuffle=args.shuffle, 

878 stacked_heads=args.stacked_heads, 

879 device=args.device, 

880 batch_size=args.batch_size, 

881 ) 

882 del model 

883 

884 print(DIVIDER_S1) 

885 

886 

887if __name__ == "__main__": 

888 main()