Coverage for tests\unit\test_figures.py: 98%
55 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-04 01:50 -0700
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-04 01:50 -0700
1# tests/unit/test_figures.py
2from pathlib import Path
3from unittest import mock
5import numpy as np
7from pattern_lens.figure_util import save_matrix_wrapper
8from pattern_lens.figures import compute_and_save_figures, process_single_head
10TEMP_DIR: Path = Path("tests/_temp")
13def test_process_single_head():
14 """Test processing a single head's attention pattern."""
15 # Setup
16 temp_dir = TEMP_DIR / "test_process_single_head"
17 head_dir = temp_dir / "L0" / "H0"
18 head_dir.mkdir(parents=True)
20 # Create a simple test attention matrix
21 attn_pattern = np.random.rand(10, 10).astype(np.float32)
23 # Create a simple test figure function
24 @save_matrix_wrapper(fmt="svg")
25 def test_fig_func(matrix):
26 return matrix
28 # Patch ATTENTION_MATRIX_FIGURE_FUNCS to use our test function
29 with mock.patch(
30 "pattern_lens.figures.ATTENTION_MATRIX_FIGURE_FUNCS",
31 [test_fig_func],
32 ):
33 # Call process_single_head
34 result = process_single_head(
35 layer_idx=0,
36 head_idx=0,
37 attn_pattern=attn_pattern,
38 save_dir=head_dir,
39 force_overwrite=True,
40 )
42 # Check that our function was called and succeeded
43 assert "test_fig_func" in result
44 assert result["test_fig_func"] is True
46 # Check that the figure file was created
47 assert (head_dir / "test_fig_func.svg").exists()
50def test_process_single_head_error_handling():
51 """Test that process_single_head properly handles errors in figure functions."""
52 # Setup
53 temp_dir = TEMP_DIR / "test_process_single_head_error_handling"
54 head_dir = temp_dir / "L0" / "H0"
55 head_dir.mkdir(parents=True)
57 # Create a simple test attention matrix
58 attn_pattern = np.random.rand(10, 10).astype(np.float32)
60 # Create a test figure function that raises an exception
61 def error_fig_func(matrix, save_dir): # noqa: ARG001
62 raise ValueError("Test error")
64 # Patch ATTENTION_MATRIX_FIGURE_FUNCS to use our test function
65 with mock.patch(
66 "pattern_lens.figures.ATTENTION_MATRIX_FIGURE_FUNCS",
67 [error_fig_func],
68 ):
69 # Call process_single_head
70 result = process_single_head(
71 layer_idx=0,
72 head_idx=0,
73 attn_pattern=attn_pattern,
74 save_dir=head_dir,
75 force_overwrite=True,
76 )
78 # Check that our function was called and failed
79 assert "error_fig_func" in result
80 assert isinstance(result["error_fig_func"], ValueError)
82 # Check that an error file was created
83 assert (head_dir / "error_fig_func.error.txt").exists()
86def test_compute_and_save_figures():
87 """Test compute_and_save_figures with a mock model config and cache."""
88 # Setup
89 temp_dir = TEMP_DIR / "test_compute_and_save_figures"
90 prompt_dir = temp_dir / "test-model" / "prompts" / "test-hash"
91 prompt_dir.mkdir(parents=True)
93 # Create a mock model config
94 class MockConfig:
95 def __init__(self):
96 self.n_layers = 2
97 self.n_heads = 2
98 self.model_name = "test-model"
100 # Create a simple test attention matrices dict
101 cache_dict = {
102 "blocks.0.attn.hook_pattern": np.random.rand(1, 2, 5, 5).astype(np.float32),
103 "blocks.1.attn.hook_pattern": np.random.rand(1, 2, 5, 5).astype(np.float32),
104 }
106 # Create a simple test figure function
107 @save_matrix_wrapper(fmt="png")
108 def test_fig_func(matrix):
109 return matrix
111 # Patch ATTENTION_MATRIX_FIGURE_FUNCS and process_single_head
112 with (
113 mock.patch(
114 "pattern_lens.figures.ATTENTION_MATRIX_FIGURE_FUNCS",
115 [test_fig_func],
116 ),
117 mock.patch("pattern_lens.figures.process_single_head") as mock_process_head,
118 mock.patch("pattern_lens.figures.generate_prompts_jsonl") as mock_gen_prompts,
119 ):
120 mock_process_head.return_value = {"test_fig_func": True}
122 # Call compute_and_save_figures with dict cache
123 compute_and_save_figures(
124 model_cfg=MockConfig(),
125 activations_path=prompt_dir / "activations.npz",
126 cache=cache_dict,
127 save_path=temp_dir,
128 force_overwrite=True,
129 )
131 # Check that process_single_head was called for each layer and head
132 assert mock_process_head.call_count == 4 # 2 layers * 2 heads
134 # Check that generate_prompts_jsonl was called
135 mock_gen_prompts.assert_called_once_with(temp_dir / "test-model")
137 # Reset mocks and test with stacked array
138 mock_process_head.reset_mock()
139 mock_gen_prompts.reset_mock()
141 # Create a stacked array of shape [n_layers, n_heads, n_ctx, n_ctx]
142 cache_array = np.random.rand(2, 2, 5, 5).astype(np.float32)
144 # Call compute_and_save_figures with array cache
145 compute_and_save_figures(
146 model_cfg=MockConfig(),
147 activations_path=prompt_dir / "activations.npy",
148 cache=cache_array,
149 save_path=temp_dir,
150 force_overwrite=True,
151 )
153 # Check that process_single_head was called for each layer and head
154 assert mock_process_head.call_count == 4 # 2 layers * 2 heads
156 # Check that generate_prompts_jsonl was called
157 mock_gen_prompts.assert_called_once_with(temp_dir / "test-model")