Coverage for pattern_lens / attn_figure_funcs.py: 70%

23 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-22 18:15 -0700

1"""default figure functions 

2 

3- If you are making a PR, add your new figure function here. 

4- if you are using this as a library, then you can see examples here 

5 

6 

7note that for `pattern_lens.figures` to recognize your function, you need to use the `register_attn_figure_func` decorator 

8which adds your function to `ATTENTION_MATRIX_FIGURE_FUNCS` 

9 

10""" 

11 

12import itertools 

13from collections.abc import Callable, Sequence 

14 

15from pattern_lens.consts import AttentionMatrix 

16from pattern_lens.figure_util import ( 

17 AttentionMatrixFigureFunc, 

18 Matrix2D, 

19 save_matrix_wrapper, 

20) 

21 

22_FIGURE_NAMES_KEY: str = "_figure_names" 

23 

24ATTENTION_MATRIX_FIGURE_FUNCS: list[AttentionMatrixFigureFunc] = list() 

25 

26 

27def get_all_figure_names() -> list[str]: 

28 """get all figure names""" 

29 return list( 

30 itertools.chain.from_iterable( 

31 getattr( 

32 func, 

33 _FIGURE_NAMES_KEY, 

34 [getattr(func, "__name__", "<unknown>")], 

35 ) 

36 for func in ATTENTION_MATRIX_FIGURE_FUNCS 

37 ), 

38 ) 

39 

40 

41def register_attn_figure_func( 

42 func: AttentionMatrixFigureFunc, 

43) -> AttentionMatrixFigureFunc: 

44 """decorator for registering attention matrix figure function 

45 

46 if you want to add a new figure function, you should use this decorator 

47 

48 # Parameters: 

49 - `func : AttentionMatrixFigureFunc` 

50 your function, which should take an attention matrix and path 

51 

52 # Returns: 

53 - `AttentionMatrixFigureFunc` 

54 your function, after we add it to `ATTENTION_MATRIX_FIGURE_FUNCS` 

55 

56 # Usage: 

57 ```python 

58 @register_attn_figure_func 

59 def my_new_figure_func(attn_matrix: AttentionMatrix, path: Path) -> None: 

60 fig, ax = plt.subplots(figsize=(10, 10)) 

61 ax.matshow(attn_matrix, cmap="viridis") 

62 ax.set_title("My New Figure Function") 

63 ax.axis("off") 

64 plt.savefig(path / "my_new_figure_func", format="svgz") 

65 plt.close(fig) 

66 ``` 

67 

68 """ 

69 setattr(func, _FIGURE_NAMES_KEY, (getattr(func, "__name__", "<unknown>"),)) 

70 global ATTENTION_MATRIX_FIGURE_FUNCS # noqa: PLW0602 

71 ATTENTION_MATRIX_FIGURE_FUNCS.append(func) 

72 

73 return func 

74 

75 

76def register_attn_figure_multifunc( 

77 names: Sequence[str], 

78) -> Callable[[AttentionMatrixFigureFunc], AttentionMatrixFigureFunc]: 

79 "decorator which registers a function as a multi-figure function" 

80 

81 def decorator(func: AttentionMatrixFigureFunc) -> AttentionMatrixFigureFunc: 

82 func_name: str = getattr(func, "__name__", "<unknown>") 

83 setattr( 

84 func, 

85 _FIGURE_NAMES_KEY, 

86 tuple([f"{func_name}.{name}" for name in names]), 

87 ) 

88 global ATTENTION_MATRIX_FIGURE_FUNCS # noqa: PLW0602 

89 ATTENTION_MATRIX_FIGURE_FUNCS.append(func) 

90 return func 

91 

92 return decorator 

93 

94 

95@register_attn_figure_func 

96@save_matrix_wrapper(fmt="png", normalize=True, cmap="Blues") 

97def raw(attn_matrix: AttentionMatrix) -> Matrix2D: 

98 "raw attention matrix" 

99 return attn_matrix 

100 

101 

102# some more examples: 

103 

104# @register_attn_figure_func 

105# @matplotlib_figure_saver 

106# def raw(attn_matrix: AttentionMatrix, ax: plt.Axes) -> None: 

107# ax.matshow(attn_matrix, cmap="viridis") 

108# ax.set_title("Raw Attention Pattern") 

109# ax.axis("off") 

110 

111# @register_attn_figure_func 

112# @save_matrix_wrapper(fmt="svg") 

113# def raw_svg(attn_matrix: AttentionMatrix) -> Matrix2D: 

114# return attn_matrix 

115 

116# @register_attn_figure_func 

117# @save_matrix_wrapper(fmt="svgz") 

118# def raw_svgz(attn_matrix: AttentionMatrix) -> Matrix2D: 

119# return attn_matrix