Coverage for pattern_lens/attn_figure_funcs.py: 73%

22 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-04-06 15:09 -0600

1"""default figure functions 

2 

3- If you are making a PR, add your new figure function here. 

4- if you are using this as a library, then you can see examples here 

5 

6 

7note that for `pattern_lens.figures` to recognize your function, you need to use the `register_attn_figure_func` decorator 

8which adds your function to `ATTENTION_MATRIX_FIGURE_FUNCS` 

9 

10""" 

11 

12import itertools 

13from collections.abc import Callable, Sequence 

14 

15from pattern_lens.consts import AttentionMatrix 

16from pattern_lens.figure_util import ( 

17 AttentionMatrixFigureFunc, 

18 Matrix2D, 

19 save_matrix_wrapper, 

20) 

21 

22_FIGURE_NAMES_KEY: str = "_figure_names" 

23 

24ATTENTION_MATRIX_FIGURE_FUNCS: list[AttentionMatrixFigureFunc] = list() 

25 

26 

27def get_all_figure_names() -> list[str]: 

28 """get all figure names""" 

29 return list( 

30 itertools.chain.from_iterable( 

31 getattr( 

32 func, 

33 _FIGURE_NAMES_KEY, 

34 [func.__name__], 

35 ) 

36 for func in ATTENTION_MATRIX_FIGURE_FUNCS 

37 ), 

38 ) 

39 

40 

41def register_attn_figure_func( 

42 func: AttentionMatrixFigureFunc, 

43) -> AttentionMatrixFigureFunc: 

44 """decorator for registering attention matrix figure function 

45 

46 if you want to add a new figure function, you should use this decorator 

47 

48 # Parameters: 

49 - `func : AttentionMatrixFigureFunc` 

50 your function, which should take an attention matrix and path 

51 

52 # Returns: 

53 - `AttentionMatrixFigureFunc` 

54 your function, after we add it to `ATTENTION_MATRIX_FIGURE_FUNCS` 

55 

56 # Usage: 

57 ```python 

58 @register_attn_figure_func 

59 def my_new_figure_func(attn_matrix: AttentionMatrix, path: Path) -> None: 

60 fig, ax = plt.subplots(figsize=(10, 10)) 

61 ax.matshow(attn_matrix, cmap="viridis") 

62 ax.set_title("My New Figure Function") 

63 ax.axis("off") 

64 plt.savefig(path / "my_new_figure_func", format="svgz") 

65 plt.close(fig) 

66 ``` 

67 

68 """ 

69 setattr(func, _FIGURE_NAMES_KEY, (func.__name__,)) 

70 global ATTENTION_MATRIX_FIGURE_FUNCS # noqa: PLW0602 

71 ATTENTION_MATRIX_FIGURE_FUNCS.append(func) 

72 

73 return func 

74 

75 

76def register_attn_figure_multifunc( 

77 names: Sequence[str], 

78) -> Callable[[AttentionMatrixFigureFunc], AttentionMatrixFigureFunc]: 

79 "decorator which registers a function as a multi-figure function" 

80 

81 def decorator(func: AttentionMatrixFigureFunc) -> AttentionMatrixFigureFunc: 

82 setattr( 

83 func, 

84 _FIGURE_NAMES_KEY, 

85 tuple([f"{func.__name__}.{name}" for name in names]), 

86 ) 

87 global ATTENTION_MATRIX_FIGURE_FUNCS # noqa: PLW0602 

88 ATTENTION_MATRIX_FIGURE_FUNCS.append(func) 

89 return func 

90 

91 return decorator 

92 

93 

94@register_attn_figure_func 

95@save_matrix_wrapper(fmt="png") 

96def raw(attn_matrix: AttentionMatrix) -> Matrix2D: 

97 "raw attention matrix" 

98 return attn_matrix 

99 

100 

101# some more examples: 

102 

103# @register_attn_figure_func 

104# @matplotlib_figure_saver 

105# def raw(attn_matrix: AttentionMatrix, ax: plt.Axes) -> None: 

106# ax.matshow(attn_matrix, cmap="viridis") 

107# ax.set_title("Raw Attention Pattern") 

108# ax.axis("off") 

109 

110# @register_attn_figure_func 

111# @save_matrix_wrapper(fmt="svg") 

112# def raw_svg(attn_matrix: AttentionMatrix) -> Matrix2D: 

113# return attn_matrix 

114 

115# @register_attn_figure_func 

116# @save_matrix_wrapper(fmt="svgz") 

117# def raw_svgz(attn_matrix: AttentionMatrix) -> Matrix2D: 

118# return attn_matrix