Coverage for pattern_lens / figure_util.py: 89%

143 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-22 18:15 -0700

1"""implements a bunch of types, default values, and templates which are useful for figure functions 

2 

3notably, you can use the decorators `matplotlib_figure_saver`, `save_matrix_wrapper` to make your functions save figures 

4""" 

5 

6import base64 

7import functools 

8import gzip 

9import io 

10from collections.abc import Callable, Sequence 

11from pathlib import Path 

12from typing import Literal, overload 

13 

14import matplotlib as mpl 

15import matplotlib.pyplot as plt 

16import numpy as np 

17from jaxtyping import Float, UInt8 

18from matplotlib.colors import Colormap 

19from PIL import Image 

20 

21from pattern_lens.consts import AttentionMatrix 

22 

23AttentionMatrixFigureFunc = Callable[[AttentionMatrix, Path], None] 

24"Type alias for a function that, given an attention matrix, saves one or more figures" 

25 

26Matrix2D = Float[np.ndarray, "n m"] 

27"Type alias for a 2D matrix (plottable)" 

28 

29Matrix2Drgb = UInt8[np.ndarray, "n m rgb=3"] 

30"Type alias for a 2D matrix with 3 channels (RGB)" 

31 

32AttentionMatrixToMatrixFunc = Callable[[AttentionMatrix], Matrix2D] 

33"Type alias for a function that, given an attention matrix, returns a 2D matrix" 

34 

35MATPLOTLIB_FIGURE_FMT: str = "svgz" 

36"format for saving matplotlib figures" 

37 

38MatrixSaveFormat = Literal["png", "svg", "svgz"] 

39"Type alias for the format to save a matrix as when saving raw matrix, not matplotlib figure" 

40 

41MATRIX_SAVE_NORMALIZE: bool = False 

42"default for whether to normalize the matrix to range [0, 1]" 

43 

44MATRIX_SAVE_CMAP: str = "viridis" 

45"default colormap for saving matrices" 

46 

47MATRIX_SAVE_FMT: MatrixSaveFormat = "svgz" 

48"default format for saving matrices" 

49 

50MATRIX_SAVE_SVG_TEMPLATE: str = """<svg xmlns="http://www.w3.org/2000/svg" width="{m}" height="{n}" viewBox="0 0 {m} {n}" image-rendering="pixelated"> <image href="data:image/png;base64,{png_base64}" width="{m}" height="{n}" /> </svg>""" 

51"template for saving an `n` by `m` matrix as an svg/svgz" 

52 

53 

54# TYPING: mypy hates it when we dont pass func=None or None as the first arg 

55@overload # without keyword arguments, returns decorated function 

56def matplotlib_figure_saver( 

57 func: Callable[[AttentionMatrix, plt.Axes], None], 

58) -> AttentionMatrixFigureFunc: ... 

59@overload # with keyword arguments, returns decorator 

60def matplotlib_figure_saver( 

61 func: None = None, 

62 fmt: str = MATPLOTLIB_FIGURE_FMT, 

63) -> Callable[ 

64 [Callable[[AttentionMatrix, plt.Axes], None]], 

65 AttentionMatrixFigureFunc, 

66]: ... 

67def matplotlib_figure_saver( 

68 func: Callable[[AttentionMatrix, plt.Axes], None] | None = None, 

69 fmt: str = MATPLOTLIB_FIGURE_FMT, 

70) -> ( 

71 AttentionMatrixFigureFunc 

72 | Callable[ 

73 [Callable[[AttentionMatrix, plt.Axes], None]], 

74 AttentionMatrixFigureFunc, 

75 ] 

76): 

77 """decorator for functions which take an attention matrix and predefined `ax` object, making it save a figure 

78 

79 # Parameters: 

80 - `func : Callable[[AttentionMatrix, plt.Axes], None]` 

81 your function, which should take an attention matrix and predefined `ax` object 

82 - `fmt : str` 

83 format for saving matplotlib figures 

84 (defaults to `MATPLOTLIB_FIGURE_FMT`) 

85 

86 # Returns: 

87 - `AttentionMatrixFigureFunc` 

88 your function, after we wrap it to save a figure 

89 

90 # Usage: 

91 ```python 

92 @register_attn_figure_func 

93 @matplotlib_figure_saver 

94 def raw(attn_matrix: AttentionMatrix, ax: plt.Axes) -> None: 

95 ax.matshow(attn_matrix, cmap="viridis") 

96 ax.set_title("Raw Attention Pattern") 

97 ax.axis("off") 

98 ``` 

99 

100 """ 

101 

102 def decorator( 

103 func: Callable[[AttentionMatrix, plt.Axes], None], 

104 fmt: str = fmt, 

105 ) -> AttentionMatrixFigureFunc: 

106 @functools.wraps(func) 

107 def wrapped(attn_matrix: AttentionMatrix, save_dir: Path) -> None: 

108 fig_path: Path = ( 

109 save_dir / f"{getattr(func, '__name__', '<unknown>')}.{fmt}" 

110 ) 

111 

112 fig, ax = plt.subplots(figsize=(10, 10)) 

113 func(attn_matrix, ax) 

114 plt.tight_layout() 

115 plt.savefig(fig_path) 

116 plt.close(fig) 

117 

118 wrapped.figure_save_fmt = fmt # type: ignore[attr-defined] 

119 

120 return wrapped 

121 

122 if callable(func): 

123 # Handle no-arguments case 

124 return decorator(func) 

125 else: 

126 # Handle arguments case 

127 return decorator 

128 

129 

130def matplotlib_multifigure_saver( 

131 names: Sequence[str], 

132 fmt: str = MATPLOTLIB_FIGURE_FMT, 

133) -> Callable[ 

134 # decorator takes in function 

135 # which takes a matrix and a dictionary of axes corresponding to the names 

136 [Callable[[AttentionMatrix, dict[str, plt.Axes]], None]], 

137 # returns the decorated function 

138 AttentionMatrixFigureFunc, 

139]: 

140 """decorate a function such that it saves multiple figures, one for each name in `names` 

141 

142 # Parameters: 

143 - `names : Sequence[str]` 

144 the names of the figures to save 

145 - `fmt : str` 

146 format for saving matplotlib figures 

147 (defaults to `MATPLOTLIB_FIGURE_FMT`) 

148 

149 # Returns: 

150 - `Callable[[Callable[[AttentionMatrix, dict[str, plt.Axes]], None], AttentionMatrixFigureFunc]` 

151 the decorator, which will then be applied to the function 

152 we expect the decorated function to take an attention pattern, and a dict of axes corresponding to the names 

153 

154 """ 

155 

156 def decorator( 

157 func: Callable[[AttentionMatrix, dict[str, plt.Axes]], None], 

158 ) -> AttentionMatrixFigureFunc: 

159 func_name: str = getattr(func, "__name__", "<unknown>") 

160 

161 @functools.wraps(func) 

162 def wrapped(attn_matrix: AttentionMatrix, save_dir: Path) -> None: 

163 # set up axes and corresponding figures 

164 axes_dict: dict[str, plt.Axes] = {} 

165 figs_dict: dict[str, plt.Figure] = {} 

166 

167 # Create all figures and axes 

168 for name in names: 

169 fig, ax = plt.subplots(figsize=(10, 10)) 

170 axes_dict[name] = ax 

171 figs_dict[name] = fig 

172 

173 try: 

174 # Run the function to make plots 

175 func(attn_matrix, axes_dict) 

176 

177 # Save each figure 

178 for name, fig_ in figs_dict.items(): 

179 fig_path: Path = save_dir / f"{func_name}.{name}.{fmt}" 

180 # TYPING: error: Item "SubFigure" of "Figure | SubFigure" has no attribute "tight_layout" [union-attr] 

181 fig_.tight_layout() # type: ignore[union-attr] 

182 # TYPING: error: Item "SubFigure" of "Figure | SubFigure" has no attribute "savefig" [union-attr] 

183 fig_.savefig(fig_path) # type: ignore[union-attr] 

184 finally: 

185 # Always clean up figures, even if an error occurred 

186 for fig in figs_dict.values(): 

187 # TYPING: error: Argument 1 to "close" has incompatible type "Figure | SubFigure"; expected "int | str | Figure | Literal['all'] | None" [arg-type] 

188 plt.close(fig) # type: ignore[arg-type] 

189 

190 # it doesn't normally have this attribute, but we're adding it 

191 wrapped.figure_save_fmt = fmt # type: ignore[attr-defined] 

192 

193 return wrapped 

194 

195 return decorator 

196 

197 

198def matrix_to_image_preprocess( 

199 matrix: Matrix2D, 

200 normalize: bool = False, 

201 cmap: str | Colormap = "viridis", 

202 diverging_colormap: bool = False, 

203 normalize_min: float | None = None, 

204) -> Matrix2Drgb: 

205 """preprocess a 2D matrix into a plottable heatmap image 

206 

207 # Parameters: 

208 - `matrix : Matrix2D` 

209 input matrix 

210 - `normalize : bool` 

211 whether to normalize the matrix to range [0, 1] 

212 (defaults to `MATRIX_SAVE_NORMALIZE`) 

213 - `cmap : str|Colormap` 

214 the colormap to use for the matrix 

215 (defaults to `MATRIX_SAVE_CMAP`) 

216 - `diverging_colormap : bool` 

217 if True and using a diverging colormap, ensures 0 values map to the center of the colormap 

218 (defaults to False) 

219 - `normalize_min : float|None` 

220 if a float, then for `normalize=True` and `diverging_colormap=False`, the minimum value to normalize to (generally set this to zero?). 

221 if `None`, then the minimum value of the matrix is used. 

222 if `diverging_colormap=True` OR `normalize=False`, this **must** be `None`. 

223 (defaults to `None`) 

224 

225 # Returns: 

226 - `Matrix2Drgb` 

227 """ 

228 # check dims (2 is not that magic of a value here, hence noqa) 

229 assert matrix.ndim == 2, f"Matrix must be 2D, got {matrix.ndim = }" # noqa: PLR2004 

230 

231 # check matrix is not empty 

232 assert matrix.size > 0, "Matrix cannot be empty" 

233 

234 if normalize_min is not None: 

235 assert not diverging_colormap, ( 

236 "normalize_min cannot be used with diverging_colormap=True" 

237 ) 

238 assert normalize, "normalize_min cannot be used with normalize=False" 

239 

240 # Normalize the matrix to range [0, 1] 

241 normalized_matrix: Matrix2D 

242 if normalize: 

243 if diverging_colormap: 

244 # For diverging colormaps, we want to center around 0 

245 max_abs: float = max(abs(matrix.max()), abs(matrix.min())) 

246 normalized_matrix = (matrix / (2 * max_abs)) + 0.5 

247 else: 

248 max_val: float = matrix.max() 

249 min_val: float 

250 if normalize_min is not None: 

251 min_val = normalize_min 

252 assert min_val < max_val, "normalize_min must be less than matrix max" 

253 assert min_val >= matrix.min(), ( 

254 "normalize_min must less than matrix min" 

255 ) 

256 else: 

257 min_val = matrix.min() 

258 

259 normalized_matrix = (matrix - min_val) / (max_val - min_val) 

260 else: 

261 if diverging_colormap: 

262 assert matrix.min() >= -1 and matrix.max() <= 1, ( # noqa: PT018 

263 "For diverging colormaps without normalization, matrix values must be in range [-1, 1]" 

264 ) 

265 normalized_matrix = matrix 

266 else: 

267 assert matrix.min() >= 0 and matrix.max() <= 1, ( # noqa: PT018 

268 "Matrix values must be in range [0, 1], or normalize must be True" 

269 ) 

270 normalized_matrix = matrix 

271 

272 # get the colormap 

273 cmap_: Colormap 

274 if isinstance(cmap, str): 

275 cmap_ = mpl.colormaps[cmap] 

276 elif isinstance(cmap, Colormap): 

277 cmap_ = cmap 

278 else: 

279 msg = f"Invalid type for {cmap = }, {type(cmap) = }, must be str or Colormap" 

280 raise TypeError( 

281 msg, 

282 ) 

283 

284 # Apply the colormap 

285 rgb_matrix: Float[np.ndarray, "n m channels=3"] = ( 

286 cmap_(normalized_matrix)[:, :, :3] * 255 

287 ).astype(np.uint8) # Drop alpha channel 

288 

289 assert rgb_matrix.shape == ( 

290 matrix.shape[0], 

291 matrix.shape[1], 

292 3, 

293 ), f"Matrix after colormap must have 3 channels, got {rgb_matrix.shape = }" 

294 

295 return rgb_matrix 

296 

297 

298@overload 

299def matrix2drgb_to_png_bytes(matrix: Matrix2Drgb, buffer: None = None) -> bytes: ... 

300@overload 

301def matrix2drgb_to_png_bytes(matrix: Matrix2Drgb, buffer: io.BytesIO) -> None: ... 

302def matrix2drgb_to_png_bytes( 

303 matrix: Matrix2Drgb, 

304 buffer: io.BytesIO | None = None, 

305) -> bytes | None: 

306 """Convert a `Matrix2Drgb` to valid PNG bytes via PIL 

307 

308 - if `buffer` is provided, it will write the PNG bytes to the buffer and return `None` 

309 - if `buffer` is not provided, it will return the PNG bytes 

310 

311 # Parameters: 

312 - `matrix : Matrix2Drgb` 

313 - `buffer : io.BytesIO | None` 

314 (defaults to `None`, in which case it will return the PNG bytes) 

315 

316 # Returns: 

317 - `bytes|None` 

318 `bytes` if `buffer` is `None`, otherwise `None` 

319 """ 

320 pil_img: Image.Image = Image.fromarray(matrix, mode="RGB") 

321 if buffer is None: 

322 buffer = io.BytesIO() 

323 pil_img.save(buffer, format="PNG") 

324 buffer.seek(0) 

325 return buffer.read() 

326 else: 

327 pil_img.save(buffer, format="PNG") 

328 return None 

329 

330 

331def matrix_as_svg( 

332 matrix: Matrix2D, 

333 normalize: bool = MATRIX_SAVE_NORMALIZE, 

334 cmap: str | Colormap = MATRIX_SAVE_CMAP, 

335 diverging_colormap: bool = False, 

336 normalize_min: float | None = None, 

337) -> str: 

338 """quickly convert a 2D matrix to an SVG image, without matplotlib 

339 

340 # Parameters: 

341 - `matrix : Float[np.ndarray, 'n m']` 

342 a 2D matrix to convert to an SVG image 

343 - `normalize : bool` 

344 whether to normalize the matrix to range [0, 1]. if it's not in the range [0, 1], this must be `True` or it will raise an `AssertionError` 

345 (defaults to `False`) 

346 - `cmap : str` 

347 the colormap to use for the matrix -- will look up in `matplotlib.colormaps` if it's a string 

348 (defaults to `"viridis"`) 

349 - `diverging_colormap : bool` 

350 if True and using a diverging colormap, ensures 0 values map to the center of the colormap 

351 (defaults to False) 

352 - `normalize_min : float|None` 

353 if a float, then for `normalize=True` and `diverging_colormap=False`, the minimum value to normalize to (generally set this to zero?) 

354 if `None`, then the minimum value of the matrix is used 

355 if `diverging_colormap=True` OR `normalize=False`, this **must** be `None` 

356 (defaults to `None`) 

357 

358 

359 # Returns: 

360 - `str` 

361 the SVG content for the matrix 

362 """ 

363 # Get the dimensions of the matrix 

364 assert matrix.ndim == 2, f"Matrix must be 2D, got {matrix.shape = }" # noqa: PLR2004 

365 m, n = matrix.shape 

366 

367 # Preprocess the matrix into an RGB image 

368 matrix_rgb: Matrix2Drgb = matrix_to_image_preprocess( 

369 matrix, 

370 normalize=normalize, 

371 cmap=cmap, 

372 diverging_colormap=diverging_colormap, 

373 normalize_min=normalize_min, 

374 ) 

375 

376 # Convert the RGB image to PNG bytes 

377 image_data: bytes = matrix2drgb_to_png_bytes(matrix_rgb) 

378 

379 # Encode the PNG bytes as base64 

380 png_base64: str = base64.b64encode(image_data).decode("utf-8") 

381 

382 # Generate the SVG content 

383 svg_content: str = MATRIX_SAVE_SVG_TEMPLATE.format(m=m, n=n, png_base64=png_base64) 

384 

385 return svg_content 

386 

387 

388@overload # with keyword arguments, returns decorator 

389def save_matrix_wrapper( 

390 func: None = None, 

391 *args: tuple[()], 

392 fmt: MatrixSaveFormat = MATRIX_SAVE_FMT, 

393 normalize: bool = MATRIX_SAVE_NORMALIZE, 

394 cmap: str | Colormap = MATRIX_SAVE_CMAP, 

395 diverging_colormap: bool = False, 

396 normalize_min: float | None = None, 

397) -> Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc]: ... 

398@overload # without keyword arguments, returns decorated function 

399def save_matrix_wrapper( 

400 func: AttentionMatrixToMatrixFunc, 

401 *args: tuple[()], 

402 fmt: MatrixSaveFormat = MATRIX_SAVE_FMT, 

403 normalize: bool = MATRIX_SAVE_NORMALIZE, 

404 cmap: str | Colormap = MATRIX_SAVE_CMAP, 

405 diverging_colormap: bool = False, 

406 normalize_min: float | None = None, 

407) -> AttentionMatrixFigureFunc: ... 

408def save_matrix_wrapper( 

409 func: AttentionMatrixToMatrixFunc | None = None, 

410 *args, 

411 fmt: MatrixSaveFormat = MATRIX_SAVE_FMT, 

412 normalize: bool = MATRIX_SAVE_NORMALIZE, 

413 cmap: str | Colormap = MATRIX_SAVE_CMAP, 

414 diverging_colormap: bool = False, 

415 normalize_min: float | None = None, 

416) -> ( 

417 AttentionMatrixFigureFunc 

418 | Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc] 

419): 

420 """Decorator for functions that process an attention matrix and save it as an SVGZ image. 

421 

422 Can handle both argumentless usage and with arguments. 

423 

424 # Parameters: 

425 

426 - `func : AttentionMatrixToMatrixFunc|None` 

427 Either the function to decorate (in the no-arguments case) or `None` when used with arguments. 

428 - `fmt : MatrixSaveFormat, keyword-only` 

429 The format to save the matrix as. Defaults to `MATRIX_SAVE_FMT`. 

430 - `normalize : bool, keyword-only` 

431 Whether to normalize the matrix to range [0, 1]. Defaults to `False`. 

432 - `cmap : str, keyword-only` 

433 The colormap to use for the matrix. Defaults to `MATRIX_SVG_CMAP`. 

434 - `diverging_colormap : bool` 

435 if True and using a diverging colormap, ensures 0 values map to the center of the colormap 

436 (defaults to False) 

437 - `normalize_min : float|None` 

438 if a float, then for `normalize=True` and `diverging_colormap=False`, the minimum value to normalize to (generally set this to zero?) 

439 if `None`, then the minimum value of the matrix is used 

440 if `diverging_colormap=True` OR `normalize=False`, this **must** be `None` 

441 (defaults to `None`) 

442 

443 # Returns: 

444 

445 `AttentionMatrixFigureFunc|Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc]` 

446 

447 - `AttentionMatrixFigureFunc` if `func` is `AttentionMatrixToMatrixFunc` (no arguments case) 

448 - `Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc]` if `func` is `None` -- returns the decorator which will then be applied to the (with arguments case) 

449 

450 # Usage: 

451 

452 ```python 

453 @save_matrix_wrapper 

454 def identity_matrix(matrix): 

455 return matrix 

456 

457 @save_matrix_wrapper(normalize=True, fmt="png") 

458 def scale_matrix(matrix): 

459 return matrix * 2 

460 

461 @save_matrix_wrapper(normalize=True, cmap="plasma") 

462 def scale_matrix(matrix): 

463 return matrix * 2 

464 ``` 

465 

466 """ 

467 assert len(args) == 0, "This decorator only supports keyword arguments" 

468 

469 assert ( 

470 fmt in MatrixSaveFormat.__args__ # type: ignore[attr-defined] 

471 ), f"Invalid format {fmt = }, must be one of {MatrixSaveFormat.__args__}" # type: ignore[attr-defined] 

472 

473 def decorator( 

474 func: Callable[[AttentionMatrix], Matrix2D], 

475 ) -> AttentionMatrixFigureFunc: 

476 @functools.wraps(func) 

477 def wrapped(attn_matrix: AttentionMatrix, save_dir: Path) -> None: 

478 fig_path: Path = ( 

479 save_dir / f"{getattr(func, '__name__', '<unknown>')}.{fmt}" 

480 ) 

481 processed_matrix: Matrix2D = func(attn_matrix) 

482 

483 if fmt == "png": 

484 processed_matrix_rgb: Matrix2Drgb = matrix_to_image_preprocess( 

485 processed_matrix, 

486 normalize=normalize, 

487 cmap=cmap, 

488 diverging_colormap=diverging_colormap, 

489 normalize_min=normalize_min, 

490 ) 

491 image_data: bytes = matrix2drgb_to_png_bytes(processed_matrix_rgb) 

492 fig_path.write_bytes(image_data) 

493 

494 else: 

495 svg_content: str = matrix_as_svg( 

496 processed_matrix, 

497 normalize=normalize, 

498 cmap=cmap, 

499 diverging_colormap=diverging_colormap, 

500 normalize_min=normalize_min, 

501 ) 

502 

503 if fmt == "svgz": 

504 with gzip.open(fig_path, "wt") as f: 

505 f.write(svg_content) 

506 

507 else: 

508 fig_path.write_text(svg_content, encoding="utf-8") 

509 

510 wrapped.figure_save_fmt = fmt # type: ignore[attr-defined] 

511 

512 return wrapped 

513 

514 if callable(func): 

515 # Handle no-arguments case 

516 return decorator(func) 

517 else: 

518 # Handle arguments case 

519 return decorator