Coverage for pattern_lens/figure_util.py: 90%

153 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-04-06 15:09 -0600

1"""implements a bunch of types, default values, and templates which are useful for figure functions 

2 

3notably, you can use the decorators `matplotlib_figure_saver`, `save_matrix_wrapper` to make your functions save figures 

4""" 

5 

6import base64 

7import functools 

8import gzip 

9import io 

10from collections.abc import Callable, Sequence 

11from pathlib import Path 

12from typing import Literal, overload 

13 

14import matplotlib as mpl 

15import matplotlib.pyplot as plt 

16import numpy as np 

17from jaxtyping import Float, UInt8 

18from matplotlib.colors import Colormap 

19from PIL import Image 

20 

21from pattern_lens.consts import AttentionMatrix 

22 

23AttentionMatrixFigureFunc = Callable[[AttentionMatrix, Path], None] 

24"Type alias for a function that, given an attention matrix, saves one or more figures" 

25 

26Matrix2D = Float[np.ndarray, "n m"] 

27"Type alias for a 2D matrix (plottable)" 

28 

29Matrix2Drgb = UInt8[np.ndarray, "n m rgb=3"] 

30"Type alias for a 2D matrix with 3 channels (RGB)" 

31 

32AttentionMatrixToMatrixFunc = Callable[[AttentionMatrix], Matrix2D] 

33"Type alias for a function that, given an attention matrix, returns a 2D matrix" 

34 

35MATPLOTLIB_FIGURE_FMT: str = "svgz" 

36"format for saving matplotlib figures" 

37 

38MatrixSaveFormat = Literal["png", "svg", "svgz"] 

39"Type alias for the format to save a matrix as when saving raw matrix, not matplotlib figure" 

40 

41MATRIX_SAVE_NORMALIZE: bool = False 

42"default for whether to normalize the matrix to range [0, 1]" 

43 

44MATRIX_SAVE_CMAP: str = "viridis" 

45"default colormap for saving matrices" 

46 

47MATRIX_SAVE_FMT: MatrixSaveFormat = "svgz" 

48"default format for saving matrices" 

49 

50MATRIX_SAVE_SVG_TEMPLATE: str = """<svg xmlns="http://www.w3.org/2000/svg" width="{m}" height="{n}" viewBox="0 0 {m} {n}" image-rendering="pixelated"> <image href="data:image/png;base64,{png_base64}" width="{m}" height="{n}" /> </svg>""" 

51"template for saving an `n` by `m` matrix as an svg/svgz" 

52 

53 

54# TYPING: mypy hates it when we dont pass func=None or None as the first arg 

55@overload # without keyword arguments, returns decorated function 

56def matplotlib_figure_saver( 

57 func: Callable[[AttentionMatrix, plt.Axes], None], 

58) -> AttentionMatrixFigureFunc: ... 

59@overload # with keyword arguments, returns decorator 

60def matplotlib_figure_saver( 

61 func: None = None, 

62 fmt: str = MATPLOTLIB_FIGURE_FMT, 

63) -> Callable[ 

64 [Callable[[AttentionMatrix, plt.Axes], None], str], 

65 AttentionMatrixFigureFunc, 

66]: ... 

67def matplotlib_figure_saver( 

68 func: Callable[[AttentionMatrix, plt.Axes], None] | None = None, 

69 fmt: str = MATPLOTLIB_FIGURE_FMT, 

70) -> ( 

71 AttentionMatrixFigureFunc 

72 | Callable[ 

73 [Callable[[AttentionMatrix, plt.Axes], None], str], 

74 AttentionMatrixFigureFunc, 

75 ] 

76): 

77 """decorator for functions which take an attention matrix and predefined `ax` object, making it save a figure 

78 

79 # Parameters: 

80 - `func : Callable[[AttentionMatrix, plt.Axes], None]` 

81 your function, which should take an attention matrix and predefined `ax` object 

82 - `fmt : str` 

83 format for saving matplotlib figures 

84 (defaults to `MATPLOTLIB_FIGURE_FMT`) 

85 

86 # Returns: 

87 - `AttentionMatrixFigureFunc` 

88 your function, after we wrap it to save a figure 

89 

90 # Usage: 

91 ```python 

92 @register_attn_figure_func 

93 @matplotlib_figure_saver 

94 def raw(attn_matrix: AttentionMatrix, ax: plt.Axes) -> None: 

95 ax.matshow(attn_matrix, cmap="viridis") 

96 ax.set_title("Raw Attention Pattern") 

97 ax.axis("off") 

98 ``` 

99 

100 """ 

101 

102 def decorator( 

103 func: Callable[[AttentionMatrix, plt.Axes], None], 

104 fmt: str = fmt, 

105 ) -> AttentionMatrixFigureFunc: 

106 @functools.wraps(func) 

107 def wrapped(attn_matrix: AttentionMatrix, save_dir: Path) -> None: 

108 fig_path: Path = save_dir / f"{func.__name__}.{fmt}" 

109 

110 fig, ax = plt.subplots(figsize=(10, 10)) 

111 func(attn_matrix, ax) 

112 plt.tight_layout() 

113 plt.savefig(fig_path) 

114 plt.close(fig) 

115 

116 wrapped.figure_save_fmt = fmt # type: ignore[attr-defined] 

117 

118 return wrapped 

119 

120 if callable(func): 

121 # Handle no-arguments case 

122 return decorator(func) 

123 else: 

124 # Handle arguments case 

125 return decorator 

126 

127 

128def matplotlib_multifigure_saver( 

129 names: Sequence[str], 

130 fmt: str = MATPLOTLIB_FIGURE_FMT, 

131) -> Callable[ 

132 # decorator takes in function 

133 # which takes a matrix and a dictionary of axes corresponding to the names 

134 [Callable[[AttentionMatrix, dict[str, plt.Axes]], None]], 

135 # returns the decorated function 

136 AttentionMatrixFigureFunc, 

137]: 

138 """decorate a function such that it saves multiple figures, one for each name in `names` 

139 

140 # Parameters: 

141 - `names : Sequence[str]` 

142 the names of the figures to save 

143 - `fmt : str` 

144 format for saving matplotlib figures 

145 (defaults to `MATPLOTLIB_FIGURE_FMT`) 

146 

147 # Returns: 

148 - `Callable[[Callable[[AttentionMatrix, dict[str, plt.Axes]], None], AttentionMatrixFigureFunc]` 

149 the decorator, which will then be applied to the function 

150 we expect the decorated function to take an attention pattern, and a dict of axes corresponding to the names 

151 

152 """ 

153 

154 def decorator( 

155 func: Callable[[AttentionMatrix, dict[str, plt.Axes]], None], 

156 ) -> AttentionMatrixFigureFunc: 

157 func_name: str = func.__name__ 

158 

159 @functools.wraps(func) 

160 def wrapped(attn_matrix: AttentionMatrix, save_dir: Path) -> None: 

161 # set up axes and corresponding figures 

162 axes_dict: dict[str, plt.Axes] = {} 

163 figs_dict: dict[str, plt.Figure] = {} 

164 

165 # Create all figures and axes 

166 for name in names: 

167 fig, ax = plt.subplots(figsize=(10, 10)) 

168 axes_dict[name] = ax 

169 figs_dict[name] = fig 

170 

171 try: 

172 # Run the function to make plots 

173 func(attn_matrix, axes_dict) 

174 

175 # Save each figure 

176 for name, fig_ in figs_dict.items(): 

177 fig_path: Path = save_dir / f"{func_name}.{name}.{fmt}" 

178 # TYPING: error: Item "SubFigure" of "Figure | SubFigure" has no attribute "tight_layout" [union-attr] 

179 fig_.tight_layout() # type: ignore[union-attr] 

180 # TYPING: error: Item "SubFigure" of "Figure | SubFigure" has no attribute "savefig" [union-attr] 

181 fig_.savefig(fig_path) # type: ignore[union-attr] 

182 finally: 

183 # Always clean up figures, even if an error occurred 

184 for fig in figs_dict.values(): 

185 # TYPING: error: Argument 1 to "close" has incompatible type "Figure | SubFigure"; expected "int | str | Figure | Literal['all'] | None" [arg-type] 

186 plt.close(fig) # type: ignore[arg-type] 

187 

188 # it doesn't normally have this attribute, but we're adding it 

189 wrapped.figure_save_fmt = fmt # type: ignore[attr-defined] 

190 

191 return wrapped 

192 

193 return decorator 

194 

195 

196def matrix_to_image_preprocess( 

197 matrix: Matrix2D, 

198 normalize: bool = False, 

199 cmap: str | Colormap = "viridis", 

200 diverging_colormap: bool = False, 

201 normalize_min: float | None = None, 

202) -> Matrix2Drgb: 

203 """preprocess a 2D matrix into a plottable heatmap image 

204 

205 # Parameters: 

206 - `matrix : Matrix2D` 

207 input matrix 

208 - `normalize : bool` 

209 whether to normalize the matrix to range [0, 1] 

210 (defaults to `MATRIX_SAVE_NORMALIZE`) 

211 - `cmap : str|Colormap` 

212 the colormap to use for the matrix 

213 (defaults to `MATRIX_SAVE_CMAP`) 

214 - `diverging_colormap : bool` 

215 if True and using a diverging colormap, ensures 0 values map to the center of the colormap 

216 (defaults to False) 

217 - `normalize_min : float|None` 

218 if a float, then for `normalize=True` and `diverging_colormap=False`, the minimum value to normalize to (generally set this to zero?). 

219 if `None`, then the minimum value of the matrix is used. 

220 if `diverging_colormap=True` OR `normalize=False`, this **must** be `None`. 

221 (defaults to `None`) 

222 

223 # Returns: 

224 - `Matrix2Drgb` 

225 """ 

226 # check dims (2 is not that magic of a value here, hence noqa) 

227 assert matrix.ndim == 2, f"Matrix must be 2D, got {matrix.ndim = }" # noqa: PLR2004 

228 

229 # check matrix is not empty 

230 assert matrix.size > 0, "Matrix cannot be empty" 

231 

232 if normalize_min is not None: 

233 assert not diverging_colormap, ( 

234 "normalize_min cannot be used with diverging_colormap=True" 

235 ) 

236 assert normalize, "normalize_min cannot be used with normalize=False" 

237 

238 # Normalize the matrix to range [0, 1] 

239 normalized_matrix: Matrix2D 

240 if normalize: 

241 if diverging_colormap: 

242 # For diverging colormaps, we want to center around 0 

243 max_abs: float = max(abs(matrix.max()), abs(matrix.min())) 

244 normalized_matrix = (matrix / (2 * max_abs)) + 0.5 

245 else: 

246 max_val: float = matrix.max() 

247 min_val: float 

248 if normalize_min is not None: 

249 min_val = normalize_min 

250 assert min_val < max_val, "normalize_min must be less than matrix max" 

251 assert min_val >= matrix.min(), ( 

252 "normalize_min must less than matrix min" 

253 ) 

254 else: 

255 min_val = matrix.min() 

256 

257 normalized_matrix = (matrix - min_val) / (max_val - min_val) 

258 else: 

259 if diverging_colormap: 

260 assert matrix.min() >= -1 and matrix.max() <= 1, ( # noqa: PT018 

261 "For diverging colormaps without normalization, matrix values must be in range [-1, 1]" 

262 ) 

263 normalized_matrix = matrix 

264 else: 

265 assert matrix.min() >= 0 and matrix.max() <= 1, ( # noqa: PT018 

266 "Matrix values must be in range [0, 1], or normalize must be True" 

267 ) 

268 normalized_matrix = matrix 

269 

270 # get the colormap 

271 cmap_: Colormap 

272 if isinstance(cmap, str): 

273 cmap_ = mpl.colormaps[cmap] 

274 elif isinstance(cmap, Colormap): 

275 cmap_ = cmap 

276 else: 

277 msg = f"Invalid type for {cmap = }, {type(cmap) = }, must be str or Colormap" 

278 raise TypeError( 

279 msg, 

280 ) 

281 

282 # Apply the colormap 

283 rgb_matrix: Float[np.ndarray, "n m channels=3"] = ( 

284 cmap_(normalized_matrix)[:, :, :3] * 255 

285 ).astype(np.uint8) # Drop alpha channel 

286 

287 assert rgb_matrix.shape == ( 

288 matrix.shape[0], 

289 matrix.shape[1], 

290 3, 

291 ), f"Matrix after colormap must have 3 channels, got {rgb_matrix.shape = }" 

292 

293 return rgb_matrix 

294 

295 

296@overload 

297def matrix2drgb_to_png_bytes(matrix: Matrix2Drgb, buffer: None = None) -> bytes: ... 

298@overload 

299def matrix2drgb_to_png_bytes(matrix: Matrix2Drgb, buffer: io.BytesIO) -> None: ... 

300def matrix2drgb_to_png_bytes( 

301 matrix: Matrix2Drgb, 

302 buffer: io.BytesIO | None = None, 

303) -> bytes | None: 

304 """Convert a `Matrix2Drgb` to valid PNG bytes via PIL 

305 

306 - if `buffer` is provided, it will write the PNG bytes to the buffer and return `None` 

307 - if `buffer` is not provided, it will return the PNG bytes 

308 

309 # Parameters: 

310 - `matrix : Matrix2Drgb` 

311 - `buffer : io.BytesIO | None` 

312 (defaults to `None`, in which case it will return the PNG bytes) 

313 

314 # Returns: 

315 - `bytes|None` 

316 `bytes` if `buffer` is `None`, otherwise `None` 

317 """ 

318 pil_img: Image.Image = Image.fromarray(matrix, mode="RGB") 

319 if buffer is None: 

320 buffer = io.BytesIO() 

321 pil_img.save(buffer, format="PNG") 

322 buffer.seek(0) 

323 return buffer.read() 

324 else: 

325 pil_img.save(buffer, format="PNG") 

326 return None 

327 

328 

329def matrix_as_svg( 

330 matrix: Matrix2D, 

331 normalize: bool = MATRIX_SAVE_NORMALIZE, 

332 cmap: str | Colormap = MATRIX_SAVE_CMAP, 

333 diverging_colormap: bool = False, 

334 normalize_min: float | None = None, 

335) -> str: 

336 """quickly convert a 2D matrix to an SVG image, without matplotlib 

337 

338 # Parameters: 

339 - `matrix : Float[np.ndarray, 'n m']` 

340 a 2D matrix to convert to an SVG image 

341 - `normalize : bool` 

342 whether to normalize the matrix to range [0, 1]. if it's not in the range [0, 1], this must be `True` or it will raise an `AssertionError` 

343 (defaults to `False`) 

344 - `cmap : str` 

345 the colormap to use for the matrix -- will look up in `matplotlib.colormaps` if it's a string 

346 (defaults to `"viridis"`) 

347 - `diverging_colormap : bool` 

348 if True and using a diverging colormap, ensures 0 values map to the center of the colormap 

349 (defaults to False) 

350 - `normalize_min : float|None` 

351 if a float, then for `normalize=True` and `diverging_colormap=False`, the minimum value to normalize to (generally set this to zero?) 

352 if `None`, then the minimum value of the matrix is used 

353 if `diverging_colormap=True` OR `normalize=False`, this **must** be `None` 

354 (defaults to `None`) 

355 

356 

357 # Returns: 

358 - `str` 

359 the SVG content for the matrix 

360 """ 

361 # Get the dimensions of the matrix 

362 assert matrix.ndim == 2, f"Matrix must be 2D, got {matrix.shape = }" # noqa: PLR2004 

363 m, n = matrix.shape 

364 

365 # Preprocess the matrix into an RGB image 

366 matrix_rgb: Matrix2Drgb = matrix_to_image_preprocess( 

367 matrix, 

368 normalize=normalize, 

369 cmap=cmap, 

370 diverging_colormap=diverging_colormap, 

371 normalize_min=normalize_min, 

372 ) 

373 

374 # Convert the RGB image to PNG bytes 

375 image_data: bytes = matrix2drgb_to_png_bytes(matrix_rgb) 

376 

377 # Encode the PNG bytes as base64 

378 png_base64: str = base64.b64encode(image_data).decode("utf-8") 

379 

380 # Generate the SVG content 

381 svg_content: str = MATRIX_SAVE_SVG_TEMPLATE.format(m=m, n=n, png_base64=png_base64) 

382 

383 return svg_content 

384 

385 

386@overload # with keyword arguments, returns decorator 

387def save_matrix_wrapper( 

388 func: None = None, 

389 *args: tuple[()], 

390 fmt: MatrixSaveFormat = MATRIX_SAVE_FMT, 

391 normalize: bool = MATRIX_SAVE_NORMALIZE, 

392 cmap: str | Colormap = MATRIX_SAVE_CMAP, 

393 diverging_colormap: bool = False, 

394 normalize_min: float | None = None, 

395) -> Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc]: ... 

396@overload # without keyword arguments, returns decorated function 

397def save_matrix_wrapper( 

398 func: AttentionMatrixToMatrixFunc, 

399 *args: tuple[()], 

400 fmt: MatrixSaveFormat = MATRIX_SAVE_FMT, 

401 normalize: bool = MATRIX_SAVE_NORMALIZE, 

402 cmap: str | Colormap = MATRIX_SAVE_CMAP, 

403 diverging_colormap: bool = False, 

404 normalize_min: float | None = None, 

405) -> AttentionMatrixFigureFunc: ... 

406def save_matrix_wrapper( 

407 func: AttentionMatrixToMatrixFunc | None = None, 

408 *args, 

409 fmt: MatrixSaveFormat = MATRIX_SAVE_FMT, 

410 normalize: bool = MATRIX_SAVE_NORMALIZE, 

411 cmap: str | Colormap = MATRIX_SAVE_CMAP, 

412 diverging_colormap: bool = False, 

413 normalize_min: float | None = None, 

414) -> ( 

415 AttentionMatrixFigureFunc 

416 | Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc] 

417): 

418 """Decorator for functions that process an attention matrix and save it as an SVGZ image. 

419 

420 Can handle both argumentless usage and with arguments. 

421 

422 # Parameters: 

423 

424 - `func : AttentionMatrixToMatrixFunc|None` 

425 Either the function to decorate (in the no-arguments case) or `None` when used with arguments. 

426 - `fmt : MatrixSaveFormat, keyword-only` 

427 The format to save the matrix as. Defaults to `MATRIX_SAVE_FMT`. 

428 - `normalize : bool, keyword-only` 

429 Whether to normalize the matrix to range [0, 1]. Defaults to `False`. 

430 - `cmap : str, keyword-only` 

431 The colormap to use for the matrix. Defaults to `MATRIX_SVG_CMAP`. 

432 - `diverging_colormap : bool` 

433 if True and using a diverging colormap, ensures 0 values map to the center of the colormap 

434 (defaults to False) 

435 - `normalize_min : float|None` 

436 if a float, then for `normalize=True` and `diverging_colormap=False`, the minimum value to normalize to (generally set this to zero?) 

437 if `None`, then the minimum value of the matrix is used 

438 if `diverging_colormap=True` OR `normalize=False`, this **must** be `None` 

439 (defaults to `None`) 

440 

441 # Returns: 

442 

443 `AttentionMatrixFigureFunc|Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc]` 

444 

445 - `AttentionMatrixFigureFunc` if `func` is `AttentionMatrixToMatrixFunc` (no arguments case) 

446 - `Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc]` if `func` is `None` -- returns the decorator which will then be applied to the (with arguments case) 

447 

448 # Usage: 

449 

450 ```python 

451 @save_matrix_wrapper 

452 def identity_matrix(matrix): 

453 return matrix 

454 

455 @save_matrix_wrapper(normalize=True, fmt="png") 

456 def scale_matrix(matrix): 

457 return matrix * 2 

458 

459 @save_matrix_wrapper(normalize=True, cmap="plasma") 

460 def scale_matrix(matrix): 

461 return matrix * 2 

462 ``` 

463 

464 """ 

465 assert len(args) == 0, "This decorator only supports keyword arguments" 

466 

467 assert ( 

468 fmt in MatrixSaveFormat.__args__ # type: ignore[attr-defined] 

469 ), f"Invalid format {fmt = }, must be one of {MatrixSaveFormat.__args__}" # type: ignore[attr-defined] 

470 

471 def decorator( 

472 func: Callable[[AttentionMatrix], Matrix2D], 

473 ) -> AttentionMatrixFigureFunc: 

474 @functools.wraps(func) 

475 def wrapped(attn_matrix: AttentionMatrix, save_dir: Path) -> None: 

476 fig_path: Path = save_dir / f"{func.__name__}.{fmt}" 

477 processed_matrix: Matrix2D = func(attn_matrix) 

478 

479 if fmt == "png": 

480 processed_matrix_rgb: Matrix2Drgb = matrix_to_image_preprocess( 

481 processed_matrix, 

482 normalize=normalize, 

483 cmap=cmap, 

484 diverging_colormap=diverging_colormap, 

485 normalize_min=normalize_min, 

486 ) 

487 image_data: bytes = matrix2drgb_to_png_bytes(processed_matrix_rgb) 

488 fig_path.write_bytes(image_data) 

489 

490 else: 

491 svg_content: str = matrix_as_svg( 

492 processed_matrix, 

493 normalize=normalize, 

494 cmap=cmap, 

495 diverging_colormap=diverging_colormap, 

496 normalize_min=normalize_min, 

497 ) 

498 

499 if fmt == "svgz": 

500 with gzip.open(fig_path, "wt") as f: 

501 f.write(svg_content) 

502 

503 else: 

504 fig_path.write_text(svg_content, encoding="utf-8") 

505 

506 wrapped.figure_save_fmt = fmt # type: ignore[attr-defined] 

507 

508 return wrapped 

509 

510 if callable(func): 

511 # Handle no-arguments case 

512 return decorator(func) 

513 else: 

514 # Handle arguments case 

515 return decorator