Coverage for tests/unit/test_figures.py: 98%

54 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-04-06 15:09 -0600

1# tests/unit/test_figures.py 

2from pathlib import Path 

3from unittest import mock 

4 

5import numpy as np 

6 

7from pattern_lens.attn_figure_funcs import ATTENTION_MATRIX_FIGURE_FUNCS 

8from pattern_lens.figure_util import save_matrix_wrapper 

9from pattern_lens.figures import compute_and_save_figures, process_single_head 

10 

11TEMP_DIR: Path = Path("tests/_temp") 

12 

13 

14def test_process_single_head(): 

15 """Test processing a single head's attention pattern.""" 

16 # Setup 

17 temp_dir = TEMP_DIR / "test_process_single_head" 

18 head_dir = temp_dir / "L0" / "H0" 

19 head_dir.mkdir(parents=True) 

20 

21 # Create a simple test attention matrix 

22 attn_pattern = np.random.rand(10, 10).astype(np.float32) 

23 

24 # Create a simple test figure function 

25 @save_matrix_wrapper(fmt="svg") 

26 def test_fig_func(matrix): 

27 return matrix 

28 

29 # Patch to use our test function 

30 result = process_single_head( 

31 layer_idx=0, 

32 head_idx=0, 

33 attn_pattern=attn_pattern, 

34 save_dir=head_dir, 

35 figure_funcs=[*ATTENTION_MATRIX_FIGURE_FUNCS, test_fig_func], 

36 force_overwrite=True, 

37 ) 

38 

39 # Check that our function was called and succeeded 

40 assert "test_fig_func" in result 

41 assert result["test_fig_func"] is True 

42 

43 # Check that the figure file was created 

44 assert (head_dir / "test_fig_func.svg").exists() 

45 

46 

47def test_process_single_head_error_handling(): 

48 """Test that process_single_head properly handles errors in figure functions.""" 

49 # Setup 

50 temp_dir = TEMP_DIR / "test_process_single_head_error_handling" 

51 head_dir = temp_dir / "L0" / "H0" 

52 head_dir.mkdir(parents=True) 

53 

54 # Create a simple test attention matrix 

55 attn_pattern = np.random.rand(10, 10).astype(np.float32) 

56 

57 # Create a test figure function that raises an exception 

58 def error_fig_func(matrix, save_dir): # noqa: ARG001 

59 raise ValueError("Test error") 

60 

61 # Patch to use our test function 

62 result = process_single_head( 

63 layer_idx=0, 

64 head_idx=0, 

65 attn_pattern=attn_pattern, 

66 save_dir=head_dir, 

67 figure_funcs=[error_fig_func], 

68 force_overwrite=True, 

69 ) 

70 

71 # Check that our function was called and failed 

72 assert "error_fig_func" in result 

73 assert isinstance(result["error_fig_func"], ValueError) 

74 

75 # Check that an error file was created 

76 assert (head_dir / "error_fig_func.error.txt").exists() 

77 

78 

79def test_compute_and_save_figures(): 

80 """Test compute_and_save_figures with a mock model config and cache.""" 

81 # Setup 

82 temp_dir = TEMP_DIR / "test_compute_and_save_figures" 

83 prompt_dir = temp_dir / "test-model" / "prompts" / "test-hash" 

84 prompt_dir.mkdir(parents=True) 

85 

86 # Create a mock model config 

87 class MockConfig: 

88 def __init__(self): 

89 self.n_layers = 2 

90 self.n_heads = 2 

91 self.model_name = "test-model" 

92 

93 # Create a simple test attention matrices dict 

94 cache_dict = { 

95 "blocks.0.attn.hook_pattern": np.random.rand(1, 2, 5, 5).astype(np.float32), 

96 "blocks.1.attn.hook_pattern": np.random.rand(1, 2, 5, 5).astype(np.float32), 

97 } 

98 

99 # Create a simple test figure function 

100 @save_matrix_wrapper(fmt="png") 

101 def test_fig_func(matrix): 

102 return matrix 

103 

104 # Patch ATTENTION_MATRIX_FIGURE_FUNCS and process_single_head 

105 with ( 

106 mock.patch( 

107 "pattern_lens.figures.ATTENTION_MATRIX_FIGURE_FUNCS", 

108 [test_fig_func], 

109 ), 

110 mock.patch("pattern_lens.figures.process_single_head") as mock_process_head, 

111 mock.patch("pattern_lens.figures.generate_prompts_jsonl") as mock_gen_prompts, 

112 ): 

113 mock_process_head.return_value = {"test_fig_func": True} 

114 

115 # Call compute_and_save_figures with dict cache 

116 compute_and_save_figures( 

117 model_cfg=MockConfig(), 

118 activations_path=prompt_dir / "activations.npz", 

119 cache=cache_dict, 

120 figure_funcs=ATTENTION_MATRIX_FIGURE_FUNCS, 

121 save_path=temp_dir, 

122 force_overwrite=True, 

123 ) 

124 

125 # Check that process_single_head was called for each layer and head 

126 assert mock_process_head.call_count == 4 # 2 layers * 2 heads 

127 

128 # Check that generate_prompts_jsonl was called 

129 mock_gen_prompts.assert_called_once_with(temp_dir / "test-model") 

130 

131 # Reset mocks and test with stacked array 

132 mock_process_head.reset_mock() 

133 mock_gen_prompts.reset_mock() 

134 

135 # Create a stacked array of shape [n_layers, n_heads, n_ctx, n_ctx] 

136 cache_array = np.random.rand(2, 2, 5, 5).astype(np.float32) 

137 

138 # Call compute_and_save_figures with array cache 

139 compute_and_save_figures( 

140 model_cfg=MockConfig(), 

141 activations_path=prompt_dir / "activations.npy", 

142 cache=cache_array, 

143 figure_funcs=ATTENTION_MATRIX_FIGURE_FUNCS, 

144 save_path=temp_dir, 

145 force_overwrite=True, 

146 ) 

147 

148 # Check that process_single_head was called for each layer and head 

149 assert mock_process_head.call_count == 4 # 2 layers * 2 heads 

150 

151 # Check that generate_prompts_jsonl was called 

152 mock_gen_prompts.assert_called_once_with(temp_dir / "test-model")