Coverage for pattern_lens/attn_figure_funcs.py: 73%
22 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-04-06 15:09 -0600
« prev ^ index » next coverage.py v7.6.12, created at 2025-04-06 15:09 -0600
1"""default figure functions
3- If you are making a PR, add your new figure function here.
4- if you are using this as a library, then you can see examples here
7note that for `pattern_lens.figures` to recognize your function, you need to use the `register_attn_figure_func` decorator
8which adds your function to `ATTENTION_MATRIX_FIGURE_FUNCS`
10"""
12import itertools
13from collections.abc import Callable, Sequence
15from pattern_lens.consts import AttentionMatrix
16from pattern_lens.figure_util import (
17 AttentionMatrixFigureFunc,
18 Matrix2D,
19 save_matrix_wrapper,
20)
22_FIGURE_NAMES_KEY: str = "_figure_names"
24ATTENTION_MATRIX_FIGURE_FUNCS: list[AttentionMatrixFigureFunc] = list()
27def get_all_figure_names() -> list[str]:
28 """get all figure names"""
29 return list(
30 itertools.chain.from_iterable(
31 getattr(
32 func,
33 _FIGURE_NAMES_KEY,
34 [func.__name__],
35 )
36 for func in ATTENTION_MATRIX_FIGURE_FUNCS
37 ),
38 )
41def register_attn_figure_func(
42 func: AttentionMatrixFigureFunc,
43) -> AttentionMatrixFigureFunc:
44 """decorator for registering attention matrix figure function
46 if you want to add a new figure function, you should use this decorator
48 # Parameters:
49 - `func : AttentionMatrixFigureFunc`
50 your function, which should take an attention matrix and path
52 # Returns:
53 - `AttentionMatrixFigureFunc`
54 your function, after we add it to `ATTENTION_MATRIX_FIGURE_FUNCS`
56 # Usage:
57 ```python
58 @register_attn_figure_func
59 def my_new_figure_func(attn_matrix: AttentionMatrix, path: Path) -> None:
60 fig, ax = plt.subplots(figsize=(10, 10))
61 ax.matshow(attn_matrix, cmap="viridis")
62 ax.set_title("My New Figure Function")
63 ax.axis("off")
64 plt.savefig(path / "my_new_figure_func", format="svgz")
65 plt.close(fig)
66 ```
68 """
69 setattr(func, _FIGURE_NAMES_KEY, (func.__name__,))
70 global ATTENTION_MATRIX_FIGURE_FUNCS # noqa: PLW0602
71 ATTENTION_MATRIX_FIGURE_FUNCS.append(func)
73 return func
76def register_attn_figure_multifunc(
77 names: Sequence[str],
78) -> Callable[[AttentionMatrixFigureFunc], AttentionMatrixFigureFunc]:
79 "decorator which registers a function as a multi-figure function"
81 def decorator(func: AttentionMatrixFigureFunc) -> AttentionMatrixFigureFunc:
82 setattr(
83 func,
84 _FIGURE_NAMES_KEY,
85 tuple([f"{func.__name__}.{name}" for name in names]),
86 )
87 global ATTENTION_MATRIX_FIGURE_FUNCS # noqa: PLW0602
88 ATTENTION_MATRIX_FIGURE_FUNCS.append(func)
89 return func
91 return decorator
94@register_attn_figure_func
95@save_matrix_wrapper(fmt="png")
96def raw(attn_matrix: AttentionMatrix) -> Matrix2D:
97 "raw attention matrix"
98 return attn_matrix
101# some more examples:
103# @register_attn_figure_func
104# @matplotlib_figure_saver
105# def raw(attn_matrix: AttentionMatrix, ax: plt.Axes) -> None:
106# ax.matshow(attn_matrix, cmap="viridis")
107# ax.set_title("Raw Attention Pattern")
108# ax.axis("off")
110# @register_attn_figure_func
111# @save_matrix_wrapper(fmt="svg")
112# def raw_svg(attn_matrix: AttentionMatrix) -> Matrix2D:
113# return attn_matrix
115# @register_attn_figure_func
116# @save_matrix_wrapper(fmt="svgz")
117# def raw_svgz(attn_matrix: AttentionMatrix) -> Matrix2D:
118# return attn_matrix