pattern_lens.figure_util
implements a bunch of types, default values, and templates which are useful for figure functions
notably, you can use the decorators matplotlib_figure_saver, save_matrix_wrapper to make your functions save figures
1"""implements a bunch of types, default values, and templates which are useful for figure functions 2 3notably, you can use the decorators `matplotlib_figure_saver`, `save_matrix_wrapper` to make your functions save figures 4""" 5 6import base64 7import functools 8import gzip 9import io 10from collections.abc import Callable, Sequence 11from pathlib import Path 12from typing import Literal, overload 13 14import matplotlib as mpl 15import matplotlib.pyplot as plt 16import numpy as np 17from jaxtyping import Float, UInt8 18from matplotlib.colors import Colormap 19from PIL import Image 20 21from pattern_lens.consts import AttentionMatrix 22 23AttentionMatrixFigureFunc = Callable[[AttentionMatrix, Path], None] 24"Type alias for a function that, given an attention matrix, saves one or more figures" 25 26Matrix2D = Float[np.ndarray, "n m"] 27"Type alias for a 2D matrix (plottable)" 28 29Matrix2Drgb = UInt8[np.ndarray, "n m rgb=3"] 30"Type alias for a 2D matrix with 3 channels (RGB)" 31 32AttentionMatrixToMatrixFunc = Callable[[AttentionMatrix], Matrix2D] 33"Type alias for a function that, given an attention matrix, returns a 2D matrix" 34 35MATPLOTLIB_FIGURE_FMT: str = "svgz" 36"format for saving matplotlib figures" 37 38MatrixSaveFormat = Literal["png", "svg", "svgz"] 39"Type alias for the format to save a matrix as when saving raw matrix, not matplotlib figure" 40 41MATRIX_SAVE_NORMALIZE: bool = False 42"default for whether to normalize the matrix to range [0, 1]" 43 44MATRIX_SAVE_CMAP: str = "viridis" 45"default colormap for saving matrices" 46 47MATRIX_SAVE_FMT: MatrixSaveFormat = "svgz" 48"default format for saving matrices" 49 50MATRIX_SAVE_SVG_TEMPLATE: str = """<svg xmlns="http://www.w3.org/2000/svg" width="{m}" height="{n}" viewBox="0 0 {m} {n}" image-rendering="pixelated"> <image href="data:image/png;base64,{png_base64}" width="{m}" height="{n}" /> </svg>""" 51"template for saving an `n` by `m` matrix as an svg/svgz" 52 53 54# TYPING: mypy hates it when we dont pass func=None or None as the first arg 55@overload # without keyword arguments, returns decorated function 56def matplotlib_figure_saver( 57 func: Callable[[AttentionMatrix, plt.Axes], None], 58) -> AttentionMatrixFigureFunc: ... 59@overload # with keyword arguments, returns decorator 60def matplotlib_figure_saver( 61 func: None = None, 62 fmt: str = MATPLOTLIB_FIGURE_FMT, 63) -> Callable[ 64 [Callable[[AttentionMatrix, plt.Axes], None]], 65 AttentionMatrixFigureFunc, 66]: ... 67def matplotlib_figure_saver( 68 func: Callable[[AttentionMatrix, plt.Axes], None] | None = None, 69 fmt: str = MATPLOTLIB_FIGURE_FMT, 70) -> ( 71 AttentionMatrixFigureFunc 72 | Callable[ 73 [Callable[[AttentionMatrix, plt.Axes], None]], 74 AttentionMatrixFigureFunc, 75 ] 76): 77 """decorator for functions which take an attention matrix and predefined `ax` object, making it save a figure 78 79 # Parameters: 80 - `func : Callable[[AttentionMatrix, plt.Axes], None]` 81 your function, which should take an attention matrix and predefined `ax` object 82 - `fmt : str` 83 format for saving matplotlib figures 84 (defaults to `MATPLOTLIB_FIGURE_FMT`) 85 86 # Returns: 87 - `AttentionMatrixFigureFunc` 88 your function, after we wrap it to save a figure 89 90 # Usage: 91 ```python 92 @register_attn_figure_func 93 @matplotlib_figure_saver 94 def raw(attn_matrix: AttentionMatrix, ax: plt.Axes) -> None: 95 ax.matshow(attn_matrix, cmap="viridis") 96 ax.set_title("Raw Attention Pattern") 97 ax.axis("off") 98 ``` 99 100 """ 101 102 def decorator( 103 func: Callable[[AttentionMatrix, plt.Axes], None], 104 fmt: str = fmt, 105 ) -> AttentionMatrixFigureFunc: 106 @functools.wraps(func) 107 def wrapped(attn_matrix: AttentionMatrix, save_dir: Path) -> None: 108 fig_path: Path = ( 109 save_dir / f"{getattr(func, '__name__', '<unknown>')}.{fmt}" 110 ) 111 112 fig, ax = plt.subplots(figsize=(10, 10)) 113 func(attn_matrix, ax) 114 plt.tight_layout() 115 plt.savefig(fig_path) 116 plt.close(fig) 117 118 wrapped.figure_save_fmt = fmt # type: ignore[attr-defined] 119 120 return wrapped 121 122 if callable(func): 123 # Handle no-arguments case 124 return decorator(func) 125 else: 126 # Handle arguments case 127 return decorator 128 129 130def matplotlib_multifigure_saver( 131 names: Sequence[str], 132 fmt: str = MATPLOTLIB_FIGURE_FMT, 133) -> Callable[ 134 # decorator takes in function 135 # which takes a matrix and a dictionary of axes corresponding to the names 136 [Callable[[AttentionMatrix, dict[str, plt.Axes]], None]], 137 # returns the decorated function 138 AttentionMatrixFigureFunc, 139]: 140 """decorate a function such that it saves multiple figures, one for each name in `names` 141 142 # Parameters: 143 - `names : Sequence[str]` 144 the names of the figures to save 145 - `fmt : str` 146 format for saving matplotlib figures 147 (defaults to `MATPLOTLIB_FIGURE_FMT`) 148 149 # Returns: 150 - `Callable[[Callable[[AttentionMatrix, dict[str, plt.Axes]], None], AttentionMatrixFigureFunc]` 151 the decorator, which will then be applied to the function 152 we expect the decorated function to take an attention pattern, and a dict of axes corresponding to the names 153 154 """ 155 156 def decorator( 157 func: Callable[[AttentionMatrix, dict[str, plt.Axes]], None], 158 ) -> AttentionMatrixFigureFunc: 159 func_name: str = getattr(func, "__name__", "<unknown>") 160 161 @functools.wraps(func) 162 def wrapped(attn_matrix: AttentionMatrix, save_dir: Path) -> None: 163 # set up axes and corresponding figures 164 axes_dict: dict[str, plt.Axes] = {} 165 figs_dict: dict[str, plt.Figure] = {} 166 167 # Create all figures and axes 168 for name in names: 169 fig, ax = plt.subplots(figsize=(10, 10)) 170 axes_dict[name] = ax 171 figs_dict[name] = fig 172 173 try: 174 # Run the function to make plots 175 func(attn_matrix, axes_dict) 176 177 # Save each figure 178 for name, fig_ in figs_dict.items(): 179 fig_path: Path = save_dir / f"{func_name}.{name}.{fmt}" 180 # TYPING: error: Item "SubFigure" of "Figure | SubFigure" has no attribute "tight_layout" [union-attr] 181 fig_.tight_layout() # type: ignore[union-attr] 182 # TYPING: error: Item "SubFigure" of "Figure | SubFigure" has no attribute "savefig" [union-attr] 183 fig_.savefig(fig_path) # type: ignore[union-attr] 184 finally: 185 # Always clean up figures, even if an error occurred 186 for fig in figs_dict.values(): 187 # TYPING: error: Argument 1 to "close" has incompatible type "Figure | SubFigure"; expected "int | str | Figure | Literal['all'] | None" [arg-type] 188 plt.close(fig) # type: ignore[arg-type] 189 190 # it doesn't normally have this attribute, but we're adding it 191 wrapped.figure_save_fmt = fmt # type: ignore[attr-defined] 192 193 return wrapped 194 195 return decorator 196 197 198def matrix_to_image_preprocess( 199 matrix: Matrix2D, 200 normalize: bool = False, 201 cmap: str | Colormap = "viridis", 202 diverging_colormap: bool = False, 203 normalize_min: float | None = None, 204) -> Matrix2Drgb: 205 """preprocess a 2D matrix into a plottable heatmap image 206 207 # Parameters: 208 - `matrix : Matrix2D` 209 input matrix 210 - `normalize : bool` 211 whether to normalize the matrix to range [0, 1] 212 (defaults to `MATRIX_SAVE_NORMALIZE`) 213 - `cmap : str|Colormap` 214 the colormap to use for the matrix 215 (defaults to `MATRIX_SAVE_CMAP`) 216 - `diverging_colormap : bool` 217 if True and using a diverging colormap, ensures 0 values map to the center of the colormap 218 (defaults to False) 219 - `normalize_min : float|None` 220 if a float, then for `normalize=True` and `diverging_colormap=False`, the minimum value to normalize to (generally set this to zero?). 221 if `None`, then the minimum value of the matrix is used. 222 if `diverging_colormap=True` OR `normalize=False`, this **must** be `None`. 223 (defaults to `None`) 224 225 # Returns: 226 - `Matrix2Drgb` 227 """ 228 # check dims (2 is not that magic of a value here, hence noqa) 229 assert matrix.ndim == 2, f"Matrix must be 2D, got {matrix.ndim = }" # noqa: PLR2004 230 231 # check matrix is not empty 232 assert matrix.size > 0, "Matrix cannot be empty" 233 234 if normalize_min is not None: 235 assert not diverging_colormap, ( 236 "normalize_min cannot be used with diverging_colormap=True" 237 ) 238 assert normalize, "normalize_min cannot be used with normalize=False" 239 240 # Normalize the matrix to range [0, 1] 241 normalized_matrix: Matrix2D 242 if normalize: 243 if diverging_colormap: 244 # For diverging colormaps, we want to center around 0 245 max_abs: float = max(abs(matrix.max()), abs(matrix.min())) 246 normalized_matrix = (matrix / (2 * max_abs)) + 0.5 247 else: 248 max_val: float = matrix.max() 249 min_val: float 250 if normalize_min is not None: 251 min_val = normalize_min 252 assert min_val < max_val, "normalize_min must be less than matrix max" 253 assert min_val >= matrix.min(), ( 254 "normalize_min must less than matrix min" 255 ) 256 else: 257 min_val = matrix.min() 258 259 normalized_matrix = (matrix - min_val) / (max_val - min_val) 260 else: 261 if diverging_colormap: 262 assert matrix.min() >= -1 and matrix.max() <= 1, ( # noqa: PT018 263 "For diverging colormaps without normalization, matrix values must be in range [-1, 1]" 264 ) 265 normalized_matrix = matrix 266 else: 267 assert matrix.min() >= 0 and matrix.max() <= 1, ( # noqa: PT018 268 "Matrix values must be in range [0, 1], or normalize must be True" 269 ) 270 normalized_matrix = matrix 271 272 # get the colormap 273 cmap_: Colormap 274 if isinstance(cmap, str): 275 cmap_ = mpl.colormaps[cmap] 276 elif isinstance(cmap, Colormap): 277 cmap_ = cmap 278 else: 279 msg = f"Invalid type for {cmap = }, {type(cmap) = }, must be str or Colormap" 280 raise TypeError( 281 msg, 282 ) 283 284 # Apply the colormap 285 rgb_matrix: Float[np.ndarray, "n m channels=3"] = ( 286 cmap_(normalized_matrix)[:, :, :3] * 255 287 ).astype(np.uint8) # Drop alpha channel 288 289 assert rgb_matrix.shape == ( 290 matrix.shape[0], 291 matrix.shape[1], 292 3, 293 ), f"Matrix after colormap must have 3 channels, got {rgb_matrix.shape = }" 294 295 return rgb_matrix 296 297 298@overload 299def matrix2drgb_to_png_bytes(matrix: Matrix2Drgb, buffer: None = None) -> bytes: ... 300@overload 301def matrix2drgb_to_png_bytes(matrix: Matrix2Drgb, buffer: io.BytesIO) -> None: ... 302def matrix2drgb_to_png_bytes( 303 matrix: Matrix2Drgb, 304 buffer: io.BytesIO | None = None, 305) -> bytes | None: 306 """Convert a `Matrix2Drgb` to valid PNG bytes via PIL 307 308 - if `buffer` is provided, it will write the PNG bytes to the buffer and return `None` 309 - if `buffer` is not provided, it will return the PNG bytes 310 311 # Parameters: 312 - `matrix : Matrix2Drgb` 313 - `buffer : io.BytesIO | None` 314 (defaults to `None`, in which case it will return the PNG bytes) 315 316 # Returns: 317 - `bytes|None` 318 `bytes` if `buffer` is `None`, otherwise `None` 319 """ 320 pil_img: Image.Image = Image.fromarray(matrix, mode="RGB") 321 if buffer is None: 322 buffer = io.BytesIO() 323 pil_img.save(buffer, format="PNG") 324 buffer.seek(0) 325 return buffer.read() 326 else: 327 pil_img.save(buffer, format="PNG") 328 return None 329 330 331def matrix_as_svg( 332 matrix: Matrix2D, 333 normalize: bool = MATRIX_SAVE_NORMALIZE, 334 cmap: str | Colormap = MATRIX_SAVE_CMAP, 335 diverging_colormap: bool = False, 336 normalize_min: float | None = None, 337) -> str: 338 """quickly convert a 2D matrix to an SVG image, without matplotlib 339 340 # Parameters: 341 - `matrix : Float[np.ndarray, 'n m']` 342 a 2D matrix to convert to an SVG image 343 - `normalize : bool` 344 whether to normalize the matrix to range [0, 1]. if it's not in the range [0, 1], this must be `True` or it will raise an `AssertionError` 345 (defaults to `False`) 346 - `cmap : str` 347 the colormap to use for the matrix -- will look up in `matplotlib.colormaps` if it's a string 348 (defaults to `"viridis"`) 349 - `diverging_colormap : bool` 350 if True and using a diverging colormap, ensures 0 values map to the center of the colormap 351 (defaults to False) 352 - `normalize_min : float|None` 353 if a float, then for `normalize=True` and `diverging_colormap=False`, the minimum value to normalize to (generally set this to zero?) 354 if `None`, then the minimum value of the matrix is used 355 if `diverging_colormap=True` OR `normalize=False`, this **must** be `None` 356 (defaults to `None`) 357 358 359 # Returns: 360 - `str` 361 the SVG content for the matrix 362 """ 363 # Get the dimensions of the matrix 364 assert matrix.ndim == 2, f"Matrix must be 2D, got {matrix.shape = }" # noqa: PLR2004 365 m, n = matrix.shape 366 367 # Preprocess the matrix into an RGB image 368 matrix_rgb: Matrix2Drgb = matrix_to_image_preprocess( 369 matrix, 370 normalize=normalize, 371 cmap=cmap, 372 diverging_colormap=diverging_colormap, 373 normalize_min=normalize_min, 374 ) 375 376 # Convert the RGB image to PNG bytes 377 image_data: bytes = matrix2drgb_to_png_bytes(matrix_rgb) 378 379 # Encode the PNG bytes as base64 380 png_base64: str = base64.b64encode(image_data).decode("utf-8") 381 382 # Generate the SVG content 383 svg_content: str = MATRIX_SAVE_SVG_TEMPLATE.format(m=m, n=n, png_base64=png_base64) 384 385 return svg_content 386 387 388@overload # with keyword arguments, returns decorator 389def save_matrix_wrapper( 390 func: None = None, 391 *args: tuple[()], 392 fmt: MatrixSaveFormat = MATRIX_SAVE_FMT, 393 normalize: bool = MATRIX_SAVE_NORMALIZE, 394 cmap: str | Colormap = MATRIX_SAVE_CMAP, 395 diverging_colormap: bool = False, 396 normalize_min: float | None = None, 397) -> Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc]: ... 398@overload # without keyword arguments, returns decorated function 399def save_matrix_wrapper( 400 func: AttentionMatrixToMatrixFunc, 401 *args: tuple[()], 402 fmt: MatrixSaveFormat = MATRIX_SAVE_FMT, 403 normalize: bool = MATRIX_SAVE_NORMALIZE, 404 cmap: str | Colormap = MATRIX_SAVE_CMAP, 405 diverging_colormap: bool = False, 406 normalize_min: float | None = None, 407) -> AttentionMatrixFigureFunc: ... 408def save_matrix_wrapper( 409 func: AttentionMatrixToMatrixFunc | None = None, 410 *args, 411 fmt: MatrixSaveFormat = MATRIX_SAVE_FMT, 412 normalize: bool = MATRIX_SAVE_NORMALIZE, 413 cmap: str | Colormap = MATRIX_SAVE_CMAP, 414 diverging_colormap: bool = False, 415 normalize_min: float | None = None, 416) -> ( 417 AttentionMatrixFigureFunc 418 | Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc] 419): 420 """Decorator for functions that process an attention matrix and save it as an SVGZ image. 421 422 Can handle both argumentless usage and with arguments. 423 424 # Parameters: 425 426 - `func : AttentionMatrixToMatrixFunc|None` 427 Either the function to decorate (in the no-arguments case) or `None` when used with arguments. 428 - `fmt : MatrixSaveFormat, keyword-only` 429 The format to save the matrix as. Defaults to `MATRIX_SAVE_FMT`. 430 - `normalize : bool, keyword-only` 431 Whether to normalize the matrix to range [0, 1]. Defaults to `False`. 432 - `cmap : str, keyword-only` 433 The colormap to use for the matrix. Defaults to `MATRIX_SVG_CMAP`. 434 - `diverging_colormap : bool` 435 if True and using a diverging colormap, ensures 0 values map to the center of the colormap 436 (defaults to False) 437 - `normalize_min : float|None` 438 if a float, then for `normalize=True` and `diverging_colormap=False`, the minimum value to normalize to (generally set this to zero?) 439 if `None`, then the minimum value of the matrix is used 440 if `diverging_colormap=True` OR `normalize=False`, this **must** be `None` 441 (defaults to `None`) 442 443 # Returns: 444 445 `AttentionMatrixFigureFunc|Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc]` 446 447 - `AttentionMatrixFigureFunc` if `func` is `AttentionMatrixToMatrixFunc` (no arguments case) 448 - `Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc]` if `func` is `None` -- returns the decorator which will then be applied to the (with arguments case) 449 450 # Usage: 451 452 ```python 453 @save_matrix_wrapper 454 def identity_matrix(matrix): 455 return matrix 456 457 @save_matrix_wrapper(normalize=True, fmt="png") 458 def scale_matrix(matrix): 459 return matrix * 2 460 461 @save_matrix_wrapper(normalize=True, cmap="plasma") 462 def scale_matrix(matrix): 463 return matrix * 2 464 ``` 465 466 """ 467 assert len(args) == 0, "This decorator only supports keyword arguments" 468 469 assert ( 470 fmt in MatrixSaveFormat.__args__ # type: ignore[attr-defined] 471 ), f"Invalid format {fmt = }, must be one of {MatrixSaveFormat.__args__}" # type: ignore[attr-defined] 472 473 def decorator( 474 func: Callable[[AttentionMatrix], Matrix2D], 475 ) -> AttentionMatrixFigureFunc: 476 @functools.wraps(func) 477 def wrapped(attn_matrix: AttentionMatrix, save_dir: Path) -> None: 478 fig_path: Path = ( 479 save_dir / f"{getattr(func, '__name__', '<unknown>')}.{fmt}" 480 ) 481 processed_matrix: Matrix2D = func(attn_matrix) 482 483 if fmt == "png": 484 processed_matrix_rgb: Matrix2Drgb = matrix_to_image_preprocess( 485 processed_matrix, 486 normalize=normalize, 487 cmap=cmap, 488 diverging_colormap=diverging_colormap, 489 normalize_min=normalize_min, 490 ) 491 image_data: bytes = matrix2drgb_to_png_bytes(processed_matrix_rgb) 492 fig_path.write_bytes(image_data) 493 494 else: 495 svg_content: str = matrix_as_svg( 496 processed_matrix, 497 normalize=normalize, 498 cmap=cmap, 499 diverging_colormap=diverging_colormap, 500 normalize_min=normalize_min, 501 ) 502 503 if fmt == "svgz": 504 with gzip.open(fig_path, "wt") as f: 505 f.write(svg_content) 506 507 else: 508 fig_path.write_text(svg_content, encoding="utf-8") 509 510 wrapped.figure_save_fmt = fmt # type: ignore[attr-defined] 511 512 return wrapped 513 514 if callable(func): 515 # Handle no-arguments case 516 return decorator(func) 517 else: 518 # Handle arguments case 519 return decorator
Type alias for a function that, given an attention matrix, saves one or more figures
Type alias for a 2D matrix (plottable)
Type alias for a 2D matrix with 3 channels (RGB)
Type alias for a function that, given an attention matrix, returns a 2D matrix
format for saving matplotlib figures
Type alias for the format to save a matrix as when saving raw matrix, not matplotlib figure
default for whether to normalize the matrix to range [0, 1]
default colormap for saving matrices
default format for saving matrices
template for saving an n by m matrix as an svg/svgz
68def matplotlib_figure_saver( 69 func: Callable[[AttentionMatrix, plt.Axes], None] | None = None, 70 fmt: str = MATPLOTLIB_FIGURE_FMT, 71) -> ( 72 AttentionMatrixFigureFunc 73 | Callable[ 74 [Callable[[AttentionMatrix, plt.Axes], None]], 75 AttentionMatrixFigureFunc, 76 ] 77): 78 """decorator for functions which take an attention matrix and predefined `ax` object, making it save a figure 79 80 # Parameters: 81 - `func : Callable[[AttentionMatrix, plt.Axes], None]` 82 your function, which should take an attention matrix and predefined `ax` object 83 - `fmt : str` 84 format for saving matplotlib figures 85 (defaults to `MATPLOTLIB_FIGURE_FMT`) 86 87 # Returns: 88 - `AttentionMatrixFigureFunc` 89 your function, after we wrap it to save a figure 90 91 # Usage: 92 ```python 93 @register_attn_figure_func 94 @matplotlib_figure_saver 95 def raw(attn_matrix: AttentionMatrix, ax: plt.Axes) -> None: 96 ax.matshow(attn_matrix, cmap="viridis") 97 ax.set_title("Raw Attention Pattern") 98 ax.axis("off") 99 ``` 100 101 """ 102 103 def decorator( 104 func: Callable[[AttentionMatrix, plt.Axes], None], 105 fmt: str = fmt, 106 ) -> AttentionMatrixFigureFunc: 107 @functools.wraps(func) 108 def wrapped(attn_matrix: AttentionMatrix, save_dir: Path) -> None: 109 fig_path: Path = ( 110 save_dir / f"{getattr(func, '__name__', '<unknown>')}.{fmt}" 111 ) 112 113 fig, ax = plt.subplots(figsize=(10, 10)) 114 func(attn_matrix, ax) 115 plt.tight_layout() 116 plt.savefig(fig_path) 117 plt.close(fig) 118 119 wrapped.figure_save_fmt = fmt # type: ignore[attr-defined] 120 121 return wrapped 122 123 if callable(func): 124 # Handle no-arguments case 125 return decorator(func) 126 else: 127 # Handle arguments case 128 return decorator
decorator for functions which take an attention matrix and predefined ax object, making it save a figure
Parameters:
func : Callable[[AttentionMatrix, plt.Axes], None]your function, which should take an attention matrix and predefinedaxobjectfmt : strformat for saving matplotlib figures (defaults toMATPLOTLIB_FIGURE_FMT)
Returns:
AttentionMatrixFigureFuncyour function, after we wrap it to save a figure
Usage:
@register_attn_figure_func
@matplotlib_figure_saver
def raw(attn_matrix: AttentionMatrix, ax: plt.Axes) -> None:
ax.matshow(attn_matrix, cmap="viridis")
ax.set_title("Raw Attention Pattern")
ax.axis("off")
131def matplotlib_multifigure_saver( 132 names: Sequence[str], 133 fmt: str = MATPLOTLIB_FIGURE_FMT, 134) -> Callable[ 135 # decorator takes in function 136 # which takes a matrix and a dictionary of axes corresponding to the names 137 [Callable[[AttentionMatrix, dict[str, plt.Axes]], None]], 138 # returns the decorated function 139 AttentionMatrixFigureFunc, 140]: 141 """decorate a function such that it saves multiple figures, one for each name in `names` 142 143 # Parameters: 144 - `names : Sequence[str]` 145 the names of the figures to save 146 - `fmt : str` 147 format for saving matplotlib figures 148 (defaults to `MATPLOTLIB_FIGURE_FMT`) 149 150 # Returns: 151 - `Callable[[Callable[[AttentionMatrix, dict[str, plt.Axes]], None], AttentionMatrixFigureFunc]` 152 the decorator, which will then be applied to the function 153 we expect the decorated function to take an attention pattern, and a dict of axes corresponding to the names 154 155 """ 156 157 def decorator( 158 func: Callable[[AttentionMatrix, dict[str, plt.Axes]], None], 159 ) -> AttentionMatrixFigureFunc: 160 func_name: str = getattr(func, "__name__", "<unknown>") 161 162 @functools.wraps(func) 163 def wrapped(attn_matrix: AttentionMatrix, save_dir: Path) -> None: 164 # set up axes and corresponding figures 165 axes_dict: dict[str, plt.Axes] = {} 166 figs_dict: dict[str, plt.Figure] = {} 167 168 # Create all figures and axes 169 for name in names: 170 fig, ax = plt.subplots(figsize=(10, 10)) 171 axes_dict[name] = ax 172 figs_dict[name] = fig 173 174 try: 175 # Run the function to make plots 176 func(attn_matrix, axes_dict) 177 178 # Save each figure 179 for name, fig_ in figs_dict.items(): 180 fig_path: Path = save_dir / f"{func_name}.{name}.{fmt}" 181 # TYPING: error: Item "SubFigure" of "Figure | SubFigure" has no attribute "tight_layout" [union-attr] 182 fig_.tight_layout() # type: ignore[union-attr] 183 # TYPING: error: Item "SubFigure" of "Figure | SubFigure" has no attribute "savefig" [union-attr] 184 fig_.savefig(fig_path) # type: ignore[union-attr] 185 finally: 186 # Always clean up figures, even if an error occurred 187 for fig in figs_dict.values(): 188 # TYPING: error: Argument 1 to "close" has incompatible type "Figure | SubFigure"; expected "int | str | Figure | Literal['all'] | None" [arg-type] 189 plt.close(fig) # type: ignore[arg-type] 190 191 # it doesn't normally have this attribute, but we're adding it 192 wrapped.figure_save_fmt = fmt # type: ignore[attr-defined] 193 194 return wrapped 195 196 return decorator
decorate a function such that it saves multiple figures, one for each name in names
Parameters:
names : Sequence[str]the names of the figures to savefmt : strformat for saving matplotlib figures (defaults toMATPLOTLIB_FIGURE_FMT)
Returns:
Callable[[Callable[[AttentionMatrix, dict[str, plt.Axes]], None], AttentionMatrixFigureFunc]the decorator, which will then be applied to the function we expect the decorated function to take an attention pattern, and a dict of axes corresponding to the names
199def matrix_to_image_preprocess( 200 matrix: Matrix2D, 201 normalize: bool = False, 202 cmap: str | Colormap = "viridis", 203 diverging_colormap: bool = False, 204 normalize_min: float | None = None, 205) -> Matrix2Drgb: 206 """preprocess a 2D matrix into a plottable heatmap image 207 208 # Parameters: 209 - `matrix : Matrix2D` 210 input matrix 211 - `normalize : bool` 212 whether to normalize the matrix to range [0, 1] 213 (defaults to `MATRIX_SAVE_NORMALIZE`) 214 - `cmap : str|Colormap` 215 the colormap to use for the matrix 216 (defaults to `MATRIX_SAVE_CMAP`) 217 - `diverging_colormap : bool` 218 if True and using a diverging colormap, ensures 0 values map to the center of the colormap 219 (defaults to False) 220 - `normalize_min : float|None` 221 if a float, then for `normalize=True` and `diverging_colormap=False`, the minimum value to normalize to (generally set this to zero?). 222 if `None`, then the minimum value of the matrix is used. 223 if `diverging_colormap=True` OR `normalize=False`, this **must** be `None`. 224 (defaults to `None`) 225 226 # Returns: 227 - `Matrix2Drgb` 228 """ 229 # check dims (2 is not that magic of a value here, hence noqa) 230 assert matrix.ndim == 2, f"Matrix must be 2D, got {matrix.ndim = }" # noqa: PLR2004 231 232 # check matrix is not empty 233 assert matrix.size > 0, "Matrix cannot be empty" 234 235 if normalize_min is not None: 236 assert not diverging_colormap, ( 237 "normalize_min cannot be used with diverging_colormap=True" 238 ) 239 assert normalize, "normalize_min cannot be used with normalize=False" 240 241 # Normalize the matrix to range [0, 1] 242 normalized_matrix: Matrix2D 243 if normalize: 244 if diverging_colormap: 245 # For diverging colormaps, we want to center around 0 246 max_abs: float = max(abs(matrix.max()), abs(matrix.min())) 247 normalized_matrix = (matrix / (2 * max_abs)) + 0.5 248 else: 249 max_val: float = matrix.max() 250 min_val: float 251 if normalize_min is not None: 252 min_val = normalize_min 253 assert min_val < max_val, "normalize_min must be less than matrix max" 254 assert min_val >= matrix.min(), ( 255 "normalize_min must less than matrix min" 256 ) 257 else: 258 min_val = matrix.min() 259 260 normalized_matrix = (matrix - min_val) / (max_val - min_val) 261 else: 262 if diverging_colormap: 263 assert matrix.min() >= -1 and matrix.max() <= 1, ( # noqa: PT018 264 "For diverging colormaps without normalization, matrix values must be in range [-1, 1]" 265 ) 266 normalized_matrix = matrix 267 else: 268 assert matrix.min() >= 0 and matrix.max() <= 1, ( # noqa: PT018 269 "Matrix values must be in range [0, 1], or normalize must be True" 270 ) 271 normalized_matrix = matrix 272 273 # get the colormap 274 cmap_: Colormap 275 if isinstance(cmap, str): 276 cmap_ = mpl.colormaps[cmap] 277 elif isinstance(cmap, Colormap): 278 cmap_ = cmap 279 else: 280 msg = f"Invalid type for {cmap = }, {type(cmap) = }, must be str or Colormap" 281 raise TypeError( 282 msg, 283 ) 284 285 # Apply the colormap 286 rgb_matrix: Float[np.ndarray, "n m channels=3"] = ( 287 cmap_(normalized_matrix)[:, :, :3] * 255 288 ).astype(np.uint8) # Drop alpha channel 289 290 assert rgb_matrix.shape == ( 291 matrix.shape[0], 292 matrix.shape[1], 293 3, 294 ), f"Matrix after colormap must have 3 channels, got {rgb_matrix.shape = }" 295 296 return rgb_matrix
preprocess a 2D matrix into a plottable heatmap image
Parameters:
matrix : Matrix2Dinput matrixnormalize : boolwhether to normalize the matrix to range [0, 1] (defaults toMATRIX_SAVE_NORMALIZE)cmap : str|Colormapthe colormap to use for the matrix (defaults toMATRIX_SAVE_CMAP)diverging_colormap : boolif True and using a diverging colormap, ensures 0 values map to the center of the colormap (defaults to False)normalize_min : float|Noneif a float, then fornormalize=Trueanddiverging_colormap=False, the minimum value to normalize to (generally set this to zero?). ifNone, then the minimum value of the matrix is used. ifdiverging_colormap=TrueORnormalize=False, this must beNone. (defaults toNone)
Returns:
303def matrix2drgb_to_png_bytes( 304 matrix: Matrix2Drgb, 305 buffer: io.BytesIO | None = None, 306) -> bytes | None: 307 """Convert a `Matrix2Drgb` to valid PNG bytes via PIL 308 309 - if `buffer` is provided, it will write the PNG bytes to the buffer and return `None` 310 - if `buffer` is not provided, it will return the PNG bytes 311 312 # Parameters: 313 - `matrix : Matrix2Drgb` 314 - `buffer : io.BytesIO | None` 315 (defaults to `None`, in which case it will return the PNG bytes) 316 317 # Returns: 318 - `bytes|None` 319 `bytes` if `buffer` is `None`, otherwise `None` 320 """ 321 pil_img: Image.Image = Image.fromarray(matrix, mode="RGB") 322 if buffer is None: 323 buffer = io.BytesIO() 324 pil_img.save(buffer, format="PNG") 325 buffer.seek(0) 326 return buffer.read() 327 else: 328 pil_img.save(buffer, format="PNG") 329 return None
Convert a Matrix2Drgb to valid PNG bytes via PIL
- if
bufferis provided, it will write the PNG bytes to the buffer and returnNone - if
bufferis not provided, it will return the PNG bytes
Parameters:
matrix : Matrix2Drgbbuffer : io.BytesIO | None(defaults toNone, in which case it will return the PNG bytes)
Returns:
bytes|NonebytesifbufferisNone, otherwiseNone
332def matrix_as_svg( 333 matrix: Matrix2D, 334 normalize: bool = MATRIX_SAVE_NORMALIZE, 335 cmap: str | Colormap = MATRIX_SAVE_CMAP, 336 diverging_colormap: bool = False, 337 normalize_min: float | None = None, 338) -> str: 339 """quickly convert a 2D matrix to an SVG image, without matplotlib 340 341 # Parameters: 342 - `matrix : Float[np.ndarray, 'n m']` 343 a 2D matrix to convert to an SVG image 344 - `normalize : bool` 345 whether to normalize the matrix to range [0, 1]. if it's not in the range [0, 1], this must be `True` or it will raise an `AssertionError` 346 (defaults to `False`) 347 - `cmap : str` 348 the colormap to use for the matrix -- will look up in `matplotlib.colormaps` if it's a string 349 (defaults to `"viridis"`) 350 - `diverging_colormap : bool` 351 if True and using a diverging colormap, ensures 0 values map to the center of the colormap 352 (defaults to False) 353 - `normalize_min : float|None` 354 if a float, then for `normalize=True` and `diverging_colormap=False`, the minimum value to normalize to (generally set this to zero?) 355 if `None`, then the minimum value of the matrix is used 356 if `diverging_colormap=True` OR `normalize=False`, this **must** be `None` 357 (defaults to `None`) 358 359 360 # Returns: 361 - `str` 362 the SVG content for the matrix 363 """ 364 # Get the dimensions of the matrix 365 assert matrix.ndim == 2, f"Matrix must be 2D, got {matrix.shape = }" # noqa: PLR2004 366 m, n = matrix.shape 367 368 # Preprocess the matrix into an RGB image 369 matrix_rgb: Matrix2Drgb = matrix_to_image_preprocess( 370 matrix, 371 normalize=normalize, 372 cmap=cmap, 373 diverging_colormap=diverging_colormap, 374 normalize_min=normalize_min, 375 ) 376 377 # Convert the RGB image to PNG bytes 378 image_data: bytes = matrix2drgb_to_png_bytes(matrix_rgb) 379 380 # Encode the PNG bytes as base64 381 png_base64: str = base64.b64encode(image_data).decode("utf-8") 382 383 # Generate the SVG content 384 svg_content: str = MATRIX_SAVE_SVG_TEMPLATE.format(m=m, n=n, png_base64=png_base64) 385 386 return svg_content
quickly convert a 2D matrix to an SVG image, without matplotlib
Parameters:
matrix : Float[np.ndarray, 'n m']a 2D matrix to convert to an SVG imagenormalize : boolwhether to normalize the matrix to range [0, 1]. if it's not in the range [0, 1], this must beTrueor it will raise anAssertionError(defaults toFalse)cmap : strthe colormap to use for the matrix -- will look up inmatplotlib.colormapsif it's a string (defaults to"viridis")diverging_colormap : boolif True and using a diverging colormap, ensures 0 values map to the center of the colormap (defaults to False)normalize_min : float|Noneif a float, then fornormalize=Trueanddiverging_colormap=False, the minimum value to normalize to (generally set this to zero?) ifNone, then the minimum value of the matrix is used ifdiverging_colormap=TrueORnormalize=False, this must beNone(defaults toNone)
Returns:
strthe SVG content for the matrix
409def save_matrix_wrapper( 410 func: AttentionMatrixToMatrixFunc | None = None, 411 *args, 412 fmt: MatrixSaveFormat = MATRIX_SAVE_FMT, 413 normalize: bool = MATRIX_SAVE_NORMALIZE, 414 cmap: str | Colormap = MATRIX_SAVE_CMAP, 415 diverging_colormap: bool = False, 416 normalize_min: float | None = None, 417) -> ( 418 AttentionMatrixFigureFunc 419 | Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc] 420): 421 """Decorator for functions that process an attention matrix and save it as an SVGZ image. 422 423 Can handle both argumentless usage and with arguments. 424 425 # Parameters: 426 427 - `func : AttentionMatrixToMatrixFunc|None` 428 Either the function to decorate (in the no-arguments case) or `None` when used with arguments. 429 - `fmt : MatrixSaveFormat, keyword-only` 430 The format to save the matrix as. Defaults to `MATRIX_SAVE_FMT`. 431 - `normalize : bool, keyword-only` 432 Whether to normalize the matrix to range [0, 1]. Defaults to `False`. 433 - `cmap : str, keyword-only` 434 The colormap to use for the matrix. Defaults to `MATRIX_SVG_CMAP`. 435 - `diverging_colormap : bool` 436 if True and using a diverging colormap, ensures 0 values map to the center of the colormap 437 (defaults to False) 438 - `normalize_min : float|None` 439 if a float, then for `normalize=True` and `diverging_colormap=False`, the minimum value to normalize to (generally set this to zero?) 440 if `None`, then the minimum value of the matrix is used 441 if `diverging_colormap=True` OR `normalize=False`, this **must** be `None` 442 (defaults to `None`) 443 444 # Returns: 445 446 `AttentionMatrixFigureFunc|Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc]` 447 448 - `AttentionMatrixFigureFunc` if `func` is `AttentionMatrixToMatrixFunc` (no arguments case) 449 - `Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc]` if `func` is `None` -- returns the decorator which will then be applied to the (with arguments case) 450 451 # Usage: 452 453 ```python 454 @save_matrix_wrapper 455 def identity_matrix(matrix): 456 return matrix 457 458 @save_matrix_wrapper(normalize=True, fmt="png") 459 def scale_matrix(matrix): 460 return matrix * 2 461 462 @save_matrix_wrapper(normalize=True, cmap="plasma") 463 def scale_matrix(matrix): 464 return matrix * 2 465 ``` 466 467 """ 468 assert len(args) == 0, "This decorator only supports keyword arguments" 469 470 assert ( 471 fmt in MatrixSaveFormat.__args__ # type: ignore[attr-defined] 472 ), f"Invalid format {fmt = }, must be one of {MatrixSaveFormat.__args__}" # type: ignore[attr-defined] 473 474 def decorator( 475 func: Callable[[AttentionMatrix], Matrix2D], 476 ) -> AttentionMatrixFigureFunc: 477 @functools.wraps(func) 478 def wrapped(attn_matrix: AttentionMatrix, save_dir: Path) -> None: 479 fig_path: Path = ( 480 save_dir / f"{getattr(func, '__name__', '<unknown>')}.{fmt}" 481 ) 482 processed_matrix: Matrix2D = func(attn_matrix) 483 484 if fmt == "png": 485 processed_matrix_rgb: Matrix2Drgb = matrix_to_image_preprocess( 486 processed_matrix, 487 normalize=normalize, 488 cmap=cmap, 489 diverging_colormap=diverging_colormap, 490 normalize_min=normalize_min, 491 ) 492 image_data: bytes = matrix2drgb_to_png_bytes(processed_matrix_rgb) 493 fig_path.write_bytes(image_data) 494 495 else: 496 svg_content: str = matrix_as_svg( 497 processed_matrix, 498 normalize=normalize, 499 cmap=cmap, 500 diverging_colormap=diverging_colormap, 501 normalize_min=normalize_min, 502 ) 503 504 if fmt == "svgz": 505 with gzip.open(fig_path, "wt") as f: 506 f.write(svg_content) 507 508 else: 509 fig_path.write_text(svg_content, encoding="utf-8") 510 511 wrapped.figure_save_fmt = fmt # type: ignore[attr-defined] 512 513 return wrapped 514 515 if callable(func): 516 # Handle no-arguments case 517 return decorator(func) 518 else: 519 # Handle arguments case 520 return decorator
Decorator for functions that process an attention matrix and save it as an SVGZ image.
Can handle both argumentless usage and with arguments.
Parameters:
func : AttentionMatrixToMatrixFunc|NoneEither the function to decorate (in the no-arguments case) orNonewhen used with arguments.fmt : MatrixSaveFormat, keyword-onlyThe format to save the matrix as. Defaults toMATRIX_SAVE_FMT.normalize : bool, keyword-onlyWhether to normalize the matrix to range [0, 1]. Defaults toFalse.cmap : str, keyword-onlyThe colormap to use for the matrix. Defaults toMATRIX_SVG_CMAP.diverging_colormap : boolif True and using a diverging colormap, ensures 0 values map to the center of the colormap (defaults to False)normalize_min : float|Noneif a float, then fornormalize=Trueanddiverging_colormap=False, the minimum value to normalize to (generally set this to zero?) ifNone, then the minimum value of the matrix is used ifdiverging_colormap=TrueORnormalize=False, this must beNone(defaults toNone)
Returns:
AttentionMatrixFigureFunc|Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc]
AttentionMatrixFigureFunciffuncisAttentionMatrixToMatrixFunc(no arguments case)Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc]iffuncisNone-- returns the decorator which will then be applied to the (with arguments case)
Usage:
@save_matrix_wrapper
def identity_matrix(matrix):
return matrix
@save_matrix_wrapper(normalize=True, fmt="png")
def scale_matrix(matrix):
return matrix * 2
@save_matrix_wrapper(normalize=True, cmap="plasma")
def scale_matrix(matrix):
return matrix * 2