pattern_lens.figure_util
implements a bunch of types, default values, and templates which are useful for figure functions
notably, you can use the decorators matplotlib_figure_saver, save_matrix_wrapper to make your functions save figures
1"""implements a bunch of types, default values, and templates which are useful for figure functions 2 3notably, you can use the decorators `matplotlib_figure_saver`, `save_matrix_wrapper` to make your functions save figures 4""" 5 6import base64 7import functools 8import gzip 9import io 10from collections.abc import Callable, Sequence 11from pathlib import Path 12from typing import Literal, overload 13 14import matplotlib as mpl 15import matplotlib.pyplot as plt 16import numpy as np 17from jaxtyping import Float, UInt8 18from matplotlib.colors import Colormap 19from PIL import Image 20 21from pattern_lens.consts import AttentionMatrix 22 23AttentionMatrixFigureFunc = Callable[[AttentionMatrix, Path], None] 24"Type alias for a function that, given an attention matrix, saves one or more figures" 25 26Matrix2D = Float[np.ndarray, "n m"] 27"Type alias for a 2D matrix (plottable)" 28 29Matrix2Drgb = UInt8[np.ndarray, "n m rgb=3"] 30"Type alias for a 2D matrix with 3 channels (RGB)" 31 32AttentionMatrixToMatrixFunc = Callable[[AttentionMatrix], Matrix2D] 33"Type alias for a function that, given an attention matrix, returns a 2D matrix" 34 35MATPLOTLIB_FIGURE_FMT: str = "svgz" 36"format for saving matplotlib figures" 37 38MatrixSaveFormat = Literal["png", "svg", "svgz"] 39"Type alias for the format to save a matrix as when saving raw matrix, not matplotlib figure" 40 41MATRIX_SAVE_NORMALIZE: bool = False 42"default for whether to normalize the matrix to range [0, 1]" 43 44MATRIX_SAVE_CMAP: str = "viridis" 45"default colormap for saving matrices" 46 47MATRIX_SAVE_FMT: MatrixSaveFormat = "svgz" 48"default format for saving matrices" 49 50MATRIX_SAVE_SVG_TEMPLATE: str = """<svg xmlns="http://www.w3.org/2000/svg" width="{m}" height="{n}" viewBox="0 0 {m} {n}" image-rendering="pixelated"> <image href="data:image/png;base64,{png_base64}" width="{m}" height="{n}" /> </svg>""" 51"template for saving an `n` by `m` matrix as an svg/svgz" 52 53 54# TYPING: mypy hates it when we dont pass func=None or None as the first arg 55@overload # without keyword arguments, returns decorated function 56def matplotlib_figure_saver( 57 func: Callable[[AttentionMatrix, plt.Axes], None], 58) -> AttentionMatrixFigureFunc: ... 59@overload # with keyword arguments, returns decorator 60def matplotlib_figure_saver( 61 func: None = None, 62 fmt: str = MATPLOTLIB_FIGURE_FMT, 63) -> Callable[ 64 [Callable[[AttentionMatrix, plt.Axes], None], str], 65 AttentionMatrixFigureFunc, 66]: ... 67def matplotlib_figure_saver( 68 func: Callable[[AttentionMatrix, plt.Axes], None] | None = None, 69 fmt: str = MATPLOTLIB_FIGURE_FMT, 70) -> ( 71 AttentionMatrixFigureFunc 72 | Callable[ 73 [Callable[[AttentionMatrix, plt.Axes], None], str], 74 AttentionMatrixFigureFunc, 75 ] 76): 77 """decorator for functions which take an attention matrix and predefined `ax` object, making it save a figure 78 79 # Parameters: 80 - `func : Callable[[AttentionMatrix, plt.Axes], None]` 81 your function, which should take an attention matrix and predefined `ax` object 82 - `fmt : str` 83 format for saving matplotlib figures 84 (defaults to `MATPLOTLIB_FIGURE_FMT`) 85 86 # Returns: 87 - `AttentionMatrixFigureFunc` 88 your function, after we wrap it to save a figure 89 90 # Usage: 91 ```python 92 @register_attn_figure_func 93 @matplotlib_figure_saver 94 def raw(attn_matrix: AttentionMatrix, ax: plt.Axes) -> None: 95 ax.matshow(attn_matrix, cmap="viridis") 96 ax.set_title("Raw Attention Pattern") 97 ax.axis("off") 98 ``` 99 100 """ 101 102 def decorator( 103 func: Callable[[AttentionMatrix, plt.Axes], None], 104 fmt: str = fmt, 105 ) -> AttentionMatrixFigureFunc: 106 @functools.wraps(func) 107 def wrapped(attn_matrix: AttentionMatrix, save_dir: Path) -> None: 108 fig_path: Path = save_dir / f"{func.__name__}.{fmt}" 109 110 fig, ax = plt.subplots(figsize=(10, 10)) 111 func(attn_matrix, ax) 112 plt.tight_layout() 113 plt.savefig(fig_path) 114 plt.close(fig) 115 116 wrapped.figure_save_fmt = fmt # type: ignore[attr-defined] 117 118 return wrapped 119 120 if callable(func): 121 # Handle no-arguments case 122 return decorator(func) 123 else: 124 # Handle arguments case 125 return decorator 126 127 128def matplotlib_multifigure_saver( 129 names: Sequence[str], 130 fmt: str = MATPLOTLIB_FIGURE_FMT, 131) -> Callable[ 132 # decorator takes in function 133 # which takes a matrix and a dictionary of axes corresponding to the names 134 [Callable[[AttentionMatrix, dict[str, plt.Axes]], None]], 135 # returns the decorated function 136 AttentionMatrixFigureFunc, 137]: 138 """decorate a function such that it saves multiple figures, one for each name in `names` 139 140 # Parameters: 141 - `names : Sequence[str]` 142 the names of the figures to save 143 - `fmt : str` 144 format for saving matplotlib figures 145 (defaults to `MATPLOTLIB_FIGURE_FMT`) 146 147 # Returns: 148 - `Callable[[Callable[[AttentionMatrix, dict[str, plt.Axes]], None], AttentionMatrixFigureFunc]` 149 the decorator, which will then be applied to the function 150 we expect the decorated function to take an attention pattern, and a dict of axes corresponding to the names 151 152 """ 153 154 def decorator( 155 func: Callable[[AttentionMatrix, dict[str, plt.Axes]], None], 156 ) -> AttentionMatrixFigureFunc: 157 func_name: str = func.__name__ 158 159 @functools.wraps(func) 160 def wrapped(attn_matrix: AttentionMatrix, save_dir: Path) -> None: 161 # set up axes and corresponding figures 162 axes_dict: dict[str, plt.Axes] = {} 163 figs_dict: dict[str, plt.Figure] = {} 164 165 # Create all figures and axes 166 for name in names: 167 fig, ax = plt.subplots(figsize=(10, 10)) 168 axes_dict[name] = ax 169 figs_dict[name] = fig 170 171 try: 172 # Run the function to make plots 173 func(attn_matrix, axes_dict) 174 175 # Save each figure 176 for name, fig_ in figs_dict.items(): 177 fig_path: Path = save_dir / f"{func_name}.{name}.{fmt}" 178 # TYPING: error: Item "SubFigure" of "Figure | SubFigure" has no attribute "tight_layout" [union-attr] 179 fig_.tight_layout() # type: ignore[union-attr] 180 # TYPING: error: Item "SubFigure" of "Figure | SubFigure" has no attribute "savefig" [union-attr] 181 fig_.savefig(fig_path) # type: ignore[union-attr] 182 finally: 183 # Always clean up figures, even if an error occurred 184 for fig in figs_dict.values(): 185 # TYPING: error: Argument 1 to "close" has incompatible type "Figure | SubFigure"; expected "int | str | Figure | Literal['all'] | None" [arg-type] 186 plt.close(fig) # type: ignore[arg-type] 187 188 # it doesn't normally have this attribute, but we're adding it 189 wrapped.figure_save_fmt = fmt # type: ignore[attr-defined] 190 191 return wrapped 192 193 return decorator 194 195 196def matrix_to_image_preprocess( 197 matrix: Matrix2D, 198 normalize: bool = False, 199 cmap: str | Colormap = "viridis", 200 diverging_colormap: bool = False, 201 normalize_min: float | None = None, 202) -> Matrix2Drgb: 203 """preprocess a 2D matrix into a plottable heatmap image 204 205 # Parameters: 206 - `matrix : Matrix2D` 207 input matrix 208 - `normalize : bool` 209 whether to normalize the matrix to range [0, 1] 210 (defaults to `MATRIX_SAVE_NORMALIZE`) 211 - `cmap : str|Colormap` 212 the colormap to use for the matrix 213 (defaults to `MATRIX_SAVE_CMAP`) 214 - `diverging_colormap : bool` 215 if True and using a diverging colormap, ensures 0 values map to the center of the colormap 216 (defaults to False) 217 - `normalize_min : float|None` 218 if a float, then for `normalize=True` and `diverging_colormap=False`, the minimum value to normalize to (generally set this to zero?). 219 if `None`, then the minimum value of the matrix is used. 220 if `diverging_colormap=True` OR `normalize=False`, this **must** be `None`. 221 (defaults to `None`) 222 223 # Returns: 224 - `Matrix2Drgb` 225 """ 226 # check dims (2 is not that magic of a value here, hence noqa) 227 assert matrix.ndim == 2, f"Matrix must be 2D, got {matrix.ndim = }" # noqa: PLR2004 228 229 # check matrix is not empty 230 assert matrix.size > 0, "Matrix cannot be empty" 231 232 if normalize_min is not None: 233 assert not diverging_colormap, ( 234 "normalize_min cannot be used with diverging_colormap=True" 235 ) 236 assert normalize, "normalize_min cannot be used with normalize=False" 237 238 # Normalize the matrix to range [0, 1] 239 normalized_matrix: Matrix2D 240 if normalize: 241 if diverging_colormap: 242 # For diverging colormaps, we want to center around 0 243 max_abs: float = max(abs(matrix.max()), abs(matrix.min())) 244 normalized_matrix = (matrix / (2 * max_abs)) + 0.5 245 else: 246 max_val: float = matrix.max() 247 min_val: float 248 if normalize_min is not None: 249 min_val = normalize_min 250 assert min_val < max_val, "normalize_min must be less than matrix max" 251 assert min_val >= matrix.min(), ( 252 "normalize_min must less than matrix min" 253 ) 254 else: 255 min_val = matrix.min() 256 257 normalized_matrix = (matrix - min_val) / (max_val - min_val) 258 else: 259 if diverging_colormap: 260 assert matrix.min() >= -1 and matrix.max() <= 1, ( # noqa: PT018 261 "For diverging colormaps without normalization, matrix values must be in range [-1, 1]" 262 ) 263 normalized_matrix = matrix 264 else: 265 assert matrix.min() >= 0 and matrix.max() <= 1, ( # noqa: PT018 266 "Matrix values must be in range [0, 1], or normalize must be True" 267 ) 268 normalized_matrix = matrix 269 270 # get the colormap 271 cmap_: Colormap 272 if isinstance(cmap, str): 273 cmap_ = mpl.colormaps[cmap] 274 elif isinstance(cmap, Colormap): 275 cmap_ = cmap 276 else: 277 msg = f"Invalid type for {cmap = }, {type(cmap) = }, must be str or Colormap" 278 raise TypeError( 279 msg, 280 ) 281 282 # Apply the colormap 283 rgb_matrix: Float[np.ndarray, "n m channels=3"] = ( 284 cmap_(normalized_matrix)[:, :, :3] * 255 285 ).astype(np.uint8) # Drop alpha channel 286 287 assert rgb_matrix.shape == ( 288 matrix.shape[0], 289 matrix.shape[1], 290 3, 291 ), f"Matrix after colormap must have 3 channels, got {rgb_matrix.shape = }" 292 293 return rgb_matrix 294 295 296@overload 297def matrix2drgb_to_png_bytes(matrix: Matrix2Drgb, buffer: None = None) -> bytes: ... 298@overload 299def matrix2drgb_to_png_bytes(matrix: Matrix2Drgb, buffer: io.BytesIO) -> None: ... 300def matrix2drgb_to_png_bytes( 301 matrix: Matrix2Drgb, 302 buffer: io.BytesIO | None = None, 303) -> bytes | None: 304 """Convert a `Matrix2Drgb` to valid PNG bytes via PIL 305 306 - if `buffer` is provided, it will write the PNG bytes to the buffer and return `None` 307 - if `buffer` is not provided, it will return the PNG bytes 308 309 # Parameters: 310 - `matrix : Matrix2Drgb` 311 - `buffer : io.BytesIO | None` 312 (defaults to `None`, in which case it will return the PNG bytes) 313 314 # Returns: 315 - `bytes|None` 316 `bytes` if `buffer` is `None`, otherwise `None` 317 """ 318 pil_img: Image.Image = Image.fromarray(matrix, mode="RGB") 319 if buffer is None: 320 buffer = io.BytesIO() 321 pil_img.save(buffer, format="PNG") 322 buffer.seek(0) 323 return buffer.read() 324 else: 325 pil_img.save(buffer, format="PNG") 326 return None 327 328 329def matrix_as_svg( 330 matrix: Matrix2D, 331 normalize: bool = MATRIX_SAVE_NORMALIZE, 332 cmap: str | Colormap = MATRIX_SAVE_CMAP, 333 diverging_colormap: bool = False, 334 normalize_min: float | None = None, 335) -> str: 336 """quickly convert a 2D matrix to an SVG image, without matplotlib 337 338 # Parameters: 339 - `matrix : Float[np.ndarray, 'n m']` 340 a 2D matrix to convert to an SVG image 341 - `normalize : bool` 342 whether to normalize the matrix to range [0, 1]. if it's not in the range [0, 1], this must be `True` or it will raise an `AssertionError` 343 (defaults to `False`) 344 - `cmap : str` 345 the colormap to use for the matrix -- will look up in `matplotlib.colormaps` if it's a string 346 (defaults to `"viridis"`) 347 - `diverging_colormap : bool` 348 if True and using a diverging colormap, ensures 0 values map to the center of the colormap 349 (defaults to False) 350 - `normalize_min : float|None` 351 if a float, then for `normalize=True` and `diverging_colormap=False`, the minimum value to normalize to (generally set this to zero?) 352 if `None`, then the minimum value of the matrix is used 353 if `diverging_colormap=True` OR `normalize=False`, this **must** be `None` 354 (defaults to `None`) 355 356 357 # Returns: 358 - `str` 359 the SVG content for the matrix 360 """ 361 # Get the dimensions of the matrix 362 assert matrix.ndim == 2, f"Matrix must be 2D, got {matrix.shape = }" # noqa: PLR2004 363 m, n = matrix.shape 364 365 # Preprocess the matrix into an RGB image 366 matrix_rgb: Matrix2Drgb = matrix_to_image_preprocess( 367 matrix, 368 normalize=normalize, 369 cmap=cmap, 370 diverging_colormap=diverging_colormap, 371 normalize_min=normalize_min, 372 ) 373 374 # Convert the RGB image to PNG bytes 375 image_data: bytes = matrix2drgb_to_png_bytes(matrix_rgb) 376 377 # Encode the PNG bytes as base64 378 png_base64: str = base64.b64encode(image_data).decode("utf-8") 379 380 # Generate the SVG content 381 svg_content: str = MATRIX_SAVE_SVG_TEMPLATE.format(m=m, n=n, png_base64=png_base64) 382 383 return svg_content 384 385 386@overload # with keyword arguments, returns decorator 387def save_matrix_wrapper( 388 func: None = None, 389 *args: tuple[()], 390 fmt: MatrixSaveFormat = MATRIX_SAVE_FMT, 391 normalize: bool = MATRIX_SAVE_NORMALIZE, 392 cmap: str | Colormap = MATRIX_SAVE_CMAP, 393 diverging_colormap: bool = False, 394 normalize_min: float | None = None, 395) -> Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc]: ... 396@overload # without keyword arguments, returns decorated function 397def save_matrix_wrapper( 398 func: AttentionMatrixToMatrixFunc, 399 *args: tuple[()], 400 fmt: MatrixSaveFormat = MATRIX_SAVE_FMT, 401 normalize: bool = MATRIX_SAVE_NORMALIZE, 402 cmap: str | Colormap = MATRIX_SAVE_CMAP, 403 diverging_colormap: bool = False, 404 normalize_min: float | None = None, 405) -> AttentionMatrixFigureFunc: ... 406def save_matrix_wrapper( 407 func: AttentionMatrixToMatrixFunc | None = None, 408 *args, 409 fmt: MatrixSaveFormat = MATRIX_SAVE_FMT, 410 normalize: bool = MATRIX_SAVE_NORMALIZE, 411 cmap: str | Colormap = MATRIX_SAVE_CMAP, 412 diverging_colormap: bool = False, 413 normalize_min: float | None = None, 414) -> ( 415 AttentionMatrixFigureFunc 416 | Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc] 417): 418 """Decorator for functions that process an attention matrix and save it as an SVGZ image. 419 420 Can handle both argumentless usage and with arguments. 421 422 # Parameters: 423 424 - `func : AttentionMatrixToMatrixFunc|None` 425 Either the function to decorate (in the no-arguments case) or `None` when used with arguments. 426 - `fmt : MatrixSaveFormat, keyword-only` 427 The format to save the matrix as. Defaults to `MATRIX_SAVE_FMT`. 428 - `normalize : bool, keyword-only` 429 Whether to normalize the matrix to range [0, 1]. Defaults to `False`. 430 - `cmap : str, keyword-only` 431 The colormap to use for the matrix. Defaults to `MATRIX_SVG_CMAP`. 432 - `diverging_colormap : bool` 433 if True and using a diverging colormap, ensures 0 values map to the center of the colormap 434 (defaults to False) 435 - `normalize_min : float|None` 436 if a float, then for `normalize=True` and `diverging_colormap=False`, the minimum value to normalize to (generally set this to zero?) 437 if `None`, then the minimum value of the matrix is used 438 if `diverging_colormap=True` OR `normalize=False`, this **must** be `None` 439 (defaults to `None`) 440 441 # Returns: 442 443 `AttentionMatrixFigureFunc|Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc]` 444 445 - `AttentionMatrixFigureFunc` if `func` is `AttentionMatrixToMatrixFunc` (no arguments case) 446 - `Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc]` if `func` is `None` -- returns the decorator which will then be applied to the (with arguments case) 447 448 # Usage: 449 450 ```python 451 @save_matrix_wrapper 452 def identity_matrix(matrix): 453 return matrix 454 455 @save_matrix_wrapper(normalize=True, fmt="png") 456 def scale_matrix(matrix): 457 return matrix * 2 458 459 @save_matrix_wrapper(normalize=True, cmap="plasma") 460 def scale_matrix(matrix): 461 return matrix * 2 462 ``` 463 464 """ 465 assert len(args) == 0, "This decorator only supports keyword arguments" 466 467 assert ( 468 fmt in MatrixSaveFormat.__args__ # type: ignore[attr-defined] 469 ), f"Invalid format {fmt = }, must be one of {MatrixSaveFormat.__args__}" # type: ignore[attr-defined] 470 471 def decorator( 472 func: Callable[[AttentionMatrix], Matrix2D], 473 ) -> AttentionMatrixFigureFunc: 474 @functools.wraps(func) 475 def wrapped(attn_matrix: AttentionMatrix, save_dir: Path) -> None: 476 fig_path: Path = save_dir / f"{func.__name__}.{fmt}" 477 processed_matrix: Matrix2D = func(attn_matrix) 478 479 if fmt == "png": 480 processed_matrix_rgb: Matrix2Drgb = matrix_to_image_preprocess( 481 processed_matrix, 482 normalize=normalize, 483 cmap=cmap, 484 diverging_colormap=diverging_colormap, 485 normalize_min=normalize_min, 486 ) 487 image_data: bytes = matrix2drgb_to_png_bytes(processed_matrix_rgb) 488 fig_path.write_bytes(image_data) 489 490 else: 491 svg_content: str = matrix_as_svg( 492 processed_matrix, 493 normalize=normalize, 494 cmap=cmap, 495 diverging_colormap=diverging_colormap, 496 normalize_min=normalize_min, 497 ) 498 499 if fmt == "svgz": 500 with gzip.open(fig_path, "wt") as f: 501 f.write(svg_content) 502 503 else: 504 fig_path.write_text(svg_content, encoding="utf-8") 505 506 wrapped.figure_save_fmt = fmt # type: ignore[attr-defined] 507 508 return wrapped 509 510 if callable(func): 511 # Handle no-arguments case 512 return decorator(func) 513 else: 514 # Handle arguments case 515 return decorator
Type alias for a function that, given an attention matrix, saves one or more figures
Type alias for a 2D matrix (plottable)
Type alias for a 2D matrix with 3 channels (RGB)
Type alias for a function that, given an attention matrix, returns a 2D matrix
format for saving matplotlib figures
Type alias for the format to save a matrix as when saving raw matrix, not matplotlib figure
default for whether to normalize the matrix to range [0, 1]
default colormap for saving matrices
default format for saving matrices
template for saving an n by m matrix as an svg/svgz
68def matplotlib_figure_saver( 69 func: Callable[[AttentionMatrix, plt.Axes], None] | None = None, 70 fmt: str = MATPLOTLIB_FIGURE_FMT, 71) -> ( 72 AttentionMatrixFigureFunc 73 | Callable[ 74 [Callable[[AttentionMatrix, plt.Axes], None], str], 75 AttentionMatrixFigureFunc, 76 ] 77): 78 """decorator for functions which take an attention matrix and predefined `ax` object, making it save a figure 79 80 # Parameters: 81 - `func : Callable[[AttentionMatrix, plt.Axes], None]` 82 your function, which should take an attention matrix and predefined `ax` object 83 - `fmt : str` 84 format for saving matplotlib figures 85 (defaults to `MATPLOTLIB_FIGURE_FMT`) 86 87 # Returns: 88 - `AttentionMatrixFigureFunc` 89 your function, after we wrap it to save a figure 90 91 # Usage: 92 ```python 93 @register_attn_figure_func 94 @matplotlib_figure_saver 95 def raw(attn_matrix: AttentionMatrix, ax: plt.Axes) -> None: 96 ax.matshow(attn_matrix, cmap="viridis") 97 ax.set_title("Raw Attention Pattern") 98 ax.axis("off") 99 ``` 100 101 """ 102 103 def decorator( 104 func: Callable[[AttentionMatrix, plt.Axes], None], 105 fmt: str = fmt, 106 ) -> AttentionMatrixFigureFunc: 107 @functools.wraps(func) 108 def wrapped(attn_matrix: AttentionMatrix, save_dir: Path) -> None: 109 fig_path: Path = save_dir / f"{func.__name__}.{fmt}" 110 111 fig, ax = plt.subplots(figsize=(10, 10)) 112 func(attn_matrix, ax) 113 plt.tight_layout() 114 plt.savefig(fig_path) 115 plt.close(fig) 116 117 wrapped.figure_save_fmt = fmt # type: ignore[attr-defined] 118 119 return wrapped 120 121 if callable(func): 122 # Handle no-arguments case 123 return decorator(func) 124 else: 125 # Handle arguments case 126 return decorator
decorator for functions which take an attention matrix and predefined ax object, making it save a figure
Parameters:
func : Callable[[AttentionMatrix, plt.Axes], None]your function, which should take an attention matrix and predefinedaxobjectfmt : strformat for saving matplotlib figures (defaults toMATPLOTLIB_FIGURE_FMT)
Returns:
AttentionMatrixFigureFuncyour function, after we wrap it to save a figure
Usage:
@register_attn_figure_func
@matplotlib_figure_saver
def raw(attn_matrix: AttentionMatrix, ax: plt.Axes) -> None:
ax.matshow(attn_matrix, cmap="viridis")
ax.set_title("Raw Attention Pattern")
ax.axis("off")
129def matplotlib_multifigure_saver( 130 names: Sequence[str], 131 fmt: str = MATPLOTLIB_FIGURE_FMT, 132) -> Callable[ 133 # decorator takes in function 134 # which takes a matrix and a dictionary of axes corresponding to the names 135 [Callable[[AttentionMatrix, dict[str, plt.Axes]], None]], 136 # returns the decorated function 137 AttentionMatrixFigureFunc, 138]: 139 """decorate a function such that it saves multiple figures, one for each name in `names` 140 141 # Parameters: 142 - `names : Sequence[str]` 143 the names of the figures to save 144 - `fmt : str` 145 format for saving matplotlib figures 146 (defaults to `MATPLOTLIB_FIGURE_FMT`) 147 148 # Returns: 149 - `Callable[[Callable[[AttentionMatrix, dict[str, plt.Axes]], None], AttentionMatrixFigureFunc]` 150 the decorator, which will then be applied to the function 151 we expect the decorated function to take an attention pattern, and a dict of axes corresponding to the names 152 153 """ 154 155 def decorator( 156 func: Callable[[AttentionMatrix, dict[str, plt.Axes]], None], 157 ) -> AttentionMatrixFigureFunc: 158 func_name: str = func.__name__ 159 160 @functools.wraps(func) 161 def wrapped(attn_matrix: AttentionMatrix, save_dir: Path) -> None: 162 # set up axes and corresponding figures 163 axes_dict: dict[str, plt.Axes] = {} 164 figs_dict: dict[str, plt.Figure] = {} 165 166 # Create all figures and axes 167 for name in names: 168 fig, ax = plt.subplots(figsize=(10, 10)) 169 axes_dict[name] = ax 170 figs_dict[name] = fig 171 172 try: 173 # Run the function to make plots 174 func(attn_matrix, axes_dict) 175 176 # Save each figure 177 for name, fig_ in figs_dict.items(): 178 fig_path: Path = save_dir / f"{func_name}.{name}.{fmt}" 179 # TYPING: error: Item "SubFigure" of "Figure | SubFigure" has no attribute "tight_layout" [union-attr] 180 fig_.tight_layout() # type: ignore[union-attr] 181 # TYPING: error: Item "SubFigure" of "Figure | SubFigure" has no attribute "savefig" [union-attr] 182 fig_.savefig(fig_path) # type: ignore[union-attr] 183 finally: 184 # Always clean up figures, even if an error occurred 185 for fig in figs_dict.values(): 186 # TYPING: error: Argument 1 to "close" has incompatible type "Figure | SubFigure"; expected "int | str | Figure | Literal['all'] | None" [arg-type] 187 plt.close(fig) # type: ignore[arg-type] 188 189 # it doesn't normally have this attribute, but we're adding it 190 wrapped.figure_save_fmt = fmt # type: ignore[attr-defined] 191 192 return wrapped 193 194 return decorator
decorate a function such that it saves multiple figures, one for each name in names
Parameters:
names : Sequence[str]the names of the figures to savefmt : strformat for saving matplotlib figures (defaults toMATPLOTLIB_FIGURE_FMT)
Returns:
Callable[[Callable[[AttentionMatrix, dict[str, plt.Axes]], None], AttentionMatrixFigureFunc]the decorator, which will then be applied to the function we expect the decorated function to take an attention pattern, and a dict of axes corresponding to the names
197def matrix_to_image_preprocess( 198 matrix: Matrix2D, 199 normalize: bool = False, 200 cmap: str | Colormap = "viridis", 201 diverging_colormap: bool = False, 202 normalize_min: float | None = None, 203) -> Matrix2Drgb: 204 """preprocess a 2D matrix into a plottable heatmap image 205 206 # Parameters: 207 - `matrix : Matrix2D` 208 input matrix 209 - `normalize : bool` 210 whether to normalize the matrix to range [0, 1] 211 (defaults to `MATRIX_SAVE_NORMALIZE`) 212 - `cmap : str|Colormap` 213 the colormap to use for the matrix 214 (defaults to `MATRIX_SAVE_CMAP`) 215 - `diverging_colormap : bool` 216 if True and using a diverging colormap, ensures 0 values map to the center of the colormap 217 (defaults to False) 218 - `normalize_min : float|None` 219 if a float, then for `normalize=True` and `diverging_colormap=False`, the minimum value to normalize to (generally set this to zero?). 220 if `None`, then the minimum value of the matrix is used. 221 if `diverging_colormap=True` OR `normalize=False`, this **must** be `None`. 222 (defaults to `None`) 223 224 # Returns: 225 - `Matrix2Drgb` 226 """ 227 # check dims (2 is not that magic of a value here, hence noqa) 228 assert matrix.ndim == 2, f"Matrix must be 2D, got {matrix.ndim = }" # noqa: PLR2004 229 230 # check matrix is not empty 231 assert matrix.size > 0, "Matrix cannot be empty" 232 233 if normalize_min is not None: 234 assert not diverging_colormap, ( 235 "normalize_min cannot be used with diverging_colormap=True" 236 ) 237 assert normalize, "normalize_min cannot be used with normalize=False" 238 239 # Normalize the matrix to range [0, 1] 240 normalized_matrix: Matrix2D 241 if normalize: 242 if diverging_colormap: 243 # For diverging colormaps, we want to center around 0 244 max_abs: float = max(abs(matrix.max()), abs(matrix.min())) 245 normalized_matrix = (matrix / (2 * max_abs)) + 0.5 246 else: 247 max_val: float = matrix.max() 248 min_val: float 249 if normalize_min is not None: 250 min_val = normalize_min 251 assert min_val < max_val, "normalize_min must be less than matrix max" 252 assert min_val >= matrix.min(), ( 253 "normalize_min must less than matrix min" 254 ) 255 else: 256 min_val = matrix.min() 257 258 normalized_matrix = (matrix - min_val) / (max_val - min_val) 259 else: 260 if diverging_colormap: 261 assert matrix.min() >= -1 and matrix.max() <= 1, ( # noqa: PT018 262 "For diverging colormaps without normalization, matrix values must be in range [-1, 1]" 263 ) 264 normalized_matrix = matrix 265 else: 266 assert matrix.min() >= 0 and matrix.max() <= 1, ( # noqa: PT018 267 "Matrix values must be in range [0, 1], or normalize must be True" 268 ) 269 normalized_matrix = matrix 270 271 # get the colormap 272 cmap_: Colormap 273 if isinstance(cmap, str): 274 cmap_ = mpl.colormaps[cmap] 275 elif isinstance(cmap, Colormap): 276 cmap_ = cmap 277 else: 278 msg = f"Invalid type for {cmap = }, {type(cmap) = }, must be str or Colormap" 279 raise TypeError( 280 msg, 281 ) 282 283 # Apply the colormap 284 rgb_matrix: Float[np.ndarray, "n m channels=3"] = ( 285 cmap_(normalized_matrix)[:, :, :3] * 255 286 ).astype(np.uint8) # Drop alpha channel 287 288 assert rgb_matrix.shape == ( 289 matrix.shape[0], 290 matrix.shape[1], 291 3, 292 ), f"Matrix after colormap must have 3 channels, got {rgb_matrix.shape = }" 293 294 return rgb_matrix
preprocess a 2D matrix into a plottable heatmap image
Parameters:
matrix : Matrix2Dinput matrixnormalize : boolwhether to normalize the matrix to range [0, 1] (defaults toMATRIX_SAVE_NORMALIZE)cmap : str|Colormapthe colormap to use for the matrix (defaults toMATRIX_SAVE_CMAP)diverging_colormap : boolif True and using a diverging colormap, ensures 0 values map to the center of the colormap (defaults to False)normalize_min : float|Noneif a float, then fornormalize=Trueanddiverging_colormap=False, the minimum value to normalize to (generally set this to zero?). ifNone, then the minimum value of the matrix is used. ifdiverging_colormap=TrueORnormalize=False, this must beNone. (defaults toNone)
Returns:
301def matrix2drgb_to_png_bytes( 302 matrix: Matrix2Drgb, 303 buffer: io.BytesIO | None = None, 304) -> bytes | None: 305 """Convert a `Matrix2Drgb` to valid PNG bytes via PIL 306 307 - if `buffer` is provided, it will write the PNG bytes to the buffer and return `None` 308 - if `buffer` is not provided, it will return the PNG bytes 309 310 # Parameters: 311 - `matrix : Matrix2Drgb` 312 - `buffer : io.BytesIO | None` 313 (defaults to `None`, in which case it will return the PNG bytes) 314 315 # Returns: 316 - `bytes|None` 317 `bytes` if `buffer` is `None`, otherwise `None` 318 """ 319 pil_img: Image.Image = Image.fromarray(matrix, mode="RGB") 320 if buffer is None: 321 buffer = io.BytesIO() 322 pil_img.save(buffer, format="PNG") 323 buffer.seek(0) 324 return buffer.read() 325 else: 326 pil_img.save(buffer, format="PNG") 327 return None
Convert a Matrix2Drgb to valid PNG bytes via PIL
- if
bufferis provided, it will write the PNG bytes to the buffer and returnNone - if
bufferis not provided, it will return the PNG bytes
Parameters:
matrix : Matrix2Drgbbuffer : io.BytesIO | None(defaults toNone, in which case it will return the PNG bytes)
Returns:
bytes|NonebytesifbufferisNone, otherwiseNone
330def matrix_as_svg( 331 matrix: Matrix2D, 332 normalize: bool = MATRIX_SAVE_NORMALIZE, 333 cmap: str | Colormap = MATRIX_SAVE_CMAP, 334 diverging_colormap: bool = False, 335 normalize_min: float | None = None, 336) -> str: 337 """quickly convert a 2D matrix to an SVG image, without matplotlib 338 339 # Parameters: 340 - `matrix : Float[np.ndarray, 'n m']` 341 a 2D matrix to convert to an SVG image 342 - `normalize : bool` 343 whether to normalize the matrix to range [0, 1]. if it's not in the range [0, 1], this must be `True` or it will raise an `AssertionError` 344 (defaults to `False`) 345 - `cmap : str` 346 the colormap to use for the matrix -- will look up in `matplotlib.colormaps` if it's a string 347 (defaults to `"viridis"`) 348 - `diverging_colormap : bool` 349 if True and using a diverging colormap, ensures 0 values map to the center of the colormap 350 (defaults to False) 351 - `normalize_min : float|None` 352 if a float, then for `normalize=True` and `diverging_colormap=False`, the minimum value to normalize to (generally set this to zero?) 353 if `None`, then the minimum value of the matrix is used 354 if `diverging_colormap=True` OR `normalize=False`, this **must** be `None` 355 (defaults to `None`) 356 357 358 # Returns: 359 - `str` 360 the SVG content for the matrix 361 """ 362 # Get the dimensions of the matrix 363 assert matrix.ndim == 2, f"Matrix must be 2D, got {matrix.shape = }" # noqa: PLR2004 364 m, n = matrix.shape 365 366 # Preprocess the matrix into an RGB image 367 matrix_rgb: Matrix2Drgb = matrix_to_image_preprocess( 368 matrix, 369 normalize=normalize, 370 cmap=cmap, 371 diverging_colormap=diverging_colormap, 372 normalize_min=normalize_min, 373 ) 374 375 # Convert the RGB image to PNG bytes 376 image_data: bytes = matrix2drgb_to_png_bytes(matrix_rgb) 377 378 # Encode the PNG bytes as base64 379 png_base64: str = base64.b64encode(image_data).decode("utf-8") 380 381 # Generate the SVG content 382 svg_content: str = MATRIX_SAVE_SVG_TEMPLATE.format(m=m, n=n, png_base64=png_base64) 383 384 return svg_content
quickly convert a 2D matrix to an SVG image, without matplotlib
Parameters:
matrix : Float[np.ndarray, 'n m']a 2D matrix to convert to an SVG imagenormalize : boolwhether to normalize the matrix to range [0, 1]. if it's not in the range [0, 1], this must beTrueor it will raise anAssertionError(defaults toFalse)cmap : strthe colormap to use for the matrix -- will look up inmatplotlib.colormapsif it's a string (defaults to"viridis")diverging_colormap : boolif True and using a diverging colormap, ensures 0 values map to the center of the colormap (defaults to False)normalize_min : float|Noneif a float, then fornormalize=Trueanddiverging_colormap=False, the minimum value to normalize to (generally set this to zero?) ifNone, then the minimum value of the matrix is used ifdiverging_colormap=TrueORnormalize=False, this must beNone(defaults toNone)
Returns:
strthe SVG content for the matrix
407def save_matrix_wrapper( 408 func: AttentionMatrixToMatrixFunc | None = None, 409 *args, 410 fmt: MatrixSaveFormat = MATRIX_SAVE_FMT, 411 normalize: bool = MATRIX_SAVE_NORMALIZE, 412 cmap: str | Colormap = MATRIX_SAVE_CMAP, 413 diverging_colormap: bool = False, 414 normalize_min: float | None = None, 415) -> ( 416 AttentionMatrixFigureFunc 417 | Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc] 418): 419 """Decorator for functions that process an attention matrix and save it as an SVGZ image. 420 421 Can handle both argumentless usage and with arguments. 422 423 # Parameters: 424 425 - `func : AttentionMatrixToMatrixFunc|None` 426 Either the function to decorate (in the no-arguments case) or `None` when used with arguments. 427 - `fmt : MatrixSaveFormat, keyword-only` 428 The format to save the matrix as. Defaults to `MATRIX_SAVE_FMT`. 429 - `normalize : bool, keyword-only` 430 Whether to normalize the matrix to range [0, 1]. Defaults to `False`. 431 - `cmap : str, keyword-only` 432 The colormap to use for the matrix. Defaults to `MATRIX_SVG_CMAP`. 433 - `diverging_colormap : bool` 434 if True and using a diverging colormap, ensures 0 values map to the center of the colormap 435 (defaults to False) 436 - `normalize_min : float|None` 437 if a float, then for `normalize=True` and `diverging_colormap=False`, the minimum value to normalize to (generally set this to zero?) 438 if `None`, then the minimum value of the matrix is used 439 if `diverging_colormap=True` OR `normalize=False`, this **must** be `None` 440 (defaults to `None`) 441 442 # Returns: 443 444 `AttentionMatrixFigureFunc|Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc]` 445 446 - `AttentionMatrixFigureFunc` if `func` is `AttentionMatrixToMatrixFunc` (no arguments case) 447 - `Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc]` if `func` is `None` -- returns the decorator which will then be applied to the (with arguments case) 448 449 # Usage: 450 451 ```python 452 @save_matrix_wrapper 453 def identity_matrix(matrix): 454 return matrix 455 456 @save_matrix_wrapper(normalize=True, fmt="png") 457 def scale_matrix(matrix): 458 return matrix * 2 459 460 @save_matrix_wrapper(normalize=True, cmap="plasma") 461 def scale_matrix(matrix): 462 return matrix * 2 463 ``` 464 465 """ 466 assert len(args) == 0, "This decorator only supports keyword arguments" 467 468 assert ( 469 fmt in MatrixSaveFormat.__args__ # type: ignore[attr-defined] 470 ), f"Invalid format {fmt = }, must be one of {MatrixSaveFormat.__args__}" # type: ignore[attr-defined] 471 472 def decorator( 473 func: Callable[[AttentionMatrix], Matrix2D], 474 ) -> AttentionMatrixFigureFunc: 475 @functools.wraps(func) 476 def wrapped(attn_matrix: AttentionMatrix, save_dir: Path) -> None: 477 fig_path: Path = save_dir / f"{func.__name__}.{fmt}" 478 processed_matrix: Matrix2D = func(attn_matrix) 479 480 if fmt == "png": 481 processed_matrix_rgb: Matrix2Drgb = matrix_to_image_preprocess( 482 processed_matrix, 483 normalize=normalize, 484 cmap=cmap, 485 diverging_colormap=diverging_colormap, 486 normalize_min=normalize_min, 487 ) 488 image_data: bytes = matrix2drgb_to_png_bytes(processed_matrix_rgb) 489 fig_path.write_bytes(image_data) 490 491 else: 492 svg_content: str = matrix_as_svg( 493 processed_matrix, 494 normalize=normalize, 495 cmap=cmap, 496 diverging_colormap=diverging_colormap, 497 normalize_min=normalize_min, 498 ) 499 500 if fmt == "svgz": 501 with gzip.open(fig_path, "wt") as f: 502 f.write(svg_content) 503 504 else: 505 fig_path.write_text(svg_content, encoding="utf-8") 506 507 wrapped.figure_save_fmt = fmt # type: ignore[attr-defined] 508 509 return wrapped 510 511 if callable(func): 512 # Handle no-arguments case 513 return decorator(func) 514 else: 515 # Handle arguments case 516 return decorator
Decorator for functions that process an attention matrix and save it as an SVGZ image.
Can handle both argumentless usage and with arguments.
Parameters:
func : AttentionMatrixToMatrixFunc|NoneEither the function to decorate (in the no-arguments case) orNonewhen used with arguments.fmt : MatrixSaveFormat, keyword-onlyThe format to save the matrix as. Defaults toMATRIX_SAVE_FMT.normalize : bool, keyword-onlyWhether to normalize the matrix to range [0, 1]. Defaults toFalse.cmap : str, keyword-onlyThe colormap to use for the matrix. Defaults toMATRIX_SVG_CMAP.diverging_colormap : boolif True and using a diverging colormap, ensures 0 values map to the center of the colormap (defaults to False)normalize_min : float|Noneif a float, then fornormalize=Trueanddiverging_colormap=False, the minimum value to normalize to (generally set this to zero?) ifNone, then the minimum value of the matrix is used ifdiverging_colormap=TrueORnormalize=False, this must beNone(defaults toNone)
Returns:
AttentionMatrixFigureFunc|Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc]
AttentionMatrixFigureFunciffuncisAttentionMatrixToMatrixFunc(no arguments case)Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc]iffuncisNone-- returns the decorator which will then be applied to the (with arguments case)
Usage:
@save_matrix_wrapper
def identity_matrix(matrix):
return matrix
@save_matrix_wrapper(normalize=True, fmt="png")
def scale_matrix(matrix):
return matrix * 2
@save_matrix_wrapper(normalize=True, cmap="plasma")
def scale_matrix(matrix):
return matrix * 2