pattern_lens.figure_util
implements a bunch of types, default values, and templates which are useful for figure functions
notably, you can use the decorators matplotlib_figure_saver, save_matrix_wrapper to make your functions save figures
1"""implements a bunch of types, default values, and templates which are useful for figure functions 2 3notably, you can use the decorators `matplotlib_figure_saver`, `save_matrix_wrapper` to make your functions save figures 4""" 5 6from pathlib import Path 7from typing import Callable, Literal, overload, Union 8import functools 9import base64 10import gzip 11import io 12 13from PIL import Image 14import numpy as np 15from jaxtyping import Float, UInt8 16import matplotlib 17import matplotlib.pyplot as plt 18from matplotlib.colors import Colormap 19 20from pattern_lens.consts import AttentionMatrix 21 22AttentionMatrixFigureFunc = Callable[[AttentionMatrix, Path], None] 23"Type alias for a function that, given an attention matrix, saves a figure" 24 25Matrix2D = Float[np.ndarray, "n m"] 26"Type alias for a 2D matrix (plottable)" 27 28Matrix2Drgb = UInt8[np.ndarray, "n m rgb=3"] 29"Type alias for a 2D matrix with 3 channels (RGB)" 30 31AttentionMatrixToMatrixFunc = Callable[[AttentionMatrix], Matrix2D] 32"Type alias for a function that, given an attention matrix, returns a 2D matrix" 33 34MATPLOTLIB_FIGURE_FMT: str = "svgz" 35"format for saving matplotlib figures" 36 37MatrixSaveFormat = Literal["png", "svg", "svgz"] 38"Type alias for the format to save a matrix as when saving raw matrix, not matplotlib figure" 39 40MATRIX_SAVE_NORMALIZE: bool = False 41"default for whether to normalize the matrix to range [0, 1]" 42 43MATRIX_SAVE_CMAP: str = "viridis" 44"default colormap for saving matrices" 45 46MATRIX_SAVE_FMT: MatrixSaveFormat = "svgz" 47"default format for saving matrices" 48 49MATRIX_SAVE_SVG_TEMPLATE: str = """<svg xmlns="http://www.w3.org/2000/svg" width="{m}" height="{n}" viewBox="0 0 {m} {n}" image-rendering="pixelated"> <image href="data:image/png;base64,{png_base64}" width="{m}" height="{n}" /> </svg>""" 50"template for saving an `n` by `m` matrix as an svg/svgz" 51 52 53@overload # without keyword arguments, returns decorated function 54def matplotlib_figure_saver( 55 func: Callable[[AttentionMatrix, plt.Axes], None], 56 *args, 57 fmt: str = MATPLOTLIB_FIGURE_FMT, 58) -> AttentionMatrixFigureFunc: ... 59@overload # with keyword arguments, returns decorator 60def matplotlib_figure_saver( 61 func: None = None, 62 *args, 63 fmt: str = MATPLOTLIB_FIGURE_FMT, 64) -> Callable[ 65 [Callable[[AttentionMatrix, plt.Axes], None], str], AttentionMatrixFigureFunc 66]: ... 67def matplotlib_figure_saver( 68 func: Callable[[AttentionMatrix, plt.Axes], None] | None = None, 69 *args, 70 fmt: str = MATPLOTLIB_FIGURE_FMT, 71) -> Union[ 72 AttentionMatrixFigureFunc, 73 Callable[ 74 [Callable[[AttentionMatrix, plt.Axes], None], str], AttentionMatrixFigureFunc 75 ], 76]: 77 """decorator for functions which take an attention matrix and predefined `ax` object, making it save a figure 78 79 # Parameters: 80 - `func : Callable[[AttentionMatrix, plt.Axes], None]` 81 your function, which should take an attention matrix and predefined `ax` object 82 - `fmt : str` 83 format for saving matplotlib figures 84 (defaults to `MATPLOTLIB_FIGURE_FMT`) 85 86 # Returns: 87 - `AttentionMatrixFigureFunc` 88 your function, after we wrap it to save a figure 89 90 # Usage: 91 ```python 92 @register_attn_figure_func 93 @matplotlib_figure_saver 94 def raw(attn_matrix: AttentionMatrix, ax: plt.Axes) -> None: 95 ax.matshow(attn_matrix, cmap="viridis") 96 ax.set_title("Raw Attention Pattern") 97 ax.axis("off") 98 ``` 99 100 """ 101 102 assert len(args) == 0, "This decorator only supports keyword arguments" 103 104 def decorator( 105 func: Callable[[AttentionMatrix, plt.Axes], None], 106 fmt: str = fmt, 107 ) -> AttentionMatrixFigureFunc: 108 @functools.wraps(func) 109 def wrapped(attn_matrix: AttentionMatrix, save_dir: Path) -> None: 110 fig_path: Path = save_dir / f"{func.__name__}.{fmt}" 111 112 fig, ax = plt.subplots(figsize=(10, 10)) 113 func(attn_matrix, ax) 114 plt.tight_layout() 115 plt.savefig(fig_path) 116 plt.close(fig) 117 118 wrapped.figure_save_fmt = fmt # type: ignore[attr-defined] 119 120 return wrapped 121 122 if callable(func): 123 # Handle no-arguments case 124 return decorator(func) 125 else: 126 # Handle arguments case 127 return decorator 128 129 130def matrix_to_image_preprocess( 131 matrix: Matrix2D, 132 normalize: bool = False, 133 cmap: str | Colormap = "viridis", 134 diverging_colormap: bool = False, 135 normalize_min: float | None = None, 136) -> Matrix2Drgb: 137 """preprocess a 2D matrix into a plottable heatmap image 138 139 # Parameters: 140 - `matrix : Matrix2D` 141 input matrix 142 - `normalize : bool` 143 whether to normalize the matrix to range [0, 1] 144 (defaults to `MATRIX_SAVE_NORMALIZE`) 145 - `cmap : str|Colormap` 146 the colormap to use for the matrix 147 (defaults to `MATRIX_SAVE_CMAP`) 148 - `diverging_colormap : bool` 149 if True and using a diverging colormap, ensures 0 values map to the center of the colormap 150 (defaults to False) 151 - `normalize_min : float|None` 152 if a float, then for `normalize=True` and `diverging_colormap=False`, the minimum value to normalize to (generally set this to zero?). 153 if `None`, then the minimum value of the matrix is used. 154 if `diverging_colormap=True` OR `normalize=False`, this **must** be `None`. 155 (defaults to `None`) 156 157 # Returns: 158 - `Matrix2Drgb` 159 """ 160 # check dims 161 assert matrix.ndim == 2, f"Matrix must be 2D, got {matrix.ndim = }" 162 163 # check matrix is not empty 164 assert matrix.size > 0, "Matrix cannot be empty" 165 166 if normalize_min is not None: 167 assert ( 168 not diverging_colormap 169 ), "normalize_min cannot be used with diverging_colormap=True" 170 assert normalize, "normalize_min cannot be used with normalize=False" 171 172 # Normalize the matrix to range [0, 1] 173 normalized_matrix: Matrix2D 174 if normalize: 175 if diverging_colormap: 176 # For diverging colormaps, we want to center around 0 177 max_abs: float = max(abs(matrix.max()), abs(matrix.min())) 178 normalized_matrix = (matrix / (2 * max_abs)) + 0.5 179 else: 180 max_val: float = matrix.max() 181 min_val: float 182 if normalize_min is not None: 183 min_val = normalize_min 184 assert min_val < max_val, "normalize_min must be less than matrix max" 185 assert ( 186 min_val >= matrix.min() 187 ), "normalize_min must less than matrix min" 188 else: 189 min_val = matrix.min() 190 191 normalized_matrix = (matrix - min_val) / (max_val - min_val) 192 else: 193 if diverging_colormap: 194 assert ( 195 matrix.min() >= -1 and matrix.max() <= 1 196 ), "For diverging colormaps without normalization, matrix values must be in range [-1, 1]" 197 normalized_matrix = matrix 198 else: 199 assert ( 200 matrix.min() >= 0 and matrix.max() <= 1 201 ), "Matrix values must be in range [0, 1], or normalize must be True" 202 normalized_matrix = matrix 203 204 # get the colormap 205 cmap_: Colormap 206 if isinstance(cmap, str): 207 cmap_ = matplotlib.colormaps[cmap] 208 elif isinstance(cmap, Colormap): 209 cmap_ = cmap 210 else: 211 raise TypeError( 212 f"Invalid type for {cmap = }, {type(cmap) = }, must be str or Colormap" 213 ) 214 215 # Apply the colormap 216 rgb_matrix: Float[np.ndarray, "n m channels=3"] = ( # noqa: F722 217 cmap_(normalized_matrix)[:, :, :3] * 255 218 ).astype(np.uint8) # Drop alpha channel 219 220 assert rgb_matrix.shape == ( 221 matrix.shape[0], 222 matrix.shape[1], 223 3, 224 ), f"Matrix after colormap must have 3 channels, got {rgb_matrix.shape = }" 225 226 return rgb_matrix 227 228 229@overload 230def matrix2drgb_to_png_bytes(matrix: Matrix2Drgb, buffer: None = None) -> bytes: ... 231@overload 232def matrix2drgb_to_png_bytes(matrix: Matrix2Drgb, buffer: io.BytesIO) -> None: ... 233def matrix2drgb_to_png_bytes( 234 matrix: Matrix2Drgb, buffer: io.BytesIO | None = None 235) -> bytes | None: 236 """Convert a `Matrix2Drgb` to valid PNG bytes via PIL 237 238 - if `buffer` is provided, it will write the PNG bytes to the buffer and return `None` 239 - if `buffer` is not provided, it will return the PNG bytes 240 241 # Parameters: 242 - `matrix : Matrix2Drgb` 243 - `buffer : io.BytesIO | None` 244 (defaults to `None`, in which case it will return the PNG bytes) 245 246 # Returns: 247 - `bytes|None` 248 `bytes` if `buffer` is `None`, otherwise `None` 249 """ 250 251 pil_img: Image.Image = Image.fromarray(matrix, mode="RGB") 252 if buffer is None: 253 buffer = io.BytesIO() 254 pil_img.save(buffer, format="PNG") 255 buffer.seek(0) 256 return buffer.read() 257 else: 258 pil_img.save(buffer, format="PNG") 259 return None 260 261 262def matrix_as_svg( 263 matrix: Matrix2D, 264 normalize: bool = MATRIX_SAVE_NORMALIZE, 265 cmap: str | Colormap = MATRIX_SAVE_CMAP, 266 diverging_colormap: bool = False, 267 normalize_min: float | None = None, 268) -> str: 269 """quickly convert a 2D matrix to an SVG image, without matplotlib 270 271 # Parameters: 272 - `matrix : Float[np.ndarray, 'n m']` 273 a 2D matrix to convert to an SVG image 274 - `normalize : bool` 275 whether to normalize the matrix to range [0, 1]. if it's not in the range [0, 1], this must be `True` or it will raise an `AssertionError` 276 (defaults to `False`) 277 - `cmap : str` 278 the colormap to use for the matrix -- will look up in `matplotlib.colormaps` if it's a string 279 (defaults to `"viridis"`) 280 - `diverging_colormap : bool` 281 if True and using a diverging colormap, ensures 0 values map to the center of the colormap 282 (defaults to False) 283 - `normalize_min : float|None` 284 if a float, then for `normalize=True` and `diverging_colormap=False`, the minimum value to normalize to (generally set this to zero?) 285 if `None`, then the minimum value of the matrix is used 286 if `diverging_colormap=True` OR `normalize=False`, this **must** be `None` 287 (defaults to `None`) 288 289 290 # Returns: 291 - `str` 292 the SVG content for the matrix 293 """ 294 # Get the dimensions of the matrix 295 m, n = matrix.shape 296 297 # Preprocess the matrix into an RGB image 298 matrix_rgb: Matrix2Drgb = matrix_to_image_preprocess( 299 matrix, 300 normalize=normalize, 301 cmap=cmap, 302 diverging_colormap=diverging_colormap, 303 normalize_min=normalize_min, 304 ) 305 306 # Convert the RGB image to PNG bytes 307 image_data: bytes = matrix2drgb_to_png_bytes(matrix_rgb) 308 309 # Encode the PNG bytes as base64 310 png_base64: str = base64.b64encode(image_data).decode("utf-8") 311 312 # Generate the SVG content 313 svg_content: str = MATRIX_SAVE_SVG_TEMPLATE.format(m=m, n=n, png_base64=png_base64) 314 315 return svg_content 316 317 318@overload # with keyword arguments, returns decorator 319def save_matrix_wrapper( 320 func: None = None, 321 *args, 322 fmt: MatrixSaveFormat = MATRIX_SAVE_FMT, 323 normalize: bool = MATRIX_SAVE_NORMALIZE, 324 cmap: str | Colormap = MATRIX_SAVE_CMAP, 325 diverging_colormap: bool = False, 326 normalize_min: float | None = None, 327) -> Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc]: ... 328@overload # without keyword arguments, returns decorated function 329def save_matrix_wrapper( 330 func: AttentionMatrixToMatrixFunc, 331 *args, 332 fmt: MatrixSaveFormat = MATRIX_SAVE_FMT, 333 normalize: bool = MATRIX_SAVE_NORMALIZE, 334 cmap: str | Colormap = MATRIX_SAVE_CMAP, 335 diverging_colormap: bool = False, 336 normalize_min: float | None = None, 337) -> AttentionMatrixFigureFunc: ... 338def save_matrix_wrapper( 339 func: AttentionMatrixToMatrixFunc | None = None, 340 *args, 341 fmt: MatrixSaveFormat = MATRIX_SAVE_FMT, 342 normalize: bool = MATRIX_SAVE_NORMALIZE, 343 cmap: str | Colormap = MATRIX_SAVE_CMAP, 344 diverging_colormap: bool = False, 345 normalize_min: float | None = None, 346) -> ( 347 AttentionMatrixFigureFunc 348 | Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc] 349): 350 """ 351 Decorator for functions that process an attention matrix and save it as an SVGZ image. 352 Can handle both argumentless usage and with arguments. 353 354 # Parameters: 355 356 - `func : AttentionMatrixToMatrixFunc|None` 357 Either the function to decorate (in the no-arguments case) or `None` when used with arguments. 358 - `fmt : MatrixSaveFormat, keyword-only` 359 The format to save the matrix as. Defaults to `MATRIX_SAVE_FMT`. 360 - `normalize : bool, keyword-only` 361 Whether to normalize the matrix to range [0, 1]. Defaults to `False`. 362 - `cmap : str, keyword-only` 363 The colormap to use for the matrix. Defaults to `MATRIX_SVG_CMAP`. 364 - `diverging_colormap : bool` 365 if True and using a diverging colormap, ensures 0 values map to the center of the colormap 366 (defaults to False) 367 - `normalize_min : float|None` 368 if a float, then for `normalize=True` and `diverging_colormap=False`, the minimum value to normalize to (generally set this to zero?) 369 if `None`, then the minimum value of the matrix is used 370 if `diverging_colormap=True` OR `normalize=False`, this **must** be `None` 371 (defaults to `None`) 372 373 # Returns: 374 375 `AttentionMatrixFigureFunc|Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc]` 376 377 - `AttentionMatrixFigureFunc` if `func` is `AttentionMatrixToMatrixFunc` (no arguments case) 378 - `Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc]` if `func` is `None` -- returns the decorator which will then be applied to the (with arguments case) 379 380 # Usage: 381 382 ```python 383 @save_matrix_wrapper 384 def identity_matrix(matrix): 385 return matrix 386 387 @save_matrix_wrapper(normalize=True, fmt="png") 388 def scale_matrix(matrix): 389 return matrix * 2 390 391 @save_matrix_wrapper(normalize=True, cmap="plasma") 392 def scale_matrix(matrix): 393 return matrix * 2 394 395 ``` 396 """ 397 398 assert len(args) == 0, "This decorator only supports keyword arguments" 399 400 assert ( 401 fmt in MatrixSaveFormat.__args__ # type: ignore[attr-defined] 402 ), f"Invalid format {fmt = }, must be one of {MatrixSaveFormat.__args__}" # type: ignore[attr-defined] 403 404 def decorator( 405 func: Callable[[AttentionMatrix], Matrix2D], 406 ) -> AttentionMatrixFigureFunc: 407 @functools.wraps(func) 408 def wrapped(attn_matrix: AttentionMatrix, save_dir: Path) -> None: 409 fig_path: Path = save_dir / f"{func.__name__}.{fmt}" 410 processed_matrix: Matrix2D = func(attn_matrix) 411 412 if fmt == "png": 413 processed_matrix_rgb: Matrix2Drgb = matrix_to_image_preprocess( 414 processed_matrix, 415 normalize=normalize, 416 cmap=cmap, 417 diverging_colormap=diverging_colormap, 418 normalize_min=normalize_min, 419 ) 420 image_data: bytes = matrix2drgb_to_png_bytes(processed_matrix_rgb) 421 fig_path.write_bytes(image_data) 422 423 else: 424 svg_content: str = matrix_as_svg( 425 processed_matrix, 426 normalize=normalize, 427 cmap=cmap, 428 diverging_colormap=diverging_colormap, 429 normalize_min=normalize_min, 430 ) 431 432 if fmt == "svgz": 433 with gzip.open(fig_path, "wt") as f: 434 f.write(svg_content) 435 436 else: 437 fig_path.write_text(svg_content, encoding="utf-8") 438 439 wrapped.figure_save_fmt = fmt # type: ignore[attr-defined] 440 441 return wrapped 442 443 if callable(func): 444 # Handle no-arguments case 445 return decorator(func) 446 else: 447 # Handle arguments case 448 return decorator
Type alias for a function that, given an attention matrix, saves a figure
Type alias for a 2D matrix (plottable)
Type alias for a 2D matrix with 3 channels (RGB)
Type alias for a function that, given an attention matrix, returns a 2D matrix
format for saving matplotlib figures
Type alias for the format to save a matrix as when saving raw matrix, not matplotlib figure
default for whether to normalize the matrix to range [0, 1]
default colormap for saving matrices
default format for saving matrices
template for saving an n by m matrix as an svg/svgz
68def matplotlib_figure_saver( 69 func: Callable[[AttentionMatrix, plt.Axes], None] | None = None, 70 *args, 71 fmt: str = MATPLOTLIB_FIGURE_FMT, 72) -> Union[ 73 AttentionMatrixFigureFunc, 74 Callable[ 75 [Callable[[AttentionMatrix, plt.Axes], None], str], AttentionMatrixFigureFunc 76 ], 77]: 78 """decorator for functions which take an attention matrix and predefined `ax` object, making it save a figure 79 80 # Parameters: 81 - `func : Callable[[AttentionMatrix, plt.Axes], None]` 82 your function, which should take an attention matrix and predefined `ax` object 83 - `fmt : str` 84 format for saving matplotlib figures 85 (defaults to `MATPLOTLIB_FIGURE_FMT`) 86 87 # Returns: 88 - `AttentionMatrixFigureFunc` 89 your function, after we wrap it to save a figure 90 91 # Usage: 92 ```python 93 @register_attn_figure_func 94 @matplotlib_figure_saver 95 def raw(attn_matrix: AttentionMatrix, ax: plt.Axes) -> None: 96 ax.matshow(attn_matrix, cmap="viridis") 97 ax.set_title("Raw Attention Pattern") 98 ax.axis("off") 99 ``` 100 101 """ 102 103 assert len(args) == 0, "This decorator only supports keyword arguments" 104 105 def decorator( 106 func: Callable[[AttentionMatrix, plt.Axes], None], 107 fmt: str = fmt, 108 ) -> AttentionMatrixFigureFunc: 109 @functools.wraps(func) 110 def wrapped(attn_matrix: AttentionMatrix, save_dir: Path) -> None: 111 fig_path: Path = save_dir / f"{func.__name__}.{fmt}" 112 113 fig, ax = plt.subplots(figsize=(10, 10)) 114 func(attn_matrix, ax) 115 plt.tight_layout() 116 plt.savefig(fig_path) 117 plt.close(fig) 118 119 wrapped.figure_save_fmt = fmt # type: ignore[attr-defined] 120 121 return wrapped 122 123 if callable(func): 124 # Handle no-arguments case 125 return decorator(func) 126 else: 127 # Handle arguments case 128 return decorator
decorator for functions which take an attention matrix and predefined ax object, making it save a figure
Parameters:
func : Callable[[AttentionMatrix, plt.Axes], None]your function, which should take an attention matrix and predefinedaxobjectfmt : strformat for saving matplotlib figures (defaults toMATPLOTLIB_FIGURE_FMT)
Returns:
AttentionMatrixFigureFuncyour function, after we wrap it to save a figure
Usage:
@register_attn_figure_func
@matplotlib_figure_saver
def raw(attn_matrix: AttentionMatrix, ax: plt.Axes) -> None:
ax.matshow(attn_matrix, cmap="viridis")
ax.set_title("Raw Attention Pattern")
ax.axis("off")
131def matrix_to_image_preprocess( 132 matrix: Matrix2D, 133 normalize: bool = False, 134 cmap: str | Colormap = "viridis", 135 diverging_colormap: bool = False, 136 normalize_min: float | None = None, 137) -> Matrix2Drgb: 138 """preprocess a 2D matrix into a plottable heatmap image 139 140 # Parameters: 141 - `matrix : Matrix2D` 142 input matrix 143 - `normalize : bool` 144 whether to normalize the matrix to range [0, 1] 145 (defaults to `MATRIX_SAVE_NORMALIZE`) 146 - `cmap : str|Colormap` 147 the colormap to use for the matrix 148 (defaults to `MATRIX_SAVE_CMAP`) 149 - `diverging_colormap : bool` 150 if True and using a diverging colormap, ensures 0 values map to the center of the colormap 151 (defaults to False) 152 - `normalize_min : float|None` 153 if a float, then for `normalize=True` and `diverging_colormap=False`, the minimum value to normalize to (generally set this to zero?). 154 if `None`, then the minimum value of the matrix is used. 155 if `diverging_colormap=True` OR `normalize=False`, this **must** be `None`. 156 (defaults to `None`) 157 158 # Returns: 159 - `Matrix2Drgb` 160 """ 161 # check dims 162 assert matrix.ndim == 2, f"Matrix must be 2D, got {matrix.ndim = }" 163 164 # check matrix is not empty 165 assert matrix.size > 0, "Matrix cannot be empty" 166 167 if normalize_min is not None: 168 assert ( 169 not diverging_colormap 170 ), "normalize_min cannot be used with diverging_colormap=True" 171 assert normalize, "normalize_min cannot be used with normalize=False" 172 173 # Normalize the matrix to range [0, 1] 174 normalized_matrix: Matrix2D 175 if normalize: 176 if diverging_colormap: 177 # For diverging colormaps, we want to center around 0 178 max_abs: float = max(abs(matrix.max()), abs(matrix.min())) 179 normalized_matrix = (matrix / (2 * max_abs)) + 0.5 180 else: 181 max_val: float = matrix.max() 182 min_val: float 183 if normalize_min is not None: 184 min_val = normalize_min 185 assert min_val < max_val, "normalize_min must be less than matrix max" 186 assert ( 187 min_val >= matrix.min() 188 ), "normalize_min must less than matrix min" 189 else: 190 min_val = matrix.min() 191 192 normalized_matrix = (matrix - min_val) / (max_val - min_val) 193 else: 194 if diverging_colormap: 195 assert ( 196 matrix.min() >= -1 and matrix.max() <= 1 197 ), "For diverging colormaps without normalization, matrix values must be in range [-1, 1]" 198 normalized_matrix = matrix 199 else: 200 assert ( 201 matrix.min() >= 0 and matrix.max() <= 1 202 ), "Matrix values must be in range [0, 1], or normalize must be True" 203 normalized_matrix = matrix 204 205 # get the colormap 206 cmap_: Colormap 207 if isinstance(cmap, str): 208 cmap_ = matplotlib.colormaps[cmap] 209 elif isinstance(cmap, Colormap): 210 cmap_ = cmap 211 else: 212 raise TypeError( 213 f"Invalid type for {cmap = }, {type(cmap) = }, must be str or Colormap" 214 ) 215 216 # Apply the colormap 217 rgb_matrix: Float[np.ndarray, "n m channels=3"] = ( # noqa: F722 218 cmap_(normalized_matrix)[:, :, :3] * 255 219 ).astype(np.uint8) # Drop alpha channel 220 221 assert rgb_matrix.shape == ( 222 matrix.shape[0], 223 matrix.shape[1], 224 3, 225 ), f"Matrix after colormap must have 3 channels, got {rgb_matrix.shape = }" 226 227 return rgb_matrix
preprocess a 2D matrix into a plottable heatmap image
Parameters:
matrix : Matrix2Dinput matrixnormalize : boolwhether to normalize the matrix to range [0, 1] (defaults toMATRIX_SAVE_NORMALIZE)cmap : str|Colormapthe colormap to use for the matrix (defaults toMATRIX_SAVE_CMAP)diverging_colormap : boolif True and using a diverging colormap, ensures 0 values map to the center of the colormap (defaults to False)normalize_min : float|Noneif a float, then fornormalize=Trueanddiverging_colormap=False, the minimum value to normalize to (generally set this to zero?). ifNone, then the minimum value of the matrix is used. ifdiverging_colormap=TrueORnormalize=False, this must beNone. (defaults toNone)
Returns:
234def matrix2drgb_to_png_bytes( 235 matrix: Matrix2Drgb, buffer: io.BytesIO | None = None 236) -> bytes | None: 237 """Convert a `Matrix2Drgb` to valid PNG bytes via PIL 238 239 - if `buffer` is provided, it will write the PNG bytes to the buffer and return `None` 240 - if `buffer` is not provided, it will return the PNG bytes 241 242 # Parameters: 243 - `matrix : Matrix2Drgb` 244 - `buffer : io.BytesIO | None` 245 (defaults to `None`, in which case it will return the PNG bytes) 246 247 # Returns: 248 - `bytes|None` 249 `bytes` if `buffer` is `None`, otherwise `None` 250 """ 251 252 pil_img: Image.Image = Image.fromarray(matrix, mode="RGB") 253 if buffer is None: 254 buffer = io.BytesIO() 255 pil_img.save(buffer, format="PNG") 256 buffer.seek(0) 257 return buffer.read() 258 else: 259 pil_img.save(buffer, format="PNG") 260 return None
Convert a Matrix2Drgb to valid PNG bytes via PIL
- if
bufferis provided, it will write the PNG bytes to the buffer and returnNone - if
bufferis not provided, it will return the PNG bytes
Parameters:
matrix : Matrix2Drgbbuffer : io.BytesIO | None(defaults toNone, in which case it will return the PNG bytes)
Returns:
bytes|NonebytesifbufferisNone, otherwiseNone
263def matrix_as_svg( 264 matrix: Matrix2D, 265 normalize: bool = MATRIX_SAVE_NORMALIZE, 266 cmap: str | Colormap = MATRIX_SAVE_CMAP, 267 diverging_colormap: bool = False, 268 normalize_min: float | None = None, 269) -> str: 270 """quickly convert a 2D matrix to an SVG image, without matplotlib 271 272 # Parameters: 273 - `matrix : Float[np.ndarray, 'n m']` 274 a 2D matrix to convert to an SVG image 275 - `normalize : bool` 276 whether to normalize the matrix to range [0, 1]. if it's not in the range [0, 1], this must be `True` or it will raise an `AssertionError` 277 (defaults to `False`) 278 - `cmap : str` 279 the colormap to use for the matrix -- will look up in `matplotlib.colormaps` if it's a string 280 (defaults to `"viridis"`) 281 - `diverging_colormap : bool` 282 if True and using a diverging colormap, ensures 0 values map to the center of the colormap 283 (defaults to False) 284 - `normalize_min : float|None` 285 if a float, then for `normalize=True` and `diverging_colormap=False`, the minimum value to normalize to (generally set this to zero?) 286 if `None`, then the minimum value of the matrix is used 287 if `diverging_colormap=True` OR `normalize=False`, this **must** be `None` 288 (defaults to `None`) 289 290 291 # Returns: 292 - `str` 293 the SVG content for the matrix 294 """ 295 # Get the dimensions of the matrix 296 m, n = matrix.shape 297 298 # Preprocess the matrix into an RGB image 299 matrix_rgb: Matrix2Drgb = matrix_to_image_preprocess( 300 matrix, 301 normalize=normalize, 302 cmap=cmap, 303 diverging_colormap=diverging_colormap, 304 normalize_min=normalize_min, 305 ) 306 307 # Convert the RGB image to PNG bytes 308 image_data: bytes = matrix2drgb_to_png_bytes(matrix_rgb) 309 310 # Encode the PNG bytes as base64 311 png_base64: str = base64.b64encode(image_data).decode("utf-8") 312 313 # Generate the SVG content 314 svg_content: str = MATRIX_SAVE_SVG_TEMPLATE.format(m=m, n=n, png_base64=png_base64) 315 316 return svg_content
quickly convert a 2D matrix to an SVG image, without matplotlib
Parameters:
matrix : Float[np.ndarray, 'n m']a 2D matrix to convert to an SVG imagenormalize : boolwhether to normalize the matrix to range [0, 1]. if it's not in the range [0, 1], this must beTrueor it will raise anAssertionError(defaults toFalse)cmap : strthe colormap to use for the matrix -- will look up inmatplotlib.colormapsif it's a string (defaults to"viridis")diverging_colormap : boolif True and using a diverging colormap, ensures 0 values map to the center of the colormap (defaults to False)normalize_min : float|Noneif a float, then fornormalize=Trueanddiverging_colormap=False, the minimum value to normalize to (generally set this to zero?) ifNone, then the minimum value of the matrix is used ifdiverging_colormap=TrueORnormalize=False, this must beNone(defaults toNone)
Returns:
strthe SVG content for the matrix
339def save_matrix_wrapper( 340 func: AttentionMatrixToMatrixFunc | None = None, 341 *args, 342 fmt: MatrixSaveFormat = MATRIX_SAVE_FMT, 343 normalize: bool = MATRIX_SAVE_NORMALIZE, 344 cmap: str | Colormap = MATRIX_SAVE_CMAP, 345 diverging_colormap: bool = False, 346 normalize_min: float | None = None, 347) -> ( 348 AttentionMatrixFigureFunc 349 | Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc] 350): 351 """ 352 Decorator for functions that process an attention matrix and save it as an SVGZ image. 353 Can handle both argumentless usage and with arguments. 354 355 # Parameters: 356 357 - `func : AttentionMatrixToMatrixFunc|None` 358 Either the function to decorate (in the no-arguments case) or `None` when used with arguments. 359 - `fmt : MatrixSaveFormat, keyword-only` 360 The format to save the matrix as. Defaults to `MATRIX_SAVE_FMT`. 361 - `normalize : bool, keyword-only` 362 Whether to normalize the matrix to range [0, 1]. Defaults to `False`. 363 - `cmap : str, keyword-only` 364 The colormap to use for the matrix. Defaults to `MATRIX_SVG_CMAP`. 365 - `diverging_colormap : bool` 366 if True and using a diverging colormap, ensures 0 values map to the center of the colormap 367 (defaults to False) 368 - `normalize_min : float|None` 369 if a float, then for `normalize=True` and `diverging_colormap=False`, the minimum value to normalize to (generally set this to zero?) 370 if `None`, then the minimum value of the matrix is used 371 if `diverging_colormap=True` OR `normalize=False`, this **must** be `None` 372 (defaults to `None`) 373 374 # Returns: 375 376 `AttentionMatrixFigureFunc|Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc]` 377 378 - `AttentionMatrixFigureFunc` if `func` is `AttentionMatrixToMatrixFunc` (no arguments case) 379 - `Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc]` if `func` is `None` -- returns the decorator which will then be applied to the (with arguments case) 380 381 # Usage: 382 383 ```python 384 @save_matrix_wrapper 385 def identity_matrix(matrix): 386 return matrix 387 388 @save_matrix_wrapper(normalize=True, fmt="png") 389 def scale_matrix(matrix): 390 return matrix * 2 391 392 @save_matrix_wrapper(normalize=True, cmap="plasma") 393 def scale_matrix(matrix): 394 return matrix * 2 395 396 ``` 397 """ 398 399 assert len(args) == 0, "This decorator only supports keyword arguments" 400 401 assert ( 402 fmt in MatrixSaveFormat.__args__ # type: ignore[attr-defined] 403 ), f"Invalid format {fmt = }, must be one of {MatrixSaveFormat.__args__}" # type: ignore[attr-defined] 404 405 def decorator( 406 func: Callable[[AttentionMatrix], Matrix2D], 407 ) -> AttentionMatrixFigureFunc: 408 @functools.wraps(func) 409 def wrapped(attn_matrix: AttentionMatrix, save_dir: Path) -> None: 410 fig_path: Path = save_dir / f"{func.__name__}.{fmt}" 411 processed_matrix: Matrix2D = func(attn_matrix) 412 413 if fmt == "png": 414 processed_matrix_rgb: Matrix2Drgb = matrix_to_image_preprocess( 415 processed_matrix, 416 normalize=normalize, 417 cmap=cmap, 418 diverging_colormap=diverging_colormap, 419 normalize_min=normalize_min, 420 ) 421 image_data: bytes = matrix2drgb_to_png_bytes(processed_matrix_rgb) 422 fig_path.write_bytes(image_data) 423 424 else: 425 svg_content: str = matrix_as_svg( 426 processed_matrix, 427 normalize=normalize, 428 cmap=cmap, 429 diverging_colormap=diverging_colormap, 430 normalize_min=normalize_min, 431 ) 432 433 if fmt == "svgz": 434 with gzip.open(fig_path, "wt") as f: 435 f.write(svg_content) 436 437 else: 438 fig_path.write_text(svg_content, encoding="utf-8") 439 440 wrapped.figure_save_fmt = fmt # type: ignore[attr-defined] 441 442 return wrapped 443 444 if callable(func): 445 # Handle no-arguments case 446 return decorator(func) 447 else: 448 # Handle arguments case 449 return decorator
Decorator for functions that process an attention matrix and save it as an SVGZ image. Can handle both argumentless usage and with arguments.
Parameters:
func : AttentionMatrixToMatrixFunc|NoneEither the function to decorate (in the no-arguments case) orNonewhen used with arguments.fmt : MatrixSaveFormat, keyword-onlyThe format to save the matrix as. Defaults toMATRIX_SAVE_FMT.normalize : bool, keyword-onlyWhether to normalize the matrix to range [0, 1]. Defaults toFalse.cmap : str, keyword-onlyThe colormap to use for the matrix. Defaults toMATRIX_SVG_CMAP.diverging_colormap : boolif True and using a diverging colormap, ensures 0 values map to the center of the colormap (defaults to False)normalize_min : float|Noneif a float, then fornormalize=Trueanddiverging_colormap=False, the minimum value to normalize to (generally set this to zero?) ifNone, then the minimum value of the matrix is used ifdiverging_colormap=TrueORnormalize=False, this must beNone(defaults toNone)
Returns:
AttentionMatrixFigureFunc|Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc]
AttentionMatrixFigureFunciffuncisAttentionMatrixToMatrixFunc(no arguments case)Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc]iffuncisNone-- returns the decorator which will then be applied to the (with arguments case)
Usage:
@save_matrix_wrapper
def identity_matrix(matrix):
return matrix
@save_matrix_wrapper(normalize=True, fmt="png")
def scale_matrix(matrix):
return matrix * 2
@save_matrix_wrapper(normalize=True, cmap="plasma")
def scale_matrix(matrix):
return matrix * 2