Coverage for pattern_lens / figure_util.py: 89%
143 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-22 18:15 -0700
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-22 18:15 -0700
1"""implements a bunch of types, default values, and templates which are useful for figure functions
3notably, you can use the decorators `matplotlib_figure_saver`, `save_matrix_wrapper` to make your functions save figures
4"""
6import base64
7import functools
8import gzip
9import io
10from collections.abc import Callable, Sequence
11from pathlib import Path
12from typing import Literal, overload
14import matplotlib as mpl
15import matplotlib.pyplot as plt
16import numpy as np
17from jaxtyping import Float, UInt8
18from matplotlib.colors import Colormap
19from PIL import Image
21from pattern_lens.consts import AttentionMatrix
23AttentionMatrixFigureFunc = Callable[[AttentionMatrix, Path], None]
24"Type alias for a function that, given an attention matrix, saves one or more figures"
26Matrix2D = Float[np.ndarray, "n m"]
27"Type alias for a 2D matrix (plottable)"
29Matrix2Drgb = UInt8[np.ndarray, "n m rgb=3"]
30"Type alias for a 2D matrix with 3 channels (RGB)"
32AttentionMatrixToMatrixFunc = Callable[[AttentionMatrix], Matrix2D]
33"Type alias for a function that, given an attention matrix, returns a 2D matrix"
35MATPLOTLIB_FIGURE_FMT: str = "svgz"
36"format for saving matplotlib figures"
38MatrixSaveFormat = Literal["png", "svg", "svgz"]
39"Type alias for the format to save a matrix as when saving raw matrix, not matplotlib figure"
41MATRIX_SAVE_NORMALIZE: bool = False
42"default for whether to normalize the matrix to range [0, 1]"
44MATRIX_SAVE_CMAP: str = "viridis"
45"default colormap for saving matrices"
47MATRIX_SAVE_FMT: MatrixSaveFormat = "svgz"
48"default format for saving matrices"
50MATRIX_SAVE_SVG_TEMPLATE: str = """<svg xmlns="http://www.w3.org/2000/svg" width="{m}" height="{n}" viewBox="0 0 {m} {n}" image-rendering="pixelated"> <image href="data:image/png;base64,{png_base64}" width="{m}" height="{n}" /> </svg>"""
51"template for saving an `n` by `m` matrix as an svg/svgz"
54# TYPING: mypy hates it when we dont pass func=None or None as the first arg
55@overload # without keyword arguments, returns decorated function
56def matplotlib_figure_saver(
57 func: Callable[[AttentionMatrix, plt.Axes], None],
58) -> AttentionMatrixFigureFunc: ...
59@overload # with keyword arguments, returns decorator
60def matplotlib_figure_saver(
61 func: None = None,
62 fmt: str = MATPLOTLIB_FIGURE_FMT,
63) -> Callable[
64 [Callable[[AttentionMatrix, plt.Axes], None]],
65 AttentionMatrixFigureFunc,
66]: ...
67def matplotlib_figure_saver(
68 func: Callable[[AttentionMatrix, plt.Axes], None] | None = None,
69 fmt: str = MATPLOTLIB_FIGURE_FMT,
70) -> (
71 AttentionMatrixFigureFunc
72 | Callable[
73 [Callable[[AttentionMatrix, plt.Axes], None]],
74 AttentionMatrixFigureFunc,
75 ]
76):
77 """decorator for functions which take an attention matrix and predefined `ax` object, making it save a figure
79 # Parameters:
80 - `func : Callable[[AttentionMatrix, plt.Axes], None]`
81 your function, which should take an attention matrix and predefined `ax` object
82 - `fmt : str`
83 format for saving matplotlib figures
84 (defaults to `MATPLOTLIB_FIGURE_FMT`)
86 # Returns:
87 - `AttentionMatrixFigureFunc`
88 your function, after we wrap it to save a figure
90 # Usage:
91 ```python
92 @register_attn_figure_func
93 @matplotlib_figure_saver
94 def raw(attn_matrix: AttentionMatrix, ax: plt.Axes) -> None:
95 ax.matshow(attn_matrix, cmap="viridis")
96 ax.set_title("Raw Attention Pattern")
97 ax.axis("off")
98 ```
100 """
102 def decorator(
103 func: Callable[[AttentionMatrix, plt.Axes], None],
104 fmt: str = fmt,
105 ) -> AttentionMatrixFigureFunc:
106 @functools.wraps(func)
107 def wrapped(attn_matrix: AttentionMatrix, save_dir: Path) -> None:
108 fig_path: Path = (
109 save_dir / f"{getattr(func, '__name__', '<unknown>')}.{fmt}"
110 )
112 fig, ax = plt.subplots(figsize=(10, 10))
113 func(attn_matrix, ax)
114 plt.tight_layout()
115 plt.savefig(fig_path)
116 plt.close(fig)
118 wrapped.figure_save_fmt = fmt # type: ignore[attr-defined]
120 return wrapped
122 if callable(func):
123 # Handle no-arguments case
124 return decorator(func)
125 else:
126 # Handle arguments case
127 return decorator
130def matplotlib_multifigure_saver(
131 names: Sequence[str],
132 fmt: str = MATPLOTLIB_FIGURE_FMT,
133) -> Callable[
134 # decorator takes in function
135 # which takes a matrix and a dictionary of axes corresponding to the names
136 [Callable[[AttentionMatrix, dict[str, plt.Axes]], None]],
137 # returns the decorated function
138 AttentionMatrixFigureFunc,
139]:
140 """decorate a function such that it saves multiple figures, one for each name in `names`
142 # Parameters:
143 - `names : Sequence[str]`
144 the names of the figures to save
145 - `fmt : str`
146 format for saving matplotlib figures
147 (defaults to `MATPLOTLIB_FIGURE_FMT`)
149 # Returns:
150 - `Callable[[Callable[[AttentionMatrix, dict[str, plt.Axes]], None], AttentionMatrixFigureFunc]`
151 the decorator, which will then be applied to the function
152 we expect the decorated function to take an attention pattern, and a dict of axes corresponding to the names
154 """
156 def decorator(
157 func: Callable[[AttentionMatrix, dict[str, plt.Axes]], None],
158 ) -> AttentionMatrixFigureFunc:
159 func_name: str = getattr(func, "__name__", "<unknown>")
161 @functools.wraps(func)
162 def wrapped(attn_matrix: AttentionMatrix, save_dir: Path) -> None:
163 # set up axes and corresponding figures
164 axes_dict: dict[str, plt.Axes] = {}
165 figs_dict: dict[str, plt.Figure] = {}
167 # Create all figures and axes
168 for name in names:
169 fig, ax = plt.subplots(figsize=(10, 10))
170 axes_dict[name] = ax
171 figs_dict[name] = fig
173 try:
174 # Run the function to make plots
175 func(attn_matrix, axes_dict)
177 # Save each figure
178 for name, fig_ in figs_dict.items():
179 fig_path: Path = save_dir / f"{func_name}.{name}.{fmt}"
180 # TYPING: error: Item "SubFigure" of "Figure | SubFigure" has no attribute "tight_layout" [union-attr]
181 fig_.tight_layout() # type: ignore[union-attr]
182 # TYPING: error: Item "SubFigure" of "Figure | SubFigure" has no attribute "savefig" [union-attr]
183 fig_.savefig(fig_path) # type: ignore[union-attr]
184 finally:
185 # Always clean up figures, even if an error occurred
186 for fig in figs_dict.values():
187 # TYPING: error: Argument 1 to "close" has incompatible type "Figure | SubFigure"; expected "int | str | Figure | Literal['all'] | None" [arg-type]
188 plt.close(fig) # type: ignore[arg-type]
190 # it doesn't normally have this attribute, but we're adding it
191 wrapped.figure_save_fmt = fmt # type: ignore[attr-defined]
193 return wrapped
195 return decorator
198def matrix_to_image_preprocess(
199 matrix: Matrix2D,
200 normalize: bool = False,
201 cmap: str | Colormap = "viridis",
202 diverging_colormap: bool = False,
203 normalize_min: float | None = None,
204) -> Matrix2Drgb:
205 """preprocess a 2D matrix into a plottable heatmap image
207 # Parameters:
208 - `matrix : Matrix2D`
209 input matrix
210 - `normalize : bool`
211 whether to normalize the matrix to range [0, 1]
212 (defaults to `MATRIX_SAVE_NORMALIZE`)
213 - `cmap : str|Colormap`
214 the colormap to use for the matrix
215 (defaults to `MATRIX_SAVE_CMAP`)
216 - `diverging_colormap : bool`
217 if True and using a diverging colormap, ensures 0 values map to the center of the colormap
218 (defaults to False)
219 - `normalize_min : float|None`
220 if a float, then for `normalize=True` and `diverging_colormap=False`, the minimum value to normalize to (generally set this to zero?).
221 if `None`, then the minimum value of the matrix is used.
222 if `diverging_colormap=True` OR `normalize=False`, this **must** be `None`.
223 (defaults to `None`)
225 # Returns:
226 - `Matrix2Drgb`
227 """
228 # check dims (2 is not that magic of a value here, hence noqa)
229 assert matrix.ndim == 2, f"Matrix must be 2D, got {matrix.ndim = }" # noqa: PLR2004
231 # check matrix is not empty
232 assert matrix.size > 0, "Matrix cannot be empty"
234 if normalize_min is not None:
235 assert not diverging_colormap, (
236 "normalize_min cannot be used with diverging_colormap=True"
237 )
238 assert normalize, "normalize_min cannot be used with normalize=False"
240 # Normalize the matrix to range [0, 1]
241 normalized_matrix: Matrix2D
242 if normalize:
243 if diverging_colormap:
244 # For diverging colormaps, we want to center around 0
245 max_abs: float = max(abs(matrix.max()), abs(matrix.min()))
246 normalized_matrix = (matrix / (2 * max_abs)) + 0.5
247 else:
248 max_val: float = matrix.max()
249 min_val: float
250 if normalize_min is not None:
251 min_val = normalize_min
252 assert min_val < max_val, "normalize_min must be less than matrix max"
253 assert min_val >= matrix.min(), (
254 "normalize_min must less than matrix min"
255 )
256 else:
257 min_val = matrix.min()
259 normalized_matrix = (matrix - min_val) / (max_val - min_val)
260 else:
261 if diverging_colormap:
262 assert matrix.min() >= -1 and matrix.max() <= 1, ( # noqa: PT018
263 "For diverging colormaps without normalization, matrix values must be in range [-1, 1]"
264 )
265 normalized_matrix = matrix
266 else:
267 assert matrix.min() >= 0 and matrix.max() <= 1, ( # noqa: PT018
268 "Matrix values must be in range [0, 1], or normalize must be True"
269 )
270 normalized_matrix = matrix
272 # get the colormap
273 cmap_: Colormap
274 if isinstance(cmap, str):
275 cmap_ = mpl.colormaps[cmap]
276 elif isinstance(cmap, Colormap):
277 cmap_ = cmap
278 else:
279 msg = f"Invalid type for {cmap = }, {type(cmap) = }, must be str or Colormap"
280 raise TypeError(
281 msg,
282 )
284 # Apply the colormap
285 rgb_matrix: Float[np.ndarray, "n m channels=3"] = (
286 cmap_(normalized_matrix)[:, :, :3] * 255
287 ).astype(np.uint8) # Drop alpha channel
289 assert rgb_matrix.shape == (
290 matrix.shape[0],
291 matrix.shape[1],
292 3,
293 ), f"Matrix after colormap must have 3 channels, got {rgb_matrix.shape = }"
295 return rgb_matrix
298@overload
299def matrix2drgb_to_png_bytes(matrix: Matrix2Drgb, buffer: None = None) -> bytes: ...
300@overload
301def matrix2drgb_to_png_bytes(matrix: Matrix2Drgb, buffer: io.BytesIO) -> None: ...
302def matrix2drgb_to_png_bytes(
303 matrix: Matrix2Drgb,
304 buffer: io.BytesIO | None = None,
305) -> bytes | None:
306 """Convert a `Matrix2Drgb` to valid PNG bytes via PIL
308 - if `buffer` is provided, it will write the PNG bytes to the buffer and return `None`
309 - if `buffer` is not provided, it will return the PNG bytes
311 # Parameters:
312 - `matrix : Matrix2Drgb`
313 - `buffer : io.BytesIO | None`
314 (defaults to `None`, in which case it will return the PNG bytes)
316 # Returns:
317 - `bytes|None`
318 `bytes` if `buffer` is `None`, otherwise `None`
319 """
320 pil_img: Image.Image = Image.fromarray(matrix, mode="RGB")
321 if buffer is None:
322 buffer = io.BytesIO()
323 pil_img.save(buffer, format="PNG")
324 buffer.seek(0)
325 return buffer.read()
326 else:
327 pil_img.save(buffer, format="PNG")
328 return None
331def matrix_as_svg(
332 matrix: Matrix2D,
333 normalize: bool = MATRIX_SAVE_NORMALIZE,
334 cmap: str | Colormap = MATRIX_SAVE_CMAP,
335 diverging_colormap: bool = False,
336 normalize_min: float | None = None,
337) -> str:
338 """quickly convert a 2D matrix to an SVG image, without matplotlib
340 # Parameters:
341 - `matrix : Float[np.ndarray, 'n m']`
342 a 2D matrix to convert to an SVG image
343 - `normalize : bool`
344 whether to normalize the matrix to range [0, 1]. if it's not in the range [0, 1], this must be `True` or it will raise an `AssertionError`
345 (defaults to `False`)
346 - `cmap : str`
347 the colormap to use for the matrix -- will look up in `matplotlib.colormaps` if it's a string
348 (defaults to `"viridis"`)
349 - `diverging_colormap : bool`
350 if True and using a diverging colormap, ensures 0 values map to the center of the colormap
351 (defaults to False)
352 - `normalize_min : float|None`
353 if a float, then for `normalize=True` and `diverging_colormap=False`, the minimum value to normalize to (generally set this to zero?)
354 if `None`, then the minimum value of the matrix is used
355 if `diverging_colormap=True` OR `normalize=False`, this **must** be `None`
356 (defaults to `None`)
359 # Returns:
360 - `str`
361 the SVG content for the matrix
362 """
363 # Get the dimensions of the matrix
364 assert matrix.ndim == 2, f"Matrix must be 2D, got {matrix.shape = }" # noqa: PLR2004
365 m, n = matrix.shape
367 # Preprocess the matrix into an RGB image
368 matrix_rgb: Matrix2Drgb = matrix_to_image_preprocess(
369 matrix,
370 normalize=normalize,
371 cmap=cmap,
372 diverging_colormap=diverging_colormap,
373 normalize_min=normalize_min,
374 )
376 # Convert the RGB image to PNG bytes
377 image_data: bytes = matrix2drgb_to_png_bytes(matrix_rgb)
379 # Encode the PNG bytes as base64
380 png_base64: str = base64.b64encode(image_data).decode("utf-8")
382 # Generate the SVG content
383 svg_content: str = MATRIX_SAVE_SVG_TEMPLATE.format(m=m, n=n, png_base64=png_base64)
385 return svg_content
388@overload # with keyword arguments, returns decorator
389def save_matrix_wrapper(
390 func: None = None,
391 *args: tuple[()],
392 fmt: MatrixSaveFormat = MATRIX_SAVE_FMT,
393 normalize: bool = MATRIX_SAVE_NORMALIZE,
394 cmap: str | Colormap = MATRIX_SAVE_CMAP,
395 diverging_colormap: bool = False,
396 normalize_min: float | None = None,
397) -> Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc]: ...
398@overload # without keyword arguments, returns decorated function
399def save_matrix_wrapper(
400 func: AttentionMatrixToMatrixFunc,
401 *args: tuple[()],
402 fmt: MatrixSaveFormat = MATRIX_SAVE_FMT,
403 normalize: bool = MATRIX_SAVE_NORMALIZE,
404 cmap: str | Colormap = MATRIX_SAVE_CMAP,
405 diverging_colormap: bool = False,
406 normalize_min: float | None = None,
407) -> AttentionMatrixFigureFunc: ...
408def save_matrix_wrapper(
409 func: AttentionMatrixToMatrixFunc | None = None,
410 *args,
411 fmt: MatrixSaveFormat = MATRIX_SAVE_FMT,
412 normalize: bool = MATRIX_SAVE_NORMALIZE,
413 cmap: str | Colormap = MATRIX_SAVE_CMAP,
414 diverging_colormap: bool = False,
415 normalize_min: float | None = None,
416) -> (
417 AttentionMatrixFigureFunc
418 | Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc]
419):
420 """Decorator for functions that process an attention matrix and save it as an SVGZ image.
422 Can handle both argumentless usage and with arguments.
424 # Parameters:
426 - `func : AttentionMatrixToMatrixFunc|None`
427 Either the function to decorate (in the no-arguments case) or `None` when used with arguments.
428 - `fmt : MatrixSaveFormat, keyword-only`
429 The format to save the matrix as. Defaults to `MATRIX_SAVE_FMT`.
430 - `normalize : bool, keyword-only`
431 Whether to normalize the matrix to range [0, 1]. Defaults to `False`.
432 - `cmap : str, keyword-only`
433 The colormap to use for the matrix. Defaults to `MATRIX_SVG_CMAP`.
434 - `diverging_colormap : bool`
435 if True and using a diverging colormap, ensures 0 values map to the center of the colormap
436 (defaults to False)
437 - `normalize_min : float|None`
438 if a float, then for `normalize=True` and `diverging_colormap=False`, the minimum value to normalize to (generally set this to zero?)
439 if `None`, then the minimum value of the matrix is used
440 if `diverging_colormap=True` OR `normalize=False`, this **must** be `None`
441 (defaults to `None`)
443 # Returns:
445 `AttentionMatrixFigureFunc|Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc]`
447 - `AttentionMatrixFigureFunc` if `func` is `AttentionMatrixToMatrixFunc` (no arguments case)
448 - `Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc]` if `func` is `None` -- returns the decorator which will then be applied to the (with arguments case)
450 # Usage:
452 ```python
453 @save_matrix_wrapper
454 def identity_matrix(matrix):
455 return matrix
457 @save_matrix_wrapper(normalize=True, fmt="png")
458 def scale_matrix(matrix):
459 return matrix * 2
461 @save_matrix_wrapper(normalize=True, cmap="plasma")
462 def scale_matrix(matrix):
463 return matrix * 2
464 ```
466 """
467 assert len(args) == 0, "This decorator only supports keyword arguments"
469 assert (
470 fmt in MatrixSaveFormat.__args__ # type: ignore[attr-defined]
471 ), f"Invalid format {fmt = }, must be one of {MatrixSaveFormat.__args__}" # type: ignore[attr-defined]
473 def decorator(
474 func: Callable[[AttentionMatrix], Matrix2D],
475 ) -> AttentionMatrixFigureFunc:
476 @functools.wraps(func)
477 def wrapped(attn_matrix: AttentionMatrix, save_dir: Path) -> None:
478 fig_path: Path = (
479 save_dir / f"{getattr(func, '__name__', '<unknown>')}.{fmt}"
480 )
481 processed_matrix: Matrix2D = func(attn_matrix)
483 if fmt == "png":
484 processed_matrix_rgb: Matrix2Drgb = matrix_to_image_preprocess(
485 processed_matrix,
486 normalize=normalize,
487 cmap=cmap,
488 diverging_colormap=diverging_colormap,
489 normalize_min=normalize_min,
490 )
491 image_data: bytes = matrix2drgb_to_png_bytes(processed_matrix_rgb)
492 fig_path.write_bytes(image_data)
494 else:
495 svg_content: str = matrix_as_svg(
496 processed_matrix,
497 normalize=normalize,
498 cmap=cmap,
499 diverging_colormap=diverging_colormap,
500 normalize_min=normalize_min,
501 )
503 if fmt == "svgz":
504 with gzip.open(fig_path, "wt") as f:
505 f.write(svg_content)
507 else:
508 fig_path.write_text(svg_content, encoding="utf-8")
510 wrapped.figure_save_fmt = fmt # type: ignore[attr-defined]
512 return wrapped
514 if callable(func):
515 # Handle no-arguments case
516 return decorator(func)
517 else:
518 # Handle arguments case
519 return decorator