Coverage for pattern_lens\figure_util.py: 88%

129 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-01-16 20:39 -0700

1"""implements a bunch of types, default values, and templates which are useful for figure functions 

2 

3notably, you can use the decorators `matplotlib_figure_saver`, `save_matrix_wrapper` to make your functions save figures 

4""" 

5 

6from pathlib import Path 

7from typing import Callable, Literal, overload, Union 

8import functools 

9import base64 

10import gzip 

11import io 

12 

13from PIL import Image 

14import numpy as np 

15from jaxtyping import Float, UInt8 

16import matplotlib 

17import matplotlib.pyplot as plt 

18from matplotlib.colors import Colormap 

19 

20from pattern_lens.consts import AttentionMatrix 

21 

22AttentionMatrixFigureFunc = Callable[[AttentionMatrix, Path], None] 

23"Type alias for a function that, given an attention matrix, saves a figure" 

24 

25Matrix2D = Float[np.ndarray, "n m"] 

26"Type alias for a 2D matrix (plottable)" 

27 

28Matrix2Drgb = UInt8[np.ndarray, "n m rgb=3"] 

29"Type alias for a 2D matrix with 3 channels (RGB)" 

30 

31AttentionMatrixToMatrixFunc = Callable[[AttentionMatrix], Matrix2D] 

32"Type alias for a function that, given an attention matrix, returns a 2D matrix" 

33 

34MATPLOTLIB_FIGURE_FMT: str = "svgz" 

35"format for saving matplotlib figures" 

36 

37MatrixSaveFormat = Literal["png", "svg", "svgz"] 

38"Type alias for the format to save a matrix as when saving raw matrix, not matplotlib figure" 

39 

40MATRIX_SAVE_NORMALIZE: bool = False 

41"default for whether to normalize the matrix to range [0, 1]" 

42 

43MATRIX_SAVE_CMAP: str = "viridis" 

44"default colormap for saving matrices" 

45 

46MATRIX_SAVE_FMT: MatrixSaveFormat = "svgz" 

47"default format for saving matrices" 

48 

49MATRIX_SAVE_SVG_TEMPLATE: str = """<svg xmlns="http://www.w3.org/2000/svg" width="{m}" height="{n}" viewBox="0 0 {m} {n}" image-rendering="pixelated"> <image href="data:image/png;base64,{png_base64}" width="{m}" height="{n}" /> </svg>""" 

50"template for saving an `n` by `m` matrix as an svg/svgz" 

51 

52 

53@overload # without keyword arguments, returns decorated function 

54def matplotlib_figure_saver( 

55 func: Callable[[AttentionMatrix, plt.Axes], None], 

56 *args, 

57 fmt: str = MATPLOTLIB_FIGURE_FMT, 

58) -> AttentionMatrixFigureFunc: ... 

59@overload # with keyword arguments, returns decorator 

60def matplotlib_figure_saver( 

61 func: None = None, 

62 *args, 

63 fmt: str = MATPLOTLIB_FIGURE_FMT, 

64) -> Callable[ 

65 [Callable[[AttentionMatrix, plt.Axes], None], str], AttentionMatrixFigureFunc 

66]: ... 

67def matplotlib_figure_saver( 

68 func: Callable[[AttentionMatrix, plt.Axes], None] | None = None, 

69 *args, 

70 fmt: str = MATPLOTLIB_FIGURE_FMT, 

71) -> Union[ 

72 AttentionMatrixFigureFunc, 

73 Callable[ 

74 [Callable[[AttentionMatrix, plt.Axes], None], str], AttentionMatrixFigureFunc 

75 ], 

76]: 

77 """decorator for functions which take an attention matrix and predefined `ax` object, making it save a figure 

78 

79 # Parameters: 

80 - `func : Callable[[AttentionMatrix, plt.Axes], None]` 

81 your function, which should take an attention matrix and predefined `ax` object 

82 - `fmt : str` 

83 format for saving matplotlib figures 

84 (defaults to `MATPLOTLIB_FIGURE_FMT`) 

85 

86 # Returns: 

87 - `AttentionMatrixFigureFunc` 

88 your function, after we wrap it to save a figure 

89 

90 # Usage: 

91 ```python 

92 @register_attn_figure_func 

93 @matplotlib_figure_saver 

94 def raw(attn_matrix: AttentionMatrix, ax: plt.Axes) -> None: 

95 ax.matshow(attn_matrix, cmap="viridis") 

96 ax.set_title("Raw Attention Pattern") 

97 ax.axis("off") 

98 ``` 

99 

100 """ 

101 

102 assert len(args) == 0, "This decorator only supports keyword arguments" 

103 

104 def decorator( 

105 func: Callable[[AttentionMatrix, plt.Axes], None], 

106 fmt: str = fmt, 

107 ) -> AttentionMatrixFigureFunc: 

108 @functools.wraps(func) 

109 def wrapped(attn_matrix: AttentionMatrix, save_dir: Path) -> None: 

110 fig_path: Path = save_dir / f"{func.__name__}.{fmt}" 

111 

112 fig, ax = plt.subplots(figsize=(10, 10)) 

113 func(attn_matrix, ax) 

114 plt.tight_layout() 

115 plt.savefig(fig_path) 

116 plt.close(fig) 

117 

118 wrapped.figure_save_fmt = fmt # type: ignore[attr-defined] 

119 

120 return wrapped 

121 

122 if callable(func): 

123 # Handle no-arguments case 

124 return decorator(func) 

125 else: 

126 # Handle arguments case 

127 return decorator 

128 

129 

130def matrix_to_image_preprocess( 

131 matrix: Matrix2D, 

132 normalize: bool = False, 

133 cmap: str | Colormap = "viridis", 

134 diverging_colormap: bool = False, 

135 normalize_min: float | None = None, 

136) -> Matrix2Drgb: 

137 """preprocess a 2D matrix into a plottable heatmap image 

138 

139 # Parameters: 

140 - `matrix : Matrix2D` 

141 input matrix 

142 - `normalize : bool` 

143 whether to normalize the matrix to range [0, 1] 

144 (defaults to `MATRIX_SAVE_NORMALIZE`) 

145 - `cmap : str|Colormap` 

146 the colormap to use for the matrix 

147 (defaults to `MATRIX_SAVE_CMAP`) 

148 - `diverging_colormap : bool` 

149 if True and using a diverging colormap, ensures 0 values map to the center of the colormap 

150 (defaults to False) 

151 - `normalize_min : float|None` 

152 if a float, then for `normalize=True` and `diverging_colormap=False`, the minimum value to normalize to (generally set this to zero?). 

153 if `None`, then the minimum value of the matrix is used. 

154 if `diverging_colormap=True` OR `normalize=False`, this **must** be `None`. 

155 (defaults to `None`) 

156 

157 # Returns: 

158 - `Matrix2Drgb` 

159 """ 

160 # check dims 

161 assert matrix.ndim == 2, f"Matrix must be 2D, got {matrix.ndim = }" 

162 

163 # check matrix is not empty 

164 assert matrix.size > 0, "Matrix cannot be empty" 

165 

166 if normalize_min is not None: 

167 assert ( 

168 not diverging_colormap 

169 ), "normalize_min cannot be used with diverging_colormap=True" 

170 assert normalize, "normalize_min cannot be used with normalize=False" 

171 

172 # Normalize the matrix to range [0, 1] 

173 normalized_matrix: Matrix2D 

174 if normalize: 

175 if diverging_colormap: 

176 # For diverging colormaps, we want to center around 0 

177 max_abs: float = max(abs(matrix.max()), abs(matrix.min())) 

178 normalized_matrix = (matrix / (2 * max_abs)) + 0.5 

179 else: 

180 max_val: float = matrix.max() 

181 min_val: float 

182 if normalize_min is not None: 

183 min_val = normalize_min 

184 assert min_val < max_val, "normalize_min must be less than matrix max" 

185 assert ( 

186 min_val >= matrix.min() 

187 ), "normalize_min must less than matrix min" 

188 else: 

189 min_val = matrix.min() 

190 

191 normalized_matrix = (matrix - min_val) / (max_val - min_val) 

192 else: 

193 if diverging_colormap: 

194 assert ( 

195 matrix.min() >= -1 and matrix.max() <= 1 

196 ), "For diverging colormaps without normalization, matrix values must be in range [-1, 1]" 

197 normalized_matrix = matrix 

198 else: 

199 assert ( 

200 matrix.min() >= 0 and matrix.max() <= 1 

201 ), "Matrix values must be in range [0, 1], or normalize must be True" 

202 normalized_matrix = matrix 

203 

204 # get the colormap 

205 cmap_: Colormap 

206 if isinstance(cmap, str): 

207 cmap_ = matplotlib.colormaps[cmap] 

208 elif isinstance(cmap, Colormap): 

209 cmap_ = cmap 

210 else: 

211 raise TypeError( 

212 f"Invalid type for {cmap = }, {type(cmap) = }, must be str or Colormap" 

213 ) 

214 

215 # Apply the colormap 

216 rgb_matrix: Float[np.ndarray, "n m channels=3"] = ( # noqa: F722 

217 cmap_(normalized_matrix)[:, :, :3] * 255 

218 ).astype(np.uint8) # Drop alpha channel 

219 

220 assert rgb_matrix.shape == ( 

221 matrix.shape[0], 

222 matrix.shape[1], 

223 3, 

224 ), f"Matrix after colormap must have 3 channels, got {rgb_matrix.shape = }" 

225 

226 return rgb_matrix 

227 

228 

229@overload 

230def matrix2drgb_to_png_bytes(matrix: Matrix2Drgb, buffer: None = None) -> bytes: ... 

231@overload 

232def matrix2drgb_to_png_bytes(matrix: Matrix2Drgb, buffer: io.BytesIO) -> None: ... 

233def matrix2drgb_to_png_bytes( 

234 matrix: Matrix2Drgb, buffer: io.BytesIO | None = None 

235) -> bytes | None: 

236 """Convert a `Matrix2Drgb` to valid PNG bytes via PIL 

237 

238 - if `buffer` is provided, it will write the PNG bytes to the buffer and return `None` 

239 - if `buffer` is not provided, it will return the PNG bytes 

240 

241 # Parameters: 

242 - `matrix : Matrix2Drgb` 

243 - `buffer : io.BytesIO | None` 

244 (defaults to `None`, in which case it will return the PNG bytes) 

245 

246 # Returns: 

247 - `bytes|None` 

248 `bytes` if `buffer` is `None`, otherwise `None` 

249 """ 

250 

251 pil_img: Image.Image = Image.fromarray(matrix, mode="RGB") 

252 if buffer is None: 

253 buffer = io.BytesIO() 

254 pil_img.save(buffer, format="PNG") 

255 buffer.seek(0) 

256 return buffer.read() 

257 else: 

258 pil_img.save(buffer, format="PNG") 

259 return None 

260 

261 

262def matrix_as_svg( 

263 matrix: Matrix2D, 

264 normalize: bool = MATRIX_SAVE_NORMALIZE, 

265 cmap: str | Colormap = MATRIX_SAVE_CMAP, 

266 diverging_colormap: bool = False, 

267 normalize_min: float | None = None, 

268) -> str: 

269 """quickly convert a 2D matrix to an SVG image, without matplotlib 

270 

271 # Parameters: 

272 - `matrix : Float[np.ndarray, 'n m']` 

273 a 2D matrix to convert to an SVG image 

274 - `normalize : bool` 

275 whether to normalize the matrix to range [0, 1]. if it's not in the range [0, 1], this must be `True` or it will raise an `AssertionError` 

276 (defaults to `False`) 

277 - `cmap : str` 

278 the colormap to use for the matrix -- will look up in `matplotlib.colormaps` if it's a string 

279 (defaults to `"viridis"`) 

280 - `diverging_colormap : bool` 

281 if True and using a diverging colormap, ensures 0 values map to the center of the colormap 

282 (defaults to False) 

283 - `normalize_min : float|None` 

284 if a float, then for `normalize=True` and `diverging_colormap=False`, the minimum value to normalize to (generally set this to zero?) 

285 if `None`, then the minimum value of the matrix is used 

286 if `diverging_colormap=True` OR `normalize=False`, this **must** be `None` 

287 (defaults to `None`) 

288 

289 

290 # Returns: 

291 - `str` 

292 the SVG content for the matrix 

293 """ 

294 # Get the dimensions of the matrix 

295 m, n = matrix.shape 

296 

297 # Preprocess the matrix into an RGB image 

298 matrix_rgb: Matrix2Drgb = matrix_to_image_preprocess( 

299 matrix, 

300 normalize=normalize, 

301 cmap=cmap, 

302 diverging_colormap=diverging_colormap, 

303 normalize_min=normalize_min, 

304 ) 

305 

306 # Convert the RGB image to PNG bytes 

307 image_data: bytes = matrix2drgb_to_png_bytes(matrix_rgb) 

308 

309 # Encode the PNG bytes as base64 

310 png_base64: str = base64.b64encode(image_data).decode("utf-8") 

311 

312 # Generate the SVG content 

313 svg_content: str = MATRIX_SAVE_SVG_TEMPLATE.format(m=m, n=n, png_base64=png_base64) 

314 

315 return svg_content 

316 

317 

318@overload # with keyword arguments, returns decorator 

319def save_matrix_wrapper( 

320 func: None = None, 

321 *args, 

322 fmt: MatrixSaveFormat = MATRIX_SAVE_FMT, 

323 normalize: bool = MATRIX_SAVE_NORMALIZE, 

324 cmap: str | Colormap = MATRIX_SAVE_CMAP, 

325 diverging_colormap: bool = False, 

326 normalize_min: float | None = None, 

327) -> Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc]: ... 

328@overload # without keyword arguments, returns decorated function 

329def save_matrix_wrapper( 

330 func: AttentionMatrixToMatrixFunc, 

331 *args, 

332 fmt: MatrixSaveFormat = MATRIX_SAVE_FMT, 

333 normalize: bool = MATRIX_SAVE_NORMALIZE, 

334 cmap: str | Colormap = MATRIX_SAVE_CMAP, 

335 diverging_colormap: bool = False, 

336 normalize_min: float | None = None, 

337) -> AttentionMatrixFigureFunc: ... 

338def save_matrix_wrapper( 

339 func: AttentionMatrixToMatrixFunc | None = None, 

340 *args, 

341 fmt: MatrixSaveFormat = MATRIX_SAVE_FMT, 

342 normalize: bool = MATRIX_SAVE_NORMALIZE, 

343 cmap: str | Colormap = MATRIX_SAVE_CMAP, 

344 diverging_colormap: bool = False, 

345 normalize_min: float | None = None, 

346) -> ( 

347 AttentionMatrixFigureFunc 

348 | Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc] 

349): 

350 """ 

351 Decorator for functions that process an attention matrix and save it as an SVGZ image. 

352 Can handle both argumentless usage and with arguments. 

353 

354 # Parameters: 

355 

356 - `func : AttentionMatrixToMatrixFunc|None` 

357 Either the function to decorate (in the no-arguments case) or `None` when used with arguments. 

358 - `fmt : MatrixSaveFormat, keyword-only` 

359 The format to save the matrix as. Defaults to `MATRIX_SAVE_FMT`. 

360 - `normalize : bool, keyword-only` 

361 Whether to normalize the matrix to range [0, 1]. Defaults to `False`. 

362 - `cmap : str, keyword-only` 

363 The colormap to use for the matrix. Defaults to `MATRIX_SVG_CMAP`. 

364 - `diverging_colormap : bool` 

365 if True and using a diverging colormap, ensures 0 values map to the center of the colormap 

366 (defaults to False) 

367 - `normalize_min : float|None` 

368 if a float, then for `normalize=True` and `diverging_colormap=False`, the minimum value to normalize to (generally set this to zero?) 

369 if `None`, then the minimum value of the matrix is used 

370 if `diverging_colormap=True` OR `normalize=False`, this **must** be `None` 

371 (defaults to `None`) 

372 

373 # Returns: 

374 

375 `AttentionMatrixFigureFunc|Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc]` 

376 

377 - `AttentionMatrixFigureFunc` if `func` is `AttentionMatrixToMatrixFunc` (no arguments case) 

378 - `Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc]` if `func` is `None` -- returns the decorator which will then be applied to the (with arguments case) 

379 

380 # Usage: 

381 

382 ```python 

383 @save_matrix_wrapper 

384 def identity_matrix(matrix): 

385 return matrix 

386 

387 @save_matrix_wrapper(normalize=True, fmt="png") 

388 def scale_matrix(matrix): 

389 return matrix * 2 

390 

391 @save_matrix_wrapper(normalize=True, cmap="plasma") 

392 def scale_matrix(matrix): 

393 return matrix * 2 

394 

395 ``` 

396 """ 

397 

398 assert len(args) == 0, "This decorator only supports keyword arguments" 

399 

400 assert ( 

401 fmt in MatrixSaveFormat.__args__ # type: ignore[attr-defined] 

402 ), f"Invalid format {fmt = }, must be one of {MatrixSaveFormat.__args__}" # type: ignore[attr-defined] 

403 

404 def decorator( 

405 func: Callable[[AttentionMatrix], Matrix2D], 

406 ) -> AttentionMatrixFigureFunc: 

407 @functools.wraps(func) 

408 def wrapped(attn_matrix: AttentionMatrix, save_dir: Path) -> None: 

409 fig_path: Path = save_dir / f"{func.__name__}.{fmt}" 

410 processed_matrix: Matrix2D = func(attn_matrix) 

411 

412 if fmt == "png": 

413 processed_matrix_rgb: Matrix2Drgb = matrix_to_image_preprocess( 

414 processed_matrix, 

415 normalize=normalize, 

416 cmap=cmap, 

417 diverging_colormap=diverging_colormap, 

418 normalize_min=normalize_min, 

419 ) 

420 image_data: bytes = matrix2drgb_to_png_bytes(processed_matrix_rgb) 

421 fig_path.write_bytes(image_data) 

422 

423 else: 

424 svg_content: str = matrix_as_svg( 

425 processed_matrix, 

426 normalize=normalize, 

427 cmap=cmap, 

428 diverging_colormap=diverging_colormap, 

429 normalize_min=normalize_min, 

430 ) 

431 

432 if fmt == "svgz": 

433 with gzip.open(fig_path, "wt") as f: 

434 f.write(svg_content) 

435 

436 else: 

437 fig_path.write_text(svg_content, encoding="utf-8") 

438 

439 wrapped.figure_save_fmt = fmt # type: ignore[attr-defined] 

440 

441 return wrapped 

442 

443 if callable(func): 

444 # Handle no-arguments case 

445 return decorator(func) 

446 else: 

447 # Handle arguments case 

448 return decorator