Coverage for tests/unit/test_figures.py: 98%
54 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-04-06 15:09 -0600
« prev ^ index » next coverage.py v7.6.12, created at 2025-04-06 15:09 -0600
1# tests/unit/test_figures.py
2from pathlib import Path
3from unittest import mock
5import numpy as np
7from pattern_lens.attn_figure_funcs import ATTENTION_MATRIX_FIGURE_FUNCS
8from pattern_lens.figure_util import save_matrix_wrapper
9from pattern_lens.figures import compute_and_save_figures, process_single_head
11TEMP_DIR: Path = Path("tests/_temp")
14def test_process_single_head():
15 """Test processing a single head's attention pattern."""
16 # Setup
17 temp_dir = TEMP_DIR / "test_process_single_head"
18 head_dir = temp_dir / "L0" / "H0"
19 head_dir.mkdir(parents=True)
21 # Create a simple test attention matrix
22 attn_pattern = np.random.rand(10, 10).astype(np.float32)
24 # Create a simple test figure function
25 @save_matrix_wrapper(fmt="svg")
26 def test_fig_func(matrix):
27 return matrix
29 # Patch to use our test function
30 result = process_single_head(
31 layer_idx=0,
32 head_idx=0,
33 attn_pattern=attn_pattern,
34 save_dir=head_dir,
35 figure_funcs=[*ATTENTION_MATRIX_FIGURE_FUNCS, test_fig_func],
36 force_overwrite=True,
37 )
39 # Check that our function was called and succeeded
40 assert "test_fig_func" in result
41 assert result["test_fig_func"] is True
43 # Check that the figure file was created
44 assert (head_dir / "test_fig_func.svg").exists()
47def test_process_single_head_error_handling():
48 """Test that process_single_head properly handles errors in figure functions."""
49 # Setup
50 temp_dir = TEMP_DIR / "test_process_single_head_error_handling"
51 head_dir = temp_dir / "L0" / "H0"
52 head_dir.mkdir(parents=True)
54 # Create a simple test attention matrix
55 attn_pattern = np.random.rand(10, 10).astype(np.float32)
57 # Create a test figure function that raises an exception
58 def error_fig_func(matrix, save_dir): # noqa: ARG001
59 raise ValueError("Test error")
61 # Patch to use our test function
62 result = process_single_head(
63 layer_idx=0,
64 head_idx=0,
65 attn_pattern=attn_pattern,
66 save_dir=head_dir,
67 figure_funcs=[error_fig_func],
68 force_overwrite=True,
69 )
71 # Check that our function was called and failed
72 assert "error_fig_func" in result
73 assert isinstance(result["error_fig_func"], ValueError)
75 # Check that an error file was created
76 assert (head_dir / "error_fig_func.error.txt").exists()
79def test_compute_and_save_figures():
80 """Test compute_and_save_figures with a mock model config and cache."""
81 # Setup
82 temp_dir = TEMP_DIR / "test_compute_and_save_figures"
83 prompt_dir = temp_dir / "test-model" / "prompts" / "test-hash"
84 prompt_dir.mkdir(parents=True)
86 # Create a mock model config
87 class MockConfig:
88 def __init__(self):
89 self.n_layers = 2
90 self.n_heads = 2
91 self.model_name = "test-model"
93 # Create a simple test attention matrices dict
94 cache_dict = {
95 "blocks.0.attn.hook_pattern": np.random.rand(1, 2, 5, 5).astype(np.float32),
96 "blocks.1.attn.hook_pattern": np.random.rand(1, 2, 5, 5).astype(np.float32),
97 }
99 # Create a simple test figure function
100 @save_matrix_wrapper(fmt="png")
101 def test_fig_func(matrix):
102 return matrix
104 # Patch ATTENTION_MATRIX_FIGURE_FUNCS and process_single_head
105 with (
106 mock.patch(
107 "pattern_lens.figures.ATTENTION_MATRIX_FIGURE_FUNCS",
108 [test_fig_func],
109 ),
110 mock.patch("pattern_lens.figures.process_single_head") as mock_process_head,
111 mock.patch("pattern_lens.figures.generate_prompts_jsonl") as mock_gen_prompts,
112 ):
113 mock_process_head.return_value = {"test_fig_func": True}
115 # Call compute_and_save_figures with dict cache
116 compute_and_save_figures(
117 model_cfg=MockConfig(),
118 activations_path=prompt_dir / "activations.npz",
119 cache=cache_dict,
120 figure_funcs=ATTENTION_MATRIX_FIGURE_FUNCS,
121 save_path=temp_dir,
122 force_overwrite=True,
123 )
125 # Check that process_single_head was called for each layer and head
126 assert mock_process_head.call_count == 4 # 2 layers * 2 heads
128 # Check that generate_prompts_jsonl was called
129 mock_gen_prompts.assert_called_once_with(temp_dir / "test-model")
131 # Reset mocks and test with stacked array
132 mock_process_head.reset_mock()
133 mock_gen_prompts.reset_mock()
135 # Create a stacked array of shape [n_layers, n_heads, n_ctx, n_ctx]
136 cache_array = np.random.rand(2, 2, 5, 5).astype(np.float32)
138 # Call compute_and_save_figures with array cache
139 compute_and_save_figures(
140 model_cfg=MockConfig(),
141 activations_path=prompt_dir / "activations.npy",
142 cache=cache_array,
143 figure_funcs=ATTENTION_MATRIX_FIGURE_FUNCS,
144 save_path=temp_dir,
145 force_overwrite=True,
146 )
148 # Check that process_single_head was called for each layer and head
149 assert mock_process_head.call_count == 4 # 2 layers * 2 heads
151 # Check that generate_prompts_jsonl was called
152 mock_gen_prompts.assert_called_once_with(temp_dir / "test-model")