Coverage for pattern_lens / figures.py: 86%

133 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-22 18:15 -0700

1"""code for generating figures from attention patterns, using the functions decorated with `register_attn_figure_func`""" 

2 

3import argparse 

4import fnmatch 

5import functools 

6import itertools 

7import json 

8import multiprocessing 

9import re 

10import warnings 

11from collections import defaultdict 

12from pathlib import Path 

13 

14import numpy as np 

15from jaxtyping import Float 

16 

17# custom utils 

18from muutils.json_serialize import json_serialize 

19from muutils.parallel import run_maybe_parallel 

20from muutils.spinner import SpinnerContext 

21 

22# pattern_lens 

23from pattern_lens.attn_figure_funcs import ATTENTION_MATRIX_FIGURE_FUNCS 

24from pattern_lens.consts import ( 

25 DATA_DIR, 

26 DIVIDER_S1, 

27 DIVIDER_S2, 

28 SPINNER_KWARGS, 

29 ActivationCacheNp, 

30 AttentionMatrix, 

31) 

32from pattern_lens.figure_util import AttentionMatrixFigureFunc 

33from pattern_lens.indexes import ( 

34 generate_functions_jsonl, 

35 generate_models_jsonl, 

36 generate_prompts_jsonl, 

37) 

38from pattern_lens.load_activations import load_activations 

39 

40 

41class HTConfigMock: 

42 """Mock of `transformer_lens.HookedTransformerConfig` for type hinting and loading config json 

43 

44 can be initialized with any kwargs, and will update its `__dict__` with them. does, however, require the following attributes: 

45 - `n_layers: int` 

46 - `n_heads: int` 

47 - `model_name: str` 

48 

49 we do this to avoid having to import `torch` and `transformer_lens`, since this would have to be done for each process in the parallelization and probably slows things down significantly 

50 """ 

51 

52 def __init__(self, **kwargs: dict[str, str | int]) -> None: 

53 "will pass all kwargs to `__dict__`" 

54 self.n_layers: int 

55 self.n_heads: int 

56 self.model_name: str 

57 self.__dict__.update(kwargs) 

58 

59 def serialize(self) -> dict: 

60 """serialize the config to json. values which aren't serializable will be converted via `muutils.json_serialize.json_serialize`""" 

61 # its fine, we know its a dict 

62 return json_serialize(self.__dict__) # type: ignore[return-value] 

63 

64 @classmethod 

65 def load(cls, data: dict) -> "HTConfigMock": 

66 "try to load a config from a dict, using the `__init__` method" 

67 return cls(**data) 

68 

69 

70def process_single_head( 

71 layer_idx: int, 

72 head_idx: int, 

73 attn_pattern: AttentionMatrix, 

74 save_dir: Path, 

75 figure_funcs: list[AttentionMatrixFigureFunc], 

76 force_overwrite: bool = False, 

77) -> dict[str, bool | Exception]: 

78 """process a single head's attention pattern, running all the functions in `figure_funcs` on the attention pattern 

79 

80 > [gotcha:] if `force_overwrite` is `False`, and we used a multi-figure function, 

81 > it will skip all figures for that function if any are already saved 

82 > and it assumes a format of `{func_name}.{figure_name}.{fmt}` for the saved figures 

83 

84 # Parameters: 

85 - `layer_idx : int` 

86 - `head_idx : int` 

87 - `attn_pattern : AttentionMatrix` 

88 attention pattern for the head 

89 - `save_dir : Path` 

90 directory to save the figures to 

91 - `force_overwrite : bool` 

92 whether to overwrite existing figures. if `False`, will skip any functions which have already saved a figure 

93 (defaults to `False`) 

94 

95 # Returns: 

96 - `dict[str, bool | Exception]` 

97 a dictionary of the status of each function, with the function name as the key and the status as the value 

98 """ 

99 funcs_status: dict[str, bool | Exception] = dict() 

100 

101 for func in figure_funcs: 

102 func_name: str = getattr(func, "__name__", "<unknown>") 

103 fig_path: list[Path] = list(save_dir.glob(f"{func_name}.*")) 

104 

105 if not force_overwrite and len(fig_path) > 0: 

106 funcs_status[func_name] = True 

107 continue 

108 

109 try: 

110 func(attn_pattern, save_dir) 

111 funcs_status[func_name] = True 

112 

113 # bling catch any exception 

114 except Exception as e: # noqa: BLE001 

115 error_file = save_dir / f"{func_name}.error.txt" 

116 error_file.write_text(str(e)) 

117 warnings.warn( 

118 f"Error in {func_name} for L{layer_idx}H{head_idx}: {e!s}", 

119 stacklevel=2, 

120 ) 

121 funcs_status[func_name] = e 

122 

123 return funcs_status 

124 

125 

126def compute_and_save_figures( 

127 model_cfg: "HookedTransformerConfig|HTConfigMock", # type: ignore[name-defined] # noqa: F821 

128 activations_path: Path, 

129 cache: ActivationCacheNp | Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"], 

130 figure_funcs: list[AttentionMatrixFigureFunc], 

131 save_path: Path = Path(DATA_DIR), 

132 force_overwrite: bool = False, 

133 track_results: bool = False, 

134) -> None: 

135 """compute and save figures for all heads in the model, using the functions in `ATTENTION_MATRIX_FIGURE_FUNCS` 

136 

137 # Parameters: 

138 - `model_cfg : HookedTransformerConfig|HTConfigMock` 

139 configuration of the model, used for loading the activations 

140 - `cache : ActivationCacheNp | Float[np.ndarray, &quot;n_layers n_heads n_ctx n_ctx&quot;]` 

141 activation cache containing actual patterns for the prompt we are processing 

142 - `figure_funcs : list[AttentionMatrixFigureFunc]` 

143 list of functions to run 

144 - `save_path : Path` 

145 directory to save the figures to 

146 (defaults to `Path(DATA_DIR)`) 

147 - `force_overwrite : bool` 

148 force overwrite of existing figures. if `False`, will skip any functions which have already saved a figure 

149 (defaults to `False`) 

150 - `track_results : bool` 

151 whether to track the results of each function for each head. Isn't used for anything yet, but this is a TODO 

152 (defaults to `False`) 

153 """ 

154 prompt_dir: Path = activations_path.parent 

155 

156 if track_results: 

157 results: defaultdict[ 

158 str, # func name 

159 dict[ 

160 tuple[int, int], # layer, head 

161 bool | Exception, # success or exception 

162 ], 

163 ] = defaultdict(dict) 

164 

165 for layer_idx, head_idx in itertools.product( 

166 range(model_cfg.n_layers), 

167 range(model_cfg.n_heads), 

168 ): 

169 attn_pattern: AttentionMatrix 

170 if isinstance(cache, dict): 

171 attn_pattern = cache[f"blocks.{layer_idx}.attn.hook_pattern"][0, head_idx] 

172 elif isinstance(cache, np.ndarray): 

173 attn_pattern = cache[layer_idx, head_idx] 

174 else: 

175 msg = ( 

176 f"cache must be a dict or np.ndarray, not {type(cache) = }\n{cache = }" 

177 ) 

178 raise TypeError( 

179 msg, 

180 ) 

181 

182 save_dir: Path = prompt_dir / f"L{layer_idx}" / f"H{head_idx}" 

183 save_dir.mkdir(parents=True, exist_ok=True) 

184 head_res: dict[str, bool | Exception] = process_single_head( 

185 layer_idx=layer_idx, 

186 head_idx=head_idx, 

187 attn_pattern=attn_pattern, 

188 save_dir=save_dir, 

189 force_overwrite=force_overwrite, 

190 figure_funcs=figure_funcs, 

191 ) 

192 

193 if track_results: 

194 for func_name, status in head_res.items(): 

195 results[func_name][(layer_idx, head_idx)] = status 

196 

197 # TODO: do something with results 

198 

199 generate_prompts_jsonl(save_path / model_cfg.model_name) 

200 

201 

202def process_prompt( 

203 prompt: dict, 

204 model_cfg: "HookedTransformerConfig|HTConfigMock", # type: ignore[name-defined] # noqa: F821 

205 save_path: Path, 

206 figure_funcs: list[AttentionMatrixFigureFunc], 

207 force_overwrite: bool = False, 

208) -> None: 

209 """process a single prompt, loading the activations and computing and saving the figures 

210 

211 basically just calls `load_activations` and then `compute_and_save_figures` 

212 

213 # Parameters: 

214 - `prompt : dict` 

215 prompt to process, should be a dict with the following keys: 

216 - `"text"`: the prompt string 

217 - `"hash"`: the hash of the prompt 

218 - `model_cfg : HookedTransformerConfig|HTConfigMock` 

219 configuration of the model, used for figuring out where to save 

220 - `save_path : Path` 

221 directory to save the figures to 

222 - `figure_funcs : list[AttentionMatrixFigureFunc]` 

223 list of functions to run 

224 - `force_overwrite : bool` 

225 (defaults to `False`) 

226 """ 

227 # load the activations 

228 activations_path: Path 

229 cache: ActivationCacheNp | Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"] 

230 activations_path, cache = load_activations( 

231 model_name=model_cfg.model_name, 

232 prompt=prompt, 

233 save_path=save_path, 

234 return_fmt="numpy", 

235 ) 

236 

237 # compute and save the figures 

238 compute_and_save_figures( 

239 model_cfg=model_cfg, 

240 activations_path=activations_path, 

241 cache=cache, 

242 figure_funcs=figure_funcs, 

243 save_path=save_path, 

244 force_overwrite=force_overwrite, 

245 ) 

246 

247 

248def select_attn_figure_funcs( 

249 figure_funcs_select: set[str] | str | None = None, 

250) -> list[AttentionMatrixFigureFunc]: 

251 """given a selector, figure out which functions from `ATTENTION_MATRIX_FIGURE_FUNCS` to use 

252 

253 - if arg is `None`, will use all functions 

254 - if a string, will use the function names which match the string (glob/fnmatch syntax) 

255 - if a set, will use functions whose names are in the set 

256 

257 """ 

258 # figure out which functions to use 

259 figure_funcs: list[AttentionMatrixFigureFunc] 

260 if figure_funcs_select is None: 

261 # all if nothing specified 

262 figure_funcs = ATTENTION_MATRIX_FIGURE_FUNCS 

263 elif isinstance(figure_funcs_select, str): 

264 # if a string, assume a glob pattern 

265 pattern: re.Pattern = re.compile(fnmatch.translate(figure_funcs_select)) 

266 figure_funcs = [ 

267 func 

268 for func in ATTENTION_MATRIX_FIGURE_FUNCS 

269 if pattern.match(getattr(func, "__name__", "<unknown>")) 

270 ] 

271 elif isinstance(figure_funcs_select, set): 

272 # if a set, assume a set of function names 

273 figure_funcs = [ 

274 func 

275 for func in ATTENTION_MATRIX_FIGURE_FUNCS 

276 if getattr(func, "__name__", "<unknown>") in figure_funcs_select 

277 ] 

278 else: 

279 err_msg: str = ( 

280 f"figure_funcs_select must be None, str, or set, not {type(figure_funcs_select) = }" 

281 f"\n{figure_funcs_select = }" 

282 ) 

283 raise TypeError(err_msg) 

284 return figure_funcs 

285 

286 

287def figures_main( 

288 model_name: str, 

289 save_path: str | Path, 

290 n_samples: int, 

291 force: bool, 

292 figure_funcs_select: set[str] | str | None = None, 

293 parallel: bool | int = True, 

294) -> None: 

295 """main function for generating figures from attention patterns, using the functions in `ATTENTION_MATRIX_FIGURE_FUNCS` 

296 

297 # Parameters: 

298 - `model_name : str` 

299 model name to use, used for loading the model config, prompts, activations, and saving the figures 

300 - `save_path : str | Path` 

301 base path to look in 

302 - `n_samples : int` 

303 max number of samples to process 

304 - `force : bool` 

305 force overwrite of existing figures. if `False`, will skip any functions which have already saved a figure 

306 - `figure_funcs_select : set[str]|str|None` 

307 figure functions to use. if `None`, will use all functions. if a string, will use the function names which match the string. if a set, will use the function names in the set 

308 (defaults to `None`) 

309 - `parallel : bool | int` 

310 whether to run in parallel. if `True`, will use all available cores. if `False`, will run in serial. if an int, will try to use that many cores 

311 (defaults to `True`) 

312 """ 

313 with SpinnerContext(message="setting up paths", **SPINNER_KWARGS): 

314 # save model info or check if it exists 

315 save_path_p: Path = Path(save_path) 

316 model_path: Path = save_path_p / model_name 

317 with open(model_path / "model_cfg.json", "r") as f: 

318 model_cfg = HTConfigMock.load(json.load(f)) 

319 

320 with SpinnerContext(message="loading prompts", **SPINNER_KWARGS): 

321 # load prompts 

322 with open(model_path / "prompts.jsonl", "r") as f: 

323 prompts: list[dict] = [json.loads(line) for line in f.readlines()] 

324 # truncate to n_samples 

325 prompts = prompts[:n_samples] 

326 

327 print(f"{len(prompts)} prompts loaded") 

328 

329 figure_funcs: list[AttentionMatrixFigureFunc] = select_attn_figure_funcs( 

330 figure_funcs_select=figure_funcs_select, 

331 ) 

332 print(f"{len(figure_funcs)} figure functions loaded") 

333 print( 

334 "\t" 

335 + ", ".join([getattr(func, "__name__", "<unknown>") for func in figure_funcs]), 

336 ) 

337 

338 chunksize: int = int( 

339 max( 

340 1, 

341 len(prompts) // (5 * multiprocessing.cpu_count()), 

342 ), 

343 ) 

344 print(f"chunksize: {chunksize}") 

345 

346 list( 

347 run_maybe_parallel( 

348 func=functools.partial( 

349 process_prompt, 

350 model_cfg=model_cfg, 

351 save_path=save_path_p, 

352 figure_funcs=figure_funcs, 

353 force_overwrite=force, 

354 ), 

355 iterable=prompts, 

356 parallel=parallel, 

357 chunksize=chunksize, 

358 pbar="tqdm", 

359 pbar_kwargs=dict( 

360 desc="Making figures", 

361 unit="prompt", 

362 ), 

363 ), 

364 ) 

365 

366 with SpinnerContext( 

367 message="updating jsonl metadata for models and functions", 

368 **SPINNER_KWARGS, 

369 ): 

370 generate_models_jsonl(save_path_p) 

371 generate_functions_jsonl(save_path_p) 

372 

373 

374def _parse_args() -> tuple[ 

375 argparse.Namespace, 

376 list[str], # models 

377 set[str] | str | None, # figure_funcs_select 

378]: 

379 arg_parser: argparse.ArgumentParser = argparse.ArgumentParser() 

380 # input and output 

381 arg_parser.add_argument( 

382 "--model", 

383 "-m", 

384 type=str, 

385 required=True, 

386 help="The model name(s) to use. comma separated with no whitespace if multiple", 

387 ) 

388 arg_parser.add_argument( 

389 "--save-path", 

390 "-s", 

391 type=str, 

392 required=False, 

393 help="The path to save the attention patterns", 

394 default=DATA_DIR, 

395 ) 

396 # number of samples 

397 arg_parser.add_argument( 

398 "--n-samples", 

399 "-n", 

400 type=int, 

401 required=False, 

402 help="The max number of samples to process, do all in the file if None", 

403 default=None, 

404 ) 

405 # force overwrite of existing figures 

406 arg_parser.add_argument( 

407 "--force", 

408 "-f", 

409 type=bool, 

410 required=False, 

411 help="Force overwrite of existing figures", 

412 default=False, 

413 ) 

414 # figure functions 

415 arg_parser.add_argument( 

416 "--figure-funcs", 

417 type=str, 

418 required=False, 

419 help="The figure functions to use. if 'None' (default), will use all functions. if a string, will use the function names which match the string. if a comma-separated list of strings, will use the function names in the set", 

420 default=None, 

421 ) 

422 

423 args: argparse.Namespace = arg_parser.parse_args() 

424 

425 # figure out models 

426 models: list[str] 

427 if "," in args.model: 

428 models = args.model.split(",") 

429 else: 

430 models = [args.model] 

431 

432 # figure out figures 

433 figure_funcs_select: set[str] | str | None 

434 if (args.figure_funcs is None) or (args.figure_funcs.lower().strip() == "none"): 

435 figure_funcs_select = None 

436 elif "," in args.figure_funcs: 

437 figure_funcs_select = {x.strip() for x in args.figure_funcs.split(",")} 

438 else: 

439 figure_funcs_select = args.figure_funcs.strip() 

440 

441 return args, models, figure_funcs_select 

442 

443 

444def main() -> None: 

445 "generates figures from the activations using the functions decorated with `register_attn_figure_func`" 

446 # parse args 

447 print(DIVIDER_S1) 

448 args: argparse.Namespace 

449 models: list[str] 

450 figure_funcs_select: set[str] | str | None 

451 with SpinnerContext(message="parsing args", **SPINNER_KWARGS): 

452 args, models, figure_funcs_select = _parse_args() 

453 print(f"\targs parsed: '{args}'") 

454 print(f"\tmodels: '{models}'") 

455 print(f"\tfigure_funcs_select: '{figure_funcs_select}'") 

456 

457 # compute for each model 

458 n_models: int = len(models) 

459 for idx, model in enumerate(models): 

460 print(DIVIDER_S2) 

461 print(f"processing model {idx + 1} / {n_models}: {model}") 

462 print(DIVIDER_S2) 

463 figures_main( 

464 model_name=model, 

465 save_path=args.save_path, 

466 n_samples=args.n_samples, 

467 force=args.force, 

468 figure_funcs_select=figure_funcs_select, 

469 ) 

470 

471 print(DIVIDER_S1) 

472 

473 

474if __name__ == "__main__": 

475 main()