docs for pattern_lens v0.2.0
View Source on GitHub

pattern_lens.attn_figure_funcs

default figure functions

  • If you are making a PR, add your new figure function here.
  • if you are using this as a library, then you can see examples here

note that for pattern_lens.figures to recognize your function, you need to use the register_attn_figure_func decorator which adds your function to ATTENTION_MATRIX_FIGURE_FUNCS


 1"""default figure functions
 2
 3- If you are making a PR, add your new figure function here.
 4- if you are using this as a library, then you can see examples here
 5
 6
 7note that for `pattern_lens.figures` to recognize your function, you need to use the `register_attn_figure_func` decorator
 8which adds your function to `ATTENTION_MATRIX_FIGURE_FUNCS`
 9
10"""
11
12
13from pattern_lens.consts import AttentionMatrix
14from pattern_lens.figure_util import (
15    AttentionMatrixFigureFunc,
16    save_matrix_wrapper,
17    Matrix2D,
18)
19
20
21ATTENTION_MATRIX_FIGURE_FUNCS: list[AttentionMatrixFigureFunc] = list()
22
23
24def register_attn_figure_func(
25    func: AttentionMatrixFigureFunc,
26) -> AttentionMatrixFigureFunc:
27    """decorator for registering attention matrix figure function
28
29    if you want to add a new figure function, you should use this decorator
30
31        # Parameters:
32         - `func : AttentionMatrixFigureFunc`
33           your function, which should take an attention matrix and path
34
35        # Returns:
36         - `AttentionMatrixFigureFunc`
37           your function, after we add it to `ATTENTION_MATRIX_FIGURE_FUNCS`
38
39    # Usage:
40    ```python
41    @register_attn_figure_func
42    def my_new_figure_func(attn_matrix: AttentionMatrix, path: Path) -> None:
43        fig, ax = plt.subplots(figsize=(10, 10))
44        ax.matshow(attn_matrix, cmap="viridis")
45        ax.set_title("My New Figure Function")
46        ax.axis("off")
47        plt.savefig(path / "my_new_figure_func", format="svgz")
48        plt.close(fig)
49    ```
50
51    """
52    global ATTENTION_MATRIX_FIGURE_FUNCS
53
54    ATTENTION_MATRIX_FIGURE_FUNCS.append(func)
55
56    return func
57
58
59# def register_attn_figure_multifunc(
60#     names: list[str],
61# ) -> Callable[[AttentionMatrixFigureFunc], AttentionMatrixFigureFunc]:
62
63#     def decorator(func: AttentionMatrixFigureFunc) -> AttentionMatrixFigureFunc:
64
65#         @functools.wraps(func)
66#         def wrapper(*args, **kwargs):
67#             return func(*args, **kwargs)
68
69#         for name in names:
70#             setattr(wrapper, name, True)
71
72#         return register_attn_figure_func(wrapper)
73
74
75@register_attn_figure_func
76@save_matrix_wrapper(fmt="png")
77def raw(attn_matrix: AttentionMatrix) -> Matrix2D:
78    "raw attention matrix"
79    return attn_matrix
80
81
82# some more examples:
83
84# @register_attn_figure_func
85# @matplotlib_figure_saver
86# def raw(attn_matrix: AttentionMatrix, ax: plt.Axes) -> None:
87#     ax.matshow(attn_matrix, cmap="viridis")
88#     ax.set_title("Raw Attention Pattern")
89#     ax.axis("off")
90
91# @register_attn_figure_func
92# @save_matrix_wrapper(fmt="svg")
93# def raw_svg(attn_matrix: AttentionMatrix) -> Matrix2D:
94#     return attn_matrix
95
96# @register_attn_figure_func
97# @save_matrix_wrapper(fmt="svgz")
98# def raw_svgz(attn_matrix: AttentionMatrix) -> Matrix2D:
99#     return attn_matrix

ATTENTION_MATRIX_FIGURE_FUNCS: list[typing.Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], pathlib.Path], NoneType]] = [<function raw>]
def register_attn_figure_func( func: Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], pathlib.Path], NoneType]) -> Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], pathlib.Path], NoneType]:
25def register_attn_figure_func(
26    func: AttentionMatrixFigureFunc,
27) -> AttentionMatrixFigureFunc:
28    """decorator for registering attention matrix figure function
29
30    if you want to add a new figure function, you should use this decorator
31
32        # Parameters:
33         - `func : AttentionMatrixFigureFunc`
34           your function, which should take an attention matrix and path
35
36        # Returns:
37         - `AttentionMatrixFigureFunc`
38           your function, after we add it to `ATTENTION_MATRIX_FIGURE_FUNCS`
39
40    # Usage:
41    ```python
42    @register_attn_figure_func
43    def my_new_figure_func(attn_matrix: AttentionMatrix, path: Path) -> None:
44        fig, ax = plt.subplots(figsize=(10, 10))
45        ax.matshow(attn_matrix, cmap="viridis")
46        ax.set_title("My New Figure Function")
47        ax.axis("off")
48        plt.savefig(path / "my_new_figure_func", format="svgz")
49        plt.close(fig)
50    ```
51
52    """
53    global ATTENTION_MATRIX_FIGURE_FUNCS
54
55    ATTENTION_MATRIX_FIGURE_FUNCS.append(func)
56
57    return func

decorator for registering attention matrix figure function

if you want to add a new figure function, you should use this decorator

# Parameters:
 - `func : AttentionMatrixFigureFunc`
   your function, which should take an attention matrix and path

# Returns:
 - `AttentionMatrixFigureFunc`
   your function, after we add it to `ATTENTION_MATRIX_FIGURE_FUNCS`

Usage:

@register_attn_figure_func
def my_new_figure_func(attn_matrix: AttentionMatrix, path: Path) -> None:
    fig, ax = plt.subplots(figsize=(10, 10))
    ax.matshow(attn_matrix, cmap="viridis")
    ax.set_title("My New Figure Function")
    ax.axis("off")
    plt.savefig(path / "my_new_figure_func", format="svgz")
    plt.close(fig)
@register_attn_figure_func
@save_matrix_wrapper(fmt='png')
def raw( attn_matrix: jaxtyping.Float[ndarray, 'n_ctx n_ctx']) -> jaxtyping.Float[ndarray, 'n m']:
76@register_attn_figure_func
77@save_matrix_wrapper(fmt="png")
78def raw(attn_matrix: AttentionMatrix) -> Matrix2D:
79    "raw attention matrix"
80    return attn_matrix

raw attention matrix