Coverage for pattern_lens\attn_figure_funcs.py: 100%
10 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-16 20:39 -0700
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-16 20:39 -0700
1"""default figure functions
3- If you are making a PR, add your new figure function here.
4- if you are using this as a library, then you can see examples here
7note that for `pattern_lens.figures` to recognize your function, you need to use the `register_attn_figure_func` decorator
8which adds your function to `ATTENTION_MATRIX_FIGURE_FUNCS`
10"""
13from pattern_lens.consts import AttentionMatrix
14from pattern_lens.figure_util import (
15 AttentionMatrixFigureFunc,
16 save_matrix_wrapper,
17 Matrix2D,
18)
21ATTENTION_MATRIX_FIGURE_FUNCS: list[AttentionMatrixFigureFunc] = list()
24def register_attn_figure_func(
25 func: AttentionMatrixFigureFunc,
26) -> AttentionMatrixFigureFunc:
27 """decorator for registering attention matrix figure function
29 if you want to add a new figure function, you should use this decorator
31 # Parameters:
32 - `func : AttentionMatrixFigureFunc`
33 your function, which should take an attention matrix and path
35 # Returns:
36 - `AttentionMatrixFigureFunc`
37 your function, after we add it to `ATTENTION_MATRIX_FIGURE_FUNCS`
39 # Usage:
40 ```python
41 @register_attn_figure_func
42 def my_new_figure_func(attn_matrix: AttentionMatrix, path: Path) -> None:
43 fig, ax = plt.subplots(figsize=(10, 10))
44 ax.matshow(attn_matrix, cmap="viridis")
45 ax.set_title("My New Figure Function")
46 ax.axis("off")
47 plt.savefig(path / "my_new_figure_func", format="svgz")
48 plt.close(fig)
49 ```
51 """
52 global ATTENTION_MATRIX_FIGURE_FUNCS
54 ATTENTION_MATRIX_FIGURE_FUNCS.append(func)
56 return func
59# def register_attn_figure_multifunc(
60# names: list[str],
61# ) -> Callable[[AttentionMatrixFigureFunc], AttentionMatrixFigureFunc]:
63# def decorator(func: AttentionMatrixFigureFunc) -> AttentionMatrixFigureFunc:
65# @functools.wraps(func)
66# def wrapper(*args, **kwargs):
67# return func(*args, **kwargs)
69# for name in names:
70# setattr(wrapper, name, True)
72# return register_attn_figure_func(wrapper)
75@register_attn_figure_func
76@save_matrix_wrapper(fmt="png")
77def raw(attn_matrix: AttentionMatrix) -> Matrix2D:
78 "raw attention matrix"
79 return attn_matrix
82# some more examples:
84# @register_attn_figure_func
85# @matplotlib_figure_saver
86# def raw(attn_matrix: AttentionMatrix, ax: plt.Axes) -> None:
87# ax.matshow(attn_matrix, cmap="viridis")
88# ax.set_title("Raw Attention Pattern")
89# ax.axis("off")
91# @register_attn_figure_func
92# @save_matrix_wrapper(fmt="svg")
93# def raw_svg(attn_matrix: AttentionMatrix) -> Matrix2D:
94# return attn_matrix
96# @register_attn_figure_func
97# @save_matrix_wrapper(fmt="svgz")
98# def raw_svgz(attn_matrix: AttentionMatrix) -> Matrix2D:
99# return attn_matrix