Coverage for pattern_lens\attn_figure_funcs.py: 100%

10 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-01-16 20:39 -0700

1"""default figure functions 

2 

3- If you are making a PR, add your new figure function here. 

4- if you are using this as a library, then you can see examples here 

5 

6 

7note that for `pattern_lens.figures` to recognize your function, you need to use the `register_attn_figure_func` decorator 

8which adds your function to `ATTENTION_MATRIX_FIGURE_FUNCS` 

9 

10""" 

11 

12 

13from pattern_lens.consts import AttentionMatrix 

14from pattern_lens.figure_util import ( 

15 AttentionMatrixFigureFunc, 

16 save_matrix_wrapper, 

17 Matrix2D, 

18) 

19 

20 

21ATTENTION_MATRIX_FIGURE_FUNCS: list[AttentionMatrixFigureFunc] = list() 

22 

23 

24def register_attn_figure_func( 

25 func: AttentionMatrixFigureFunc, 

26) -> AttentionMatrixFigureFunc: 

27 """decorator for registering attention matrix figure function 

28 

29 if you want to add a new figure function, you should use this decorator 

30 

31 # Parameters: 

32 - `func : AttentionMatrixFigureFunc` 

33 your function, which should take an attention matrix and path 

34 

35 # Returns: 

36 - `AttentionMatrixFigureFunc` 

37 your function, after we add it to `ATTENTION_MATRIX_FIGURE_FUNCS` 

38 

39 # Usage: 

40 ```python 

41 @register_attn_figure_func 

42 def my_new_figure_func(attn_matrix: AttentionMatrix, path: Path) -> None: 

43 fig, ax = plt.subplots(figsize=(10, 10)) 

44 ax.matshow(attn_matrix, cmap="viridis") 

45 ax.set_title("My New Figure Function") 

46 ax.axis("off") 

47 plt.savefig(path / "my_new_figure_func", format="svgz") 

48 plt.close(fig) 

49 ``` 

50 

51 """ 

52 global ATTENTION_MATRIX_FIGURE_FUNCS 

53 

54 ATTENTION_MATRIX_FIGURE_FUNCS.append(func) 

55 

56 return func 

57 

58 

59# def register_attn_figure_multifunc( 

60# names: list[str], 

61# ) -> Callable[[AttentionMatrixFigureFunc], AttentionMatrixFigureFunc]: 

62 

63# def decorator(func: AttentionMatrixFigureFunc) -> AttentionMatrixFigureFunc: 

64 

65# @functools.wraps(func) 

66# def wrapper(*args, **kwargs): 

67# return func(*args, **kwargs) 

68 

69# for name in names: 

70# setattr(wrapper, name, True) 

71 

72# return register_attn_figure_func(wrapper) 

73 

74 

75@register_attn_figure_func 

76@save_matrix_wrapper(fmt="png") 

77def raw(attn_matrix: AttentionMatrix) -> Matrix2D: 

78 "raw attention matrix" 

79 return attn_matrix 

80 

81 

82# some more examples: 

83 

84# @register_attn_figure_func 

85# @matplotlib_figure_saver 

86# def raw(attn_matrix: AttentionMatrix, ax: plt.Axes) -> None: 

87# ax.matshow(attn_matrix, cmap="viridis") 

88# ax.set_title("Raw Attention Pattern") 

89# ax.axis("off") 

90 

91# @register_attn_figure_func 

92# @save_matrix_wrapper(fmt="svg") 

93# def raw_svg(attn_matrix: AttentionMatrix) -> Matrix2D: 

94# return attn_matrix 

95 

96# @register_attn_figure_func 

97# @save_matrix_wrapper(fmt="svgz") 

98# def raw_svgz(attn_matrix: AttentionMatrix) -> Matrix2D: 

99# return attn_matrix