Coverage for pattern_lens\figures.py: 91%

104 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-04 01:50 -0700

1"""code for generating figures from attention patterns, using the functions decorated with `register_attn_figure_func`""" 

2 

3import argparse 

4import functools 

5import itertools 

6import json 

7import warnings 

8from collections import defaultdict 

9from pathlib import Path 

10 

11import numpy as np 

12from jaxtyping import Float 

13 

14# custom utils 

15from muutils.json_serialize import json_serialize 

16from muutils.parallel import run_maybe_parallel 

17from muutils.spinner import SpinnerContext 

18 

19# pattern_lens 

20from pattern_lens.attn_figure_funcs import ATTENTION_MATRIX_FIGURE_FUNCS 

21from pattern_lens.consts import ( 

22 DATA_DIR, 

23 DIVIDER_S1, 

24 DIVIDER_S2, 

25 SPINNER_KWARGS, 

26 ActivationCacheNp, 

27 AttentionMatrix, 

28) 

29from pattern_lens.indexes import ( 

30 generate_functions_jsonl, 

31 generate_models_jsonl, 

32 generate_prompts_jsonl, 

33) 

34from pattern_lens.load_activations import load_activations 

35 

36 

37class HTConfigMock: 

38 """Mock of `transformer_lens.HookedTransformerConfig` for type hinting and loading config json 

39 

40 can be initialized with any kwargs, and will update its `__dict__` with them. does, however, require the following attributes: 

41 - `n_layers: int` 

42 - `n_heads: int` 

43 - `model_name: str` 

44 

45 we do this to avoid having to import `torch` and `transformer_lens`, since this would have to be done for each process in the parallelization and probably slows things down significantly 

46 """ 

47 

48 def __init__(self, **kwargs: dict[str, str | int]) -> None: 

49 "will pass all kwargs to `__dict__`" 

50 self.n_layers: int 

51 self.n_heads: int 

52 self.model_name: str 

53 self.__dict__.update(kwargs) 

54 

55 def serialize(self) -> dict: 

56 """serialize the config to json. values which aren't serializable will be converted via `muutils.json_serialize.json_serialize`""" 

57 # its fine, we know its a dict 

58 return json_serialize(self.__dict__) # type: ignore[return-value] 

59 

60 @classmethod 

61 def load(cls, data: dict) -> "HTConfigMock": 

62 "try to load a config from a dict, using the `__init__` method" 

63 return cls(**data) 

64 

65 

66def process_single_head( 

67 layer_idx: int, 

68 head_idx: int, 

69 attn_pattern: AttentionMatrix, 

70 save_dir: Path, 

71 force_overwrite: bool = False, 

72) -> dict[str, bool | Exception]: 

73 """process a single head's attention pattern, running all the functions in `ATTENTION_MATRIX_FIGURE_FUNCS` on the attention pattern 

74 

75 > [gotcha:] if `force_overwrite` is `False`, and we used a multi-figure function, 

76 > it will skip all figures for that function if any are already saved 

77 > and it assumes a format of `{func_name}.{figure_name}.{fmt}` for the saved figures 

78 

79 # Parameters: 

80 - `layer_idx : int` 

81 - `head_idx : int` 

82 - `attn_pattern : AttentionMatrix` 

83 attention pattern for the head 

84 - `save_dir : Path` 

85 directory to save the figures to 

86 - `force_overwrite : bool` 

87 whether to overwrite existing figures. if `False`, will skip any functions which have already saved a figure 

88 (defaults to `False`) 

89 

90 # Returns: 

91 - `dict[str, bool | Exception]` 

92 a dictionary of the status of each function, with the function name as the key and the status as the value 

93 """ 

94 funcs_status: dict[str, bool | Exception] = dict() 

95 

96 for func in ATTENTION_MATRIX_FIGURE_FUNCS: 

97 func_name: str = func.__name__ 

98 fig_path: list[Path] = list(save_dir.glob(f"{func_name}.*")) 

99 

100 if not force_overwrite and len(fig_path) > 0: 

101 funcs_status[func_name] = True 

102 continue 

103 

104 try: 

105 func(attn_pattern, save_dir) 

106 funcs_status[func_name] = True 

107 

108 # bling catch any exception 

109 except Exception as e: # noqa: BLE001 

110 error_file = save_dir / f"{func.__name__}.error.txt" 

111 error_file.write_text(str(e)) 

112 warnings.warn( 

113 f"Error in {func.__name__} for L{layer_idx}H{head_idx}: {e!s}", 

114 stacklevel=2, 

115 ) 

116 funcs_status[func_name] = e 

117 

118 return funcs_status 

119 

120 

121def compute_and_save_figures( 

122 model_cfg: "HookedTransformerConfig|HTConfigMock", # type: ignore[name-defined] # noqa: F821 

123 activations_path: Path, 

124 cache: ActivationCacheNp | Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"], 

125 save_path: Path = Path(DATA_DIR), 

126 force_overwrite: bool = False, 

127 track_results: bool = False, 

128) -> None: 

129 """compute and save figures for all heads in the model, using the functions in `ATTENTION_MATRIX_FIGURE_FUNCS` 

130 

131 # Parameters: 

132 - `model_cfg : HookedTransformerConfig|HTConfigMock` 

133 - `cache : ActivationCacheNp | Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"]` 

134 - `save_path : Path` 

135 (defaults to `Path(DATA_DIR)`) 

136 - `force_overwrite : bool` 

137 force overwrite of existing figures. if `False`, will skip any functions which have already saved a figure 

138 (defaults to `False`) 

139 - `track_results : bool` 

140 whether to track the results of each function for each head. Isn't used for anything yet, but this is a TODO 

141 (defaults to `False`) 

142 """ 

143 prompt_dir: Path = activations_path.parent 

144 

145 if track_results: 

146 results: defaultdict[ 

147 str, # func name 

148 dict[ 

149 tuple[int, int], # layer, head 

150 bool | Exception, # success or exception 

151 ], 

152 ] = defaultdict(dict) 

153 

154 for layer_idx, head_idx in itertools.product( 

155 range(model_cfg.n_layers), 

156 range(model_cfg.n_heads), 

157 ): 

158 attn_pattern: AttentionMatrix 

159 if isinstance(cache, dict): 

160 attn_pattern = cache[f"blocks.{layer_idx}.attn.hook_pattern"][0, head_idx] 

161 elif isinstance(cache, np.ndarray): 

162 attn_pattern = cache[layer_idx, head_idx] 

163 else: 

164 msg = ( 

165 f"cache must be a dict or np.ndarray, not {type(cache) = }\n{cache = }" 

166 ) 

167 raise TypeError( 

168 msg, 

169 ) 

170 

171 save_dir: Path = prompt_dir / f"L{layer_idx}" / f"H{head_idx}" 

172 save_dir.mkdir(parents=True, exist_ok=True) 

173 head_res: dict[str, bool | Exception] = process_single_head( 

174 layer_idx=layer_idx, 

175 head_idx=head_idx, 

176 attn_pattern=attn_pattern, 

177 save_dir=save_dir, 

178 force_overwrite=force_overwrite, 

179 ) 

180 

181 if track_results: 

182 for func_name, status in head_res.items(): 

183 results[func_name][(layer_idx, head_idx)] = status 

184 

185 # TODO: do something with results 

186 

187 generate_prompts_jsonl(save_path / model_cfg.model_name) 

188 

189 

190def process_prompt( 

191 prompt: dict, 

192 model_cfg: "HookedTransformerConfig|HTConfigMock", # type: ignore[name-defined] # noqa: F821 

193 save_path: Path, 

194 force_overwrite: bool = False, 

195) -> None: 

196 """process a single prompt, loading the activations and computing and saving the figures 

197 

198 basically just calls `load_activations` and then `compute_and_save_figures` 

199 

200 # Parameters: 

201 - `prompt : dict` 

202 - `model_cfg : HookedTransformerConfig|HTConfigMock` 

203 - `force_overwrite : bool` 

204 (defaults to `False`) 

205 """ 

206 activations_path: Path 

207 cache: ActivationCacheNp | Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"] 

208 activations_path, cache = load_activations( 

209 model_name=model_cfg.model_name, 

210 prompt=prompt, 

211 save_path=save_path, 

212 return_fmt="numpy", 

213 ) 

214 

215 compute_and_save_figures( 

216 model_cfg=model_cfg, 

217 activations_path=activations_path, 

218 cache=cache, 

219 save_path=save_path, 

220 force_overwrite=force_overwrite, 

221 ) 

222 

223 

224def figures_main( 

225 model_name: str, 

226 save_path: str, 

227 n_samples: int, 

228 force: bool, 

229 parallel: bool | int = True, 

230) -> None: 

231 """main function for generating figures from attention patterns, using the functions in `ATTENTION_MATRIX_FIGURE_FUNCS` 

232 

233 # Parameters: 

234 - `model_name : str` 

235 model name to use, used for loading the model config, prompts, activations, and saving the figures 

236 - `save_path : str` 

237 base path to look in 

238 - `n_samples : int` 

239 max number of samples to process 

240 - `force : bool` 

241 force overwrite of existing figures. if `False`, will skip any functions which have already saved a figure 

242 - `parallel : bool | int` 

243 whether to run in parallel. if `True`, will use all available cores. if `False`, will run in serial. if an int, will try to use that many cores 

244 (defaults to `True`) 

245 """ 

246 with SpinnerContext(message="setting up paths", **SPINNER_KWARGS): 

247 # save model info or check if it exists 

248 save_path_p: Path = Path(save_path) 

249 model_path: Path = save_path_p / model_name 

250 with open(model_path / "model_cfg.json", "r") as f: 

251 model_cfg = HTConfigMock.load(json.load(f)) 

252 

253 with SpinnerContext(message="loading prompts", **SPINNER_KWARGS): 

254 # load prompts 

255 with open(model_path / "prompts.jsonl", "r") as f: 

256 prompts: list[dict] = [json.loads(line) for line in f.readlines()] 

257 # truncate to n_samples 

258 prompts = prompts[:n_samples] 

259 

260 print(f"{len(prompts)} prompts loaded") 

261 

262 print(f"{len(ATTENTION_MATRIX_FIGURE_FUNCS)} figure functions loaded") 

263 print("\t" + ", ".join([func.__name__ for func in ATTENTION_MATRIX_FIGURE_FUNCS])) 

264 

265 list( 

266 run_maybe_parallel( 

267 func=functools.partial( 

268 process_prompt, 

269 model_cfg=model_cfg, 

270 save_path=save_path_p, 

271 force_overwrite=force, 

272 ), 

273 iterable=prompts, 

274 parallel=parallel, 

275 pbar="tqdm", 

276 pbar_kwargs=dict( 

277 desc="Making figures", 

278 unit="prompt", 

279 ), 

280 ), 

281 ) 

282 

283 with SpinnerContext( 

284 message="updating jsonl metadata for models and functions", 

285 **SPINNER_KWARGS, 

286 ): 

287 generate_models_jsonl(save_path_p) 

288 generate_functions_jsonl(save_path_p) 

289 

290 

291def main() -> None: 

292 "generates figures from the activations using the functions decorated with `register_attn_figure_func`" 

293 print(DIVIDER_S1) 

294 with SpinnerContext(message="parsing args", **SPINNER_KWARGS): 

295 arg_parser: argparse.ArgumentParser = argparse.ArgumentParser() 

296 # input and output 

297 arg_parser.add_argument( 

298 "--model", 

299 "-m", 

300 type=str, 

301 required=True, 

302 help="The model name(s) to use. comma separated with no whitespace if multiple", 

303 ) 

304 arg_parser.add_argument( 

305 "--save-path", 

306 "-s", 

307 type=str, 

308 required=False, 

309 help="The path to save the attention patterns", 

310 default=DATA_DIR, 

311 ) 

312 # number of samples 

313 arg_parser.add_argument( 

314 "--n-samples", 

315 "-n", 

316 type=int, 

317 required=False, 

318 help="The max number of samples to process, do all in the file if None", 

319 default=None, 

320 ) 

321 # force overwrite of existing figures 

322 arg_parser.add_argument( 

323 "--force", 

324 "-f", 

325 type=bool, 

326 required=False, 

327 help="Force overwrite of existing figures", 

328 default=False, 

329 ) 

330 

331 args: argparse.Namespace = arg_parser.parse_args() 

332 

333 print(f"args parsed: {args}") 

334 

335 models: list[str] 

336 if "," in args.model: 

337 models = args.model.split(",") 

338 else: 

339 models = [args.model] 

340 

341 n_models: int = len(models) 

342 for idx, model in enumerate(models): 

343 print(DIVIDER_S2) 

344 print(f"processing model {idx + 1} / {n_models}: {model}") 

345 print(DIVIDER_S2) 

346 figures_main( 

347 model_name=model, 

348 save_path=args.save_path, 

349 n_samples=args.n_samples, 

350 force=args.force, 

351 ) 

352 

353 print(DIVIDER_S1) 

354 

355 

356if __name__ == "__main__": 

357 main()