Coverage for pattern_lens\activations.py: 67%

123 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-01-16 20:39 -0700

1"""computing and saving activations given a model and prompts 

2 

3 

4# Usage: 

5 

6from the command line: 

7 

8```bash 

9python -m pattern_lens.activations --model <model_name> --prompts <prompts_path> --save-path <save_path> --min-chars <min_chars> --max-chars <max_chars> --n-samples <n_samples> 

10``` 

11 

12from a script: 

13 

14```python 

15from pattern_lens.activations import activations_main 

16activations_main( 

17 model_name="gpt2", 

18 save_path="demo/" 

19 prompts_path="data/pile_1k.jsonl", 

20) 

21``` 

22 

23""" 

24 

25import argparse 

26import functools 

27import json 

28from dataclasses import asdict 

29from pathlib import Path 

30import re 

31from typing import Callable, Literal, overload 

32 

33import numpy as np 

34import torch 

35import tqdm 

36from muutils.spinner import SpinnerContext 

37from muutils.misc.numerical import shorten_numerical_to_str 

38from muutils.json_serialize import json_serialize 

39from transformer_lens import HookedTransformer, HookedTransformerConfig # type: ignore[import-untyped] 

40 

41from pattern_lens.consts import ( 

42 ATTN_PATTERN_REGEX, 

43 DATA_DIR, 

44 ActivationCacheNp, 

45 SPINNER_KWARGS, 

46 DIVIDER_S1, 

47 DIVIDER_S2, 

48) 

49from pattern_lens.indexes import ( 

50 generate_models_jsonl, 

51 generate_prompts_jsonl, 

52 write_html_index, 

53) 

54from pattern_lens.load_activations import ( 

55 ActivationsMissingError, 

56 augment_prompt_with_hash, 

57 load_activations, 

58) 

59from pattern_lens.prompts import load_text_data 

60 

61 

62def compute_activations( 

63 prompt: dict, 

64 model: HookedTransformer | None = None, 

65 save_path: Path = Path(DATA_DIR), 

66 return_cache: bool = True, 

67 names_filter: Callable[[str], bool] | re.Pattern = ATTN_PATTERN_REGEX, 

68) -> tuple[Path, ActivationCacheNp | None]: 

69 """get activations for a given model and prompt, possibly from a cache 

70 

71 if from a cache, prompt_meta must be passed and contain the prompt hash 

72 

73 # Parameters: 

74 - `prompt : dict | None` 

75 (defaults to `None`) 

76 - `model : HookedTransformer` 

77 - `save_path : Path` 

78 (defaults to `Path(DATA_DIR)`) 

79 - `return_cache : bool` 

80 will return `None` as the second element if `False` 

81 (defaults to `True`) 

82 - `names_filter : Callable[[str], bool]|re.Pattern` 

83 a filter for the names of the activations to return. if an `re.Pattern`, will use `lambda key: names_filter.match(key) is not None` 

84 (defaults to `ATTN_PATTERN_REGEX`) 

85 

86 # Returns: 

87 - `tuple[Path, ActivationCacheNp|None]` 

88 """ 

89 assert model is not None, "model must be passed" 

90 assert "text" in prompt, "prompt must contain 'text' key" 

91 prompt_str: str = prompt["text"] 

92 

93 # compute or get prompt metadata 

94 prompt_tokenized: list[str] = prompt.get( 

95 "tokens", 

96 model.tokenizer.tokenize(prompt_str), 

97 ) 

98 prompt.update( 

99 dict( 

100 n_tokens=len(prompt_tokenized), 

101 tokens=prompt_tokenized, 

102 ) 

103 ) 

104 

105 # save metadata 

106 prompt_dir: Path = save_path / model.model_name / "prompts" / prompt["hash"] 

107 prompt_dir.mkdir(parents=True, exist_ok=True) 

108 with open(prompt_dir / "prompt.json", "w") as f: 

109 json.dump(prompt, f) 

110 

111 # set up names filter 

112 names_filter_fn: Callable[[str], bool] 

113 if isinstance(names_filter, re.Pattern): 

114 names_filter_fn = lambda key: names_filter.match(key) is not None # noqa: E731 

115 else: 

116 names_filter_fn = names_filter 

117 

118 # compute activations 

119 with torch.no_grad(): 

120 model.eval() 

121 # TODO: batching? 

122 _, cache = model.run_with_cache( 

123 prompt_str, 

124 names_filter=names_filter_fn, 

125 return_type=None, 

126 ) 

127 

128 cache_np: ActivationCacheNp = { 

129 k: v.detach().cpu().numpy() for k, v in cache.items() 

130 } 

131 

132 # save activations 

133 activations_path: Path = prompt_dir / "activations.npz" 

134 np.savez_compressed( 

135 activations_path, 

136 **cache_np, 

137 ) 

138 

139 # return path and cache 

140 if return_cache: 

141 return activations_path, cache_np 

142 else: 

143 return activations_path, None 

144 

145 

146@overload 

147def get_activations( 

148 prompt: dict, 

149 model: HookedTransformer | str, 

150 save_path: Path = Path(DATA_DIR), 

151 allow_disk_cache: bool = True, 

152 return_cache: Literal[False] = False, 

153) -> tuple[Path, None]: ... 

154@overload 

155def get_activations( 

156 prompt: dict, 

157 model: HookedTransformer | str, 

158 save_path: Path = Path(DATA_DIR), 

159 allow_disk_cache: bool = True, 

160 return_cache: Literal[True] = True, 

161) -> tuple[Path, ActivationCacheNp]: ... 

162def get_activations( 

163 prompt: dict, 

164 model: HookedTransformer | str, 

165 save_path: Path = Path(DATA_DIR), 

166 allow_disk_cache: bool = True, 

167 return_cache: bool = True, 

168) -> tuple[Path, ActivationCacheNp | None]: 

169 """given a prompt and a model, save or load activations 

170 

171 # Parameters: 

172 - `prompt : dict` 

173 expected to contain the 'text' key 

174 - `model : HookedTransformer | str` 

175 either a `HookedTransformer` or a string model name, to be loaded with `HookedTransformer.from_pretrained` 

176 - `save_path : Path` 

177 path to save the activations to (and load from) 

178 (defaults to `Path(DATA_DIR)`) 

179 - `allow_disk_cache : bool` 

180 whether to allow loading from disk cache 

181 (defaults to `True`) 

182 - `return_cache : bool` 

183 whether to return the cache. if `False`, will return `None` as the second element 

184 (defaults to `True`) 

185 

186 # Returns: 

187 - `tuple[Path, ActivationCacheNp | None]` 

188 the path to the activations and the cache if `return_cache` is `True` 

189 

190 """ 

191 # add hash to prompt 

192 augment_prompt_with_hash(prompt) 

193 

194 # get the model 

195 model_name: str = ( 

196 model.model_name if isinstance(model, HookedTransformer) else model 

197 ) 

198 

199 # from cache 

200 if allow_disk_cache: 

201 try: 

202 path, cache = load_activations( 

203 model_name=model_name, 

204 prompt=prompt, 

205 save_path=save_path, 

206 ) 

207 if return_cache: 

208 return path, cache 

209 else: 

210 return path, None 

211 except ActivationsMissingError: 

212 pass 

213 

214 # compute them 

215 if isinstance(model, str): 

216 model = HookedTransformer.from_pretrained(model_name) 

217 

218 return compute_activations( 

219 prompt=prompt, 

220 model=model, 

221 save_path=save_path, 

222 return_cache=True, 

223 ) 

224 

225 

226def activations_main( 

227 model_name: str, 

228 save_path: str, 

229 prompts_path: str, 

230 raw_prompts: bool, 

231 min_chars: int, 

232 max_chars: int, 

233 force: bool, 

234 n_samples: int, 

235 no_index_html: bool, 

236 shuffle: bool = False, 

237 device: str | torch.device = "cuda" if torch.cuda.is_available() else "cpu", 

238) -> None: 

239 """main function for computing activations 

240 

241 # Parameters: 

242 - `model_name : str` 

243 name of a model to load with `HookedTransformer.from_pretrained` 

244 - `save_path : str` 

245 path to save the activations to 

246 - `prompts_path : str` 

247 path to the prompts file 

248 - `raw_prompts : bool` 

249 whether the prompts are raw, not filtered by length. `load_text_data` will be called if `True`, otherwise just load the "text" field from each line in `prompts_path` 

250 - `min_chars : int` 

251 minimum number of characters for a prompt 

252 - `max_chars : int` 

253 maximum number of characters for a prompt 

254 - `force : bool` 

255 whether to overwrite existing files 

256 - `n_samples : int` 

257 maximum number of samples to process 

258 - `no_index_html : bool` 

259 whether to write an index.html file 

260 - `shuffle : bool` 

261 whether to shuffle the prompts 

262 (defaults to `False`) 

263 - `device : str | torch.device` 

264 the device to use. if a string, will be passed to `torch.device` 

265 """ 

266 

267 # figure out the device to use 

268 device_: torch.device 

269 if isinstance(device, torch.device): 

270 device_ = device 

271 elif isinstance(device, str): 

272 device_ = torch.device(device) 

273 else: 

274 raise ValueError(f"invalid device: {device}") 

275 

276 print(f"using device: {device_}") 

277 

278 with SpinnerContext(message="loading model", **SPINNER_KWARGS): 

279 model: HookedTransformer = HookedTransformer.from_pretrained( 

280 model_name, device=device_ 

281 ) 

282 model.model_name = model_name 

283 model.cfg.model_name = model_name 

284 n_params: int = sum(p.numel() for p in model.parameters()) 

285 print( 

286 f"loaded {model_name} with {shorten_numerical_to_str(n_params)} ({n_params}) parameters" 

287 ) 

288 print(f"\tmodel devices: {set(p.device for p in model.parameters())}") 

289 

290 save_path_p: Path = Path(save_path) 

291 save_path_p.mkdir(parents=True, exist_ok=True) 

292 model_path: Path = save_path_p / model_name 

293 with SpinnerContext( 

294 message=f"saving model info to {model_path.as_posix()}", **SPINNER_KWARGS 

295 ): 

296 model_cfg: HookedTransformerConfig 

297 model_cfg = model.cfg 

298 model_path.mkdir(parents=True, exist_ok=True) 

299 with open(model_path / "model_cfg.json", "w") as f: 

300 json.dump(json_serialize(asdict(model_cfg)), f) 

301 

302 # load prompts 

303 with SpinnerContext( 

304 message=f"loading prompts from {prompts_path = }", **SPINNER_KWARGS 

305 ): 

306 prompts: list[dict] 

307 if raw_prompts: 

308 prompts = load_text_data( 

309 Path(prompts_path), 

310 min_chars=min_chars, 

311 max_chars=max_chars, 

312 shuffle=shuffle, 

313 ) 

314 else: 

315 with open(model_path / "prompts.jsonl", "r") as f: 

316 prompts = [json.loads(line) for line in f.readlines()] 

317 # truncate to n_samples 

318 prompts = prompts[:n_samples] 

319 

320 print(f"{len(prompts)} prompts loaded") 

321 

322 # write index.html 

323 with SpinnerContext(message="writing index.html", **SPINNER_KWARGS): 

324 if not no_index_html: 

325 write_html_index(save_path_p) 

326 

327 # get activations 

328 list( 

329 tqdm.tqdm( 

330 map( 

331 functools.partial( 

332 get_activations, 

333 model=model, 

334 save_path=save_path_p, 

335 allow_disk_cache=not force, 

336 return_cache=False, 

337 ), 

338 prompts, 

339 ), 

340 total=len(prompts), 

341 desc="Computing activations", 

342 unit="prompt", 

343 ) 

344 ) 

345 

346 with SpinnerContext( 

347 message="updating jsonl metadata for models and prompts", **SPINNER_KWARGS 

348 ): 

349 generate_models_jsonl(save_path_p) 

350 generate_prompts_jsonl(save_path_p / model_name) 

351 

352 

353def main(): 

354 print(DIVIDER_S1) 

355 with SpinnerContext(message="parsing args", **SPINNER_KWARGS): 

356 arg_parser: argparse.ArgumentParser = argparse.ArgumentParser() 

357 # input and output 

358 arg_parser.add_argument( 

359 "--model", 

360 "-m", 

361 type=str, 

362 required=True, 

363 help="The model name(s) to use. comma separated with no whitespace if multiple", 

364 ) 

365 

366 arg_parser.add_argument( 

367 "--prompts", 

368 "-p", 

369 type=str, 

370 required=False, 

371 help="The path to the prompts file (jsonl with 'text' key on each line). If `None`, expects that `--figures` is passed and will generate figures for all prompts in the model directory", 

372 default=None, 

373 ) 

374 

375 arg_parser.add_argument( 

376 "--save-path", 

377 "-s", 

378 type=str, 

379 required=False, 

380 help="The path to save the attention patterns", 

381 default=DATA_DIR, 

382 ) 

383 

384 # min and max prompt lengths 

385 arg_parser.add_argument( 

386 "--min-chars", 

387 type=int, 

388 required=False, 

389 help="The minimum number of characters for a prompt", 

390 default=100, 

391 ) 

392 arg_parser.add_argument( 

393 "--max-chars", 

394 type=int, 

395 required=False, 

396 help="The maximum number of characters for a prompt", 

397 default=1000, 

398 ) 

399 

400 # number of samples 

401 arg_parser.add_argument( 

402 "--n-samples", 

403 "-n", 

404 type=int, 

405 required=False, 

406 help="The max number of samples to process, do all in the file if None", 

407 default=None, 

408 ) 

409 

410 # force overwrite 

411 arg_parser.add_argument( 

412 "--force", 

413 "-f", 

414 action="store_true", 

415 help="If passed, will overwrite existing files", 

416 ) 

417 

418 # no index html 

419 arg_parser.add_argument( 

420 "--no-index-html", 

421 action="store_true", 

422 help="If passed, will not write an index.html file for the model", 

423 ) 

424 

425 # raw prompts 

426 arg_parser.add_argument( 

427 "--raw-prompts", 

428 "-r", 

429 action="store_true", 

430 help="pass if the prompts have not been split and tokenized (still needs keys 'text' and 'meta' for each item)", 

431 ) 

432 

433 # shuffle 

434 arg_parser.add_argument( 

435 "--shuffle", 

436 action="store_true", 

437 help="If passed, will shuffle the prompts", 

438 ) 

439 

440 # device 

441 arg_parser.add_argument( 

442 "--device", 

443 type=str, 

444 required=False, 

445 help="The device to use for the model", 

446 default="cuda" if torch.cuda.is_available() else "cpu", 

447 ) 

448 

449 args: argparse.Namespace = arg_parser.parse_args() 

450 

451 print(f"args parsed: {args}") 

452 

453 models: list[str] 

454 if "," in args.model: 

455 models = args.model.split(",") 

456 else: 

457 models = [args.model] 

458 

459 n_models: int = len(models) 

460 for idx, model in enumerate(models): 

461 print(DIVIDER_S2) 

462 print(f"processing model {idx+1} / {n_models}: {model}") 

463 print(DIVIDER_S2) 

464 

465 activations_main( 

466 model_name=model, 

467 save_path=args.save_path, 

468 prompts_path=args.prompts, 

469 raw_prompts=args.raw_prompts, 

470 min_chars=args.min_chars, 

471 max_chars=args.max_chars, 

472 force=args.force, 

473 n_samples=args.n_samples, 

474 no_index_html=args.no_index_html, 

475 shuffle=args.shuffle, 

476 device=args.device, 

477 ) 

478 

479 print(DIVIDER_S1) 

480 

481 

482if __name__ == "__main__": 

483 main()