docs for pattern_lens v0.2.0
View Source on GitHub

pattern_lens.figure_util

implements a bunch of types, default values, and templates which are useful for figure functions

notably, you can use the decorators matplotlib_figure_saver, save_matrix_wrapper to make your functions save figures


  1"""implements a bunch of types, default values, and templates which are useful for figure functions
  2
  3notably, you can use the decorators `matplotlib_figure_saver`, `save_matrix_wrapper` to make your functions save figures
  4"""
  5
  6from pathlib import Path
  7from typing import Callable, Literal, overload, Union
  8import functools
  9import base64
 10import gzip
 11import io
 12
 13from PIL import Image
 14import numpy as np
 15from jaxtyping import Float, UInt8
 16import matplotlib
 17import matplotlib.pyplot as plt
 18from matplotlib.colors import Colormap
 19
 20from pattern_lens.consts import AttentionMatrix
 21
 22AttentionMatrixFigureFunc = Callable[[AttentionMatrix, Path], None]
 23"Type alias for a function that, given an attention matrix, saves a figure"
 24
 25Matrix2D = Float[np.ndarray, "n m"]
 26"Type alias for a 2D matrix (plottable)"
 27
 28Matrix2Drgb = UInt8[np.ndarray, "n m rgb=3"]
 29"Type alias for a 2D matrix with 3 channels (RGB)"
 30
 31AttentionMatrixToMatrixFunc = Callable[[AttentionMatrix], Matrix2D]
 32"Type alias for a function that, given an attention matrix, returns a 2D matrix"
 33
 34MATPLOTLIB_FIGURE_FMT: str = "svgz"
 35"format for saving matplotlib figures"
 36
 37MatrixSaveFormat = Literal["png", "svg", "svgz"]
 38"Type alias for the format to save a matrix as when saving raw matrix, not matplotlib figure"
 39
 40MATRIX_SAVE_NORMALIZE: bool = False
 41"default for whether to normalize the matrix to range [0, 1]"
 42
 43MATRIX_SAVE_CMAP: str = "viridis"
 44"default colormap for saving matrices"
 45
 46MATRIX_SAVE_FMT: MatrixSaveFormat = "svgz"
 47"default format for saving matrices"
 48
 49MATRIX_SAVE_SVG_TEMPLATE: str = """<svg xmlns="http://www.w3.org/2000/svg" width="{m}" height="{n}" viewBox="0 0 {m} {n}" image-rendering="pixelated"> <image href="data:image/png;base64,{png_base64}" width="{m}" height="{n}" /> </svg>"""
 50"template for saving an `n` by `m` matrix as an svg/svgz"
 51
 52
 53@overload  # without keyword arguments, returns decorated function
 54def matplotlib_figure_saver(
 55    func: Callable[[AttentionMatrix, plt.Axes], None],
 56    *args,
 57    fmt: str = MATPLOTLIB_FIGURE_FMT,
 58) -> AttentionMatrixFigureFunc: ...
 59@overload  # with keyword arguments, returns decorator
 60def matplotlib_figure_saver(
 61    func: None = None,
 62    *args,
 63    fmt: str = MATPLOTLIB_FIGURE_FMT,
 64) -> Callable[
 65    [Callable[[AttentionMatrix, plt.Axes], None], str], AttentionMatrixFigureFunc
 66]: ...
 67def matplotlib_figure_saver(
 68    func: Callable[[AttentionMatrix, plt.Axes], None] | None = None,
 69    *args,
 70    fmt: str = MATPLOTLIB_FIGURE_FMT,
 71) -> Union[
 72    AttentionMatrixFigureFunc,
 73    Callable[
 74        [Callable[[AttentionMatrix, plt.Axes], None], str], AttentionMatrixFigureFunc
 75    ],
 76]:
 77    """decorator for functions which take an attention matrix and predefined `ax` object, making it save a figure
 78
 79    # Parameters:
 80     - `func : Callable[[AttentionMatrix, plt.Axes], None]`
 81       your function, which should take an attention matrix and predefined `ax` object
 82     - `fmt : str`
 83       format for saving matplotlib figures
 84       (defaults to `MATPLOTLIB_FIGURE_FMT`)
 85
 86    # Returns:
 87     - `AttentionMatrixFigureFunc`
 88       your function, after we wrap it to save a figure
 89
 90    # Usage:
 91    ```python
 92    @register_attn_figure_func
 93    @matplotlib_figure_saver
 94    def raw(attn_matrix: AttentionMatrix, ax: plt.Axes) -> None:
 95        ax.matshow(attn_matrix, cmap="viridis")
 96        ax.set_title("Raw Attention Pattern")
 97        ax.axis("off")
 98    ```
 99
100    """
101
102    assert len(args) == 0, "This decorator only supports keyword arguments"
103
104    def decorator(
105        func: Callable[[AttentionMatrix, plt.Axes], None],
106        fmt: str = fmt,
107    ) -> AttentionMatrixFigureFunc:
108        @functools.wraps(func)
109        def wrapped(attn_matrix: AttentionMatrix, save_dir: Path) -> None:
110            fig_path: Path = save_dir / f"{func.__name__}.{fmt}"
111
112            fig, ax = plt.subplots(figsize=(10, 10))
113            func(attn_matrix, ax)
114            plt.tight_layout()
115            plt.savefig(fig_path)
116            plt.close(fig)
117
118        wrapped.figure_save_fmt = fmt  # type: ignore[attr-defined]
119
120        return wrapped
121
122    if callable(func):
123        # Handle no-arguments case
124        return decorator(func)
125    else:
126        # Handle arguments case
127        return decorator
128
129
130def matrix_to_image_preprocess(
131    matrix: Matrix2D,
132    normalize: bool = False,
133    cmap: str | Colormap = "viridis",
134    diverging_colormap: bool = False,
135    normalize_min: float | None = None,
136) -> Matrix2Drgb:
137    """preprocess a 2D matrix into a plottable heatmap image
138
139    # Parameters:
140     - `matrix : Matrix2D`
141        input matrix
142     - `normalize : bool`
143        whether to normalize the matrix to range [0, 1]
144       (defaults to `MATRIX_SAVE_NORMALIZE`)
145     - `cmap : str|Colormap`
146        the colormap to use for the matrix
147       (defaults to `MATRIX_SAVE_CMAP`)
148     - `diverging_colormap : bool`
149        if True and using a diverging colormap, ensures 0 values map to the center of the colormap
150        (defaults to False)
151     - `normalize_min : float|None`
152        if a float, then for `normalize=True` and `diverging_colormap=False`, the minimum value to normalize to (generally set this to zero?).
153        if `None`, then the minimum value of the matrix is used.
154        if `diverging_colormap=True` OR `normalize=False`, this **must** be `None`.
155        (defaults to `None`)
156
157    # Returns:
158     - `Matrix2Drgb`
159    """
160    # check dims
161    assert matrix.ndim == 2, f"Matrix must be 2D, got {matrix.ndim = }"
162
163    # check matrix is not empty
164    assert matrix.size > 0, "Matrix cannot be empty"
165
166    if normalize_min is not None:
167        assert (
168            not diverging_colormap
169        ), "normalize_min cannot be used with diverging_colormap=True"
170        assert normalize, "normalize_min cannot be used with normalize=False"
171
172    # Normalize the matrix to range [0, 1]
173    normalized_matrix: Matrix2D
174    if normalize:
175        if diverging_colormap:
176            # For diverging colormaps, we want to center around 0
177            max_abs: float = max(abs(matrix.max()), abs(matrix.min()))
178            normalized_matrix = (matrix / (2 * max_abs)) + 0.5
179        else:
180            max_val: float = matrix.max()
181            min_val: float
182            if normalize_min is not None:
183                min_val = normalize_min
184                assert min_val < max_val, "normalize_min must be less than matrix max"
185                assert (
186                    min_val >= matrix.min()
187                ), "normalize_min must less than matrix min"
188            else:
189                min_val = matrix.min()
190
191            normalized_matrix = (matrix - min_val) / (max_val - min_val)
192    else:
193        if diverging_colormap:
194            assert (
195                matrix.min() >= -1 and matrix.max() <= 1
196            ), "For diverging colormaps without normalization, matrix values must be in range [-1, 1]"
197            normalized_matrix = matrix
198        else:
199            assert (
200                matrix.min() >= 0 and matrix.max() <= 1
201            ), "Matrix values must be in range [0, 1], or normalize must be True"
202            normalized_matrix = matrix
203
204    # get the colormap
205    cmap_: Colormap
206    if isinstance(cmap, str):
207        cmap_ = matplotlib.colormaps[cmap]
208    elif isinstance(cmap, Colormap):
209        cmap_ = cmap
210    else:
211        raise TypeError(
212            f"Invalid type for {cmap = }, {type(cmap) = }, must be str or Colormap"
213        )
214
215    # Apply the colormap
216    rgb_matrix: Float[np.ndarray, "n m channels=3"] = (  # noqa: F722
217        cmap_(normalized_matrix)[:, :, :3] * 255
218    ).astype(np.uint8)  # Drop alpha channel
219
220    assert rgb_matrix.shape == (
221        matrix.shape[0],
222        matrix.shape[1],
223        3,
224    ), f"Matrix after colormap must have 3 channels, got {rgb_matrix.shape = }"
225
226    return rgb_matrix
227
228
229@overload
230def matrix2drgb_to_png_bytes(matrix: Matrix2Drgb, buffer: None = None) -> bytes: ...
231@overload
232def matrix2drgb_to_png_bytes(matrix: Matrix2Drgb, buffer: io.BytesIO) -> None: ...
233def matrix2drgb_to_png_bytes(
234    matrix: Matrix2Drgb, buffer: io.BytesIO | None = None
235) -> bytes | None:
236    """Convert a `Matrix2Drgb` to valid PNG bytes via PIL
237
238    - if `buffer` is provided, it will write the PNG bytes to the buffer and return `None`
239    - if `buffer` is not provided, it will return the PNG bytes
240
241    # Parameters:
242     - `matrix : Matrix2Drgb`
243     - `buffer : io.BytesIO | None`
244       (defaults to `None`, in which case it will return the PNG bytes)
245
246    # Returns:
247     - `bytes|None`
248       `bytes` if `buffer` is `None`, otherwise `None`
249    """
250
251    pil_img: Image.Image = Image.fromarray(matrix, mode="RGB")
252    if buffer is None:
253        buffer = io.BytesIO()
254        pil_img.save(buffer, format="PNG")
255        buffer.seek(0)
256        return buffer.read()
257    else:
258        pil_img.save(buffer, format="PNG")
259        return None
260
261
262def matrix_as_svg(
263    matrix: Matrix2D,
264    normalize: bool = MATRIX_SAVE_NORMALIZE,
265    cmap: str | Colormap = MATRIX_SAVE_CMAP,
266    diverging_colormap: bool = False,
267    normalize_min: float | None = None,
268) -> str:
269    """quickly convert a 2D matrix to an SVG image, without matplotlib
270
271    # Parameters:
272     - `matrix : Float[np.ndarray, 'n m']`
273       a 2D matrix to convert to an SVG image
274     - `normalize : bool`
275       whether to normalize the matrix to range [0, 1]. if it's not in the range [0, 1], this must be `True` or it will raise an `AssertionError`
276       (defaults to `False`)
277     - `cmap : str`
278       the colormap to use for the matrix -- will look up in `matplotlib.colormaps` if it's a string
279       (defaults to `"viridis"`)
280     - `diverging_colormap : bool`
281        if True and using a diverging colormap, ensures 0 values map to the center of the colormap
282        (defaults to False)
283     - `normalize_min : float|None`
284        if a float, then for `normalize=True` and `diverging_colormap=False`, the minimum value to normalize to (generally set this to zero?)
285        if `None`, then the minimum value of the matrix is used
286        if `diverging_colormap=True` OR `normalize=False`, this **must** be `None`
287        (defaults to `None`)
288
289
290    # Returns:
291     - `str`
292       the SVG content for the matrix
293    """
294    # Get the dimensions of the matrix
295    m, n = matrix.shape
296
297    # Preprocess the matrix into an RGB image
298    matrix_rgb: Matrix2Drgb = matrix_to_image_preprocess(
299        matrix,
300        normalize=normalize,
301        cmap=cmap,
302        diverging_colormap=diverging_colormap,
303        normalize_min=normalize_min,
304    )
305
306    # Convert the RGB image to PNG bytes
307    image_data: bytes = matrix2drgb_to_png_bytes(matrix_rgb)
308
309    # Encode the PNG bytes as base64
310    png_base64: str = base64.b64encode(image_data).decode("utf-8")
311
312    # Generate the SVG content
313    svg_content: str = MATRIX_SAVE_SVG_TEMPLATE.format(m=m, n=n, png_base64=png_base64)
314
315    return svg_content
316
317
318@overload  # with keyword arguments, returns decorator
319def save_matrix_wrapper(
320    func: None = None,
321    *args,
322    fmt: MatrixSaveFormat = MATRIX_SAVE_FMT,
323    normalize: bool = MATRIX_SAVE_NORMALIZE,
324    cmap: str | Colormap = MATRIX_SAVE_CMAP,
325    diverging_colormap: bool = False,
326    normalize_min: float | None = None,
327) -> Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc]: ...
328@overload  # without keyword arguments, returns decorated function
329def save_matrix_wrapper(
330    func: AttentionMatrixToMatrixFunc,
331    *args,
332    fmt: MatrixSaveFormat = MATRIX_SAVE_FMT,
333    normalize: bool = MATRIX_SAVE_NORMALIZE,
334    cmap: str | Colormap = MATRIX_SAVE_CMAP,
335    diverging_colormap: bool = False,
336    normalize_min: float | None = None,
337) -> AttentionMatrixFigureFunc: ...
338def save_matrix_wrapper(
339    func: AttentionMatrixToMatrixFunc | None = None,
340    *args,
341    fmt: MatrixSaveFormat = MATRIX_SAVE_FMT,
342    normalize: bool = MATRIX_SAVE_NORMALIZE,
343    cmap: str | Colormap = MATRIX_SAVE_CMAP,
344    diverging_colormap: bool = False,
345    normalize_min: float | None = None,
346) -> (
347    AttentionMatrixFigureFunc
348    | Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc]
349):
350    """
351    Decorator for functions that process an attention matrix and save it as an SVGZ image.
352    Can handle both argumentless usage and with arguments.
353
354    # Parameters:
355
356     - `func : AttentionMatrixToMatrixFunc|None`
357        Either the function to decorate (in the no-arguments case) or `None` when used with arguments.
358     - `fmt : MatrixSaveFormat, keyword-only`
359        The format to save the matrix as. Defaults to `MATRIX_SAVE_FMT`.
360     - `normalize : bool, keyword-only`
361        Whether to normalize the matrix to range [0, 1]. Defaults to `False`.
362     - `cmap : str, keyword-only`
363        The colormap to use for the matrix. Defaults to `MATRIX_SVG_CMAP`.
364     - `diverging_colormap : bool`
365        if True and using a diverging colormap, ensures 0 values map to the center of the colormap
366        (defaults to False)
367     - `normalize_min : float|None`
368        if a float, then for `normalize=True` and `diverging_colormap=False`, the minimum value to normalize to (generally set this to zero?)
369        if `None`, then the minimum value of the matrix is used
370        if `diverging_colormap=True` OR `normalize=False`, this **must** be `None`
371        (defaults to `None`)
372
373    # Returns:
374
375    `AttentionMatrixFigureFunc|Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc]`
376
377    - `AttentionMatrixFigureFunc` if `func` is `AttentionMatrixToMatrixFunc` (no arguments case)
378    - `Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc]` if `func` is `None` -- returns the decorator which will then be applied to the  (with arguments case)
379
380    # Usage:
381
382    ```python
383    @save_matrix_wrapper
384    def identity_matrix(matrix):
385        return matrix
386
387    @save_matrix_wrapper(normalize=True, fmt="png")
388    def scale_matrix(matrix):
389        return matrix * 2
390
391    @save_matrix_wrapper(normalize=True, cmap="plasma")
392    def scale_matrix(matrix):
393        return matrix * 2
394
395    ```
396    """
397
398    assert len(args) == 0, "This decorator only supports keyword arguments"
399
400    assert (
401        fmt in MatrixSaveFormat.__args__  # type: ignore[attr-defined]
402    ), f"Invalid format {fmt = }, must be one of {MatrixSaveFormat.__args__}"  # type: ignore[attr-defined]
403
404    def decorator(
405        func: Callable[[AttentionMatrix], Matrix2D],
406    ) -> AttentionMatrixFigureFunc:
407        @functools.wraps(func)
408        def wrapped(attn_matrix: AttentionMatrix, save_dir: Path) -> None:
409            fig_path: Path = save_dir / f"{func.__name__}.{fmt}"
410            processed_matrix: Matrix2D = func(attn_matrix)
411
412            if fmt == "png":
413                processed_matrix_rgb: Matrix2Drgb = matrix_to_image_preprocess(
414                    processed_matrix,
415                    normalize=normalize,
416                    cmap=cmap,
417                    diverging_colormap=diverging_colormap,
418                    normalize_min=normalize_min,
419                )
420                image_data: bytes = matrix2drgb_to_png_bytes(processed_matrix_rgb)
421                fig_path.write_bytes(image_data)
422
423            else:
424                svg_content: str = matrix_as_svg(
425                    processed_matrix,
426                    normalize=normalize,
427                    cmap=cmap,
428                    diverging_colormap=diverging_colormap,
429                    normalize_min=normalize_min,
430                )
431
432                if fmt == "svgz":
433                    with gzip.open(fig_path, "wt") as f:
434                        f.write(svg_content)
435
436                else:
437                    fig_path.write_text(svg_content, encoding="utf-8")
438
439        wrapped.figure_save_fmt = fmt  # type: ignore[attr-defined]
440
441        return wrapped
442
443    if callable(func):
444        # Handle no-arguments case
445        return decorator(func)
446    else:
447        # Handle arguments case
448        return decorator

AttentionMatrixFigureFunc = typing.Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], pathlib.Path], NoneType]

Type alias for a function that, given an attention matrix, saves a figure

Matrix2D = <class 'jaxtyping.Float[ndarray, 'n m']'>

Type alias for a 2D matrix (plottable)

Matrix2Drgb = <class 'jaxtyping.UInt8[ndarray, 'n m rgb=3']'>

Type alias for a 2D matrix with 3 channels (RGB)

AttentionMatrixToMatrixFunc = typing.Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx']], jaxtyping.Float[ndarray, 'n m']]

Type alias for a function that, given an attention matrix, returns a 2D matrix

MATPLOTLIB_FIGURE_FMT: str = 'svgz'

format for saving matplotlib figures

MatrixSaveFormat = typing.Literal['png', 'svg', 'svgz']

Type alias for the format to save a matrix as when saving raw matrix, not matplotlib figure

MATRIX_SAVE_NORMALIZE: bool = False

default for whether to normalize the matrix to range [0, 1]

MATRIX_SAVE_CMAP: str = 'viridis'

default colormap for saving matrices

MATRIX_SAVE_FMT: Literal['png', 'svg', 'svgz'] = 'svgz'

default format for saving matrices

MATRIX_SAVE_SVG_TEMPLATE: str = '<svg xmlns="http://www.w3.org/2000/svg" width="{m}" height="{n}" viewBox="0 0 {m} {n}" image-rendering="pixelated"> <image href="data:image/png;base64,{png_base64}" width="{m}" height="{n}" /> </svg>'

template for saving an n by m matrix as an svg/svgz

def matplotlib_figure_saver( func: Optional[Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], matplotlib.axes._axes.Axes], NoneType]] = None, *args, fmt: str = 'svgz') -> Union[Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], pathlib.Path], NoneType], Callable[[Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], matplotlib.axes._axes.Axes], NoneType], str], Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], pathlib.Path], NoneType]]]:
 68def matplotlib_figure_saver(
 69    func: Callable[[AttentionMatrix, plt.Axes], None] | None = None,
 70    *args,
 71    fmt: str = MATPLOTLIB_FIGURE_FMT,
 72) -> Union[
 73    AttentionMatrixFigureFunc,
 74    Callable[
 75        [Callable[[AttentionMatrix, plt.Axes], None], str], AttentionMatrixFigureFunc
 76    ],
 77]:
 78    """decorator for functions which take an attention matrix and predefined `ax` object, making it save a figure
 79
 80    # Parameters:
 81     - `func : Callable[[AttentionMatrix, plt.Axes], None]`
 82       your function, which should take an attention matrix and predefined `ax` object
 83     - `fmt : str`
 84       format for saving matplotlib figures
 85       (defaults to `MATPLOTLIB_FIGURE_FMT`)
 86
 87    # Returns:
 88     - `AttentionMatrixFigureFunc`
 89       your function, after we wrap it to save a figure
 90
 91    # Usage:
 92    ```python
 93    @register_attn_figure_func
 94    @matplotlib_figure_saver
 95    def raw(attn_matrix: AttentionMatrix, ax: plt.Axes) -> None:
 96        ax.matshow(attn_matrix, cmap="viridis")
 97        ax.set_title("Raw Attention Pattern")
 98        ax.axis("off")
 99    ```
100
101    """
102
103    assert len(args) == 0, "This decorator only supports keyword arguments"
104
105    def decorator(
106        func: Callable[[AttentionMatrix, plt.Axes], None],
107        fmt: str = fmt,
108    ) -> AttentionMatrixFigureFunc:
109        @functools.wraps(func)
110        def wrapped(attn_matrix: AttentionMatrix, save_dir: Path) -> None:
111            fig_path: Path = save_dir / f"{func.__name__}.{fmt}"
112
113            fig, ax = plt.subplots(figsize=(10, 10))
114            func(attn_matrix, ax)
115            plt.tight_layout()
116            plt.savefig(fig_path)
117            plt.close(fig)
118
119        wrapped.figure_save_fmt = fmt  # type: ignore[attr-defined]
120
121        return wrapped
122
123    if callable(func):
124        # Handle no-arguments case
125        return decorator(func)
126    else:
127        # Handle arguments case
128        return decorator

decorator for functions which take an attention matrix and predefined ax object, making it save a figure

Parameters:

  • func : Callable[[AttentionMatrix, plt.Axes], None] your function, which should take an attention matrix and predefined ax object
  • fmt : str format for saving matplotlib figures (defaults to MATPLOTLIB_FIGURE_FMT)

Returns:

Usage:

@register_attn_figure_func
@matplotlib_figure_saver
def raw(attn_matrix: AttentionMatrix, ax: plt.Axes) -> None:
    ax.matshow(attn_matrix, cmap="viridis")
    ax.set_title("Raw Attention Pattern")
    ax.axis("off")
def matrix_to_image_preprocess( matrix: jaxtyping.Float[ndarray, 'n m'], normalize: bool = False, cmap: str | matplotlib.colors.Colormap = 'viridis', diverging_colormap: bool = False, normalize_min: float | None = None) -> jaxtyping.UInt8[ndarray, 'n m rgb=3']:
131def matrix_to_image_preprocess(
132    matrix: Matrix2D,
133    normalize: bool = False,
134    cmap: str | Colormap = "viridis",
135    diverging_colormap: bool = False,
136    normalize_min: float | None = None,
137) -> Matrix2Drgb:
138    """preprocess a 2D matrix into a plottable heatmap image
139
140    # Parameters:
141     - `matrix : Matrix2D`
142        input matrix
143     - `normalize : bool`
144        whether to normalize the matrix to range [0, 1]
145       (defaults to `MATRIX_SAVE_NORMALIZE`)
146     - `cmap : str|Colormap`
147        the colormap to use for the matrix
148       (defaults to `MATRIX_SAVE_CMAP`)
149     - `diverging_colormap : bool`
150        if True and using a diverging colormap, ensures 0 values map to the center of the colormap
151        (defaults to False)
152     - `normalize_min : float|None`
153        if a float, then for `normalize=True` and `diverging_colormap=False`, the minimum value to normalize to (generally set this to zero?).
154        if `None`, then the minimum value of the matrix is used.
155        if `diverging_colormap=True` OR `normalize=False`, this **must** be `None`.
156        (defaults to `None`)
157
158    # Returns:
159     - `Matrix2Drgb`
160    """
161    # check dims
162    assert matrix.ndim == 2, f"Matrix must be 2D, got {matrix.ndim = }"
163
164    # check matrix is not empty
165    assert matrix.size > 0, "Matrix cannot be empty"
166
167    if normalize_min is not None:
168        assert (
169            not diverging_colormap
170        ), "normalize_min cannot be used with diverging_colormap=True"
171        assert normalize, "normalize_min cannot be used with normalize=False"
172
173    # Normalize the matrix to range [0, 1]
174    normalized_matrix: Matrix2D
175    if normalize:
176        if diverging_colormap:
177            # For diverging colormaps, we want to center around 0
178            max_abs: float = max(abs(matrix.max()), abs(matrix.min()))
179            normalized_matrix = (matrix / (2 * max_abs)) + 0.5
180        else:
181            max_val: float = matrix.max()
182            min_val: float
183            if normalize_min is not None:
184                min_val = normalize_min
185                assert min_val < max_val, "normalize_min must be less than matrix max"
186                assert (
187                    min_val >= matrix.min()
188                ), "normalize_min must less than matrix min"
189            else:
190                min_val = matrix.min()
191
192            normalized_matrix = (matrix - min_val) / (max_val - min_val)
193    else:
194        if diverging_colormap:
195            assert (
196                matrix.min() >= -1 and matrix.max() <= 1
197            ), "For diverging colormaps without normalization, matrix values must be in range [-1, 1]"
198            normalized_matrix = matrix
199        else:
200            assert (
201                matrix.min() >= 0 and matrix.max() <= 1
202            ), "Matrix values must be in range [0, 1], or normalize must be True"
203            normalized_matrix = matrix
204
205    # get the colormap
206    cmap_: Colormap
207    if isinstance(cmap, str):
208        cmap_ = matplotlib.colormaps[cmap]
209    elif isinstance(cmap, Colormap):
210        cmap_ = cmap
211    else:
212        raise TypeError(
213            f"Invalid type for {cmap = }, {type(cmap) = }, must be str or Colormap"
214        )
215
216    # Apply the colormap
217    rgb_matrix: Float[np.ndarray, "n m channels=3"] = (  # noqa: F722
218        cmap_(normalized_matrix)[:, :, :3] * 255
219    ).astype(np.uint8)  # Drop alpha channel
220
221    assert rgb_matrix.shape == (
222        matrix.shape[0],
223        matrix.shape[1],
224        3,
225    ), f"Matrix after colormap must have 3 channels, got {rgb_matrix.shape = }"
226
227    return rgb_matrix

preprocess a 2D matrix into a plottable heatmap image

Parameters:

  • matrix : Matrix2D input matrix
  • normalize : bool whether to normalize the matrix to range [0, 1] (defaults to MATRIX_SAVE_NORMALIZE)
  • cmap : str|Colormap the colormap to use for the matrix (defaults to MATRIX_SAVE_CMAP)
  • diverging_colormap : bool if True and using a diverging colormap, ensures 0 values map to the center of the colormap (defaults to False)
  • normalize_min : float|None if a float, then for normalize=True and diverging_colormap=False, the minimum value to normalize to (generally set this to zero?). if None, then the minimum value of the matrix is used. if diverging_colormap=True OR normalize=False, this must be None. (defaults to None)

Returns:

def matrix2drgb_to_png_bytes( matrix: jaxtyping.UInt8[ndarray, 'n m rgb=3'], buffer: _io.BytesIO | None = None) -> bytes | None:
234def matrix2drgb_to_png_bytes(
235    matrix: Matrix2Drgb, buffer: io.BytesIO | None = None
236) -> bytes | None:
237    """Convert a `Matrix2Drgb` to valid PNG bytes via PIL
238
239    - if `buffer` is provided, it will write the PNG bytes to the buffer and return `None`
240    - if `buffer` is not provided, it will return the PNG bytes
241
242    # Parameters:
243     - `matrix : Matrix2Drgb`
244     - `buffer : io.BytesIO | None`
245       (defaults to `None`, in which case it will return the PNG bytes)
246
247    # Returns:
248     - `bytes|None`
249       `bytes` if `buffer` is `None`, otherwise `None`
250    """
251
252    pil_img: Image.Image = Image.fromarray(matrix, mode="RGB")
253    if buffer is None:
254        buffer = io.BytesIO()
255        pil_img.save(buffer, format="PNG")
256        buffer.seek(0)
257        return buffer.read()
258    else:
259        pil_img.save(buffer, format="PNG")
260        return None

Convert a Matrix2Drgb to valid PNG bytes via PIL

  • if buffer is provided, it will write the PNG bytes to the buffer and return None
  • if buffer is not provided, it will return the PNG bytes

Parameters:

  • matrix : Matrix2Drgb
  • buffer : io.BytesIO | None (defaults to None, in which case it will return the PNG bytes)

Returns:

  • bytes|None bytes if buffer is None, otherwise None
def matrix_as_svg( matrix: jaxtyping.Float[ndarray, 'n m'], normalize: bool = False, cmap: str | matplotlib.colors.Colormap = 'viridis', diverging_colormap: bool = False, normalize_min: float | None = None) -> str:
263def matrix_as_svg(
264    matrix: Matrix2D,
265    normalize: bool = MATRIX_SAVE_NORMALIZE,
266    cmap: str | Colormap = MATRIX_SAVE_CMAP,
267    diverging_colormap: bool = False,
268    normalize_min: float | None = None,
269) -> str:
270    """quickly convert a 2D matrix to an SVG image, without matplotlib
271
272    # Parameters:
273     - `matrix : Float[np.ndarray, 'n m']`
274       a 2D matrix to convert to an SVG image
275     - `normalize : bool`
276       whether to normalize the matrix to range [0, 1]. if it's not in the range [0, 1], this must be `True` or it will raise an `AssertionError`
277       (defaults to `False`)
278     - `cmap : str`
279       the colormap to use for the matrix -- will look up in `matplotlib.colormaps` if it's a string
280       (defaults to `"viridis"`)
281     - `diverging_colormap : bool`
282        if True and using a diverging colormap, ensures 0 values map to the center of the colormap
283        (defaults to False)
284     - `normalize_min : float|None`
285        if a float, then for `normalize=True` and `diverging_colormap=False`, the minimum value to normalize to (generally set this to zero?)
286        if `None`, then the minimum value of the matrix is used
287        if `diverging_colormap=True` OR `normalize=False`, this **must** be `None`
288        (defaults to `None`)
289
290
291    # Returns:
292     - `str`
293       the SVG content for the matrix
294    """
295    # Get the dimensions of the matrix
296    m, n = matrix.shape
297
298    # Preprocess the matrix into an RGB image
299    matrix_rgb: Matrix2Drgb = matrix_to_image_preprocess(
300        matrix,
301        normalize=normalize,
302        cmap=cmap,
303        diverging_colormap=diverging_colormap,
304        normalize_min=normalize_min,
305    )
306
307    # Convert the RGB image to PNG bytes
308    image_data: bytes = matrix2drgb_to_png_bytes(matrix_rgb)
309
310    # Encode the PNG bytes as base64
311    png_base64: str = base64.b64encode(image_data).decode("utf-8")
312
313    # Generate the SVG content
314    svg_content: str = MATRIX_SAVE_SVG_TEMPLATE.format(m=m, n=n, png_base64=png_base64)
315
316    return svg_content

quickly convert a 2D matrix to an SVG image, without matplotlib

Parameters:

  • matrix : Float[np.ndarray, 'n m'] a 2D matrix to convert to an SVG image
  • normalize : bool whether to normalize the matrix to range [0, 1]. if it's not in the range [0, 1], this must be True or it will raise an AssertionError (defaults to False)
  • cmap : str the colormap to use for the matrix -- will look up in matplotlib.colormaps if it's a string (defaults to "viridis")
  • diverging_colormap : bool if True and using a diverging colormap, ensures 0 values map to the center of the colormap (defaults to False)
  • normalize_min : float|None if a float, then for normalize=True and diverging_colormap=False, the minimum value to normalize to (generally set this to zero?) if None, then the minimum value of the matrix is used if diverging_colormap=True OR normalize=False, this must be None (defaults to None)

Returns:

  • str the SVG content for the matrix
def save_matrix_wrapper( func: Optional[Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx']], jaxtyping.Float[ndarray, 'n m']]] = None, *args, fmt: Literal['png', 'svg', 'svgz'] = 'svgz', normalize: bool = False, cmap: str | matplotlib.colors.Colormap = 'viridis', diverging_colormap: bool = False, normalize_min: float | None = None) -> Union[Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], pathlib.Path], NoneType], Callable[[Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx']], jaxtyping.Float[ndarray, 'n m']]], Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], pathlib.Path], NoneType]]]:
339def save_matrix_wrapper(
340    func: AttentionMatrixToMatrixFunc | None = None,
341    *args,
342    fmt: MatrixSaveFormat = MATRIX_SAVE_FMT,
343    normalize: bool = MATRIX_SAVE_NORMALIZE,
344    cmap: str | Colormap = MATRIX_SAVE_CMAP,
345    diverging_colormap: bool = False,
346    normalize_min: float | None = None,
347) -> (
348    AttentionMatrixFigureFunc
349    | Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc]
350):
351    """
352    Decorator for functions that process an attention matrix and save it as an SVGZ image.
353    Can handle both argumentless usage and with arguments.
354
355    # Parameters:
356
357     - `func : AttentionMatrixToMatrixFunc|None`
358        Either the function to decorate (in the no-arguments case) or `None` when used with arguments.
359     - `fmt : MatrixSaveFormat, keyword-only`
360        The format to save the matrix as. Defaults to `MATRIX_SAVE_FMT`.
361     - `normalize : bool, keyword-only`
362        Whether to normalize the matrix to range [0, 1]. Defaults to `False`.
363     - `cmap : str, keyword-only`
364        The colormap to use for the matrix. Defaults to `MATRIX_SVG_CMAP`.
365     - `diverging_colormap : bool`
366        if True and using a diverging colormap, ensures 0 values map to the center of the colormap
367        (defaults to False)
368     - `normalize_min : float|None`
369        if a float, then for `normalize=True` and `diverging_colormap=False`, the minimum value to normalize to (generally set this to zero?)
370        if `None`, then the minimum value of the matrix is used
371        if `diverging_colormap=True` OR `normalize=False`, this **must** be `None`
372        (defaults to `None`)
373
374    # Returns:
375
376    `AttentionMatrixFigureFunc|Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc]`
377
378    - `AttentionMatrixFigureFunc` if `func` is `AttentionMatrixToMatrixFunc` (no arguments case)
379    - `Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc]` if `func` is `None` -- returns the decorator which will then be applied to the  (with arguments case)
380
381    # Usage:
382
383    ```python
384    @save_matrix_wrapper
385    def identity_matrix(matrix):
386        return matrix
387
388    @save_matrix_wrapper(normalize=True, fmt="png")
389    def scale_matrix(matrix):
390        return matrix * 2
391
392    @save_matrix_wrapper(normalize=True, cmap="plasma")
393    def scale_matrix(matrix):
394        return matrix * 2
395
396    ```
397    """
398
399    assert len(args) == 0, "This decorator only supports keyword arguments"
400
401    assert (
402        fmt in MatrixSaveFormat.__args__  # type: ignore[attr-defined]
403    ), f"Invalid format {fmt = }, must be one of {MatrixSaveFormat.__args__}"  # type: ignore[attr-defined]
404
405    def decorator(
406        func: Callable[[AttentionMatrix], Matrix2D],
407    ) -> AttentionMatrixFigureFunc:
408        @functools.wraps(func)
409        def wrapped(attn_matrix: AttentionMatrix, save_dir: Path) -> None:
410            fig_path: Path = save_dir / f"{func.__name__}.{fmt}"
411            processed_matrix: Matrix2D = func(attn_matrix)
412
413            if fmt == "png":
414                processed_matrix_rgb: Matrix2Drgb = matrix_to_image_preprocess(
415                    processed_matrix,
416                    normalize=normalize,
417                    cmap=cmap,
418                    diverging_colormap=diverging_colormap,
419                    normalize_min=normalize_min,
420                )
421                image_data: bytes = matrix2drgb_to_png_bytes(processed_matrix_rgb)
422                fig_path.write_bytes(image_data)
423
424            else:
425                svg_content: str = matrix_as_svg(
426                    processed_matrix,
427                    normalize=normalize,
428                    cmap=cmap,
429                    diverging_colormap=diverging_colormap,
430                    normalize_min=normalize_min,
431                )
432
433                if fmt == "svgz":
434                    with gzip.open(fig_path, "wt") as f:
435                        f.write(svg_content)
436
437                else:
438                    fig_path.write_text(svg_content, encoding="utf-8")
439
440        wrapped.figure_save_fmt = fmt  # type: ignore[attr-defined]
441
442        return wrapped
443
444    if callable(func):
445        # Handle no-arguments case
446        return decorator(func)
447    else:
448        # Handle arguments case
449        return decorator

Decorator for functions that process an attention matrix and save it as an SVGZ image. Can handle both argumentless usage and with arguments.

Parameters:

  • func : AttentionMatrixToMatrixFunc|None Either the function to decorate (in the no-arguments case) or None when used with arguments.
  • fmt : MatrixSaveFormat, keyword-only The format to save the matrix as. Defaults to MATRIX_SAVE_FMT.
  • normalize : bool, keyword-only Whether to normalize the matrix to range [0, 1]. Defaults to False.
  • cmap : str, keyword-only The colormap to use for the matrix. Defaults to MATRIX_SVG_CMAP.
  • diverging_colormap : bool if True and using a diverging colormap, ensures 0 values map to the center of the colormap (defaults to False)
  • normalize_min : float|None if a float, then for normalize=True and diverging_colormap=False, the minimum value to normalize to (generally set this to zero?) if None, then the minimum value of the matrix is used if diverging_colormap=True OR normalize=False, this must be None (defaults to None)

Returns:

AttentionMatrixFigureFunc|Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc]

Usage:

@save_matrix_wrapper
def identity_matrix(matrix):
    return matrix

@save_matrix_wrapper(normalize=True, fmt="png")
def scale_matrix(matrix):
    return matrix * 2

@save_matrix_wrapper(normalize=True, cmap="plasma")
def scale_matrix(matrix):
    return matrix * 2