Coverage for pattern_lens\figures.py: 68%

97 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-01-16 20:39 -0700

1"""code for generating figures from attention patterns, using the functions decorated with `register_attn_figure_func`""" 

2 

3import argparse 

4from collections import defaultdict 

5import functools 

6import itertools 

7import json 

8import warnings 

9from pathlib import Path 

10 

11from muutils.json_serialize import json_serialize 

12from muutils.spinner import SpinnerContext 

13from muutils.parallel import run_maybe_parallel 

14 

15from pattern_lens.attn_figure_funcs import ATTENTION_MATRIX_FIGURE_FUNCS 

16from pattern_lens.consts import ( 

17 DATA_DIR, 

18 AttentionMatrix, 

19 SPINNER_KWARGS, 

20 ActivationCacheNp, 

21 DIVIDER_S1, 

22 DIVIDER_S2, 

23) 

24from pattern_lens.indexes import ( 

25 generate_functions_jsonl, 

26 generate_models_jsonl, 

27 generate_prompts_jsonl, 

28) 

29from pattern_lens.load_activations import load_activations 

30 

31 

32class HTConfigMock: 

33 """Mock of `transformer_lens.HookedTransformerConfig` for type hinting and loading config json 

34 

35 can be initialized with any kwargs, and will update its `__dict__` with them. does, however, require the following attributes: 

36 - `n_layers: int` 

37 - `n_heads: int` 

38 - `model_name: str` 

39 """ 

40 

41 def __init__(self, **kwargs): 

42 self.n_layers: int 

43 self.n_heads: int 

44 self.model_name: str 

45 self.__dict__.update(kwargs) 

46 

47 def serialize(self): 

48 """serialize the config to json. values which aren't serializable will be converted via `muutils.json_serialize.json_serialize`""" 

49 return json_serialize(self.__dict__) 

50 

51 @classmethod 

52 def load(cls, data: dict): 

53 "try to load a config from a dict, using the `__init__` method" 

54 return cls(**data) 

55 

56 

57def process_single_head( 

58 layer_idx: int, 

59 head_idx: int, 

60 attn_pattern: AttentionMatrix, 

61 save_dir: Path, 

62 force_overwrite: bool = False, 

63) -> dict[str, bool | Exception]: 

64 """process a single head's attention pattern, running all the functions in `ATTENTION_MATRIX_FIGURE_FUNCS` on the attention pattern 

65 

66 # Parameters: 

67 - `layer_idx : int` 

68 - `head_idx : int` 

69 - `attn_pattern : AttentionMatrix` 

70 attention pattern for the head 

71 - `save_dir : Path` 

72 directory to save the figures to 

73 - `force_overwrite : bool` 

74 whether to overwrite existing figures. if `False`, will skip any functions which have already saved a figure 

75 (defaults to `False`) 

76 

77 # Returns: 

78 - `dict[str, bool | Exception]` 

79 a dictionary of the status of each function, with the function name as the key and the status as the value 

80 """ 

81 funcs_status: dict[str, bool | Exception] = dict() 

82 

83 for func in ATTENTION_MATRIX_FIGURE_FUNCS: 

84 func_name: str = func.__name__ 

85 fig_path: list[Path] = list(save_dir.glob(f"{func_name}.*")) 

86 

87 if not force_overwrite and len(fig_path) > 0: 

88 funcs_status[func_name] = True 

89 continue 

90 

91 try: 

92 func(attn_pattern, save_dir) 

93 funcs_status[func_name] = True 

94 

95 except Exception as e: 

96 error_file = save_dir / f"{func.__name__}.error.txt" 

97 error_file.write_text(str(e)) 

98 warnings.warn( 

99 f"Error in {func.__name__} for L{layer_idx}H{head_idx}: {str(e)}" 

100 ) 

101 funcs_status[func_name] = e 

102 

103 return funcs_status 

104 

105 

106def compute_and_save_figures( 

107 model_cfg: "HookedTransformerConfig|HTConfigMock", # type: ignore[name-defined] # noqa: F821 

108 activations_path: Path, 

109 cache: ActivationCacheNp, 

110 save_path: Path = Path(DATA_DIR), 

111 force_overwrite: bool = False, 

112 track_results: bool = False, 

113) -> None: 

114 """compute and save figures for all heads in the model, using the functions in `ATTENTION_MATRIX_FIGURE_FUNCS` 

115 

116 # Parameters: 

117 - `model_cfg : HookedTransformerConfig|HTConfigMock` 

118 - `cache : ActivationCacheNp` 

119 - `save_path : Path` 

120 (defaults to `Path(DATA_DIR)`) 

121 - `force_overwrite : bool` 

122 force overwrite of existing figures. if `False`, will skip any functions which have already saved a figure 

123 (defaults to `False`) 

124 - `track_results : bool` 

125 whether to track the results of each function for each head. Isn't used for anything yet, but this is a TODO 

126 (defaults to `False`) 

127 """ 

128 prompt_dir: Path = activations_path.parent 

129 

130 if track_results: 

131 results: defaultdict[ 

132 str, # func name 

133 dict[ 

134 tuple[int, int], # layer, head 

135 bool | Exception, # success or exception 

136 ], 

137 ] = defaultdict(dict) 

138 

139 for layer_idx, head_idx in itertools.product( 

140 range(model_cfg.n_layers), 

141 range(model_cfg.n_heads), 

142 ): 

143 attn_pattern: AttentionMatrix = cache[f"blocks.{layer_idx}.attn.hook_pattern"][ 

144 0, head_idx 

145 ] 

146 save_dir: Path = prompt_dir / f"L{layer_idx}" / f"H{head_idx}" 

147 save_dir.mkdir(parents=True, exist_ok=True) 

148 head_res: dict[str, bool | Exception] = process_single_head( 

149 layer_idx=layer_idx, 

150 head_idx=head_idx, 

151 attn_pattern=attn_pattern, 

152 save_dir=save_dir, 

153 force_overwrite=force_overwrite, 

154 ) 

155 

156 if track_results: 

157 for func_name, status in head_res.items(): 

158 results[func_name][(layer_idx, head_idx)] = status 

159 

160 # TODO: do something with results 

161 

162 generate_prompts_jsonl(save_path / model_cfg.model_name) 

163 

164 

165def process_prompt( 

166 prompt: dict, 

167 model_cfg: "HookedTransformerConfig|HTConfigMock", # type: ignore[name-defined] # noqa: F821 

168 save_path: Path, 

169 force_overwrite: bool = False, 

170) -> None: 

171 """process a single prompt, loading the activations and computing and saving the figures 

172 

173 basically just calls `load_activations` and then `compute_and_save_figures` 

174 

175 # Parameters: 

176 - `prompt : dict` 

177 - `model_cfg : HookedTransformerConfig|HTConfigMock` 

178 - `force_overwrite : bool` 

179 (defaults to `False`) 

180 """ 

181 activations_path: Path 

182 cache: ActivationCacheNp 

183 activations_path, cache = load_activations( 

184 model_name=model_cfg.model_name, 

185 prompt=prompt, 

186 save_path=save_path, 

187 return_fmt="numpy", 

188 ) 

189 

190 compute_and_save_figures( 

191 model_cfg=model_cfg, 

192 activations_path=activations_path, 

193 cache=cache, 

194 save_path=save_path, 

195 force_overwrite=force_overwrite, 

196 ) 

197 

198 

199def figures_main( 

200 model_name: str, 

201 save_path: str, 

202 n_samples: int, 

203 force: bool, 

204 parallel: bool | int = True, 

205) -> None: 

206 """main function for generating figures from attention patterns, using the functions in `ATTENTION_MATRIX_FIGURE_FUNCS` 

207 

208 # Parameters: 

209 - `model_name : str` 

210 model name to use, used for loading the model config, prompts, activations, and saving the figures 

211 - `save_path : str` 

212 base path to look in 

213 - `n_samples : int` 

214 max number of samples to process 

215 - `force : bool` 

216 force overwrite of existing figures. if `False`, will skip any functions which have already saved a figure 

217 - `parallel : bool | int` 

218 whether to run in parallel. if `True`, will use all available cores. if `False`, will run in serial. if an int, will try to use that many cores 

219 (defaults to `True`) 

220 """ 

221 with SpinnerContext(message="setting up paths", **SPINNER_KWARGS): 

222 # save model info or check if it exists 

223 save_path_p: Path = Path(save_path) 

224 model_path: Path = save_path_p / model_name 

225 with open(model_path / "model_cfg.json", "r") as f: 

226 model_cfg = HTConfigMock.load(json.load(f)) 

227 

228 with SpinnerContext(message="loading prompts", **SPINNER_KWARGS): 

229 # load prompts 

230 with open(model_path / "prompts.jsonl", "r") as f: 

231 prompts: list[dict] = [json.loads(line) for line in f.readlines()] 

232 # truncate to n_samples 

233 prompts = prompts[:n_samples] 

234 

235 print(f"{len(prompts)} prompts loaded") 

236 

237 print(f"{len(ATTENTION_MATRIX_FIGURE_FUNCS)} figure functions loaded") 

238 print("\t" + ", ".join([func.__name__ for func in ATTENTION_MATRIX_FIGURE_FUNCS])) 

239 

240 list( 

241 run_maybe_parallel( 

242 func=functools.partial( 

243 process_prompt, 

244 model_cfg=model_cfg, 

245 save_path=save_path_p, 

246 force_overwrite=force, 

247 ), 

248 iterable=prompts, 

249 parallel=parallel, 

250 pbar="tqdm", 

251 pbar_kwargs=dict( 

252 desc="Making figures", 

253 unit="prompt", 

254 ), 

255 ) 

256 ) 

257 

258 with SpinnerContext( 

259 message="updating jsonl metadata for models and functions", **SPINNER_KWARGS 

260 ): 

261 generate_models_jsonl(save_path_p) 

262 generate_functions_jsonl(save_path_p) 

263 

264 

265def main(): 

266 print(DIVIDER_S1) 

267 with SpinnerContext(message="parsing args", **SPINNER_KWARGS): 

268 arg_parser: argparse.ArgumentParser = argparse.ArgumentParser() 

269 # input and output 

270 arg_parser.add_argument( 

271 "--model", 

272 "-m", 

273 type=str, 

274 required=True, 

275 help="The model name(s) to use. comma separated with no whitespace if multiple", 

276 ) 

277 arg_parser.add_argument( 

278 "--save-path", 

279 "-s", 

280 type=str, 

281 required=False, 

282 help="The path to save the attention patterns", 

283 default=DATA_DIR, 

284 ) 

285 # number of samples 

286 arg_parser.add_argument( 

287 "--n-samples", 

288 "-n", 

289 type=int, 

290 required=False, 

291 help="The max number of samples to process, do all in the file if None", 

292 default=None, 

293 ) 

294 # force overwrite of existing figures 

295 arg_parser.add_argument( 

296 "--force", 

297 "-f", 

298 type=bool, 

299 required=False, 

300 help="Force overwrite of existing figures", 

301 default=False, 

302 ) 

303 

304 args: argparse.Namespace = arg_parser.parse_args() 

305 

306 print(f"args parsed: {args}") 

307 

308 models: list[str] 

309 if "," in args.model: 

310 models = args.model.split(",") 

311 else: 

312 models = [args.model] 

313 

314 n_models: int = len(models) 

315 for idx, model in enumerate(models): 

316 print(DIVIDER_S2) 

317 print(f"processing model {idx+1} / {n_models}: {model}") 

318 print(DIVIDER_S2) 

319 figures_main( 

320 model_name=model, 

321 save_path=args.save_path, 

322 n_samples=args.n_samples, 

323 force=args.force, 

324 ) 

325 

326 print(DIVIDER_S1) 

327 

328 

329if __name__ == "__main__": 

330 main()