pattern_lens.attn_figure_funcs
default figure functions
- If you are making a PR, add your new figure function here.
- if you are using this as a library, then you can see examples here
note that for pattern_lens.figures to recognize your function, you need to use the register_attn_figure_func decorator
which adds your function to ATTENTION_MATRIX_FIGURE_FUNCS
1"""default figure functions 2 3- If you are making a PR, add your new figure function here. 4- if you are using this as a library, then you can see examples here 5 6 7note that for `pattern_lens.figures` to recognize your function, you need to use the `register_attn_figure_func` decorator 8which adds your function to `ATTENTION_MATRIX_FIGURE_FUNCS` 9 10""" 11 12 13from pattern_lens.consts import AttentionMatrix 14from pattern_lens.figure_util import ( 15 AttentionMatrixFigureFunc, 16 save_matrix_wrapper, 17 Matrix2D, 18) 19 20 21ATTENTION_MATRIX_FIGURE_FUNCS: list[AttentionMatrixFigureFunc] = list() 22 23 24def register_attn_figure_func( 25 func: AttentionMatrixFigureFunc, 26) -> AttentionMatrixFigureFunc: 27 """decorator for registering attention matrix figure function 28 29 if you want to add a new figure function, you should use this decorator 30 31 # Parameters: 32 - `func : AttentionMatrixFigureFunc` 33 your function, which should take an attention matrix and path 34 35 # Returns: 36 - `AttentionMatrixFigureFunc` 37 your function, after we add it to `ATTENTION_MATRIX_FIGURE_FUNCS` 38 39 # Usage: 40 ```python 41 @register_attn_figure_func 42 def my_new_figure_func(attn_matrix: AttentionMatrix, path: Path) -> None: 43 fig, ax = plt.subplots(figsize=(10, 10)) 44 ax.matshow(attn_matrix, cmap="viridis") 45 ax.set_title("My New Figure Function") 46 ax.axis("off") 47 plt.savefig(path / "my_new_figure_func", format="svgz") 48 plt.close(fig) 49 ``` 50 51 """ 52 global ATTENTION_MATRIX_FIGURE_FUNCS 53 54 ATTENTION_MATRIX_FIGURE_FUNCS.append(func) 55 56 return func 57 58 59# def register_attn_figure_multifunc( 60# names: list[str], 61# ) -> Callable[[AttentionMatrixFigureFunc], AttentionMatrixFigureFunc]: 62 63# def decorator(func: AttentionMatrixFigureFunc) -> AttentionMatrixFigureFunc: 64 65# @functools.wraps(func) 66# def wrapper(*args, **kwargs): 67# return func(*args, **kwargs) 68 69# for name in names: 70# setattr(wrapper, name, True) 71 72# return register_attn_figure_func(wrapper) 73 74 75@register_attn_figure_func 76@save_matrix_wrapper(fmt="png") 77def raw(attn_matrix: AttentionMatrix) -> Matrix2D: 78 "raw attention matrix" 79 return attn_matrix 80 81 82# some more examples: 83 84# @register_attn_figure_func 85# @matplotlib_figure_saver 86# def raw(attn_matrix: AttentionMatrix, ax: plt.Axes) -> None: 87# ax.matshow(attn_matrix, cmap="viridis") 88# ax.set_title("Raw Attention Pattern") 89# ax.axis("off") 90 91# @register_attn_figure_func 92# @save_matrix_wrapper(fmt="svg") 93# def raw_svg(attn_matrix: AttentionMatrix) -> Matrix2D: 94# return attn_matrix 95 96# @register_attn_figure_func 97# @save_matrix_wrapper(fmt="svgz") 98# def raw_svgz(attn_matrix: AttentionMatrix) -> Matrix2D: 99# return attn_matrix
ATTENTION_MATRIX_FIGURE_FUNCS: list[typing.Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], pathlib.Path], NoneType]] =
[<function raw>]
def
register_attn_figure_func( func: Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], pathlib.Path], NoneType]) -> Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], pathlib.Path], NoneType]:
25def register_attn_figure_func( 26 func: AttentionMatrixFigureFunc, 27) -> AttentionMatrixFigureFunc: 28 """decorator for registering attention matrix figure function 29 30 if you want to add a new figure function, you should use this decorator 31 32 # Parameters: 33 - `func : AttentionMatrixFigureFunc` 34 your function, which should take an attention matrix and path 35 36 # Returns: 37 - `AttentionMatrixFigureFunc` 38 your function, after we add it to `ATTENTION_MATRIX_FIGURE_FUNCS` 39 40 # Usage: 41 ```python 42 @register_attn_figure_func 43 def my_new_figure_func(attn_matrix: AttentionMatrix, path: Path) -> None: 44 fig, ax = plt.subplots(figsize=(10, 10)) 45 ax.matshow(attn_matrix, cmap="viridis") 46 ax.set_title("My New Figure Function") 47 ax.axis("off") 48 plt.savefig(path / "my_new_figure_func", format="svgz") 49 plt.close(fig) 50 ``` 51 52 """ 53 global ATTENTION_MATRIX_FIGURE_FUNCS 54 55 ATTENTION_MATRIX_FIGURE_FUNCS.append(func) 56 57 return func
decorator for registering attention matrix figure function
if you want to add a new figure function, you should use this decorator
# Parameters:
- `func : AttentionMatrixFigureFunc`
your function, which should take an attention matrix and path
# Returns:
- `AttentionMatrixFigureFunc`
your function, after we add it to `ATTENTION_MATRIX_FIGURE_FUNCS`
Usage:
@register_attn_figure_func
def my_new_figure_func(attn_matrix: AttentionMatrix, path: Path) -> None:
fig, ax = plt.subplots(figsize=(10, 10))
ax.matshow(attn_matrix, cmap="viridis")
ax.set_title("My New Figure Function")
ax.axis("off")
plt.savefig(path / "my_new_figure_func", format="svgz")
plt.close(fig)
@register_attn_figure_func
@save_matrix_wrapper(fmt='png')
def
raw( attn_matrix: jaxtyping.Float[ndarray, 'n_ctx n_ctx']) -> jaxtyping.Float[ndarray, 'n m']:
76@register_attn_figure_func 77@save_matrix_wrapper(fmt="png") 78def raw(attn_matrix: AttentionMatrix) -> Matrix2D: 79 "raw attention matrix" 80 return attn_matrix
raw attention matrix