Coverage for tests/unit/test_activations.py: 80%
65 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-04-06 15:09 -0600
« prev ^ index » next coverage.py v7.6.12, created at 2025-04-06 15:09 -0600
1from pathlib import Path
2from unittest import mock
4import numpy as np
5import torch
6from transformer_lens import HookedTransformer # type: ignore[import-untyped]
8from pattern_lens.activations import compute_activations, get_activations
9from pattern_lens.load_activations import ActivationsMissingError
11TEMP_DIR: Path = Path("tests/_temp")
14class MockHookedTransformer:
15 """Mock of HookedTransformer for testing compute_activations and get_activations."""
17 def __init__(self, model_name="test-model", n_layers=2, n_heads=2):
18 self.model_name = model_name
19 self.cfg = mock.MagicMock()
20 self.cfg.n_layers = n_layers
21 self.cfg.n_heads = n_heads
22 self.tokenizer = mock.MagicMock()
23 self.tokenizer.tokenize.return_value = ["test", "tokens"]
25 def eval(self):
26 return self
28 def run_with_cache(self, prompt_str, names_filter=None, return_type=None): # noqa: ARG002
29 """Mock run_with_cache to return fake attention patterns."""
30 # Create a mock activation cache with appropriately shaped attention patterns
31 cache: dict[str, torch.Tensor] = {}
32 for i in range(self.cfg.n_layers):
33 # [1, n_heads, n_ctx, n_ctx] tensor, where n_ctx is len(prompt_str)
34 n_ctx: int = len(prompt_str)
35 attn_pattern: torch.Tensor = torch.rand(
36 1,
37 self.cfg.n_heads,
38 n_ctx,
39 n_ctx,
40 ).float()
41 cache[f"blocks.{i}.attn.hook_pattern"] = attn_pattern
43 return None, cache
46def test_compute_activations_stack_heads():
47 """Test compute_activations with stack_heads=True."""
48 # Setup
49 temp_dir: Path = TEMP_DIR / "test_compute_activations_stack_heads"
50 model: HookedTransformer = HookedTransformer.from_pretrained("pythia-14m")
51 prompt: dict[str, str] = {"text": "test prompt", "hash": "testhash123"}
53 # Test with return_cache=None
54 path, result = compute_activations(
55 prompt=prompt,
56 model=model,
57 save_path=temp_dir,
58 return_cache=None,
59 stack_heads=True,
60 )
62 # Check return values
63 assert (
64 path
65 == temp_dir
66 / model.cfg.model_name
67 / "prompts"
68 / prompt["hash"]
69 / "activations-blocks.-.attn.hook_pattern.npy"
70 )
71 assert result is None
73 # Check the file was created and has correct shape
74 assert path.exists()
75 loaded = np.load(path)
76 assert loaded.shape[:3] == (
77 1,
78 model.cfg.n_layers,
79 model.cfg.n_heads,
80 )
81 assert loaded.shape[3] == loaded.shape[4]
83 # Test with return_cache="numpy"
84 path, result = compute_activations(
85 prompt=prompt,
86 model=model,
87 save_path=temp_dir,
88 return_cache="numpy",
89 stack_heads=True,
90 )
92 # Check return values
93 assert isinstance(result, np.ndarray)
94 assert result.shape[:3] == (
95 1,
96 model.cfg.n_layers,
97 model.cfg.n_heads,
98 )
99 assert result.shape[3] == result.shape[4]
102def test_compute_activations_no_stack():
103 """Test compute_activations with stack_heads=False."""
104 # Setup
105 temp_dir = TEMP_DIR / "test_compute_activations_no_stack"
106 model = HookedTransformer.from_pretrained("pythia-14m")
107 prompt = {"text": "test prompt", "hash": "testhash123"}
109 # Test with return_cache="numpy"
110 path, result = compute_activations(
111 prompt=prompt,
112 model=model,
113 save_path=temp_dir,
114 return_cache="numpy",
115 stack_heads=False,
116 )
118 # Check return values
119 assert (
120 path
121 == temp_dir
122 / model.cfg.model_name
123 / "prompts"
124 / prompt["hash"]
125 / "activations.npz"
126 )
127 assert isinstance(result, dict)
129 # Check that the keys have the expected form and values have the right shape
130 for i in range(model.cfg.n_layers):
131 key = f"blocks.{i}.attn.hook_pattern"
132 assert key in result
133 assert result[key].shape[:2] == (
134 1,
135 model.cfg.n_heads,
136 )
137 assert result[key].shape[2] == result[key].shape[3]
140def test_get_activations_missing():
141 """Test get_activations when activations don't exist."""
142 temp_dir = TEMP_DIR / "test_get_activations_missing"
143 prompt = {"text": "test prompt", "hash": "testhash123"}
145 # Patch the load_activations and compute_activations functions
146 with (
147 mock.patch("pattern_lens.activations.load_activations") as mock_load,
148 mock.patch("pattern_lens.activations.compute_activations") as mock_compute,
149 ):
150 # Set up load_activations to fail
151 mock_load.side_effect = ActivationsMissingError("Not found")
153 # Set up mock model and compute_activations
154 model = HookedTransformer.from_pretrained("pythia-14m")
155 mock_compute.return_value = (Path("mock/path"), {"mock": "cache"})
157 # Call get_activations
158 path, cache = get_activations(
159 prompt=prompt,
160 model=model,
161 save_path=temp_dir,
162 return_cache="numpy",
163 )
165 # Check that compute_activations was called with the right arguments
166 mock_compute.assert_called_once()
167 args, kwargs = mock_compute.call_args
168 assert kwargs["prompt"] == prompt
169 assert kwargs["save_path"] == temp_dir
170 assert kwargs["return_cache"] == "numpy"