Coverage for pattern_lens\figure_util.py: 88%
129 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-16 20:39 -0700
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-16 20:39 -0700
1"""implements a bunch of types, default values, and templates which are useful for figure functions
3notably, you can use the decorators `matplotlib_figure_saver`, `save_matrix_wrapper` to make your functions save figures
4"""
6from pathlib import Path
7from typing import Callable, Literal, overload, Union
8import functools
9import base64
10import gzip
11import io
13from PIL import Image
14import numpy as np
15from jaxtyping import Float, UInt8
16import matplotlib
17import matplotlib.pyplot as plt
18from matplotlib.colors import Colormap
20from pattern_lens.consts import AttentionMatrix
22AttentionMatrixFigureFunc = Callable[[AttentionMatrix, Path], None]
23"Type alias for a function that, given an attention matrix, saves a figure"
25Matrix2D = Float[np.ndarray, "n m"]
26"Type alias for a 2D matrix (plottable)"
28Matrix2Drgb = UInt8[np.ndarray, "n m rgb=3"]
29"Type alias for a 2D matrix with 3 channels (RGB)"
31AttentionMatrixToMatrixFunc = Callable[[AttentionMatrix], Matrix2D]
32"Type alias for a function that, given an attention matrix, returns a 2D matrix"
34MATPLOTLIB_FIGURE_FMT: str = "svgz"
35"format for saving matplotlib figures"
37MatrixSaveFormat = Literal["png", "svg", "svgz"]
38"Type alias for the format to save a matrix as when saving raw matrix, not matplotlib figure"
40MATRIX_SAVE_NORMALIZE: bool = False
41"default for whether to normalize the matrix to range [0, 1]"
43MATRIX_SAVE_CMAP: str = "viridis"
44"default colormap for saving matrices"
46MATRIX_SAVE_FMT: MatrixSaveFormat = "svgz"
47"default format for saving matrices"
49MATRIX_SAVE_SVG_TEMPLATE: str = """<svg xmlns="http://www.w3.org/2000/svg" width="{m}" height="{n}" viewBox="0 0 {m} {n}" image-rendering="pixelated"> <image href="data:image/png;base64,{png_base64}" width="{m}" height="{n}" /> </svg>"""
50"template for saving an `n` by `m` matrix as an svg/svgz"
53@overload # without keyword arguments, returns decorated function
54def matplotlib_figure_saver(
55 func: Callable[[AttentionMatrix, plt.Axes], None],
56 *args,
57 fmt: str = MATPLOTLIB_FIGURE_FMT,
58) -> AttentionMatrixFigureFunc: ...
59@overload # with keyword arguments, returns decorator
60def matplotlib_figure_saver(
61 func: None = None,
62 *args,
63 fmt: str = MATPLOTLIB_FIGURE_FMT,
64) -> Callable[
65 [Callable[[AttentionMatrix, plt.Axes], None], str], AttentionMatrixFigureFunc
66]: ...
67def matplotlib_figure_saver(
68 func: Callable[[AttentionMatrix, plt.Axes], None] | None = None,
69 *args,
70 fmt: str = MATPLOTLIB_FIGURE_FMT,
71) -> Union[
72 AttentionMatrixFigureFunc,
73 Callable[
74 [Callable[[AttentionMatrix, plt.Axes], None], str], AttentionMatrixFigureFunc
75 ],
76]:
77 """decorator for functions which take an attention matrix and predefined `ax` object, making it save a figure
79 # Parameters:
80 - `func : Callable[[AttentionMatrix, plt.Axes], None]`
81 your function, which should take an attention matrix and predefined `ax` object
82 - `fmt : str`
83 format for saving matplotlib figures
84 (defaults to `MATPLOTLIB_FIGURE_FMT`)
86 # Returns:
87 - `AttentionMatrixFigureFunc`
88 your function, after we wrap it to save a figure
90 # Usage:
91 ```python
92 @register_attn_figure_func
93 @matplotlib_figure_saver
94 def raw(attn_matrix: AttentionMatrix, ax: plt.Axes) -> None:
95 ax.matshow(attn_matrix, cmap="viridis")
96 ax.set_title("Raw Attention Pattern")
97 ax.axis("off")
98 ```
100 """
102 assert len(args) == 0, "This decorator only supports keyword arguments"
104 def decorator(
105 func: Callable[[AttentionMatrix, plt.Axes], None],
106 fmt: str = fmt,
107 ) -> AttentionMatrixFigureFunc:
108 @functools.wraps(func)
109 def wrapped(attn_matrix: AttentionMatrix, save_dir: Path) -> None:
110 fig_path: Path = save_dir / f"{func.__name__}.{fmt}"
112 fig, ax = plt.subplots(figsize=(10, 10))
113 func(attn_matrix, ax)
114 plt.tight_layout()
115 plt.savefig(fig_path)
116 plt.close(fig)
118 wrapped.figure_save_fmt = fmt # type: ignore[attr-defined]
120 return wrapped
122 if callable(func):
123 # Handle no-arguments case
124 return decorator(func)
125 else:
126 # Handle arguments case
127 return decorator
130def matrix_to_image_preprocess(
131 matrix: Matrix2D,
132 normalize: bool = False,
133 cmap: str | Colormap = "viridis",
134 diverging_colormap: bool = False,
135 normalize_min: float | None = None,
136) -> Matrix2Drgb:
137 """preprocess a 2D matrix into a plottable heatmap image
139 # Parameters:
140 - `matrix : Matrix2D`
141 input matrix
142 - `normalize : bool`
143 whether to normalize the matrix to range [0, 1]
144 (defaults to `MATRIX_SAVE_NORMALIZE`)
145 - `cmap : str|Colormap`
146 the colormap to use for the matrix
147 (defaults to `MATRIX_SAVE_CMAP`)
148 - `diverging_colormap : bool`
149 if True and using a diverging colormap, ensures 0 values map to the center of the colormap
150 (defaults to False)
151 - `normalize_min : float|None`
152 if a float, then for `normalize=True` and `diverging_colormap=False`, the minimum value to normalize to (generally set this to zero?).
153 if `None`, then the minimum value of the matrix is used.
154 if `diverging_colormap=True` OR `normalize=False`, this **must** be `None`.
155 (defaults to `None`)
157 # Returns:
158 - `Matrix2Drgb`
159 """
160 # check dims
161 assert matrix.ndim == 2, f"Matrix must be 2D, got {matrix.ndim = }"
163 # check matrix is not empty
164 assert matrix.size > 0, "Matrix cannot be empty"
166 if normalize_min is not None:
167 assert (
168 not diverging_colormap
169 ), "normalize_min cannot be used with diverging_colormap=True"
170 assert normalize, "normalize_min cannot be used with normalize=False"
172 # Normalize the matrix to range [0, 1]
173 normalized_matrix: Matrix2D
174 if normalize:
175 if diverging_colormap:
176 # For diverging colormaps, we want to center around 0
177 max_abs: float = max(abs(matrix.max()), abs(matrix.min()))
178 normalized_matrix = (matrix / (2 * max_abs)) + 0.5
179 else:
180 max_val: float = matrix.max()
181 min_val: float
182 if normalize_min is not None:
183 min_val = normalize_min
184 assert min_val < max_val, "normalize_min must be less than matrix max"
185 assert (
186 min_val >= matrix.min()
187 ), "normalize_min must less than matrix min"
188 else:
189 min_val = matrix.min()
191 normalized_matrix = (matrix - min_val) / (max_val - min_val)
192 else:
193 if diverging_colormap:
194 assert (
195 matrix.min() >= -1 and matrix.max() <= 1
196 ), "For diverging colormaps without normalization, matrix values must be in range [-1, 1]"
197 normalized_matrix = matrix
198 else:
199 assert (
200 matrix.min() >= 0 and matrix.max() <= 1
201 ), "Matrix values must be in range [0, 1], or normalize must be True"
202 normalized_matrix = matrix
204 # get the colormap
205 cmap_: Colormap
206 if isinstance(cmap, str):
207 cmap_ = matplotlib.colormaps[cmap]
208 elif isinstance(cmap, Colormap):
209 cmap_ = cmap
210 else:
211 raise TypeError(
212 f"Invalid type for {cmap = }, {type(cmap) = }, must be str or Colormap"
213 )
215 # Apply the colormap
216 rgb_matrix: Float[np.ndarray, "n m channels=3"] = ( # noqa: F722
217 cmap_(normalized_matrix)[:, :, :3] * 255
218 ).astype(np.uint8) # Drop alpha channel
220 assert rgb_matrix.shape == (
221 matrix.shape[0],
222 matrix.shape[1],
223 3,
224 ), f"Matrix after colormap must have 3 channels, got {rgb_matrix.shape = }"
226 return rgb_matrix
229@overload
230def matrix2drgb_to_png_bytes(matrix: Matrix2Drgb, buffer: None = None) -> bytes: ...
231@overload
232def matrix2drgb_to_png_bytes(matrix: Matrix2Drgb, buffer: io.BytesIO) -> None: ...
233def matrix2drgb_to_png_bytes(
234 matrix: Matrix2Drgb, buffer: io.BytesIO | None = None
235) -> bytes | None:
236 """Convert a `Matrix2Drgb` to valid PNG bytes via PIL
238 - if `buffer` is provided, it will write the PNG bytes to the buffer and return `None`
239 - if `buffer` is not provided, it will return the PNG bytes
241 # Parameters:
242 - `matrix : Matrix2Drgb`
243 - `buffer : io.BytesIO | None`
244 (defaults to `None`, in which case it will return the PNG bytes)
246 # Returns:
247 - `bytes|None`
248 `bytes` if `buffer` is `None`, otherwise `None`
249 """
251 pil_img: Image.Image = Image.fromarray(matrix, mode="RGB")
252 if buffer is None:
253 buffer = io.BytesIO()
254 pil_img.save(buffer, format="PNG")
255 buffer.seek(0)
256 return buffer.read()
257 else:
258 pil_img.save(buffer, format="PNG")
259 return None
262def matrix_as_svg(
263 matrix: Matrix2D,
264 normalize: bool = MATRIX_SAVE_NORMALIZE,
265 cmap: str | Colormap = MATRIX_SAVE_CMAP,
266 diverging_colormap: bool = False,
267 normalize_min: float | None = None,
268) -> str:
269 """quickly convert a 2D matrix to an SVG image, without matplotlib
271 # Parameters:
272 - `matrix : Float[np.ndarray, 'n m']`
273 a 2D matrix to convert to an SVG image
274 - `normalize : bool`
275 whether to normalize the matrix to range [0, 1]. if it's not in the range [0, 1], this must be `True` or it will raise an `AssertionError`
276 (defaults to `False`)
277 - `cmap : str`
278 the colormap to use for the matrix -- will look up in `matplotlib.colormaps` if it's a string
279 (defaults to `"viridis"`)
280 - `diverging_colormap : bool`
281 if True and using a diverging colormap, ensures 0 values map to the center of the colormap
282 (defaults to False)
283 - `normalize_min : float|None`
284 if a float, then for `normalize=True` and `diverging_colormap=False`, the minimum value to normalize to (generally set this to zero?)
285 if `None`, then the minimum value of the matrix is used
286 if `diverging_colormap=True` OR `normalize=False`, this **must** be `None`
287 (defaults to `None`)
290 # Returns:
291 - `str`
292 the SVG content for the matrix
293 """
294 # Get the dimensions of the matrix
295 m, n = matrix.shape
297 # Preprocess the matrix into an RGB image
298 matrix_rgb: Matrix2Drgb = matrix_to_image_preprocess(
299 matrix,
300 normalize=normalize,
301 cmap=cmap,
302 diverging_colormap=diverging_colormap,
303 normalize_min=normalize_min,
304 )
306 # Convert the RGB image to PNG bytes
307 image_data: bytes = matrix2drgb_to_png_bytes(matrix_rgb)
309 # Encode the PNG bytes as base64
310 png_base64: str = base64.b64encode(image_data).decode("utf-8")
312 # Generate the SVG content
313 svg_content: str = MATRIX_SAVE_SVG_TEMPLATE.format(m=m, n=n, png_base64=png_base64)
315 return svg_content
318@overload # with keyword arguments, returns decorator
319def save_matrix_wrapper(
320 func: None = None,
321 *args,
322 fmt: MatrixSaveFormat = MATRIX_SAVE_FMT,
323 normalize: bool = MATRIX_SAVE_NORMALIZE,
324 cmap: str | Colormap = MATRIX_SAVE_CMAP,
325 diverging_colormap: bool = False,
326 normalize_min: float | None = None,
327) -> Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc]: ...
328@overload # without keyword arguments, returns decorated function
329def save_matrix_wrapper(
330 func: AttentionMatrixToMatrixFunc,
331 *args,
332 fmt: MatrixSaveFormat = MATRIX_SAVE_FMT,
333 normalize: bool = MATRIX_SAVE_NORMALIZE,
334 cmap: str | Colormap = MATRIX_SAVE_CMAP,
335 diverging_colormap: bool = False,
336 normalize_min: float | None = None,
337) -> AttentionMatrixFigureFunc: ...
338def save_matrix_wrapper(
339 func: AttentionMatrixToMatrixFunc | None = None,
340 *args,
341 fmt: MatrixSaveFormat = MATRIX_SAVE_FMT,
342 normalize: bool = MATRIX_SAVE_NORMALIZE,
343 cmap: str | Colormap = MATRIX_SAVE_CMAP,
344 diverging_colormap: bool = False,
345 normalize_min: float | None = None,
346) -> (
347 AttentionMatrixFigureFunc
348 | Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc]
349):
350 """
351 Decorator for functions that process an attention matrix and save it as an SVGZ image.
352 Can handle both argumentless usage and with arguments.
354 # Parameters:
356 - `func : AttentionMatrixToMatrixFunc|None`
357 Either the function to decorate (in the no-arguments case) or `None` when used with arguments.
358 - `fmt : MatrixSaveFormat, keyword-only`
359 The format to save the matrix as. Defaults to `MATRIX_SAVE_FMT`.
360 - `normalize : bool, keyword-only`
361 Whether to normalize the matrix to range [0, 1]. Defaults to `False`.
362 - `cmap : str, keyword-only`
363 The colormap to use for the matrix. Defaults to `MATRIX_SVG_CMAP`.
364 - `diverging_colormap : bool`
365 if True and using a diverging colormap, ensures 0 values map to the center of the colormap
366 (defaults to False)
367 - `normalize_min : float|None`
368 if a float, then for `normalize=True` and `diverging_colormap=False`, the minimum value to normalize to (generally set this to zero?)
369 if `None`, then the minimum value of the matrix is used
370 if `diverging_colormap=True` OR `normalize=False`, this **must** be `None`
371 (defaults to `None`)
373 # Returns:
375 `AttentionMatrixFigureFunc|Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc]`
377 - `AttentionMatrixFigureFunc` if `func` is `AttentionMatrixToMatrixFunc` (no arguments case)
378 - `Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc]` if `func` is `None` -- returns the decorator which will then be applied to the (with arguments case)
380 # Usage:
382 ```python
383 @save_matrix_wrapper
384 def identity_matrix(matrix):
385 return matrix
387 @save_matrix_wrapper(normalize=True, fmt="png")
388 def scale_matrix(matrix):
389 return matrix * 2
391 @save_matrix_wrapper(normalize=True, cmap="plasma")
392 def scale_matrix(matrix):
393 return matrix * 2
395 ```
396 """
398 assert len(args) == 0, "This decorator only supports keyword arguments"
400 assert (
401 fmt in MatrixSaveFormat.__args__ # type: ignore[attr-defined]
402 ), f"Invalid format {fmt = }, must be one of {MatrixSaveFormat.__args__}" # type: ignore[attr-defined]
404 def decorator(
405 func: Callable[[AttentionMatrix], Matrix2D],
406 ) -> AttentionMatrixFigureFunc:
407 @functools.wraps(func)
408 def wrapped(attn_matrix: AttentionMatrix, save_dir: Path) -> None:
409 fig_path: Path = save_dir / f"{func.__name__}.{fmt}"
410 processed_matrix: Matrix2D = func(attn_matrix)
412 if fmt == "png":
413 processed_matrix_rgb: Matrix2Drgb = matrix_to_image_preprocess(
414 processed_matrix,
415 normalize=normalize,
416 cmap=cmap,
417 diverging_colormap=diverging_colormap,
418 normalize_min=normalize_min,
419 )
420 image_data: bytes = matrix2drgb_to_png_bytes(processed_matrix_rgb)
421 fig_path.write_bytes(image_data)
423 else:
424 svg_content: str = matrix_as_svg(
425 processed_matrix,
426 normalize=normalize,
427 cmap=cmap,
428 diverging_colormap=diverging_colormap,
429 normalize_min=normalize_min,
430 )
432 if fmt == "svgz":
433 with gzip.open(fig_path, "wt") as f:
434 f.write(svg_content)
436 else:
437 fig_path.write_text(svg_content, encoding="utf-8")
439 wrapped.figure_save_fmt = fmt # type: ignore[attr-defined]
441 return wrapped
443 if callable(func):
444 # Handle no-arguments case
445 return decorator(func)
446 else:
447 # Handle arguments case
448 return decorator