pattern_lens.attn_figure_funcs
default figure functions
- If you are making a PR, add your new figure function here.
- if you are using this as a library, then you can see examples here
note that for pattern_lens.figures to recognize your function, you need to use the register_attn_figure_func decorator
which adds your function to ATTENTION_MATRIX_FIGURE_FUNCS
1"""default figure functions 2 3- If you are making a PR, add your new figure function here. 4- if you are using this as a library, then you can see examples here 5 6 7note that for `pattern_lens.figures` to recognize your function, you need to use the `register_attn_figure_func` decorator 8which adds your function to `ATTENTION_MATRIX_FIGURE_FUNCS` 9 10""" 11 12import itertools 13from collections.abc import Callable, Sequence 14 15from pattern_lens.consts import AttentionMatrix 16from pattern_lens.figure_util import ( 17 AttentionMatrixFigureFunc, 18 Matrix2D, 19 save_matrix_wrapper, 20) 21 22_FIGURE_NAMES_KEY: str = "_figure_names" 23 24ATTENTION_MATRIX_FIGURE_FUNCS: list[AttentionMatrixFigureFunc] = list() 25 26 27def get_all_figure_names() -> list[str]: 28 """get all figure names""" 29 return list( 30 itertools.chain.from_iterable( 31 getattr( 32 func, 33 _FIGURE_NAMES_KEY, 34 [func.__name__], 35 ) 36 for func in ATTENTION_MATRIX_FIGURE_FUNCS 37 ), 38 ) 39 40 41def register_attn_figure_func( 42 func: AttentionMatrixFigureFunc, 43) -> AttentionMatrixFigureFunc: 44 """decorator for registering attention matrix figure function 45 46 if you want to add a new figure function, you should use this decorator 47 48 # Parameters: 49 - `func : AttentionMatrixFigureFunc` 50 your function, which should take an attention matrix and path 51 52 # Returns: 53 - `AttentionMatrixFigureFunc` 54 your function, after we add it to `ATTENTION_MATRIX_FIGURE_FUNCS` 55 56 # Usage: 57 ```python 58 @register_attn_figure_func 59 def my_new_figure_func(attn_matrix: AttentionMatrix, path: Path) -> None: 60 fig, ax = plt.subplots(figsize=(10, 10)) 61 ax.matshow(attn_matrix, cmap="viridis") 62 ax.set_title("My New Figure Function") 63 ax.axis("off") 64 plt.savefig(path / "my_new_figure_func", format="svgz") 65 plt.close(fig) 66 ``` 67 68 """ 69 setattr(func, _FIGURE_NAMES_KEY, (func.__name__,)) 70 global ATTENTION_MATRIX_FIGURE_FUNCS # noqa: PLW0602 71 ATTENTION_MATRIX_FIGURE_FUNCS.append(func) 72 73 return func 74 75 76def register_attn_figure_multifunc( 77 names: Sequence[str], 78) -> Callable[[AttentionMatrixFigureFunc], AttentionMatrixFigureFunc]: 79 "decorator which registers a function as a multi-figure function" 80 81 def decorator(func: AttentionMatrixFigureFunc) -> AttentionMatrixFigureFunc: 82 setattr( 83 func, 84 _FIGURE_NAMES_KEY, 85 tuple([f"{func.__name__}.{name}" for name in names]), 86 ) 87 global ATTENTION_MATRIX_FIGURE_FUNCS # noqa: PLW0602 88 ATTENTION_MATRIX_FIGURE_FUNCS.append(func) 89 return func 90 91 return decorator 92 93 94@register_attn_figure_func 95@save_matrix_wrapper(fmt="png") 96def raw(attn_matrix: AttentionMatrix) -> Matrix2D: 97 "raw attention matrix" 98 return attn_matrix 99 100 101# some more examples: 102 103# @register_attn_figure_func 104# @matplotlib_figure_saver 105# def raw(attn_matrix: AttentionMatrix, ax: plt.Axes) -> None: 106# ax.matshow(attn_matrix, cmap="viridis") 107# ax.set_title("Raw Attention Pattern") 108# ax.axis("off") 109 110# @register_attn_figure_func 111# @save_matrix_wrapper(fmt="svg") 112# def raw_svg(attn_matrix: AttentionMatrix) -> Matrix2D: 113# return attn_matrix 114 115# @register_attn_figure_func 116# @save_matrix_wrapper(fmt="svgz") 117# def raw_svgz(attn_matrix: AttentionMatrix) -> Matrix2D: 118# return attn_matrix
ATTENTION_MATRIX_FIGURE_FUNCS: list[Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], pathlib.Path], None]] =
[<function raw>]
def
get_all_figure_names() -> list[str]:
28def get_all_figure_names() -> list[str]: 29 """get all figure names""" 30 return list( 31 itertools.chain.from_iterable( 32 getattr( 33 func, 34 _FIGURE_NAMES_KEY, 35 [func.__name__], 36 ) 37 for func in ATTENTION_MATRIX_FIGURE_FUNCS 38 ), 39 )
get all figure names
def
register_attn_figure_func( func: Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], pathlib.Path], None]) -> Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], pathlib.Path], None]:
42def register_attn_figure_func( 43 func: AttentionMatrixFigureFunc, 44) -> AttentionMatrixFigureFunc: 45 """decorator for registering attention matrix figure function 46 47 if you want to add a new figure function, you should use this decorator 48 49 # Parameters: 50 - `func : AttentionMatrixFigureFunc` 51 your function, which should take an attention matrix and path 52 53 # Returns: 54 - `AttentionMatrixFigureFunc` 55 your function, after we add it to `ATTENTION_MATRIX_FIGURE_FUNCS` 56 57 # Usage: 58 ```python 59 @register_attn_figure_func 60 def my_new_figure_func(attn_matrix: AttentionMatrix, path: Path) -> None: 61 fig, ax = plt.subplots(figsize=(10, 10)) 62 ax.matshow(attn_matrix, cmap="viridis") 63 ax.set_title("My New Figure Function") 64 ax.axis("off") 65 plt.savefig(path / "my_new_figure_func", format="svgz") 66 plt.close(fig) 67 ``` 68 69 """ 70 setattr(func, _FIGURE_NAMES_KEY, (func.__name__,)) 71 global ATTENTION_MATRIX_FIGURE_FUNCS # noqa: PLW0602 72 ATTENTION_MATRIX_FIGURE_FUNCS.append(func) 73 74 return func
decorator for registering attention matrix figure function
if you want to add a new figure function, you should use this decorator
Parameters:
func : AttentionMatrixFigureFuncyour function, which should take an attention matrix and path
Returns:
AttentionMatrixFigureFuncyour function, after we add it toATTENTION_MATRIX_FIGURE_FUNCS
Usage:
@register_attn_figure_func
def my_new_figure_func(attn_matrix: AttentionMatrix, path: Path) -> None:
fig, ax = plt.subplots(figsize=(10, 10))
ax.matshow(attn_matrix, cmap="viridis")
ax.set_title("My New Figure Function")
ax.axis("off")
plt.savefig(path / "my_new_figure_func", format="svgz")
plt.close(fig)
def
register_attn_figure_multifunc( names: Sequence[str]) -> Callable[[Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], pathlib.Path], None]], Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], pathlib.Path], None]]:
77def register_attn_figure_multifunc( 78 names: Sequence[str], 79) -> Callable[[AttentionMatrixFigureFunc], AttentionMatrixFigureFunc]: 80 "decorator which registers a function as a multi-figure function" 81 82 def decorator(func: AttentionMatrixFigureFunc) -> AttentionMatrixFigureFunc: 83 setattr( 84 func, 85 _FIGURE_NAMES_KEY, 86 tuple([f"{func.__name__}.{name}" for name in names]), 87 ) 88 global ATTENTION_MATRIX_FIGURE_FUNCS # noqa: PLW0602 89 ATTENTION_MATRIX_FIGURE_FUNCS.append(func) 90 return func 91 92 return decorator
decorator which registers a function as a multi-figure function
@register_attn_figure_func
@save_matrix_wrapper(fmt='png')
def
raw( attn_matrix: jaxtyping.Float[ndarray, 'n_ctx n_ctx']) -> jaxtyping.Float[ndarray, 'n m']:
95@register_attn_figure_func 96@save_matrix_wrapper(fmt="png") 97def raw(attn_matrix: AttentionMatrix) -> Matrix2D: 98 "raw attention matrix" 99 return attn_matrix
raw attention matrix