Coverage for tests\unit\test_figures.py: 98%

55 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-04 01:50 -0700

1# tests/unit/test_figures.py 

2from pathlib import Path 

3from unittest import mock 

4 

5import numpy as np 

6 

7from pattern_lens.figure_util import save_matrix_wrapper 

8from pattern_lens.figures import compute_and_save_figures, process_single_head 

9 

10TEMP_DIR: Path = Path("tests/_temp") 

11 

12 

13def test_process_single_head(): 

14 """Test processing a single head's attention pattern.""" 

15 # Setup 

16 temp_dir = TEMP_DIR / "test_process_single_head" 

17 head_dir = temp_dir / "L0" / "H0" 

18 head_dir.mkdir(parents=True) 

19 

20 # Create a simple test attention matrix 

21 attn_pattern = np.random.rand(10, 10).astype(np.float32) 

22 

23 # Create a simple test figure function 

24 @save_matrix_wrapper(fmt="svg") 

25 def test_fig_func(matrix): 

26 return matrix 

27 

28 # Patch ATTENTION_MATRIX_FIGURE_FUNCS to use our test function 

29 with mock.patch( 

30 "pattern_lens.figures.ATTENTION_MATRIX_FIGURE_FUNCS", 

31 [test_fig_func], 

32 ): 

33 # Call process_single_head 

34 result = process_single_head( 

35 layer_idx=0, 

36 head_idx=0, 

37 attn_pattern=attn_pattern, 

38 save_dir=head_dir, 

39 force_overwrite=True, 

40 ) 

41 

42 # Check that our function was called and succeeded 

43 assert "test_fig_func" in result 

44 assert result["test_fig_func"] is True 

45 

46 # Check that the figure file was created 

47 assert (head_dir / "test_fig_func.svg").exists() 

48 

49 

50def test_process_single_head_error_handling(): 

51 """Test that process_single_head properly handles errors in figure functions.""" 

52 # Setup 

53 temp_dir = TEMP_DIR / "test_process_single_head_error_handling" 

54 head_dir = temp_dir / "L0" / "H0" 

55 head_dir.mkdir(parents=True) 

56 

57 # Create a simple test attention matrix 

58 attn_pattern = np.random.rand(10, 10).astype(np.float32) 

59 

60 # Create a test figure function that raises an exception 

61 def error_fig_func(matrix, save_dir): # noqa: ARG001 

62 raise ValueError("Test error") 

63 

64 # Patch ATTENTION_MATRIX_FIGURE_FUNCS to use our test function 

65 with mock.patch( 

66 "pattern_lens.figures.ATTENTION_MATRIX_FIGURE_FUNCS", 

67 [error_fig_func], 

68 ): 

69 # Call process_single_head 

70 result = process_single_head( 

71 layer_idx=0, 

72 head_idx=0, 

73 attn_pattern=attn_pattern, 

74 save_dir=head_dir, 

75 force_overwrite=True, 

76 ) 

77 

78 # Check that our function was called and failed 

79 assert "error_fig_func" in result 

80 assert isinstance(result["error_fig_func"], ValueError) 

81 

82 # Check that an error file was created 

83 assert (head_dir / "error_fig_func.error.txt").exists() 

84 

85 

86def test_compute_and_save_figures(): 

87 """Test compute_and_save_figures with a mock model config and cache.""" 

88 # Setup 

89 temp_dir = TEMP_DIR / "test_compute_and_save_figures" 

90 prompt_dir = temp_dir / "test-model" / "prompts" / "test-hash" 

91 prompt_dir.mkdir(parents=True) 

92 

93 # Create a mock model config 

94 class MockConfig: 

95 def __init__(self): 

96 self.n_layers = 2 

97 self.n_heads = 2 

98 self.model_name = "test-model" 

99 

100 # Create a simple test attention matrices dict 

101 cache_dict = { 

102 "blocks.0.attn.hook_pattern": np.random.rand(1, 2, 5, 5).astype(np.float32), 

103 "blocks.1.attn.hook_pattern": np.random.rand(1, 2, 5, 5).astype(np.float32), 

104 } 

105 

106 # Create a simple test figure function 

107 @save_matrix_wrapper(fmt="png") 

108 def test_fig_func(matrix): 

109 return matrix 

110 

111 # Patch ATTENTION_MATRIX_FIGURE_FUNCS and process_single_head 

112 with ( 

113 mock.patch( 

114 "pattern_lens.figures.ATTENTION_MATRIX_FIGURE_FUNCS", 

115 [test_fig_func], 

116 ), 

117 mock.patch("pattern_lens.figures.process_single_head") as mock_process_head, 

118 mock.patch("pattern_lens.figures.generate_prompts_jsonl") as mock_gen_prompts, 

119 ): 

120 mock_process_head.return_value = {"test_fig_func": True} 

121 

122 # Call compute_and_save_figures with dict cache 

123 compute_and_save_figures( 

124 model_cfg=MockConfig(), 

125 activations_path=prompt_dir / "activations.npz", 

126 cache=cache_dict, 

127 save_path=temp_dir, 

128 force_overwrite=True, 

129 ) 

130 

131 # Check that process_single_head was called for each layer and head 

132 assert mock_process_head.call_count == 4 # 2 layers * 2 heads 

133 

134 # Check that generate_prompts_jsonl was called 

135 mock_gen_prompts.assert_called_once_with(temp_dir / "test-model") 

136 

137 # Reset mocks and test with stacked array 

138 mock_process_head.reset_mock() 

139 mock_gen_prompts.reset_mock() 

140 

141 # Create a stacked array of shape [n_layers, n_heads, n_ctx, n_ctx] 

142 cache_array = np.random.rand(2, 2, 5, 5).astype(np.float32) 

143 

144 # Call compute_and_save_figures with array cache 

145 compute_and_save_figures( 

146 model_cfg=MockConfig(), 

147 activations_path=prompt_dir / "activations.npy", 

148 cache=cache_array, 

149 save_path=temp_dir, 

150 force_overwrite=True, 

151 ) 

152 

153 # Check that process_single_head was called for each layer and head 

154 assert mock_process_head.call_count == 4 # 2 layers * 2 heads 

155 

156 # Check that generate_prompts_jsonl was called 

157 mock_gen_prompts.assert_called_once_with(temp_dir / "test-model")