docs for pattern_lens v0.6.0
View Source on GitHub

pattern_lens.attn_figure_funcs

default figure functions

  • If you are making a PR, add your new figure function here.
  • if you are using this as a library, then you can see examples here

note that for pattern_lens.figures to recognize your function, you need to use the register_attn_figure_func decorator which adds your function to ATTENTION_MATRIX_FIGURE_FUNCS


  1"""default figure functions
  2
  3- If you are making a PR, add your new figure function here.
  4- if you are using this as a library, then you can see examples here
  5
  6
  7note that for `pattern_lens.figures` to recognize your function, you need to use the `register_attn_figure_func` decorator
  8which adds your function to `ATTENTION_MATRIX_FIGURE_FUNCS`
  9
 10"""
 11
 12import itertools
 13from collections.abc import Callable, Sequence
 14
 15from pattern_lens.consts import AttentionMatrix
 16from pattern_lens.figure_util import (
 17	AttentionMatrixFigureFunc,
 18	Matrix2D,
 19	save_matrix_wrapper,
 20)
 21
 22_FIGURE_NAMES_KEY: str = "_figure_names"
 23
 24ATTENTION_MATRIX_FIGURE_FUNCS: list[AttentionMatrixFigureFunc] = list()
 25
 26
 27def get_all_figure_names() -> list[str]:
 28	"""get all figure names"""
 29	return list(
 30		itertools.chain.from_iterable(
 31			getattr(
 32				func,
 33				_FIGURE_NAMES_KEY,
 34				[getattr(func, "__name__", "<unknown>")],
 35			)
 36			for func in ATTENTION_MATRIX_FIGURE_FUNCS
 37		),
 38	)
 39
 40
 41def register_attn_figure_func(
 42	func: AttentionMatrixFigureFunc,
 43) -> AttentionMatrixFigureFunc:
 44	"""decorator for registering attention matrix figure function
 45
 46	if you want to add a new figure function, you should use this decorator
 47
 48	# Parameters:
 49	- `func : AttentionMatrixFigureFunc`
 50		your function, which should take an attention matrix and path
 51
 52	# Returns:
 53	- `AttentionMatrixFigureFunc`
 54		your function, after we add it to `ATTENTION_MATRIX_FIGURE_FUNCS`
 55
 56	# Usage:
 57	```python
 58	@register_attn_figure_func
 59	def my_new_figure_func(attn_matrix: AttentionMatrix, path: Path) -> None:
 60		fig, ax = plt.subplots(figsize=(10, 10))
 61		ax.matshow(attn_matrix, cmap="viridis")
 62		ax.set_title("My New Figure Function")
 63		ax.axis("off")
 64		plt.savefig(path / "my_new_figure_func", format="svgz")
 65		plt.close(fig)
 66	```
 67
 68	"""
 69	setattr(func, _FIGURE_NAMES_KEY, (getattr(func, "__name__", "<unknown>"),))
 70	global ATTENTION_MATRIX_FIGURE_FUNCS  # noqa: PLW0602
 71	ATTENTION_MATRIX_FIGURE_FUNCS.append(func)
 72
 73	return func
 74
 75
 76def register_attn_figure_multifunc(
 77	names: Sequence[str],
 78) -> Callable[[AttentionMatrixFigureFunc], AttentionMatrixFigureFunc]:
 79	"decorator which registers a function as a multi-figure function"
 80
 81	def decorator(func: AttentionMatrixFigureFunc) -> AttentionMatrixFigureFunc:
 82		func_name: str = getattr(func, "__name__", "<unknown>")
 83		setattr(
 84			func,
 85			_FIGURE_NAMES_KEY,
 86			tuple([f"{func_name}.{name}" for name in names]),
 87		)
 88		global ATTENTION_MATRIX_FIGURE_FUNCS  # noqa: PLW0602
 89		ATTENTION_MATRIX_FIGURE_FUNCS.append(func)
 90		return func
 91
 92	return decorator
 93
 94
 95@register_attn_figure_func
 96@save_matrix_wrapper(fmt="png", normalize=True, cmap="Blues")
 97def raw(attn_matrix: AttentionMatrix) -> Matrix2D:
 98	"raw attention matrix"
 99	return attn_matrix
100
101
102# some more examples:
103
104# @register_attn_figure_func
105# @matplotlib_figure_saver
106# def raw(attn_matrix: AttentionMatrix, ax: plt.Axes) -> None:
107#     ax.matshow(attn_matrix, cmap="viridis")
108#     ax.set_title("Raw Attention Pattern")
109#     ax.axis("off")
110
111# @register_attn_figure_func
112# @save_matrix_wrapper(fmt="svg")
113# def raw_svg(attn_matrix: AttentionMatrix) -> Matrix2D:
114#     return attn_matrix
115
116# @register_attn_figure_func
117# @save_matrix_wrapper(fmt="svgz")
118# def raw_svgz(attn_matrix: AttentionMatrix) -> Matrix2D:
119#     return attn_matrix

ATTENTION_MATRIX_FIGURE_FUNCS: list[Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], pathlib._local.Path], None]] = [<function raw>]
def get_all_figure_names() -> list[str]:
28def get_all_figure_names() -> list[str]:
29	"""get all figure names"""
30	return list(
31		itertools.chain.from_iterable(
32			getattr(
33				func,
34				_FIGURE_NAMES_KEY,
35				[getattr(func, "__name__", "<unknown>")],
36			)
37			for func in ATTENTION_MATRIX_FIGURE_FUNCS
38		),
39	)

get all figure names

def register_attn_figure_func( func: Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], pathlib._local.Path], None]) -> Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], pathlib._local.Path], None]:
42def register_attn_figure_func(
43	func: AttentionMatrixFigureFunc,
44) -> AttentionMatrixFigureFunc:
45	"""decorator for registering attention matrix figure function
46
47	if you want to add a new figure function, you should use this decorator
48
49	# Parameters:
50	- `func : AttentionMatrixFigureFunc`
51		your function, which should take an attention matrix and path
52
53	# Returns:
54	- `AttentionMatrixFigureFunc`
55		your function, after we add it to `ATTENTION_MATRIX_FIGURE_FUNCS`
56
57	# Usage:
58	```python
59	@register_attn_figure_func
60	def my_new_figure_func(attn_matrix: AttentionMatrix, path: Path) -> None:
61		fig, ax = plt.subplots(figsize=(10, 10))
62		ax.matshow(attn_matrix, cmap="viridis")
63		ax.set_title("My New Figure Function")
64		ax.axis("off")
65		plt.savefig(path / "my_new_figure_func", format="svgz")
66		plt.close(fig)
67	```
68
69	"""
70	setattr(func, _FIGURE_NAMES_KEY, (getattr(func, "__name__", "<unknown>"),))
71	global ATTENTION_MATRIX_FIGURE_FUNCS  # noqa: PLW0602
72	ATTENTION_MATRIX_FIGURE_FUNCS.append(func)
73
74	return func

decorator for registering attention matrix figure function

if you want to add a new figure function, you should use this decorator

Parameters:

  • func : AttentionMatrixFigureFunc your function, which should take an attention matrix and path

Returns:

Usage:

@register_attn_figure_func
def my_new_figure_func(attn_matrix: AttentionMatrix, path: Path) -> None:
        fig, ax = plt.subplots(figsize=(10, 10))
        ax.matshow(attn_matrix, cmap="viridis")
        ax.set_title("My New Figure Function")
        ax.axis("off")
        plt.savefig(path / "my_new_figure_func", format="svgz")
        plt.close(fig)
def register_attn_figure_multifunc( names: Sequence[str]) -> Callable[[Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], pathlib._local.Path], None]], Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], pathlib._local.Path], None]]:
77def register_attn_figure_multifunc(
78	names: Sequence[str],
79) -> Callable[[AttentionMatrixFigureFunc], AttentionMatrixFigureFunc]:
80	"decorator which registers a function as a multi-figure function"
81
82	def decorator(func: AttentionMatrixFigureFunc) -> AttentionMatrixFigureFunc:
83		func_name: str = getattr(func, "__name__", "<unknown>")
84		setattr(
85			func,
86			_FIGURE_NAMES_KEY,
87			tuple([f"{func_name}.{name}" for name in names]),
88		)
89		global ATTENTION_MATRIX_FIGURE_FUNCS  # noqa: PLW0602
90		ATTENTION_MATRIX_FIGURE_FUNCS.append(func)
91		return func
92
93	return decorator

decorator which registers a function as a multi-figure function

@register_attn_figure_func
@save_matrix_wrapper(fmt='png', normalize=True, cmap='Blues')
def raw( attn_matrix: jaxtyping.Float[ndarray, 'n_ctx n_ctx']) -> jaxtyping.Float[ndarray, 'n m']:
 96@register_attn_figure_func
 97@save_matrix_wrapper(fmt="png", normalize=True, cmap="Blues")
 98def raw(attn_matrix: AttentionMatrix) -> Matrix2D:
 99	"raw attention matrix"
100	return attn_matrix

raw attention matrix