Coverage for pattern_lens/figure_util.py: 90%
153 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-04-06 15:09 -0600
« prev ^ index » next coverage.py v7.6.12, created at 2025-04-06 15:09 -0600
1"""implements a bunch of types, default values, and templates which are useful for figure functions
3notably, you can use the decorators `matplotlib_figure_saver`, `save_matrix_wrapper` to make your functions save figures
4"""
6import base64
7import functools
8import gzip
9import io
10from collections.abc import Callable, Sequence
11from pathlib import Path
12from typing import Literal, overload
14import matplotlib as mpl
15import matplotlib.pyplot as plt
16import numpy as np
17from jaxtyping import Float, UInt8
18from matplotlib.colors import Colormap
19from PIL import Image
21from pattern_lens.consts import AttentionMatrix
23AttentionMatrixFigureFunc = Callable[[AttentionMatrix, Path], None]
24"Type alias for a function that, given an attention matrix, saves one or more figures"
26Matrix2D = Float[np.ndarray, "n m"]
27"Type alias for a 2D matrix (plottable)"
29Matrix2Drgb = UInt8[np.ndarray, "n m rgb=3"]
30"Type alias for a 2D matrix with 3 channels (RGB)"
32AttentionMatrixToMatrixFunc = Callable[[AttentionMatrix], Matrix2D]
33"Type alias for a function that, given an attention matrix, returns a 2D matrix"
35MATPLOTLIB_FIGURE_FMT: str = "svgz"
36"format for saving matplotlib figures"
38MatrixSaveFormat = Literal["png", "svg", "svgz"]
39"Type alias for the format to save a matrix as when saving raw matrix, not matplotlib figure"
41MATRIX_SAVE_NORMALIZE: bool = False
42"default for whether to normalize the matrix to range [0, 1]"
44MATRIX_SAVE_CMAP: str = "viridis"
45"default colormap for saving matrices"
47MATRIX_SAVE_FMT: MatrixSaveFormat = "svgz"
48"default format for saving matrices"
50MATRIX_SAVE_SVG_TEMPLATE: str = """<svg xmlns="http://www.w3.org/2000/svg" width="{m}" height="{n}" viewBox="0 0 {m} {n}" image-rendering="pixelated"> <image href="data:image/png;base64,{png_base64}" width="{m}" height="{n}" /> </svg>"""
51"template for saving an `n` by `m` matrix as an svg/svgz"
54# TYPING: mypy hates it when we dont pass func=None or None as the first arg
55@overload # without keyword arguments, returns decorated function
56def matplotlib_figure_saver(
57 func: Callable[[AttentionMatrix, plt.Axes], None],
58) -> AttentionMatrixFigureFunc: ...
59@overload # with keyword arguments, returns decorator
60def matplotlib_figure_saver(
61 func: None = None,
62 fmt: str = MATPLOTLIB_FIGURE_FMT,
63) -> Callable[
64 [Callable[[AttentionMatrix, plt.Axes], None], str],
65 AttentionMatrixFigureFunc,
66]: ...
67def matplotlib_figure_saver(
68 func: Callable[[AttentionMatrix, plt.Axes], None] | None = None,
69 fmt: str = MATPLOTLIB_FIGURE_FMT,
70) -> (
71 AttentionMatrixFigureFunc
72 | Callable[
73 [Callable[[AttentionMatrix, plt.Axes], None], str],
74 AttentionMatrixFigureFunc,
75 ]
76):
77 """decorator for functions which take an attention matrix and predefined `ax` object, making it save a figure
79 # Parameters:
80 - `func : Callable[[AttentionMatrix, plt.Axes], None]`
81 your function, which should take an attention matrix and predefined `ax` object
82 - `fmt : str`
83 format for saving matplotlib figures
84 (defaults to `MATPLOTLIB_FIGURE_FMT`)
86 # Returns:
87 - `AttentionMatrixFigureFunc`
88 your function, after we wrap it to save a figure
90 # Usage:
91 ```python
92 @register_attn_figure_func
93 @matplotlib_figure_saver
94 def raw(attn_matrix: AttentionMatrix, ax: plt.Axes) -> None:
95 ax.matshow(attn_matrix, cmap="viridis")
96 ax.set_title("Raw Attention Pattern")
97 ax.axis("off")
98 ```
100 """
102 def decorator(
103 func: Callable[[AttentionMatrix, plt.Axes], None],
104 fmt: str = fmt,
105 ) -> AttentionMatrixFigureFunc:
106 @functools.wraps(func)
107 def wrapped(attn_matrix: AttentionMatrix, save_dir: Path) -> None:
108 fig_path: Path = save_dir / f"{func.__name__}.{fmt}"
110 fig, ax = plt.subplots(figsize=(10, 10))
111 func(attn_matrix, ax)
112 plt.tight_layout()
113 plt.savefig(fig_path)
114 plt.close(fig)
116 wrapped.figure_save_fmt = fmt # type: ignore[attr-defined]
118 return wrapped
120 if callable(func):
121 # Handle no-arguments case
122 return decorator(func)
123 else:
124 # Handle arguments case
125 return decorator
128def matplotlib_multifigure_saver(
129 names: Sequence[str],
130 fmt: str = MATPLOTLIB_FIGURE_FMT,
131) -> Callable[
132 # decorator takes in function
133 # which takes a matrix and a dictionary of axes corresponding to the names
134 [Callable[[AttentionMatrix, dict[str, plt.Axes]], None]],
135 # returns the decorated function
136 AttentionMatrixFigureFunc,
137]:
138 """decorate a function such that it saves multiple figures, one for each name in `names`
140 # Parameters:
141 - `names : Sequence[str]`
142 the names of the figures to save
143 - `fmt : str`
144 format for saving matplotlib figures
145 (defaults to `MATPLOTLIB_FIGURE_FMT`)
147 # Returns:
148 - `Callable[[Callable[[AttentionMatrix, dict[str, plt.Axes]], None], AttentionMatrixFigureFunc]`
149 the decorator, which will then be applied to the function
150 we expect the decorated function to take an attention pattern, and a dict of axes corresponding to the names
152 """
154 def decorator(
155 func: Callable[[AttentionMatrix, dict[str, plt.Axes]], None],
156 ) -> AttentionMatrixFigureFunc:
157 func_name: str = func.__name__
159 @functools.wraps(func)
160 def wrapped(attn_matrix: AttentionMatrix, save_dir: Path) -> None:
161 # set up axes and corresponding figures
162 axes_dict: dict[str, plt.Axes] = {}
163 figs_dict: dict[str, plt.Figure] = {}
165 # Create all figures and axes
166 for name in names:
167 fig, ax = plt.subplots(figsize=(10, 10))
168 axes_dict[name] = ax
169 figs_dict[name] = fig
171 try:
172 # Run the function to make plots
173 func(attn_matrix, axes_dict)
175 # Save each figure
176 for name, fig_ in figs_dict.items():
177 fig_path: Path = save_dir / f"{func_name}.{name}.{fmt}"
178 # TYPING: error: Item "SubFigure" of "Figure | SubFigure" has no attribute "tight_layout" [union-attr]
179 fig_.tight_layout() # type: ignore[union-attr]
180 # TYPING: error: Item "SubFigure" of "Figure | SubFigure" has no attribute "savefig" [union-attr]
181 fig_.savefig(fig_path) # type: ignore[union-attr]
182 finally:
183 # Always clean up figures, even if an error occurred
184 for fig in figs_dict.values():
185 # TYPING: error: Argument 1 to "close" has incompatible type "Figure | SubFigure"; expected "int | str | Figure | Literal['all'] | None" [arg-type]
186 plt.close(fig) # type: ignore[arg-type]
188 # it doesn't normally have this attribute, but we're adding it
189 wrapped.figure_save_fmt = fmt # type: ignore[attr-defined]
191 return wrapped
193 return decorator
196def matrix_to_image_preprocess(
197 matrix: Matrix2D,
198 normalize: bool = False,
199 cmap: str | Colormap = "viridis",
200 diverging_colormap: bool = False,
201 normalize_min: float | None = None,
202) -> Matrix2Drgb:
203 """preprocess a 2D matrix into a plottable heatmap image
205 # Parameters:
206 - `matrix : Matrix2D`
207 input matrix
208 - `normalize : bool`
209 whether to normalize the matrix to range [0, 1]
210 (defaults to `MATRIX_SAVE_NORMALIZE`)
211 - `cmap : str|Colormap`
212 the colormap to use for the matrix
213 (defaults to `MATRIX_SAVE_CMAP`)
214 - `diverging_colormap : bool`
215 if True and using a diverging colormap, ensures 0 values map to the center of the colormap
216 (defaults to False)
217 - `normalize_min : float|None`
218 if a float, then for `normalize=True` and `diverging_colormap=False`, the minimum value to normalize to (generally set this to zero?).
219 if `None`, then the minimum value of the matrix is used.
220 if `diverging_colormap=True` OR `normalize=False`, this **must** be `None`.
221 (defaults to `None`)
223 # Returns:
224 - `Matrix2Drgb`
225 """
226 # check dims (2 is not that magic of a value here, hence noqa)
227 assert matrix.ndim == 2, f"Matrix must be 2D, got {matrix.ndim = }" # noqa: PLR2004
229 # check matrix is not empty
230 assert matrix.size > 0, "Matrix cannot be empty"
232 if normalize_min is not None:
233 assert not diverging_colormap, (
234 "normalize_min cannot be used with diverging_colormap=True"
235 )
236 assert normalize, "normalize_min cannot be used with normalize=False"
238 # Normalize the matrix to range [0, 1]
239 normalized_matrix: Matrix2D
240 if normalize:
241 if diverging_colormap:
242 # For diverging colormaps, we want to center around 0
243 max_abs: float = max(abs(matrix.max()), abs(matrix.min()))
244 normalized_matrix = (matrix / (2 * max_abs)) + 0.5
245 else:
246 max_val: float = matrix.max()
247 min_val: float
248 if normalize_min is not None:
249 min_val = normalize_min
250 assert min_val < max_val, "normalize_min must be less than matrix max"
251 assert min_val >= matrix.min(), (
252 "normalize_min must less than matrix min"
253 )
254 else:
255 min_val = matrix.min()
257 normalized_matrix = (matrix - min_val) / (max_val - min_val)
258 else:
259 if diverging_colormap:
260 assert matrix.min() >= -1 and matrix.max() <= 1, ( # noqa: PT018
261 "For diverging colormaps without normalization, matrix values must be in range [-1, 1]"
262 )
263 normalized_matrix = matrix
264 else:
265 assert matrix.min() >= 0 and matrix.max() <= 1, ( # noqa: PT018
266 "Matrix values must be in range [0, 1], or normalize must be True"
267 )
268 normalized_matrix = matrix
270 # get the colormap
271 cmap_: Colormap
272 if isinstance(cmap, str):
273 cmap_ = mpl.colormaps[cmap]
274 elif isinstance(cmap, Colormap):
275 cmap_ = cmap
276 else:
277 msg = f"Invalid type for {cmap = }, {type(cmap) = }, must be str or Colormap"
278 raise TypeError(
279 msg,
280 )
282 # Apply the colormap
283 rgb_matrix: Float[np.ndarray, "n m channels=3"] = (
284 cmap_(normalized_matrix)[:, :, :3] * 255
285 ).astype(np.uint8) # Drop alpha channel
287 assert rgb_matrix.shape == (
288 matrix.shape[0],
289 matrix.shape[1],
290 3,
291 ), f"Matrix after colormap must have 3 channels, got {rgb_matrix.shape = }"
293 return rgb_matrix
296@overload
297def matrix2drgb_to_png_bytes(matrix: Matrix2Drgb, buffer: None = None) -> bytes: ...
298@overload
299def matrix2drgb_to_png_bytes(matrix: Matrix2Drgb, buffer: io.BytesIO) -> None: ...
300def matrix2drgb_to_png_bytes(
301 matrix: Matrix2Drgb,
302 buffer: io.BytesIO | None = None,
303) -> bytes | None:
304 """Convert a `Matrix2Drgb` to valid PNG bytes via PIL
306 - if `buffer` is provided, it will write the PNG bytes to the buffer and return `None`
307 - if `buffer` is not provided, it will return the PNG bytes
309 # Parameters:
310 - `matrix : Matrix2Drgb`
311 - `buffer : io.BytesIO | None`
312 (defaults to `None`, in which case it will return the PNG bytes)
314 # Returns:
315 - `bytes|None`
316 `bytes` if `buffer` is `None`, otherwise `None`
317 """
318 pil_img: Image.Image = Image.fromarray(matrix, mode="RGB")
319 if buffer is None:
320 buffer = io.BytesIO()
321 pil_img.save(buffer, format="PNG")
322 buffer.seek(0)
323 return buffer.read()
324 else:
325 pil_img.save(buffer, format="PNG")
326 return None
329def matrix_as_svg(
330 matrix: Matrix2D,
331 normalize: bool = MATRIX_SAVE_NORMALIZE,
332 cmap: str | Colormap = MATRIX_SAVE_CMAP,
333 diverging_colormap: bool = False,
334 normalize_min: float | None = None,
335) -> str:
336 """quickly convert a 2D matrix to an SVG image, without matplotlib
338 # Parameters:
339 - `matrix : Float[np.ndarray, 'n m']`
340 a 2D matrix to convert to an SVG image
341 - `normalize : bool`
342 whether to normalize the matrix to range [0, 1]. if it's not in the range [0, 1], this must be `True` or it will raise an `AssertionError`
343 (defaults to `False`)
344 - `cmap : str`
345 the colormap to use for the matrix -- will look up in `matplotlib.colormaps` if it's a string
346 (defaults to `"viridis"`)
347 - `diverging_colormap : bool`
348 if True and using a diverging colormap, ensures 0 values map to the center of the colormap
349 (defaults to False)
350 - `normalize_min : float|None`
351 if a float, then for `normalize=True` and `diverging_colormap=False`, the minimum value to normalize to (generally set this to zero?)
352 if `None`, then the minimum value of the matrix is used
353 if `diverging_colormap=True` OR `normalize=False`, this **must** be `None`
354 (defaults to `None`)
357 # Returns:
358 - `str`
359 the SVG content for the matrix
360 """
361 # Get the dimensions of the matrix
362 assert matrix.ndim == 2, f"Matrix must be 2D, got {matrix.shape = }" # noqa: PLR2004
363 m, n = matrix.shape
365 # Preprocess the matrix into an RGB image
366 matrix_rgb: Matrix2Drgb = matrix_to_image_preprocess(
367 matrix,
368 normalize=normalize,
369 cmap=cmap,
370 diverging_colormap=diverging_colormap,
371 normalize_min=normalize_min,
372 )
374 # Convert the RGB image to PNG bytes
375 image_data: bytes = matrix2drgb_to_png_bytes(matrix_rgb)
377 # Encode the PNG bytes as base64
378 png_base64: str = base64.b64encode(image_data).decode("utf-8")
380 # Generate the SVG content
381 svg_content: str = MATRIX_SAVE_SVG_TEMPLATE.format(m=m, n=n, png_base64=png_base64)
383 return svg_content
386@overload # with keyword arguments, returns decorator
387def save_matrix_wrapper(
388 func: None = None,
389 *args: tuple[()],
390 fmt: MatrixSaveFormat = MATRIX_SAVE_FMT,
391 normalize: bool = MATRIX_SAVE_NORMALIZE,
392 cmap: str | Colormap = MATRIX_SAVE_CMAP,
393 diverging_colormap: bool = False,
394 normalize_min: float | None = None,
395) -> Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc]: ...
396@overload # without keyword arguments, returns decorated function
397def save_matrix_wrapper(
398 func: AttentionMatrixToMatrixFunc,
399 *args: tuple[()],
400 fmt: MatrixSaveFormat = MATRIX_SAVE_FMT,
401 normalize: bool = MATRIX_SAVE_NORMALIZE,
402 cmap: str | Colormap = MATRIX_SAVE_CMAP,
403 diverging_colormap: bool = False,
404 normalize_min: float | None = None,
405) -> AttentionMatrixFigureFunc: ...
406def save_matrix_wrapper(
407 func: AttentionMatrixToMatrixFunc | None = None,
408 *args,
409 fmt: MatrixSaveFormat = MATRIX_SAVE_FMT,
410 normalize: bool = MATRIX_SAVE_NORMALIZE,
411 cmap: str | Colormap = MATRIX_SAVE_CMAP,
412 diverging_colormap: bool = False,
413 normalize_min: float | None = None,
414) -> (
415 AttentionMatrixFigureFunc
416 | Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc]
417):
418 """Decorator for functions that process an attention matrix and save it as an SVGZ image.
420 Can handle both argumentless usage and with arguments.
422 # Parameters:
424 - `func : AttentionMatrixToMatrixFunc|None`
425 Either the function to decorate (in the no-arguments case) or `None` when used with arguments.
426 - `fmt : MatrixSaveFormat, keyword-only`
427 The format to save the matrix as. Defaults to `MATRIX_SAVE_FMT`.
428 - `normalize : bool, keyword-only`
429 Whether to normalize the matrix to range [0, 1]. Defaults to `False`.
430 - `cmap : str, keyword-only`
431 The colormap to use for the matrix. Defaults to `MATRIX_SVG_CMAP`.
432 - `diverging_colormap : bool`
433 if True and using a diverging colormap, ensures 0 values map to the center of the colormap
434 (defaults to False)
435 - `normalize_min : float|None`
436 if a float, then for `normalize=True` and `diverging_colormap=False`, the minimum value to normalize to (generally set this to zero?)
437 if `None`, then the minimum value of the matrix is used
438 if `diverging_colormap=True` OR `normalize=False`, this **must** be `None`
439 (defaults to `None`)
441 # Returns:
443 `AttentionMatrixFigureFunc|Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc]`
445 - `AttentionMatrixFigureFunc` if `func` is `AttentionMatrixToMatrixFunc` (no arguments case)
446 - `Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc]` if `func` is `None` -- returns the decorator which will then be applied to the (with arguments case)
448 # Usage:
450 ```python
451 @save_matrix_wrapper
452 def identity_matrix(matrix):
453 return matrix
455 @save_matrix_wrapper(normalize=True, fmt="png")
456 def scale_matrix(matrix):
457 return matrix * 2
459 @save_matrix_wrapper(normalize=True, cmap="plasma")
460 def scale_matrix(matrix):
461 return matrix * 2
462 ```
464 """
465 assert len(args) == 0, "This decorator only supports keyword arguments"
467 assert (
468 fmt in MatrixSaveFormat.__args__ # type: ignore[attr-defined]
469 ), f"Invalid format {fmt = }, must be one of {MatrixSaveFormat.__args__}" # type: ignore[attr-defined]
471 def decorator(
472 func: Callable[[AttentionMatrix], Matrix2D],
473 ) -> AttentionMatrixFigureFunc:
474 @functools.wraps(func)
475 def wrapped(attn_matrix: AttentionMatrix, save_dir: Path) -> None:
476 fig_path: Path = save_dir / f"{func.__name__}.{fmt}"
477 processed_matrix: Matrix2D = func(attn_matrix)
479 if fmt == "png":
480 processed_matrix_rgb: Matrix2Drgb = matrix_to_image_preprocess(
481 processed_matrix,
482 normalize=normalize,
483 cmap=cmap,
484 diverging_colormap=diverging_colormap,
485 normalize_min=normalize_min,
486 )
487 image_data: bytes = matrix2drgb_to_png_bytes(processed_matrix_rgb)
488 fig_path.write_bytes(image_data)
490 else:
491 svg_content: str = matrix_as_svg(
492 processed_matrix,
493 normalize=normalize,
494 cmap=cmap,
495 diverging_colormap=diverging_colormap,
496 normalize_min=normalize_min,
497 )
499 if fmt == "svgz":
500 with gzip.open(fig_path, "wt") as f:
501 f.write(svg_content)
503 else:
504 fig_path.write_text(svg_content, encoding="utf-8")
506 wrapped.figure_save_fmt = fmt # type: ignore[attr-defined]
508 return wrapped
510 if callable(func):
511 # Handle no-arguments case
512 return decorator(func)
513 else:
514 # Handle arguments case
515 return decorator