pattern_lens.attn_figure_funcs
default figure functions
- If you are making a PR, add your new figure function here.
- if you are using this as a library, then you can see examples here
note that for pattern_lens.figures to recognize your function, you need to use the register_attn_figure_func decorator
which adds your function to ATTENTION_MATRIX_FIGURE_FUNCS
1"""default figure functions 2 3- If you are making a PR, add your new figure function here. 4- if you are using this as a library, then you can see examples here 5 6 7note that for `pattern_lens.figures` to recognize your function, you need to use the `register_attn_figure_func` decorator 8which adds your function to `ATTENTION_MATRIX_FIGURE_FUNCS` 9 10""" 11 12import itertools 13from collections.abc import Callable, Sequence 14 15from pattern_lens.consts import AttentionMatrix 16from pattern_lens.figure_util import ( 17 AttentionMatrixFigureFunc, 18 Matrix2D, 19 save_matrix_wrapper, 20) 21 22_FIGURE_NAMES_KEY: str = "_figure_names" 23 24ATTENTION_MATRIX_FIGURE_FUNCS: list[AttentionMatrixFigureFunc] = list() 25 26 27def get_all_figure_names() -> list[str]: 28 """get all figure names""" 29 return list( 30 itertools.chain.from_iterable( 31 getattr( 32 func, 33 _FIGURE_NAMES_KEY, 34 [getattr(func, "__name__", "<unknown>")], 35 ) 36 for func in ATTENTION_MATRIX_FIGURE_FUNCS 37 ), 38 ) 39 40 41def register_attn_figure_func( 42 func: AttentionMatrixFigureFunc, 43) -> AttentionMatrixFigureFunc: 44 """decorator for registering attention matrix figure function 45 46 if you want to add a new figure function, you should use this decorator 47 48 # Parameters: 49 - `func : AttentionMatrixFigureFunc` 50 your function, which should take an attention matrix and path 51 52 # Returns: 53 - `AttentionMatrixFigureFunc` 54 your function, after we add it to `ATTENTION_MATRIX_FIGURE_FUNCS` 55 56 # Usage: 57 ```python 58 @register_attn_figure_func 59 def my_new_figure_func(attn_matrix: AttentionMatrix, path: Path) -> None: 60 fig, ax = plt.subplots(figsize=(10, 10)) 61 ax.matshow(attn_matrix, cmap="viridis") 62 ax.set_title("My New Figure Function") 63 ax.axis("off") 64 plt.savefig(path / "my_new_figure_func", format="svgz") 65 plt.close(fig) 66 ``` 67 68 """ 69 setattr(func, _FIGURE_NAMES_KEY, (getattr(func, "__name__", "<unknown>"),)) 70 global ATTENTION_MATRIX_FIGURE_FUNCS # noqa: PLW0602 71 ATTENTION_MATRIX_FIGURE_FUNCS.append(func) 72 73 return func 74 75 76def register_attn_figure_multifunc( 77 names: Sequence[str], 78) -> Callable[[AttentionMatrixFigureFunc], AttentionMatrixFigureFunc]: 79 "decorator which registers a function as a multi-figure function" 80 81 def decorator(func: AttentionMatrixFigureFunc) -> AttentionMatrixFigureFunc: 82 func_name: str = getattr(func, "__name__", "<unknown>") 83 setattr( 84 func, 85 _FIGURE_NAMES_KEY, 86 tuple([f"{func_name}.{name}" for name in names]), 87 ) 88 global ATTENTION_MATRIX_FIGURE_FUNCS # noqa: PLW0602 89 ATTENTION_MATRIX_FIGURE_FUNCS.append(func) 90 return func 91 92 return decorator 93 94 95@register_attn_figure_func 96@save_matrix_wrapper(fmt="png", normalize=True, cmap="Blues") 97def raw(attn_matrix: AttentionMatrix) -> Matrix2D: 98 "raw attention matrix" 99 return attn_matrix 100 101 102# some more examples: 103 104# @register_attn_figure_func 105# @matplotlib_figure_saver 106# def raw(attn_matrix: AttentionMatrix, ax: plt.Axes) -> None: 107# ax.matshow(attn_matrix, cmap="viridis") 108# ax.set_title("Raw Attention Pattern") 109# ax.axis("off") 110 111# @register_attn_figure_func 112# @save_matrix_wrapper(fmt="svg") 113# def raw_svg(attn_matrix: AttentionMatrix) -> Matrix2D: 114# return attn_matrix 115 116# @register_attn_figure_func 117# @save_matrix_wrapper(fmt="svgz") 118# def raw_svgz(attn_matrix: AttentionMatrix) -> Matrix2D: 119# return attn_matrix
ATTENTION_MATRIX_FIGURE_FUNCS: list[Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], pathlib._local.Path], None]] =
[<function raw>]
def
get_all_figure_names() -> list[str]:
28def get_all_figure_names() -> list[str]: 29 """get all figure names""" 30 return list( 31 itertools.chain.from_iterable( 32 getattr( 33 func, 34 _FIGURE_NAMES_KEY, 35 [getattr(func, "__name__", "<unknown>")], 36 ) 37 for func in ATTENTION_MATRIX_FIGURE_FUNCS 38 ), 39 )
get all figure names
def
register_attn_figure_func( func: Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], pathlib._local.Path], None]) -> Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], pathlib._local.Path], None]:
42def register_attn_figure_func( 43 func: AttentionMatrixFigureFunc, 44) -> AttentionMatrixFigureFunc: 45 """decorator for registering attention matrix figure function 46 47 if you want to add a new figure function, you should use this decorator 48 49 # Parameters: 50 - `func : AttentionMatrixFigureFunc` 51 your function, which should take an attention matrix and path 52 53 # Returns: 54 - `AttentionMatrixFigureFunc` 55 your function, after we add it to `ATTENTION_MATRIX_FIGURE_FUNCS` 56 57 # Usage: 58 ```python 59 @register_attn_figure_func 60 def my_new_figure_func(attn_matrix: AttentionMatrix, path: Path) -> None: 61 fig, ax = plt.subplots(figsize=(10, 10)) 62 ax.matshow(attn_matrix, cmap="viridis") 63 ax.set_title("My New Figure Function") 64 ax.axis("off") 65 plt.savefig(path / "my_new_figure_func", format="svgz") 66 plt.close(fig) 67 ``` 68 69 """ 70 setattr(func, _FIGURE_NAMES_KEY, (getattr(func, "__name__", "<unknown>"),)) 71 global ATTENTION_MATRIX_FIGURE_FUNCS # noqa: PLW0602 72 ATTENTION_MATRIX_FIGURE_FUNCS.append(func) 73 74 return func
decorator for registering attention matrix figure function
if you want to add a new figure function, you should use this decorator
Parameters:
func : AttentionMatrixFigureFuncyour function, which should take an attention matrix and path
Returns:
AttentionMatrixFigureFuncyour function, after we add it toATTENTION_MATRIX_FIGURE_FUNCS
Usage:
@register_attn_figure_func
def my_new_figure_func(attn_matrix: AttentionMatrix, path: Path) -> None:
fig, ax = plt.subplots(figsize=(10, 10))
ax.matshow(attn_matrix, cmap="viridis")
ax.set_title("My New Figure Function")
ax.axis("off")
plt.savefig(path / "my_new_figure_func", format="svgz")
plt.close(fig)
def
register_attn_figure_multifunc( names: Sequence[str]) -> Callable[[Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], pathlib._local.Path], None]], Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], pathlib._local.Path], None]]:
77def register_attn_figure_multifunc( 78 names: Sequence[str], 79) -> Callable[[AttentionMatrixFigureFunc], AttentionMatrixFigureFunc]: 80 "decorator which registers a function as a multi-figure function" 81 82 def decorator(func: AttentionMatrixFigureFunc) -> AttentionMatrixFigureFunc: 83 func_name: str = getattr(func, "__name__", "<unknown>") 84 setattr( 85 func, 86 _FIGURE_NAMES_KEY, 87 tuple([f"{func_name}.{name}" for name in names]), 88 ) 89 global ATTENTION_MATRIX_FIGURE_FUNCS # noqa: PLW0602 90 ATTENTION_MATRIX_FIGURE_FUNCS.append(func) 91 return func 92 93 return decorator
decorator which registers a function as a multi-figure function
@register_attn_figure_func
@save_matrix_wrapper(fmt='png', normalize=True, cmap='Blues')
def
raw( attn_matrix: jaxtyping.Float[ndarray, 'n_ctx n_ctx']) -> jaxtyping.Float[ndarray, 'n m']:
96@register_attn_figure_func 97@save_matrix_wrapper(fmt="png", normalize=True, cmap="Blues") 98def raw(attn_matrix: AttentionMatrix) -> Matrix2D: 99 "raw attention matrix" 100 return attn_matrix
raw attention matrix