docs for pattern_lens v0.4.0
View Source on GitHub

pattern_lens.figure_util

implements a bunch of types, default values, and templates which are useful for figure functions

notably, you can use the decorators matplotlib_figure_saver, save_matrix_wrapper to make your functions save figures


  1"""implements a bunch of types, default values, and templates which are useful for figure functions
  2
  3notably, you can use the decorators `matplotlib_figure_saver`, `save_matrix_wrapper` to make your functions save figures
  4"""
  5
  6import base64
  7import functools
  8import gzip
  9import io
 10from collections.abc import Callable, Sequence
 11from pathlib import Path
 12from typing import Literal, overload
 13
 14import matplotlib as mpl
 15import matplotlib.pyplot as plt
 16import numpy as np
 17from jaxtyping import Float, UInt8
 18from matplotlib.colors import Colormap
 19from PIL import Image
 20
 21from pattern_lens.consts import AttentionMatrix
 22
 23AttentionMatrixFigureFunc = Callable[[AttentionMatrix, Path], None]
 24"Type alias for a function that, given an attention matrix, saves one or more figures"
 25
 26Matrix2D = Float[np.ndarray, "n m"]
 27"Type alias for a 2D matrix (plottable)"
 28
 29Matrix2Drgb = UInt8[np.ndarray, "n m rgb=3"]
 30"Type alias for a 2D matrix with 3 channels (RGB)"
 31
 32AttentionMatrixToMatrixFunc = Callable[[AttentionMatrix], Matrix2D]
 33"Type alias for a function that, given an attention matrix, returns a 2D matrix"
 34
 35MATPLOTLIB_FIGURE_FMT: str = "svgz"
 36"format for saving matplotlib figures"
 37
 38MatrixSaveFormat = Literal["png", "svg", "svgz"]
 39"Type alias for the format to save a matrix as when saving raw matrix, not matplotlib figure"
 40
 41MATRIX_SAVE_NORMALIZE: bool = False
 42"default for whether to normalize the matrix to range [0, 1]"
 43
 44MATRIX_SAVE_CMAP: str = "viridis"
 45"default colormap for saving matrices"
 46
 47MATRIX_SAVE_FMT: MatrixSaveFormat = "svgz"
 48"default format for saving matrices"
 49
 50MATRIX_SAVE_SVG_TEMPLATE: str = """<svg xmlns="http://www.w3.org/2000/svg" width="{m}" height="{n}" viewBox="0 0 {m} {n}" image-rendering="pixelated"> <image href="data:image/png;base64,{png_base64}" width="{m}" height="{n}" /> </svg>"""
 51"template for saving an `n` by `m` matrix as an svg/svgz"
 52
 53
 54# TYPING: mypy hates it when we dont pass func=None or None as the first arg
 55@overload  # without keyword arguments, returns decorated function
 56def matplotlib_figure_saver(
 57	func: Callable[[AttentionMatrix, plt.Axes], None],
 58) -> AttentionMatrixFigureFunc: ...
 59@overload  # with keyword arguments, returns decorator
 60def matplotlib_figure_saver(
 61	func: None = None,
 62	fmt: str = MATPLOTLIB_FIGURE_FMT,
 63) -> Callable[
 64	[Callable[[AttentionMatrix, plt.Axes], None], str],
 65	AttentionMatrixFigureFunc,
 66]: ...
 67def matplotlib_figure_saver(
 68	func: Callable[[AttentionMatrix, plt.Axes], None] | None = None,
 69	fmt: str = MATPLOTLIB_FIGURE_FMT,
 70) -> (
 71	AttentionMatrixFigureFunc
 72	| Callable[
 73		[Callable[[AttentionMatrix, plt.Axes], None], str],
 74		AttentionMatrixFigureFunc,
 75	]
 76):
 77	"""decorator for functions which take an attention matrix and predefined `ax` object, making it save a figure
 78
 79	# Parameters:
 80	- `func : Callable[[AttentionMatrix, plt.Axes], None]`
 81		your function, which should take an attention matrix and predefined `ax` object
 82	- `fmt : str`
 83		format for saving matplotlib figures
 84		(defaults to `MATPLOTLIB_FIGURE_FMT`)
 85
 86	# Returns:
 87	- `AttentionMatrixFigureFunc`
 88		your function, after we wrap it to save a figure
 89
 90	# Usage:
 91	```python
 92	@register_attn_figure_func
 93	@matplotlib_figure_saver
 94	def raw(attn_matrix: AttentionMatrix, ax: plt.Axes) -> None:
 95		ax.matshow(attn_matrix, cmap="viridis")
 96		ax.set_title("Raw Attention Pattern")
 97		ax.axis("off")
 98	```
 99
100	"""
101
102	def decorator(
103		func: Callable[[AttentionMatrix, plt.Axes], None],
104		fmt: str = fmt,
105	) -> AttentionMatrixFigureFunc:
106		@functools.wraps(func)
107		def wrapped(attn_matrix: AttentionMatrix, save_dir: Path) -> None:
108			fig_path: Path = save_dir / f"{func.__name__}.{fmt}"
109
110			fig, ax = plt.subplots(figsize=(10, 10))
111			func(attn_matrix, ax)
112			plt.tight_layout()
113			plt.savefig(fig_path)
114			plt.close(fig)
115
116		wrapped.figure_save_fmt = fmt  # type: ignore[attr-defined]
117
118		return wrapped
119
120	if callable(func):
121		# Handle no-arguments case
122		return decorator(func)
123	else:
124		# Handle arguments case
125		return decorator
126
127
128def matplotlib_multifigure_saver(
129	names: Sequence[str],
130	fmt: str = MATPLOTLIB_FIGURE_FMT,
131) -> Callable[
132	# decorator takes in function
133	# which takes a matrix and a dictionary of axes corresponding to the names
134	[Callable[[AttentionMatrix, dict[str, plt.Axes]], None]],
135	# returns the decorated function
136	AttentionMatrixFigureFunc,
137]:
138	"""decorate a function such that it saves multiple figures, one for each name in `names`
139
140	# Parameters:
141	- `names : Sequence[str]`
142		the names of the figures to save
143	- `fmt : str`
144		format for saving matplotlib figures
145		(defaults to `MATPLOTLIB_FIGURE_FMT`)
146
147	# Returns:
148	- `Callable[[Callable[[AttentionMatrix, dict[str, plt.Axes]], None], AttentionMatrixFigureFunc]`
149		the decorator, which will then be applied to the function
150		we expect the decorated function to take an attention pattern, and a dict of axes corresponding to the names
151
152	"""
153
154	def decorator(
155		func: Callable[[AttentionMatrix, dict[str, plt.Axes]], None],
156	) -> AttentionMatrixFigureFunc:
157		func_name: str = func.__name__
158
159		@functools.wraps(func)
160		def wrapped(attn_matrix: AttentionMatrix, save_dir: Path) -> None:
161			# set up axes and corresponding figures
162			axes_dict: dict[str, plt.Axes] = {}
163			figs_dict: dict[str, plt.Figure] = {}
164
165			# Create all figures and axes
166			for name in names:
167				fig, ax = plt.subplots(figsize=(10, 10))
168				axes_dict[name] = ax
169				figs_dict[name] = fig
170
171			try:
172				# Run the function to make plots
173				func(attn_matrix, axes_dict)
174
175				# Save each figure
176				for name, fig_ in figs_dict.items():
177					fig_path: Path = save_dir / f"{func_name}.{name}.{fmt}"
178					# TYPING: error: Item "SubFigure" of "Figure | SubFigure" has no attribute "tight_layout"  [union-attr]
179					fig_.tight_layout()  # type: ignore[union-attr]
180					# TYPING: error: Item "SubFigure" of "Figure | SubFigure" has no attribute "savefig"  [union-attr]
181					fig_.savefig(fig_path)  # type: ignore[union-attr]
182			finally:
183				# Always clean up figures, even if an error occurred
184				for fig in figs_dict.values():
185					# TYPING: error: Argument 1 to "close" has incompatible type "Figure | SubFigure"; expected "int | str | Figure | Literal['all'] | None"  [arg-type]
186					plt.close(fig)  # type: ignore[arg-type]
187
188		# it doesn't normally have this attribute, but we're adding it
189		wrapped.figure_save_fmt = fmt  # type: ignore[attr-defined]
190
191		return wrapped
192
193	return decorator
194
195
196def matrix_to_image_preprocess(
197	matrix: Matrix2D,
198	normalize: bool = False,
199	cmap: str | Colormap = "viridis",
200	diverging_colormap: bool = False,
201	normalize_min: float | None = None,
202) -> Matrix2Drgb:
203	"""preprocess a 2D matrix into a plottable heatmap image
204
205	# Parameters:
206	- `matrix : Matrix2D`
207		input matrix
208	- `normalize : bool`
209		whether to normalize the matrix to range [0, 1]
210		(defaults to `MATRIX_SAVE_NORMALIZE`)
211	- `cmap : str|Colormap`
212		the colormap to use for the matrix
213		(defaults to `MATRIX_SAVE_CMAP`)
214	- `diverging_colormap : bool`
215		if True and using a diverging colormap, ensures 0 values map to the center of the colormap
216		(defaults to False)
217	- `normalize_min : float|None`
218		if a float, then for `normalize=True` and `diverging_colormap=False`, the minimum value to normalize to (generally set this to zero?).
219		if `None`, then the minimum value of the matrix is used.
220		if `diverging_colormap=True` OR `normalize=False`, this **must** be `None`.
221		(defaults to `None`)
222
223	# Returns:
224	- `Matrix2Drgb`
225	"""
226	# check dims (2 is not that magic of a value here, hence noqa)
227	assert matrix.ndim == 2, f"Matrix must be 2D, got {matrix.ndim = }"  # noqa: PLR2004
228
229	# check matrix is not empty
230	assert matrix.size > 0, "Matrix cannot be empty"
231
232	if normalize_min is not None:
233		assert not diverging_colormap, (
234			"normalize_min cannot be used with diverging_colormap=True"
235		)
236		assert normalize, "normalize_min cannot be used with normalize=False"
237
238	# Normalize the matrix to range [0, 1]
239	normalized_matrix: Matrix2D
240	if normalize:
241		if diverging_colormap:
242			# For diverging colormaps, we want to center around 0
243			max_abs: float = max(abs(matrix.max()), abs(matrix.min()))
244			normalized_matrix = (matrix / (2 * max_abs)) + 0.5
245		else:
246			max_val: float = matrix.max()
247			min_val: float
248			if normalize_min is not None:
249				min_val = normalize_min
250				assert min_val < max_val, "normalize_min must be less than matrix max"
251				assert min_val >= matrix.min(), (
252					"normalize_min must less than matrix min"
253				)
254			else:
255				min_val = matrix.min()
256
257			normalized_matrix = (matrix - min_val) / (max_val - min_val)
258	else:
259		if diverging_colormap:
260			assert matrix.min() >= -1 and matrix.max() <= 1, (  # noqa: PT018
261				"For diverging colormaps without normalization, matrix values must be in range [-1, 1]"
262			)
263			normalized_matrix = matrix
264		else:
265			assert matrix.min() >= 0 and matrix.max() <= 1, (  # noqa: PT018
266				"Matrix values must be in range [0, 1], or normalize must be True"
267			)
268			normalized_matrix = matrix
269
270	# get the colormap
271	cmap_: Colormap
272	if isinstance(cmap, str):
273		cmap_ = mpl.colormaps[cmap]
274	elif isinstance(cmap, Colormap):
275		cmap_ = cmap
276	else:
277		msg = f"Invalid type for {cmap = }, {type(cmap) = }, must be str or Colormap"
278		raise TypeError(
279			msg,
280		)
281
282	# Apply the colormap
283	rgb_matrix: Float[np.ndarray, "n m channels=3"] = (
284		cmap_(normalized_matrix)[:, :, :3] * 255
285	).astype(np.uint8)  # Drop alpha channel
286
287	assert rgb_matrix.shape == (
288		matrix.shape[0],
289		matrix.shape[1],
290		3,
291	), f"Matrix after colormap must have 3 channels, got {rgb_matrix.shape = }"
292
293	return rgb_matrix
294
295
296@overload
297def matrix2drgb_to_png_bytes(matrix: Matrix2Drgb, buffer: None = None) -> bytes: ...
298@overload
299def matrix2drgb_to_png_bytes(matrix: Matrix2Drgb, buffer: io.BytesIO) -> None: ...
300def matrix2drgb_to_png_bytes(
301	matrix: Matrix2Drgb,
302	buffer: io.BytesIO | None = None,
303) -> bytes | None:
304	"""Convert a `Matrix2Drgb` to valid PNG bytes via PIL
305
306	- if `buffer` is provided, it will write the PNG bytes to the buffer and return `None`
307	- if `buffer` is not provided, it will return the PNG bytes
308
309	# Parameters:
310	- `matrix : Matrix2Drgb`
311	- `buffer : io.BytesIO | None`
312		(defaults to `None`, in which case it will return the PNG bytes)
313
314	# Returns:
315	- `bytes|None`
316		`bytes` if `buffer` is `None`, otherwise `None`
317	"""
318	pil_img: Image.Image = Image.fromarray(matrix, mode="RGB")
319	if buffer is None:
320		buffer = io.BytesIO()
321		pil_img.save(buffer, format="PNG")
322		buffer.seek(0)
323		return buffer.read()
324	else:
325		pil_img.save(buffer, format="PNG")
326		return None
327
328
329def matrix_as_svg(
330	matrix: Matrix2D,
331	normalize: bool = MATRIX_SAVE_NORMALIZE,
332	cmap: str | Colormap = MATRIX_SAVE_CMAP,
333	diverging_colormap: bool = False,
334	normalize_min: float | None = None,
335) -> str:
336	"""quickly convert a 2D matrix to an SVG image, without matplotlib
337
338	# Parameters:
339	- `matrix : Float[np.ndarray, 'n m']`
340		a 2D matrix to convert to an SVG image
341	- `normalize : bool`
342		whether to normalize the matrix to range [0, 1]. if it's not in the range [0, 1], this must be `True` or it will raise an `AssertionError`
343		(defaults to `False`)
344	- `cmap : str`
345		the colormap to use for the matrix -- will look up in `matplotlib.colormaps` if it's a string
346		(defaults to `"viridis"`)
347	- `diverging_colormap : bool`
348		if True and using a diverging colormap, ensures 0 values map to the center of the colormap
349		(defaults to False)
350	- `normalize_min : float|None`
351		if a float, then for `normalize=True` and `diverging_colormap=False`, the minimum value to normalize to (generally set this to zero?)
352		if `None`, then the minimum value of the matrix is used
353		if `diverging_colormap=True` OR `normalize=False`, this **must** be `None`
354		(defaults to `None`)
355
356
357	# Returns:
358	- `str`
359		the SVG content for the matrix
360	"""
361	# Get the dimensions of the matrix
362	assert matrix.ndim == 2, f"Matrix must be 2D, got {matrix.shape = }"  # noqa: PLR2004
363	m, n = matrix.shape
364
365	# Preprocess the matrix into an RGB image
366	matrix_rgb: Matrix2Drgb = matrix_to_image_preprocess(
367		matrix,
368		normalize=normalize,
369		cmap=cmap,
370		diverging_colormap=diverging_colormap,
371		normalize_min=normalize_min,
372	)
373
374	# Convert the RGB image to PNG bytes
375	image_data: bytes = matrix2drgb_to_png_bytes(matrix_rgb)
376
377	# Encode the PNG bytes as base64
378	png_base64: str = base64.b64encode(image_data).decode("utf-8")
379
380	# Generate the SVG content
381	svg_content: str = MATRIX_SAVE_SVG_TEMPLATE.format(m=m, n=n, png_base64=png_base64)
382
383	return svg_content
384
385
386@overload  # with keyword arguments, returns decorator
387def save_matrix_wrapper(
388	func: None = None,
389	*args: tuple[()],
390	fmt: MatrixSaveFormat = MATRIX_SAVE_FMT,
391	normalize: bool = MATRIX_SAVE_NORMALIZE,
392	cmap: str | Colormap = MATRIX_SAVE_CMAP,
393	diverging_colormap: bool = False,
394	normalize_min: float | None = None,
395) -> Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc]: ...
396@overload  # without keyword arguments, returns decorated function
397def save_matrix_wrapper(
398	func: AttentionMatrixToMatrixFunc,
399	*args: tuple[()],
400	fmt: MatrixSaveFormat = MATRIX_SAVE_FMT,
401	normalize: bool = MATRIX_SAVE_NORMALIZE,
402	cmap: str | Colormap = MATRIX_SAVE_CMAP,
403	diverging_colormap: bool = False,
404	normalize_min: float | None = None,
405) -> AttentionMatrixFigureFunc: ...
406def save_matrix_wrapper(
407	func: AttentionMatrixToMatrixFunc | None = None,
408	*args,
409	fmt: MatrixSaveFormat = MATRIX_SAVE_FMT,
410	normalize: bool = MATRIX_SAVE_NORMALIZE,
411	cmap: str | Colormap = MATRIX_SAVE_CMAP,
412	diverging_colormap: bool = False,
413	normalize_min: float | None = None,
414) -> (
415	AttentionMatrixFigureFunc
416	| Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc]
417):
418	"""Decorator for functions that process an attention matrix and save it as an SVGZ image.
419
420	Can handle both argumentless usage and with arguments.
421
422	# Parameters:
423
424	- `func : AttentionMatrixToMatrixFunc|None`
425		Either the function to decorate (in the no-arguments case) or `None` when used with arguments.
426	- `fmt : MatrixSaveFormat, keyword-only`
427		The format to save the matrix as. Defaults to `MATRIX_SAVE_FMT`.
428	- `normalize : bool, keyword-only`
429		Whether to normalize the matrix to range [0, 1]. Defaults to `False`.
430	- `cmap : str, keyword-only`
431		The colormap to use for the matrix. Defaults to `MATRIX_SVG_CMAP`.
432	- `diverging_colormap : bool`
433		if True and using a diverging colormap, ensures 0 values map to the center of the colormap
434		(defaults to False)
435	- `normalize_min : float|None`
436		if a float, then for `normalize=True` and `diverging_colormap=False`, the minimum value to normalize to (generally set this to zero?)
437		if `None`, then the minimum value of the matrix is used
438		if `diverging_colormap=True` OR `normalize=False`, this **must** be `None`
439		(defaults to `None`)
440
441	# Returns:
442
443	`AttentionMatrixFigureFunc|Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc]`
444
445	- `AttentionMatrixFigureFunc` if `func` is `AttentionMatrixToMatrixFunc` (no arguments case)
446	- `Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc]` if `func` is `None` -- returns the decorator which will then be applied to the  (with arguments case)
447
448	# Usage:
449
450	```python
451	@save_matrix_wrapper
452	def identity_matrix(matrix):
453		return matrix
454
455	@save_matrix_wrapper(normalize=True, fmt="png")
456	def scale_matrix(matrix):
457		return matrix * 2
458
459	@save_matrix_wrapper(normalize=True, cmap="plasma")
460	def scale_matrix(matrix):
461		return matrix * 2
462	```
463
464	"""
465	assert len(args) == 0, "This decorator only supports keyword arguments"
466
467	assert (
468		fmt in MatrixSaveFormat.__args__  # type: ignore[attr-defined]
469	), f"Invalid format {fmt = }, must be one of {MatrixSaveFormat.__args__}"  # type: ignore[attr-defined]
470
471	def decorator(
472		func: Callable[[AttentionMatrix], Matrix2D],
473	) -> AttentionMatrixFigureFunc:
474		@functools.wraps(func)
475		def wrapped(attn_matrix: AttentionMatrix, save_dir: Path) -> None:
476			fig_path: Path = save_dir / f"{func.__name__}.{fmt}"
477			processed_matrix: Matrix2D = func(attn_matrix)
478
479			if fmt == "png":
480				processed_matrix_rgb: Matrix2Drgb = matrix_to_image_preprocess(
481					processed_matrix,
482					normalize=normalize,
483					cmap=cmap,
484					diverging_colormap=diverging_colormap,
485					normalize_min=normalize_min,
486				)
487				image_data: bytes = matrix2drgb_to_png_bytes(processed_matrix_rgb)
488				fig_path.write_bytes(image_data)
489
490			else:
491				svg_content: str = matrix_as_svg(
492					processed_matrix,
493					normalize=normalize,
494					cmap=cmap,
495					diverging_colormap=diverging_colormap,
496					normalize_min=normalize_min,
497				)
498
499				if fmt == "svgz":
500					with gzip.open(fig_path, "wt") as f:
501						f.write(svg_content)
502
503				else:
504					fig_path.write_text(svg_content, encoding="utf-8")
505
506		wrapped.figure_save_fmt = fmt  # type: ignore[attr-defined]
507
508		return wrapped
509
510	if callable(func):
511		# Handle no-arguments case
512		return decorator(func)
513	else:
514		# Handle arguments case
515		return decorator

AttentionMatrixFigureFunc = collections.abc.Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], pathlib.Path], None]

Type alias for a function that, given an attention matrix, saves one or more figures

Matrix2D = <class 'jaxtyping.Float[ndarray, 'n m']'>

Type alias for a 2D matrix (plottable)

Matrix2Drgb = <class 'jaxtyping.UInt8[ndarray, 'n m rgb=3']'>

Type alias for a 2D matrix with 3 channels (RGB)

AttentionMatrixToMatrixFunc = collections.abc.Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx']], jaxtyping.Float[ndarray, 'n m']]

Type alias for a function that, given an attention matrix, returns a 2D matrix

MATPLOTLIB_FIGURE_FMT: str = 'svgz'

format for saving matplotlib figures

MatrixSaveFormat = typing.Literal['png', 'svg', 'svgz']

Type alias for the format to save a matrix as when saving raw matrix, not matplotlib figure

MATRIX_SAVE_NORMALIZE: bool = False

default for whether to normalize the matrix to range [0, 1]

MATRIX_SAVE_CMAP: str = 'viridis'

default colormap for saving matrices

MATRIX_SAVE_FMT: Literal['png', 'svg', 'svgz'] = 'svgz'

default format for saving matrices

MATRIX_SAVE_SVG_TEMPLATE: str = '<svg xmlns="http://www.w3.org/2000/svg" width="{m}" height="{n}" viewBox="0 0 {m} {n}" image-rendering="pixelated"> <image href="data:image/png;base64,{png_base64}" width="{m}" height="{n}" /> </svg>'

template for saving an n by m matrix as an svg/svgz

def matplotlib_figure_saver( func: Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], matplotlib.axes._axes.Axes], None] | None = None, fmt: str = 'svgz') -> Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], pathlib.Path], None] | Callable[[Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], matplotlib.axes._axes.Axes], None], str], Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], pathlib.Path], None]]:
 68def matplotlib_figure_saver(
 69	func: Callable[[AttentionMatrix, plt.Axes], None] | None = None,
 70	fmt: str = MATPLOTLIB_FIGURE_FMT,
 71) -> (
 72	AttentionMatrixFigureFunc
 73	| Callable[
 74		[Callable[[AttentionMatrix, plt.Axes], None], str],
 75		AttentionMatrixFigureFunc,
 76	]
 77):
 78	"""decorator for functions which take an attention matrix and predefined `ax` object, making it save a figure
 79
 80	# Parameters:
 81	- `func : Callable[[AttentionMatrix, plt.Axes], None]`
 82		your function, which should take an attention matrix and predefined `ax` object
 83	- `fmt : str`
 84		format for saving matplotlib figures
 85		(defaults to `MATPLOTLIB_FIGURE_FMT`)
 86
 87	# Returns:
 88	- `AttentionMatrixFigureFunc`
 89		your function, after we wrap it to save a figure
 90
 91	# Usage:
 92	```python
 93	@register_attn_figure_func
 94	@matplotlib_figure_saver
 95	def raw(attn_matrix: AttentionMatrix, ax: plt.Axes) -> None:
 96		ax.matshow(attn_matrix, cmap="viridis")
 97		ax.set_title("Raw Attention Pattern")
 98		ax.axis("off")
 99	```
100
101	"""
102
103	def decorator(
104		func: Callable[[AttentionMatrix, plt.Axes], None],
105		fmt: str = fmt,
106	) -> AttentionMatrixFigureFunc:
107		@functools.wraps(func)
108		def wrapped(attn_matrix: AttentionMatrix, save_dir: Path) -> None:
109			fig_path: Path = save_dir / f"{func.__name__}.{fmt}"
110
111			fig, ax = plt.subplots(figsize=(10, 10))
112			func(attn_matrix, ax)
113			plt.tight_layout()
114			plt.savefig(fig_path)
115			plt.close(fig)
116
117		wrapped.figure_save_fmt = fmt  # type: ignore[attr-defined]
118
119		return wrapped
120
121	if callable(func):
122		# Handle no-arguments case
123		return decorator(func)
124	else:
125		# Handle arguments case
126		return decorator

decorator for functions which take an attention matrix and predefined ax object, making it save a figure

Parameters:

  • func : Callable[[AttentionMatrix, plt.Axes], None] your function, which should take an attention matrix and predefined ax object
  • fmt : str format for saving matplotlib figures (defaults to MATPLOTLIB_FIGURE_FMT)

Returns:

Usage:

@register_attn_figure_func
@matplotlib_figure_saver
def raw(attn_matrix: AttentionMatrix, ax: plt.Axes) -> None:
        ax.matshow(attn_matrix, cmap="viridis")
        ax.set_title("Raw Attention Pattern")
        ax.axis("off")
def matplotlib_multifigure_saver( names: Sequence[str], fmt: str = 'svgz') -> Callable[[Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], dict[str, matplotlib.axes._axes.Axes]], None]], Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], pathlib.Path], None]]:
129def matplotlib_multifigure_saver(
130	names: Sequence[str],
131	fmt: str = MATPLOTLIB_FIGURE_FMT,
132) -> Callable[
133	# decorator takes in function
134	# which takes a matrix and a dictionary of axes corresponding to the names
135	[Callable[[AttentionMatrix, dict[str, plt.Axes]], None]],
136	# returns the decorated function
137	AttentionMatrixFigureFunc,
138]:
139	"""decorate a function such that it saves multiple figures, one for each name in `names`
140
141	# Parameters:
142	- `names : Sequence[str]`
143		the names of the figures to save
144	- `fmt : str`
145		format for saving matplotlib figures
146		(defaults to `MATPLOTLIB_FIGURE_FMT`)
147
148	# Returns:
149	- `Callable[[Callable[[AttentionMatrix, dict[str, plt.Axes]], None], AttentionMatrixFigureFunc]`
150		the decorator, which will then be applied to the function
151		we expect the decorated function to take an attention pattern, and a dict of axes corresponding to the names
152
153	"""
154
155	def decorator(
156		func: Callable[[AttentionMatrix, dict[str, plt.Axes]], None],
157	) -> AttentionMatrixFigureFunc:
158		func_name: str = func.__name__
159
160		@functools.wraps(func)
161		def wrapped(attn_matrix: AttentionMatrix, save_dir: Path) -> None:
162			# set up axes and corresponding figures
163			axes_dict: dict[str, plt.Axes] = {}
164			figs_dict: dict[str, plt.Figure] = {}
165
166			# Create all figures and axes
167			for name in names:
168				fig, ax = plt.subplots(figsize=(10, 10))
169				axes_dict[name] = ax
170				figs_dict[name] = fig
171
172			try:
173				# Run the function to make plots
174				func(attn_matrix, axes_dict)
175
176				# Save each figure
177				for name, fig_ in figs_dict.items():
178					fig_path: Path = save_dir / f"{func_name}.{name}.{fmt}"
179					# TYPING: error: Item "SubFigure" of "Figure | SubFigure" has no attribute "tight_layout"  [union-attr]
180					fig_.tight_layout()  # type: ignore[union-attr]
181					# TYPING: error: Item "SubFigure" of "Figure | SubFigure" has no attribute "savefig"  [union-attr]
182					fig_.savefig(fig_path)  # type: ignore[union-attr]
183			finally:
184				# Always clean up figures, even if an error occurred
185				for fig in figs_dict.values():
186					# TYPING: error: Argument 1 to "close" has incompatible type "Figure | SubFigure"; expected "int | str | Figure | Literal['all'] | None"  [arg-type]
187					plt.close(fig)  # type: ignore[arg-type]
188
189		# it doesn't normally have this attribute, but we're adding it
190		wrapped.figure_save_fmt = fmt  # type: ignore[attr-defined]
191
192		return wrapped
193
194	return decorator

decorate a function such that it saves multiple figures, one for each name in names

Parameters:

  • names : Sequence[str] the names of the figures to save
  • fmt : str format for saving matplotlib figures (defaults to MATPLOTLIB_FIGURE_FMT)

Returns:

  • Callable[[Callable[[AttentionMatrix, dict[str, plt.Axes]], None], AttentionMatrixFigureFunc] the decorator, which will then be applied to the function we expect the decorated function to take an attention pattern, and a dict of axes corresponding to the names
def matrix_to_image_preprocess( matrix: jaxtyping.Float[ndarray, 'n m'], normalize: bool = False, cmap: str | matplotlib.colors.Colormap = 'viridis', diverging_colormap: bool = False, normalize_min: float | None = None) -> jaxtyping.UInt8[ndarray, 'n m rgb=3']:
197def matrix_to_image_preprocess(
198	matrix: Matrix2D,
199	normalize: bool = False,
200	cmap: str | Colormap = "viridis",
201	diverging_colormap: bool = False,
202	normalize_min: float | None = None,
203) -> Matrix2Drgb:
204	"""preprocess a 2D matrix into a plottable heatmap image
205
206	# Parameters:
207	- `matrix : Matrix2D`
208		input matrix
209	- `normalize : bool`
210		whether to normalize the matrix to range [0, 1]
211		(defaults to `MATRIX_SAVE_NORMALIZE`)
212	- `cmap : str|Colormap`
213		the colormap to use for the matrix
214		(defaults to `MATRIX_SAVE_CMAP`)
215	- `diverging_colormap : bool`
216		if True and using a diverging colormap, ensures 0 values map to the center of the colormap
217		(defaults to False)
218	- `normalize_min : float|None`
219		if a float, then for `normalize=True` and `diverging_colormap=False`, the minimum value to normalize to (generally set this to zero?).
220		if `None`, then the minimum value of the matrix is used.
221		if `diverging_colormap=True` OR `normalize=False`, this **must** be `None`.
222		(defaults to `None`)
223
224	# Returns:
225	- `Matrix2Drgb`
226	"""
227	# check dims (2 is not that magic of a value here, hence noqa)
228	assert matrix.ndim == 2, f"Matrix must be 2D, got {matrix.ndim = }"  # noqa: PLR2004
229
230	# check matrix is not empty
231	assert matrix.size > 0, "Matrix cannot be empty"
232
233	if normalize_min is not None:
234		assert not diverging_colormap, (
235			"normalize_min cannot be used with diverging_colormap=True"
236		)
237		assert normalize, "normalize_min cannot be used with normalize=False"
238
239	# Normalize the matrix to range [0, 1]
240	normalized_matrix: Matrix2D
241	if normalize:
242		if diverging_colormap:
243			# For diverging colormaps, we want to center around 0
244			max_abs: float = max(abs(matrix.max()), abs(matrix.min()))
245			normalized_matrix = (matrix / (2 * max_abs)) + 0.5
246		else:
247			max_val: float = matrix.max()
248			min_val: float
249			if normalize_min is not None:
250				min_val = normalize_min
251				assert min_val < max_val, "normalize_min must be less than matrix max"
252				assert min_val >= matrix.min(), (
253					"normalize_min must less than matrix min"
254				)
255			else:
256				min_val = matrix.min()
257
258			normalized_matrix = (matrix - min_val) / (max_val - min_val)
259	else:
260		if diverging_colormap:
261			assert matrix.min() >= -1 and matrix.max() <= 1, (  # noqa: PT018
262				"For diverging colormaps without normalization, matrix values must be in range [-1, 1]"
263			)
264			normalized_matrix = matrix
265		else:
266			assert matrix.min() >= 0 and matrix.max() <= 1, (  # noqa: PT018
267				"Matrix values must be in range [0, 1], or normalize must be True"
268			)
269			normalized_matrix = matrix
270
271	# get the colormap
272	cmap_: Colormap
273	if isinstance(cmap, str):
274		cmap_ = mpl.colormaps[cmap]
275	elif isinstance(cmap, Colormap):
276		cmap_ = cmap
277	else:
278		msg = f"Invalid type for {cmap = }, {type(cmap) = }, must be str or Colormap"
279		raise TypeError(
280			msg,
281		)
282
283	# Apply the colormap
284	rgb_matrix: Float[np.ndarray, "n m channels=3"] = (
285		cmap_(normalized_matrix)[:, :, :3] * 255
286	).astype(np.uint8)  # Drop alpha channel
287
288	assert rgb_matrix.shape == (
289		matrix.shape[0],
290		matrix.shape[1],
291		3,
292	), f"Matrix after colormap must have 3 channels, got {rgb_matrix.shape = }"
293
294	return rgb_matrix

preprocess a 2D matrix into a plottable heatmap image

Parameters:

  • matrix : Matrix2D input matrix
  • normalize : bool whether to normalize the matrix to range [0, 1] (defaults to MATRIX_SAVE_NORMALIZE)
  • cmap : str|Colormap the colormap to use for the matrix (defaults to MATRIX_SAVE_CMAP)
  • diverging_colormap : bool if True and using a diverging colormap, ensures 0 values map to the center of the colormap (defaults to False)
  • normalize_min : float|None if a float, then for normalize=True and diverging_colormap=False, the minimum value to normalize to (generally set this to zero?). if None, then the minimum value of the matrix is used. if diverging_colormap=True OR normalize=False, this must be None. (defaults to None)

Returns:

def matrix2drgb_to_png_bytes( matrix: jaxtyping.UInt8[ndarray, 'n m rgb=3'], buffer: _io.BytesIO | None = None) -> bytes | None:
301def matrix2drgb_to_png_bytes(
302	matrix: Matrix2Drgb,
303	buffer: io.BytesIO | None = None,
304) -> bytes | None:
305	"""Convert a `Matrix2Drgb` to valid PNG bytes via PIL
306
307	- if `buffer` is provided, it will write the PNG bytes to the buffer and return `None`
308	- if `buffer` is not provided, it will return the PNG bytes
309
310	# Parameters:
311	- `matrix : Matrix2Drgb`
312	- `buffer : io.BytesIO | None`
313		(defaults to `None`, in which case it will return the PNG bytes)
314
315	# Returns:
316	- `bytes|None`
317		`bytes` if `buffer` is `None`, otherwise `None`
318	"""
319	pil_img: Image.Image = Image.fromarray(matrix, mode="RGB")
320	if buffer is None:
321		buffer = io.BytesIO()
322		pil_img.save(buffer, format="PNG")
323		buffer.seek(0)
324		return buffer.read()
325	else:
326		pil_img.save(buffer, format="PNG")
327		return None

Convert a Matrix2Drgb to valid PNG bytes via PIL

  • if buffer is provided, it will write the PNG bytes to the buffer and return None
  • if buffer is not provided, it will return the PNG bytes

Parameters:

  • matrix : Matrix2Drgb
  • buffer : io.BytesIO | None (defaults to None, in which case it will return the PNG bytes)

Returns:

  • bytes|None bytes if buffer is None, otherwise None
def matrix_as_svg( matrix: jaxtyping.Float[ndarray, 'n m'], normalize: bool = False, cmap: str | matplotlib.colors.Colormap = 'viridis', diverging_colormap: bool = False, normalize_min: float | None = None) -> str:
330def matrix_as_svg(
331	matrix: Matrix2D,
332	normalize: bool = MATRIX_SAVE_NORMALIZE,
333	cmap: str | Colormap = MATRIX_SAVE_CMAP,
334	diverging_colormap: bool = False,
335	normalize_min: float | None = None,
336) -> str:
337	"""quickly convert a 2D matrix to an SVG image, without matplotlib
338
339	# Parameters:
340	- `matrix : Float[np.ndarray, 'n m']`
341		a 2D matrix to convert to an SVG image
342	- `normalize : bool`
343		whether to normalize the matrix to range [0, 1]. if it's not in the range [0, 1], this must be `True` or it will raise an `AssertionError`
344		(defaults to `False`)
345	- `cmap : str`
346		the colormap to use for the matrix -- will look up in `matplotlib.colormaps` if it's a string
347		(defaults to `"viridis"`)
348	- `diverging_colormap : bool`
349		if True and using a diverging colormap, ensures 0 values map to the center of the colormap
350		(defaults to False)
351	- `normalize_min : float|None`
352		if a float, then for `normalize=True` and `diverging_colormap=False`, the minimum value to normalize to (generally set this to zero?)
353		if `None`, then the minimum value of the matrix is used
354		if `diverging_colormap=True` OR `normalize=False`, this **must** be `None`
355		(defaults to `None`)
356
357
358	# Returns:
359	- `str`
360		the SVG content for the matrix
361	"""
362	# Get the dimensions of the matrix
363	assert matrix.ndim == 2, f"Matrix must be 2D, got {matrix.shape = }"  # noqa: PLR2004
364	m, n = matrix.shape
365
366	# Preprocess the matrix into an RGB image
367	matrix_rgb: Matrix2Drgb = matrix_to_image_preprocess(
368		matrix,
369		normalize=normalize,
370		cmap=cmap,
371		diverging_colormap=diverging_colormap,
372		normalize_min=normalize_min,
373	)
374
375	# Convert the RGB image to PNG bytes
376	image_data: bytes = matrix2drgb_to_png_bytes(matrix_rgb)
377
378	# Encode the PNG bytes as base64
379	png_base64: str = base64.b64encode(image_data).decode("utf-8")
380
381	# Generate the SVG content
382	svg_content: str = MATRIX_SAVE_SVG_TEMPLATE.format(m=m, n=n, png_base64=png_base64)
383
384	return svg_content

quickly convert a 2D matrix to an SVG image, without matplotlib

Parameters:

  • matrix : Float[np.ndarray, 'n m'] a 2D matrix to convert to an SVG image
  • normalize : bool whether to normalize the matrix to range [0, 1]. if it's not in the range [0, 1], this must be True or it will raise an AssertionError (defaults to False)
  • cmap : str the colormap to use for the matrix -- will look up in matplotlib.colormaps if it's a string (defaults to "viridis")
  • diverging_colormap : bool if True and using a diverging colormap, ensures 0 values map to the center of the colormap (defaults to False)
  • normalize_min : float|None if a float, then for normalize=True and diverging_colormap=False, the minimum value to normalize to (generally set this to zero?) if None, then the minimum value of the matrix is used if diverging_colormap=True OR normalize=False, this must be None (defaults to None)

Returns:

  • str the SVG content for the matrix
def save_matrix_wrapper( func: Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx']], jaxtyping.Float[ndarray, 'n m']] | None = None, *args, fmt: Literal['png', 'svg', 'svgz'] = 'svgz', normalize: bool = False, cmap: str | matplotlib.colors.Colormap = 'viridis', diverging_colormap: bool = False, normalize_min: float | None = None) -> Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], pathlib.Path], None] | Callable[[Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx']], jaxtyping.Float[ndarray, 'n m']]], Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], pathlib.Path], None]]:
407def save_matrix_wrapper(
408	func: AttentionMatrixToMatrixFunc | None = None,
409	*args,
410	fmt: MatrixSaveFormat = MATRIX_SAVE_FMT,
411	normalize: bool = MATRIX_SAVE_NORMALIZE,
412	cmap: str | Colormap = MATRIX_SAVE_CMAP,
413	diverging_colormap: bool = False,
414	normalize_min: float | None = None,
415) -> (
416	AttentionMatrixFigureFunc
417	| Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc]
418):
419	"""Decorator for functions that process an attention matrix and save it as an SVGZ image.
420
421	Can handle both argumentless usage and with arguments.
422
423	# Parameters:
424
425	- `func : AttentionMatrixToMatrixFunc|None`
426		Either the function to decorate (in the no-arguments case) or `None` when used with arguments.
427	- `fmt : MatrixSaveFormat, keyword-only`
428		The format to save the matrix as. Defaults to `MATRIX_SAVE_FMT`.
429	- `normalize : bool, keyword-only`
430		Whether to normalize the matrix to range [0, 1]. Defaults to `False`.
431	- `cmap : str, keyword-only`
432		The colormap to use for the matrix. Defaults to `MATRIX_SVG_CMAP`.
433	- `diverging_colormap : bool`
434		if True and using a diverging colormap, ensures 0 values map to the center of the colormap
435		(defaults to False)
436	- `normalize_min : float|None`
437		if a float, then for `normalize=True` and `diverging_colormap=False`, the minimum value to normalize to (generally set this to zero?)
438		if `None`, then the minimum value of the matrix is used
439		if `diverging_colormap=True` OR `normalize=False`, this **must** be `None`
440		(defaults to `None`)
441
442	# Returns:
443
444	`AttentionMatrixFigureFunc|Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc]`
445
446	- `AttentionMatrixFigureFunc` if `func` is `AttentionMatrixToMatrixFunc` (no arguments case)
447	- `Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc]` if `func` is `None` -- returns the decorator which will then be applied to the  (with arguments case)
448
449	# Usage:
450
451	```python
452	@save_matrix_wrapper
453	def identity_matrix(matrix):
454		return matrix
455
456	@save_matrix_wrapper(normalize=True, fmt="png")
457	def scale_matrix(matrix):
458		return matrix * 2
459
460	@save_matrix_wrapper(normalize=True, cmap="plasma")
461	def scale_matrix(matrix):
462		return matrix * 2
463	```
464
465	"""
466	assert len(args) == 0, "This decorator only supports keyword arguments"
467
468	assert (
469		fmt in MatrixSaveFormat.__args__  # type: ignore[attr-defined]
470	), f"Invalid format {fmt = }, must be one of {MatrixSaveFormat.__args__}"  # type: ignore[attr-defined]
471
472	def decorator(
473		func: Callable[[AttentionMatrix], Matrix2D],
474	) -> AttentionMatrixFigureFunc:
475		@functools.wraps(func)
476		def wrapped(attn_matrix: AttentionMatrix, save_dir: Path) -> None:
477			fig_path: Path = save_dir / f"{func.__name__}.{fmt}"
478			processed_matrix: Matrix2D = func(attn_matrix)
479
480			if fmt == "png":
481				processed_matrix_rgb: Matrix2Drgb = matrix_to_image_preprocess(
482					processed_matrix,
483					normalize=normalize,
484					cmap=cmap,
485					diverging_colormap=diverging_colormap,
486					normalize_min=normalize_min,
487				)
488				image_data: bytes = matrix2drgb_to_png_bytes(processed_matrix_rgb)
489				fig_path.write_bytes(image_data)
490
491			else:
492				svg_content: str = matrix_as_svg(
493					processed_matrix,
494					normalize=normalize,
495					cmap=cmap,
496					diverging_colormap=diverging_colormap,
497					normalize_min=normalize_min,
498				)
499
500				if fmt == "svgz":
501					with gzip.open(fig_path, "wt") as f:
502						f.write(svg_content)
503
504				else:
505					fig_path.write_text(svg_content, encoding="utf-8")
506
507		wrapped.figure_save_fmt = fmt  # type: ignore[attr-defined]
508
509		return wrapped
510
511	if callable(func):
512		# Handle no-arguments case
513		return decorator(func)
514	else:
515		# Handle arguments case
516		return decorator

Decorator for functions that process an attention matrix and save it as an SVGZ image.

Can handle both argumentless usage and with arguments.

Parameters:

  • func : AttentionMatrixToMatrixFunc|None Either the function to decorate (in the no-arguments case) or None when used with arguments.
  • fmt : MatrixSaveFormat, keyword-only The format to save the matrix as. Defaults to MATRIX_SAVE_FMT.
  • normalize : bool, keyword-only Whether to normalize the matrix to range [0, 1]. Defaults to False.
  • cmap : str, keyword-only The colormap to use for the matrix. Defaults to MATRIX_SVG_CMAP.
  • diverging_colormap : bool if True and using a diverging colormap, ensures 0 values map to the center of the colormap (defaults to False)
  • normalize_min : float|None if a float, then for normalize=True and diverging_colormap=False, the minimum value to normalize to (generally set this to zero?) if None, then the minimum value of the matrix is used if diverging_colormap=True OR normalize=False, this must be None (defaults to None)

Returns:

AttentionMatrixFigureFunc|Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc]

Usage:

@save_matrix_wrapper
def identity_matrix(matrix):
        return matrix

@save_matrix_wrapper(normalize=True, fmt="png")
def scale_matrix(matrix):
        return matrix * 2

@save_matrix_wrapper(normalize=True, cmap="plasma")
def scale_matrix(matrix):
        return matrix * 2