docs for pattern_lens v0.6.0
View Source on GitHub

pattern_lens.figure_util

implements a bunch of types, default values, and templates which are useful for figure functions

notably, you can use the decorators matplotlib_figure_saver, save_matrix_wrapper to make your functions save figures


  1"""implements a bunch of types, default values, and templates which are useful for figure functions
  2
  3notably, you can use the decorators `matplotlib_figure_saver`, `save_matrix_wrapper` to make your functions save figures
  4"""
  5
  6import base64
  7import functools
  8import gzip
  9import io
 10from collections.abc import Callable, Sequence
 11from pathlib import Path
 12from typing import Literal, overload
 13
 14import matplotlib as mpl
 15import matplotlib.pyplot as plt
 16import numpy as np
 17from jaxtyping import Float, UInt8
 18from matplotlib.colors import Colormap
 19from PIL import Image
 20
 21from pattern_lens.consts import AttentionMatrix
 22
 23AttentionMatrixFigureFunc = Callable[[AttentionMatrix, Path], None]
 24"Type alias for a function that, given an attention matrix, saves one or more figures"
 25
 26Matrix2D = Float[np.ndarray, "n m"]
 27"Type alias for a 2D matrix (plottable)"
 28
 29Matrix2Drgb = UInt8[np.ndarray, "n m rgb=3"]
 30"Type alias for a 2D matrix with 3 channels (RGB)"
 31
 32AttentionMatrixToMatrixFunc = Callable[[AttentionMatrix], Matrix2D]
 33"Type alias for a function that, given an attention matrix, returns a 2D matrix"
 34
 35MATPLOTLIB_FIGURE_FMT: str = "svgz"
 36"format for saving matplotlib figures"
 37
 38MatrixSaveFormat = Literal["png", "svg", "svgz"]
 39"Type alias for the format to save a matrix as when saving raw matrix, not matplotlib figure"
 40
 41MATRIX_SAVE_NORMALIZE: bool = False
 42"default for whether to normalize the matrix to range [0, 1]"
 43
 44MATRIX_SAVE_CMAP: str = "viridis"
 45"default colormap for saving matrices"
 46
 47MATRIX_SAVE_FMT: MatrixSaveFormat = "svgz"
 48"default format for saving matrices"
 49
 50MATRIX_SAVE_SVG_TEMPLATE: str = """<svg xmlns="http://www.w3.org/2000/svg" width="{m}" height="{n}" viewBox="0 0 {m} {n}" image-rendering="pixelated"> <image href="data:image/png;base64,{png_base64}" width="{m}" height="{n}" /> </svg>"""
 51"template for saving an `n` by `m` matrix as an svg/svgz"
 52
 53
 54# TYPING: mypy hates it when we dont pass func=None or None as the first arg
 55@overload  # without keyword arguments, returns decorated function
 56def matplotlib_figure_saver(
 57	func: Callable[[AttentionMatrix, plt.Axes], None],
 58) -> AttentionMatrixFigureFunc: ...
 59@overload  # with keyword arguments, returns decorator
 60def matplotlib_figure_saver(
 61	func: None = None,
 62	fmt: str = MATPLOTLIB_FIGURE_FMT,
 63) -> Callable[
 64	[Callable[[AttentionMatrix, plt.Axes], None]],
 65	AttentionMatrixFigureFunc,
 66]: ...
 67def matplotlib_figure_saver(
 68	func: Callable[[AttentionMatrix, plt.Axes], None] | None = None,
 69	fmt: str = MATPLOTLIB_FIGURE_FMT,
 70) -> (
 71	AttentionMatrixFigureFunc
 72	| Callable[
 73		[Callable[[AttentionMatrix, plt.Axes], None]],
 74		AttentionMatrixFigureFunc,
 75	]
 76):
 77	"""decorator for functions which take an attention matrix and predefined `ax` object, making it save a figure
 78
 79	# Parameters:
 80	- `func : Callable[[AttentionMatrix, plt.Axes], None]`
 81		your function, which should take an attention matrix and predefined `ax` object
 82	- `fmt : str`
 83		format for saving matplotlib figures
 84		(defaults to `MATPLOTLIB_FIGURE_FMT`)
 85
 86	# Returns:
 87	- `AttentionMatrixFigureFunc`
 88		your function, after we wrap it to save a figure
 89
 90	# Usage:
 91	```python
 92	@register_attn_figure_func
 93	@matplotlib_figure_saver
 94	def raw(attn_matrix: AttentionMatrix, ax: plt.Axes) -> None:
 95		ax.matshow(attn_matrix, cmap="viridis")
 96		ax.set_title("Raw Attention Pattern")
 97		ax.axis("off")
 98	```
 99
100	"""
101
102	def decorator(
103		func: Callable[[AttentionMatrix, plt.Axes], None],
104		fmt: str = fmt,
105	) -> AttentionMatrixFigureFunc:
106		@functools.wraps(func)
107		def wrapped(attn_matrix: AttentionMatrix, save_dir: Path) -> None:
108			fig_path: Path = (
109				save_dir / f"{getattr(func, '__name__', '<unknown>')}.{fmt}"
110			)
111
112			fig, ax = plt.subplots(figsize=(10, 10))
113			func(attn_matrix, ax)
114			plt.tight_layout()
115			plt.savefig(fig_path)
116			plt.close(fig)
117
118		wrapped.figure_save_fmt = fmt  # type: ignore[attr-defined]
119
120		return wrapped
121
122	if callable(func):
123		# Handle no-arguments case
124		return decorator(func)
125	else:
126		# Handle arguments case
127		return decorator
128
129
130def matplotlib_multifigure_saver(
131	names: Sequence[str],
132	fmt: str = MATPLOTLIB_FIGURE_FMT,
133) -> Callable[
134	# decorator takes in function
135	# which takes a matrix and a dictionary of axes corresponding to the names
136	[Callable[[AttentionMatrix, dict[str, plt.Axes]], None]],
137	# returns the decorated function
138	AttentionMatrixFigureFunc,
139]:
140	"""decorate a function such that it saves multiple figures, one for each name in `names`
141
142	# Parameters:
143	- `names : Sequence[str]`
144		the names of the figures to save
145	- `fmt : str`
146		format for saving matplotlib figures
147		(defaults to `MATPLOTLIB_FIGURE_FMT`)
148
149	# Returns:
150	- `Callable[[Callable[[AttentionMatrix, dict[str, plt.Axes]], None], AttentionMatrixFigureFunc]`
151		the decorator, which will then be applied to the function
152		we expect the decorated function to take an attention pattern, and a dict of axes corresponding to the names
153
154	"""
155
156	def decorator(
157		func: Callable[[AttentionMatrix, dict[str, plt.Axes]], None],
158	) -> AttentionMatrixFigureFunc:
159		func_name: str = getattr(func, "__name__", "<unknown>")
160
161		@functools.wraps(func)
162		def wrapped(attn_matrix: AttentionMatrix, save_dir: Path) -> None:
163			# set up axes and corresponding figures
164			axes_dict: dict[str, plt.Axes] = {}
165			figs_dict: dict[str, plt.Figure] = {}
166
167			# Create all figures and axes
168			for name in names:
169				fig, ax = plt.subplots(figsize=(10, 10))
170				axes_dict[name] = ax
171				figs_dict[name] = fig
172
173			try:
174				# Run the function to make plots
175				func(attn_matrix, axes_dict)
176
177				# Save each figure
178				for name, fig_ in figs_dict.items():
179					fig_path: Path = save_dir / f"{func_name}.{name}.{fmt}"
180					# TYPING: error: Item "SubFigure" of "Figure | SubFigure" has no attribute "tight_layout"  [union-attr]
181					fig_.tight_layout()  # type: ignore[union-attr]
182					# TYPING: error: Item "SubFigure" of "Figure | SubFigure" has no attribute "savefig"  [union-attr]
183					fig_.savefig(fig_path)  # type: ignore[union-attr]
184			finally:
185				# Always clean up figures, even if an error occurred
186				for fig in figs_dict.values():
187					# TYPING: error: Argument 1 to "close" has incompatible type "Figure | SubFigure"; expected "int | str | Figure | Literal['all'] | None"  [arg-type]
188					plt.close(fig)  # type: ignore[arg-type]
189
190		# it doesn't normally have this attribute, but we're adding it
191		wrapped.figure_save_fmt = fmt  # type: ignore[attr-defined]
192
193		return wrapped
194
195	return decorator
196
197
198def matrix_to_image_preprocess(
199	matrix: Matrix2D,
200	normalize: bool = False,
201	cmap: str | Colormap = "viridis",
202	diverging_colormap: bool = False,
203	normalize_min: float | None = None,
204) -> Matrix2Drgb:
205	"""preprocess a 2D matrix into a plottable heatmap image
206
207	# Parameters:
208	- `matrix : Matrix2D`
209		input matrix
210	- `normalize : bool`
211		whether to normalize the matrix to range [0, 1]
212		(defaults to `MATRIX_SAVE_NORMALIZE`)
213	- `cmap : str|Colormap`
214		the colormap to use for the matrix
215		(defaults to `MATRIX_SAVE_CMAP`)
216	- `diverging_colormap : bool`
217		if True and using a diverging colormap, ensures 0 values map to the center of the colormap
218		(defaults to False)
219	- `normalize_min : float|None`
220		if a float, then for `normalize=True` and `diverging_colormap=False`, the minimum value to normalize to (generally set this to zero?).
221		if `None`, then the minimum value of the matrix is used.
222		if `diverging_colormap=True` OR `normalize=False`, this **must** be `None`.
223		(defaults to `None`)
224
225	# Returns:
226	- `Matrix2Drgb`
227	"""
228	# check dims (2 is not that magic of a value here, hence noqa)
229	assert matrix.ndim == 2, f"Matrix must be 2D, got {matrix.ndim = }"  # noqa: PLR2004
230
231	# check matrix is not empty
232	assert matrix.size > 0, "Matrix cannot be empty"
233
234	if normalize_min is not None:
235		assert not diverging_colormap, (
236			"normalize_min cannot be used with diverging_colormap=True"
237		)
238		assert normalize, "normalize_min cannot be used with normalize=False"
239
240	# Normalize the matrix to range [0, 1]
241	normalized_matrix: Matrix2D
242	if normalize:
243		if diverging_colormap:
244			# For diverging colormaps, we want to center around 0
245			max_abs: float = max(abs(matrix.max()), abs(matrix.min()))
246			normalized_matrix = (matrix / (2 * max_abs)) + 0.5
247		else:
248			max_val: float = matrix.max()
249			min_val: float
250			if normalize_min is not None:
251				min_val = normalize_min
252				assert min_val < max_val, "normalize_min must be less than matrix max"
253				assert min_val >= matrix.min(), (
254					"normalize_min must less than matrix min"
255				)
256			else:
257				min_val = matrix.min()
258
259			normalized_matrix = (matrix - min_val) / (max_val - min_val)
260	else:
261		if diverging_colormap:
262			assert matrix.min() >= -1 and matrix.max() <= 1, (  # noqa: PT018
263				"For diverging colormaps without normalization, matrix values must be in range [-1, 1]"
264			)
265			normalized_matrix = matrix
266		else:
267			assert matrix.min() >= 0 and matrix.max() <= 1, (  # noqa: PT018
268				"Matrix values must be in range [0, 1], or normalize must be True"
269			)
270			normalized_matrix = matrix
271
272	# get the colormap
273	cmap_: Colormap
274	if isinstance(cmap, str):
275		cmap_ = mpl.colormaps[cmap]
276	elif isinstance(cmap, Colormap):
277		cmap_ = cmap
278	else:
279		msg = f"Invalid type for {cmap = }, {type(cmap) = }, must be str or Colormap"
280		raise TypeError(
281			msg,
282		)
283
284	# Apply the colormap
285	rgb_matrix: Float[np.ndarray, "n m channels=3"] = (
286		cmap_(normalized_matrix)[:, :, :3] * 255
287	).astype(np.uint8)  # Drop alpha channel
288
289	assert rgb_matrix.shape == (
290		matrix.shape[0],
291		matrix.shape[1],
292		3,
293	), f"Matrix after colormap must have 3 channels, got {rgb_matrix.shape = }"
294
295	return rgb_matrix
296
297
298@overload
299def matrix2drgb_to_png_bytes(matrix: Matrix2Drgb, buffer: None = None) -> bytes: ...
300@overload
301def matrix2drgb_to_png_bytes(matrix: Matrix2Drgb, buffer: io.BytesIO) -> None: ...
302def matrix2drgb_to_png_bytes(
303	matrix: Matrix2Drgb,
304	buffer: io.BytesIO | None = None,
305) -> bytes | None:
306	"""Convert a `Matrix2Drgb` to valid PNG bytes via PIL
307
308	- if `buffer` is provided, it will write the PNG bytes to the buffer and return `None`
309	- if `buffer` is not provided, it will return the PNG bytes
310
311	# Parameters:
312	- `matrix : Matrix2Drgb`
313	- `buffer : io.BytesIO | None`
314		(defaults to `None`, in which case it will return the PNG bytes)
315
316	# Returns:
317	- `bytes|None`
318		`bytes` if `buffer` is `None`, otherwise `None`
319	"""
320	pil_img: Image.Image = Image.fromarray(matrix, mode="RGB")
321	if buffer is None:
322		buffer = io.BytesIO()
323		pil_img.save(buffer, format="PNG")
324		buffer.seek(0)
325		return buffer.read()
326	else:
327		pil_img.save(buffer, format="PNG")
328		return None
329
330
331def matrix_as_svg(
332	matrix: Matrix2D,
333	normalize: bool = MATRIX_SAVE_NORMALIZE,
334	cmap: str | Colormap = MATRIX_SAVE_CMAP,
335	diverging_colormap: bool = False,
336	normalize_min: float | None = None,
337) -> str:
338	"""quickly convert a 2D matrix to an SVG image, without matplotlib
339
340	# Parameters:
341	- `matrix : Float[np.ndarray, 'n m']`
342		a 2D matrix to convert to an SVG image
343	- `normalize : bool`
344		whether to normalize the matrix to range [0, 1]. if it's not in the range [0, 1], this must be `True` or it will raise an `AssertionError`
345		(defaults to `False`)
346	- `cmap : str`
347		the colormap to use for the matrix -- will look up in `matplotlib.colormaps` if it's a string
348		(defaults to `"viridis"`)
349	- `diverging_colormap : bool`
350		if True and using a diverging colormap, ensures 0 values map to the center of the colormap
351		(defaults to False)
352	- `normalize_min : float|None`
353		if a float, then for `normalize=True` and `diverging_colormap=False`, the minimum value to normalize to (generally set this to zero?)
354		if `None`, then the minimum value of the matrix is used
355		if `diverging_colormap=True` OR `normalize=False`, this **must** be `None`
356		(defaults to `None`)
357
358
359	# Returns:
360	- `str`
361		the SVG content for the matrix
362	"""
363	# Get the dimensions of the matrix
364	assert matrix.ndim == 2, f"Matrix must be 2D, got {matrix.shape = }"  # noqa: PLR2004
365	m, n = matrix.shape
366
367	# Preprocess the matrix into an RGB image
368	matrix_rgb: Matrix2Drgb = matrix_to_image_preprocess(
369		matrix,
370		normalize=normalize,
371		cmap=cmap,
372		diverging_colormap=diverging_colormap,
373		normalize_min=normalize_min,
374	)
375
376	# Convert the RGB image to PNG bytes
377	image_data: bytes = matrix2drgb_to_png_bytes(matrix_rgb)
378
379	# Encode the PNG bytes as base64
380	png_base64: str = base64.b64encode(image_data).decode("utf-8")
381
382	# Generate the SVG content
383	svg_content: str = MATRIX_SAVE_SVG_TEMPLATE.format(m=m, n=n, png_base64=png_base64)
384
385	return svg_content
386
387
388@overload  # with keyword arguments, returns decorator
389def save_matrix_wrapper(
390	func: None = None,
391	*args: tuple[()],
392	fmt: MatrixSaveFormat = MATRIX_SAVE_FMT,
393	normalize: bool = MATRIX_SAVE_NORMALIZE,
394	cmap: str | Colormap = MATRIX_SAVE_CMAP,
395	diverging_colormap: bool = False,
396	normalize_min: float | None = None,
397) -> Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc]: ...
398@overload  # without keyword arguments, returns decorated function
399def save_matrix_wrapper(
400	func: AttentionMatrixToMatrixFunc,
401	*args: tuple[()],
402	fmt: MatrixSaveFormat = MATRIX_SAVE_FMT,
403	normalize: bool = MATRIX_SAVE_NORMALIZE,
404	cmap: str | Colormap = MATRIX_SAVE_CMAP,
405	diverging_colormap: bool = False,
406	normalize_min: float | None = None,
407) -> AttentionMatrixFigureFunc: ...
408def save_matrix_wrapper(
409	func: AttentionMatrixToMatrixFunc | None = None,
410	*args,
411	fmt: MatrixSaveFormat = MATRIX_SAVE_FMT,
412	normalize: bool = MATRIX_SAVE_NORMALIZE,
413	cmap: str | Colormap = MATRIX_SAVE_CMAP,
414	diverging_colormap: bool = False,
415	normalize_min: float | None = None,
416) -> (
417	AttentionMatrixFigureFunc
418	| Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc]
419):
420	"""Decorator for functions that process an attention matrix and save it as an SVGZ image.
421
422	Can handle both argumentless usage and with arguments.
423
424	# Parameters:
425
426	- `func : AttentionMatrixToMatrixFunc|None`
427		Either the function to decorate (in the no-arguments case) or `None` when used with arguments.
428	- `fmt : MatrixSaveFormat, keyword-only`
429		The format to save the matrix as. Defaults to `MATRIX_SAVE_FMT`.
430	- `normalize : bool, keyword-only`
431		Whether to normalize the matrix to range [0, 1]. Defaults to `False`.
432	- `cmap : str, keyword-only`
433		The colormap to use for the matrix. Defaults to `MATRIX_SVG_CMAP`.
434	- `diverging_colormap : bool`
435		if True and using a diverging colormap, ensures 0 values map to the center of the colormap
436		(defaults to False)
437	- `normalize_min : float|None`
438		if a float, then for `normalize=True` and `diverging_colormap=False`, the minimum value to normalize to (generally set this to zero?)
439		if `None`, then the minimum value of the matrix is used
440		if `diverging_colormap=True` OR `normalize=False`, this **must** be `None`
441		(defaults to `None`)
442
443	# Returns:
444
445	`AttentionMatrixFigureFunc|Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc]`
446
447	- `AttentionMatrixFigureFunc` if `func` is `AttentionMatrixToMatrixFunc` (no arguments case)
448	- `Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc]` if `func` is `None` -- returns the decorator which will then be applied to the  (with arguments case)
449
450	# Usage:
451
452	```python
453	@save_matrix_wrapper
454	def identity_matrix(matrix):
455		return matrix
456
457	@save_matrix_wrapper(normalize=True, fmt="png")
458	def scale_matrix(matrix):
459		return matrix * 2
460
461	@save_matrix_wrapper(normalize=True, cmap="plasma")
462	def scale_matrix(matrix):
463		return matrix * 2
464	```
465
466	"""
467	assert len(args) == 0, "This decorator only supports keyword arguments"
468
469	assert (
470		fmt in MatrixSaveFormat.__args__  # type: ignore[attr-defined]
471	), f"Invalid format {fmt = }, must be one of {MatrixSaveFormat.__args__}"  # type: ignore[attr-defined]
472
473	def decorator(
474		func: Callable[[AttentionMatrix], Matrix2D],
475	) -> AttentionMatrixFigureFunc:
476		@functools.wraps(func)
477		def wrapped(attn_matrix: AttentionMatrix, save_dir: Path) -> None:
478			fig_path: Path = (
479				save_dir / f"{getattr(func, '__name__', '<unknown>')}.{fmt}"
480			)
481			processed_matrix: Matrix2D = func(attn_matrix)
482
483			if fmt == "png":
484				processed_matrix_rgb: Matrix2Drgb = matrix_to_image_preprocess(
485					processed_matrix,
486					normalize=normalize,
487					cmap=cmap,
488					diverging_colormap=diverging_colormap,
489					normalize_min=normalize_min,
490				)
491				image_data: bytes = matrix2drgb_to_png_bytes(processed_matrix_rgb)
492				fig_path.write_bytes(image_data)
493
494			else:
495				svg_content: str = matrix_as_svg(
496					processed_matrix,
497					normalize=normalize,
498					cmap=cmap,
499					diverging_colormap=diverging_colormap,
500					normalize_min=normalize_min,
501				)
502
503				if fmt == "svgz":
504					with gzip.open(fig_path, "wt") as f:
505						f.write(svg_content)
506
507				else:
508					fig_path.write_text(svg_content, encoding="utf-8")
509
510		wrapped.figure_save_fmt = fmt  # type: ignore[attr-defined]
511
512		return wrapped
513
514	if callable(func):
515		# Handle no-arguments case
516		return decorator(func)
517	else:
518		# Handle arguments case
519		return decorator

AttentionMatrixFigureFunc = collections.abc.Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], pathlib._local.Path], None]

Type alias for a function that, given an attention matrix, saves one or more figures

Matrix2D = <class 'jaxtyping.Float[ndarray, 'n m']'>

Type alias for a 2D matrix (plottable)

Matrix2Drgb = <class 'jaxtyping.UInt8[ndarray, 'n m rgb=3']'>

Type alias for a 2D matrix with 3 channels (RGB)

AttentionMatrixToMatrixFunc = collections.abc.Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx']], jaxtyping.Float[ndarray, 'n m']]

Type alias for a function that, given an attention matrix, returns a 2D matrix

MATPLOTLIB_FIGURE_FMT: str = 'svgz'

format for saving matplotlib figures

MatrixSaveFormat = typing.Literal['png', 'svg', 'svgz']

Type alias for the format to save a matrix as when saving raw matrix, not matplotlib figure

MATRIX_SAVE_NORMALIZE: bool = False

default for whether to normalize the matrix to range [0, 1]

MATRIX_SAVE_CMAP: str = 'viridis'

default colormap for saving matrices

MATRIX_SAVE_FMT: Literal['png', 'svg', 'svgz'] = 'svgz'

default format for saving matrices

MATRIX_SAVE_SVG_TEMPLATE: str = '<svg xmlns="http://www.w3.org/2000/svg" width="{m}" height="{n}" viewBox="0 0 {m} {n}" image-rendering="pixelated"> <image href="data:image/png;base64,{png_base64}" width="{m}" height="{n}" /> </svg>'

template for saving an n by m matrix as an svg/svgz

def matplotlib_figure_saver( func: Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], matplotlib.axes._axes.Axes], None] | None = None, fmt: str = 'svgz') -> Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], pathlib._local.Path], None] | Callable[[Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], matplotlib.axes._axes.Axes], None]], Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], pathlib._local.Path], None]]:
 68def matplotlib_figure_saver(
 69	func: Callable[[AttentionMatrix, plt.Axes], None] | None = None,
 70	fmt: str = MATPLOTLIB_FIGURE_FMT,
 71) -> (
 72	AttentionMatrixFigureFunc
 73	| Callable[
 74		[Callable[[AttentionMatrix, plt.Axes], None]],
 75		AttentionMatrixFigureFunc,
 76	]
 77):
 78	"""decorator for functions which take an attention matrix and predefined `ax` object, making it save a figure
 79
 80	# Parameters:
 81	- `func : Callable[[AttentionMatrix, plt.Axes], None]`
 82		your function, which should take an attention matrix and predefined `ax` object
 83	- `fmt : str`
 84		format for saving matplotlib figures
 85		(defaults to `MATPLOTLIB_FIGURE_FMT`)
 86
 87	# Returns:
 88	- `AttentionMatrixFigureFunc`
 89		your function, after we wrap it to save a figure
 90
 91	# Usage:
 92	```python
 93	@register_attn_figure_func
 94	@matplotlib_figure_saver
 95	def raw(attn_matrix: AttentionMatrix, ax: plt.Axes) -> None:
 96		ax.matshow(attn_matrix, cmap="viridis")
 97		ax.set_title("Raw Attention Pattern")
 98		ax.axis("off")
 99	```
100
101	"""
102
103	def decorator(
104		func: Callable[[AttentionMatrix, plt.Axes], None],
105		fmt: str = fmt,
106	) -> AttentionMatrixFigureFunc:
107		@functools.wraps(func)
108		def wrapped(attn_matrix: AttentionMatrix, save_dir: Path) -> None:
109			fig_path: Path = (
110				save_dir / f"{getattr(func, '__name__', '<unknown>')}.{fmt}"
111			)
112
113			fig, ax = plt.subplots(figsize=(10, 10))
114			func(attn_matrix, ax)
115			plt.tight_layout()
116			plt.savefig(fig_path)
117			plt.close(fig)
118
119		wrapped.figure_save_fmt = fmt  # type: ignore[attr-defined]
120
121		return wrapped
122
123	if callable(func):
124		# Handle no-arguments case
125		return decorator(func)
126	else:
127		# Handle arguments case
128		return decorator

decorator for functions which take an attention matrix and predefined ax object, making it save a figure

Parameters:

  • func : Callable[[AttentionMatrix, plt.Axes], None] your function, which should take an attention matrix and predefined ax object
  • fmt : str format for saving matplotlib figures (defaults to MATPLOTLIB_FIGURE_FMT)

Returns:

Usage:

@register_attn_figure_func
@matplotlib_figure_saver
def raw(attn_matrix: AttentionMatrix, ax: plt.Axes) -> None:
        ax.matshow(attn_matrix, cmap="viridis")
        ax.set_title("Raw Attention Pattern")
        ax.axis("off")
def matplotlib_multifigure_saver( names: Sequence[str], fmt: str = 'svgz') -> Callable[[Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], dict[str, matplotlib.axes._axes.Axes]], None]], Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], pathlib._local.Path], None]]:
131def matplotlib_multifigure_saver(
132	names: Sequence[str],
133	fmt: str = MATPLOTLIB_FIGURE_FMT,
134) -> Callable[
135	# decorator takes in function
136	# which takes a matrix and a dictionary of axes corresponding to the names
137	[Callable[[AttentionMatrix, dict[str, plt.Axes]], None]],
138	# returns the decorated function
139	AttentionMatrixFigureFunc,
140]:
141	"""decorate a function such that it saves multiple figures, one for each name in `names`
142
143	# Parameters:
144	- `names : Sequence[str]`
145		the names of the figures to save
146	- `fmt : str`
147		format for saving matplotlib figures
148		(defaults to `MATPLOTLIB_FIGURE_FMT`)
149
150	# Returns:
151	- `Callable[[Callable[[AttentionMatrix, dict[str, plt.Axes]], None], AttentionMatrixFigureFunc]`
152		the decorator, which will then be applied to the function
153		we expect the decorated function to take an attention pattern, and a dict of axes corresponding to the names
154
155	"""
156
157	def decorator(
158		func: Callable[[AttentionMatrix, dict[str, plt.Axes]], None],
159	) -> AttentionMatrixFigureFunc:
160		func_name: str = getattr(func, "__name__", "<unknown>")
161
162		@functools.wraps(func)
163		def wrapped(attn_matrix: AttentionMatrix, save_dir: Path) -> None:
164			# set up axes and corresponding figures
165			axes_dict: dict[str, plt.Axes] = {}
166			figs_dict: dict[str, plt.Figure] = {}
167
168			# Create all figures and axes
169			for name in names:
170				fig, ax = plt.subplots(figsize=(10, 10))
171				axes_dict[name] = ax
172				figs_dict[name] = fig
173
174			try:
175				# Run the function to make plots
176				func(attn_matrix, axes_dict)
177
178				# Save each figure
179				for name, fig_ in figs_dict.items():
180					fig_path: Path = save_dir / f"{func_name}.{name}.{fmt}"
181					# TYPING: error: Item "SubFigure" of "Figure | SubFigure" has no attribute "tight_layout"  [union-attr]
182					fig_.tight_layout()  # type: ignore[union-attr]
183					# TYPING: error: Item "SubFigure" of "Figure | SubFigure" has no attribute "savefig"  [union-attr]
184					fig_.savefig(fig_path)  # type: ignore[union-attr]
185			finally:
186				# Always clean up figures, even if an error occurred
187				for fig in figs_dict.values():
188					# TYPING: error: Argument 1 to "close" has incompatible type "Figure | SubFigure"; expected "int | str | Figure | Literal['all'] | None"  [arg-type]
189					plt.close(fig)  # type: ignore[arg-type]
190
191		# it doesn't normally have this attribute, but we're adding it
192		wrapped.figure_save_fmt = fmt  # type: ignore[attr-defined]
193
194		return wrapped
195
196	return decorator

decorate a function such that it saves multiple figures, one for each name in names

Parameters:

  • names : Sequence[str] the names of the figures to save
  • fmt : str format for saving matplotlib figures (defaults to MATPLOTLIB_FIGURE_FMT)

Returns:

  • Callable[[Callable[[AttentionMatrix, dict[str, plt.Axes]], None], AttentionMatrixFigureFunc] the decorator, which will then be applied to the function we expect the decorated function to take an attention pattern, and a dict of axes corresponding to the names
def matrix_to_image_preprocess( matrix: jaxtyping.Float[ndarray, 'n m'], normalize: bool = False, cmap: str | matplotlib.colors.Colormap = 'viridis', diverging_colormap: bool = False, normalize_min: float | None = None) -> jaxtyping.UInt8[ndarray, 'n m rgb=3']:
199def matrix_to_image_preprocess(
200	matrix: Matrix2D,
201	normalize: bool = False,
202	cmap: str | Colormap = "viridis",
203	diverging_colormap: bool = False,
204	normalize_min: float | None = None,
205) -> Matrix2Drgb:
206	"""preprocess a 2D matrix into a plottable heatmap image
207
208	# Parameters:
209	- `matrix : Matrix2D`
210		input matrix
211	- `normalize : bool`
212		whether to normalize the matrix to range [0, 1]
213		(defaults to `MATRIX_SAVE_NORMALIZE`)
214	- `cmap : str|Colormap`
215		the colormap to use for the matrix
216		(defaults to `MATRIX_SAVE_CMAP`)
217	- `diverging_colormap : bool`
218		if True and using a diverging colormap, ensures 0 values map to the center of the colormap
219		(defaults to False)
220	- `normalize_min : float|None`
221		if a float, then for `normalize=True` and `diverging_colormap=False`, the minimum value to normalize to (generally set this to zero?).
222		if `None`, then the minimum value of the matrix is used.
223		if `diverging_colormap=True` OR `normalize=False`, this **must** be `None`.
224		(defaults to `None`)
225
226	# Returns:
227	- `Matrix2Drgb`
228	"""
229	# check dims (2 is not that magic of a value here, hence noqa)
230	assert matrix.ndim == 2, f"Matrix must be 2D, got {matrix.ndim = }"  # noqa: PLR2004
231
232	# check matrix is not empty
233	assert matrix.size > 0, "Matrix cannot be empty"
234
235	if normalize_min is not None:
236		assert not diverging_colormap, (
237			"normalize_min cannot be used with diverging_colormap=True"
238		)
239		assert normalize, "normalize_min cannot be used with normalize=False"
240
241	# Normalize the matrix to range [0, 1]
242	normalized_matrix: Matrix2D
243	if normalize:
244		if diverging_colormap:
245			# For diverging colormaps, we want to center around 0
246			max_abs: float = max(abs(matrix.max()), abs(matrix.min()))
247			normalized_matrix = (matrix / (2 * max_abs)) + 0.5
248		else:
249			max_val: float = matrix.max()
250			min_val: float
251			if normalize_min is not None:
252				min_val = normalize_min
253				assert min_val < max_val, "normalize_min must be less than matrix max"
254				assert min_val >= matrix.min(), (
255					"normalize_min must less than matrix min"
256				)
257			else:
258				min_val = matrix.min()
259
260			normalized_matrix = (matrix - min_val) / (max_val - min_val)
261	else:
262		if diverging_colormap:
263			assert matrix.min() >= -1 and matrix.max() <= 1, (  # noqa: PT018
264				"For diverging colormaps without normalization, matrix values must be in range [-1, 1]"
265			)
266			normalized_matrix = matrix
267		else:
268			assert matrix.min() >= 0 and matrix.max() <= 1, (  # noqa: PT018
269				"Matrix values must be in range [0, 1], or normalize must be True"
270			)
271			normalized_matrix = matrix
272
273	# get the colormap
274	cmap_: Colormap
275	if isinstance(cmap, str):
276		cmap_ = mpl.colormaps[cmap]
277	elif isinstance(cmap, Colormap):
278		cmap_ = cmap
279	else:
280		msg = f"Invalid type for {cmap = }, {type(cmap) = }, must be str or Colormap"
281		raise TypeError(
282			msg,
283		)
284
285	# Apply the colormap
286	rgb_matrix: Float[np.ndarray, "n m channels=3"] = (
287		cmap_(normalized_matrix)[:, :, :3] * 255
288	).astype(np.uint8)  # Drop alpha channel
289
290	assert rgb_matrix.shape == (
291		matrix.shape[0],
292		matrix.shape[1],
293		3,
294	), f"Matrix after colormap must have 3 channels, got {rgb_matrix.shape = }"
295
296	return rgb_matrix

preprocess a 2D matrix into a plottable heatmap image

Parameters:

  • matrix : Matrix2D input matrix
  • normalize : bool whether to normalize the matrix to range [0, 1] (defaults to MATRIX_SAVE_NORMALIZE)
  • cmap : str|Colormap the colormap to use for the matrix (defaults to MATRIX_SAVE_CMAP)
  • diverging_colormap : bool if True and using a diverging colormap, ensures 0 values map to the center of the colormap (defaults to False)
  • normalize_min : float|None if a float, then for normalize=True and diverging_colormap=False, the minimum value to normalize to (generally set this to zero?). if None, then the minimum value of the matrix is used. if diverging_colormap=True OR normalize=False, this must be None. (defaults to None)

Returns:

def matrix2drgb_to_png_bytes( matrix: jaxtyping.UInt8[ndarray, 'n m rgb=3'], buffer: _io.BytesIO | None = None) -> bytes | None:
303def matrix2drgb_to_png_bytes(
304	matrix: Matrix2Drgb,
305	buffer: io.BytesIO | None = None,
306) -> bytes | None:
307	"""Convert a `Matrix2Drgb` to valid PNG bytes via PIL
308
309	- if `buffer` is provided, it will write the PNG bytes to the buffer and return `None`
310	- if `buffer` is not provided, it will return the PNG bytes
311
312	# Parameters:
313	- `matrix : Matrix2Drgb`
314	- `buffer : io.BytesIO | None`
315		(defaults to `None`, in which case it will return the PNG bytes)
316
317	# Returns:
318	- `bytes|None`
319		`bytes` if `buffer` is `None`, otherwise `None`
320	"""
321	pil_img: Image.Image = Image.fromarray(matrix, mode="RGB")
322	if buffer is None:
323		buffer = io.BytesIO()
324		pil_img.save(buffer, format="PNG")
325		buffer.seek(0)
326		return buffer.read()
327	else:
328		pil_img.save(buffer, format="PNG")
329		return None

Convert a Matrix2Drgb to valid PNG bytes via PIL

  • if buffer is provided, it will write the PNG bytes to the buffer and return None
  • if buffer is not provided, it will return the PNG bytes

Parameters:

  • matrix : Matrix2Drgb
  • buffer : io.BytesIO | None (defaults to None, in which case it will return the PNG bytes)

Returns:

  • bytes|None bytes if buffer is None, otherwise None
def matrix_as_svg( matrix: jaxtyping.Float[ndarray, 'n m'], normalize: bool = False, cmap: str | matplotlib.colors.Colormap = 'viridis', diverging_colormap: bool = False, normalize_min: float | None = None) -> str:
332def matrix_as_svg(
333	matrix: Matrix2D,
334	normalize: bool = MATRIX_SAVE_NORMALIZE,
335	cmap: str | Colormap = MATRIX_SAVE_CMAP,
336	diverging_colormap: bool = False,
337	normalize_min: float | None = None,
338) -> str:
339	"""quickly convert a 2D matrix to an SVG image, without matplotlib
340
341	# Parameters:
342	- `matrix : Float[np.ndarray, 'n m']`
343		a 2D matrix to convert to an SVG image
344	- `normalize : bool`
345		whether to normalize the matrix to range [0, 1]. if it's not in the range [0, 1], this must be `True` or it will raise an `AssertionError`
346		(defaults to `False`)
347	- `cmap : str`
348		the colormap to use for the matrix -- will look up in `matplotlib.colormaps` if it's a string
349		(defaults to `"viridis"`)
350	- `diverging_colormap : bool`
351		if True and using a diverging colormap, ensures 0 values map to the center of the colormap
352		(defaults to False)
353	- `normalize_min : float|None`
354		if a float, then for `normalize=True` and `diverging_colormap=False`, the minimum value to normalize to (generally set this to zero?)
355		if `None`, then the minimum value of the matrix is used
356		if `diverging_colormap=True` OR `normalize=False`, this **must** be `None`
357		(defaults to `None`)
358
359
360	# Returns:
361	- `str`
362		the SVG content for the matrix
363	"""
364	# Get the dimensions of the matrix
365	assert matrix.ndim == 2, f"Matrix must be 2D, got {matrix.shape = }"  # noqa: PLR2004
366	m, n = matrix.shape
367
368	# Preprocess the matrix into an RGB image
369	matrix_rgb: Matrix2Drgb = matrix_to_image_preprocess(
370		matrix,
371		normalize=normalize,
372		cmap=cmap,
373		diverging_colormap=diverging_colormap,
374		normalize_min=normalize_min,
375	)
376
377	# Convert the RGB image to PNG bytes
378	image_data: bytes = matrix2drgb_to_png_bytes(matrix_rgb)
379
380	# Encode the PNG bytes as base64
381	png_base64: str = base64.b64encode(image_data).decode("utf-8")
382
383	# Generate the SVG content
384	svg_content: str = MATRIX_SAVE_SVG_TEMPLATE.format(m=m, n=n, png_base64=png_base64)
385
386	return svg_content

quickly convert a 2D matrix to an SVG image, without matplotlib

Parameters:

  • matrix : Float[np.ndarray, 'n m'] a 2D matrix to convert to an SVG image
  • normalize : bool whether to normalize the matrix to range [0, 1]. if it's not in the range [0, 1], this must be True or it will raise an AssertionError (defaults to False)
  • cmap : str the colormap to use for the matrix -- will look up in matplotlib.colormaps if it's a string (defaults to "viridis")
  • diverging_colormap : bool if True and using a diverging colormap, ensures 0 values map to the center of the colormap (defaults to False)
  • normalize_min : float|None if a float, then for normalize=True and diverging_colormap=False, the minimum value to normalize to (generally set this to zero?) if None, then the minimum value of the matrix is used if diverging_colormap=True OR normalize=False, this must be None (defaults to None)

Returns:

  • str the SVG content for the matrix
def save_matrix_wrapper( func: Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx']], jaxtyping.Float[ndarray, 'n m']] | None = None, *args, fmt: Literal['png', 'svg', 'svgz'] = 'svgz', normalize: bool = False, cmap: str | matplotlib.colors.Colormap = 'viridis', diverging_colormap: bool = False, normalize_min: float | None = None) -> Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], pathlib._local.Path], None] | Callable[[Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx']], jaxtyping.Float[ndarray, 'n m']]], Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], pathlib._local.Path], None]]:
409def save_matrix_wrapper(
410	func: AttentionMatrixToMatrixFunc | None = None,
411	*args,
412	fmt: MatrixSaveFormat = MATRIX_SAVE_FMT,
413	normalize: bool = MATRIX_SAVE_NORMALIZE,
414	cmap: str | Colormap = MATRIX_SAVE_CMAP,
415	diverging_colormap: bool = False,
416	normalize_min: float | None = None,
417) -> (
418	AttentionMatrixFigureFunc
419	| Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc]
420):
421	"""Decorator for functions that process an attention matrix and save it as an SVGZ image.
422
423	Can handle both argumentless usage and with arguments.
424
425	# Parameters:
426
427	- `func : AttentionMatrixToMatrixFunc|None`
428		Either the function to decorate (in the no-arguments case) or `None` when used with arguments.
429	- `fmt : MatrixSaveFormat, keyword-only`
430		The format to save the matrix as. Defaults to `MATRIX_SAVE_FMT`.
431	- `normalize : bool, keyword-only`
432		Whether to normalize the matrix to range [0, 1]. Defaults to `False`.
433	- `cmap : str, keyword-only`
434		The colormap to use for the matrix. Defaults to `MATRIX_SVG_CMAP`.
435	- `diverging_colormap : bool`
436		if True and using a diverging colormap, ensures 0 values map to the center of the colormap
437		(defaults to False)
438	- `normalize_min : float|None`
439		if a float, then for `normalize=True` and `diverging_colormap=False`, the minimum value to normalize to (generally set this to zero?)
440		if `None`, then the minimum value of the matrix is used
441		if `diverging_colormap=True` OR `normalize=False`, this **must** be `None`
442		(defaults to `None`)
443
444	# Returns:
445
446	`AttentionMatrixFigureFunc|Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc]`
447
448	- `AttentionMatrixFigureFunc` if `func` is `AttentionMatrixToMatrixFunc` (no arguments case)
449	- `Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc]` if `func` is `None` -- returns the decorator which will then be applied to the  (with arguments case)
450
451	# Usage:
452
453	```python
454	@save_matrix_wrapper
455	def identity_matrix(matrix):
456		return matrix
457
458	@save_matrix_wrapper(normalize=True, fmt="png")
459	def scale_matrix(matrix):
460		return matrix * 2
461
462	@save_matrix_wrapper(normalize=True, cmap="plasma")
463	def scale_matrix(matrix):
464		return matrix * 2
465	```
466
467	"""
468	assert len(args) == 0, "This decorator only supports keyword arguments"
469
470	assert (
471		fmt in MatrixSaveFormat.__args__  # type: ignore[attr-defined]
472	), f"Invalid format {fmt = }, must be one of {MatrixSaveFormat.__args__}"  # type: ignore[attr-defined]
473
474	def decorator(
475		func: Callable[[AttentionMatrix], Matrix2D],
476	) -> AttentionMatrixFigureFunc:
477		@functools.wraps(func)
478		def wrapped(attn_matrix: AttentionMatrix, save_dir: Path) -> None:
479			fig_path: Path = (
480				save_dir / f"{getattr(func, '__name__', '<unknown>')}.{fmt}"
481			)
482			processed_matrix: Matrix2D = func(attn_matrix)
483
484			if fmt == "png":
485				processed_matrix_rgb: Matrix2Drgb = matrix_to_image_preprocess(
486					processed_matrix,
487					normalize=normalize,
488					cmap=cmap,
489					diverging_colormap=diverging_colormap,
490					normalize_min=normalize_min,
491				)
492				image_data: bytes = matrix2drgb_to_png_bytes(processed_matrix_rgb)
493				fig_path.write_bytes(image_data)
494
495			else:
496				svg_content: str = matrix_as_svg(
497					processed_matrix,
498					normalize=normalize,
499					cmap=cmap,
500					diverging_colormap=diverging_colormap,
501					normalize_min=normalize_min,
502				)
503
504				if fmt == "svgz":
505					with gzip.open(fig_path, "wt") as f:
506						f.write(svg_content)
507
508				else:
509					fig_path.write_text(svg_content, encoding="utf-8")
510
511		wrapped.figure_save_fmt = fmt  # type: ignore[attr-defined]
512
513		return wrapped
514
515	if callable(func):
516		# Handle no-arguments case
517		return decorator(func)
518	else:
519		# Handle arguments case
520		return decorator

Decorator for functions that process an attention matrix and save it as an SVGZ image.

Can handle both argumentless usage and with arguments.

Parameters:

  • func : AttentionMatrixToMatrixFunc|None Either the function to decorate (in the no-arguments case) or None when used with arguments.
  • fmt : MatrixSaveFormat, keyword-only The format to save the matrix as. Defaults to MATRIX_SAVE_FMT.
  • normalize : bool, keyword-only Whether to normalize the matrix to range [0, 1]. Defaults to False.
  • cmap : str, keyword-only The colormap to use for the matrix. Defaults to MATRIX_SVG_CMAP.
  • diverging_colormap : bool if True and using a diverging colormap, ensures 0 values map to the center of the colormap (defaults to False)
  • normalize_min : float|None if a float, then for normalize=True and diverging_colormap=False, the minimum value to normalize to (generally set this to zero?) if None, then the minimum value of the matrix is used if diverging_colormap=True OR normalize=False, this must be None (defaults to None)

Returns:

AttentionMatrixFigureFunc|Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc]

Usage:

@save_matrix_wrapper
def identity_matrix(matrix):
        return matrix

@save_matrix_wrapper(normalize=True, fmt="png")
def scale_matrix(matrix):
        return matrix * 2

@save_matrix_wrapper(normalize=True, cmap="plasma")
def scale_matrix(matrix):
        return matrix * 2