Coverage for tests / unit / test_activations_return.py: 100%
63 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-22 18:15 -0700
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-22 18:15 -0700
1# tests/unit/test_activations_return.py
2from pathlib import Path
3from unittest import mock
5import pytest
6import torch
7from transformer_lens import HookedTransformer # type: ignore[import-untyped]
9from pattern_lens.activations import compute_activations, get_activations
11TEMP_DIR: Path = Path("tests/.temp")
14class MockHookedTransformer:
15 """Mock of HookedTransformer for testing compute_activations and get_activations."""
17 def __init__(self, model_name="test-model", n_layers=2, n_heads=2):
18 self.model_name = model_name
19 self.cfg = mock.MagicMock()
20 self.cfg.n_layers = n_layers
21 self.cfg.n_heads = n_heads
22 self.tokenizer = mock.MagicMock()
23 self.tokenizer.tokenize.return_value = ["test", "tokens"]
25 def eval(self):
26 return self
28 def run_with_cache(self, prompt_str, names_filter=None, return_type=None): # noqa: ARG002
29 """Mock run_with_cache to return fake attention patterns."""
30 # Create a mock activation cache with appropriately shaped attention patterns
31 cache = {}
32 for i in range(self.cfg.n_layers):
33 # [1, n_heads, n_ctx, n_ctx] tensor, where n_ctx is len(prompt_str)
34 n_ctx = len(prompt_str)
35 attn_pattern = torch.rand(
36 1,
37 self.cfg.n_heads,
38 n_ctx,
39 n_ctx,
40 ).float()
41 cache[f"blocks.{i}.attn.hook_pattern"] = attn_pattern
43 return None, cache
46def test_compute_activations_torch_return():
47 """Test compute_activations with return_cache="torch"."""
48 # Setup
49 temp_dir = TEMP_DIR / "test_compute_activations_torch_return"
50 model = MockHookedTransformer(n_layers=3, n_heads=4)
51 prompt = {"text": "test prompt", "hash": "testhash123"}
53 # Test with stack_heads=True
54 _path, result = compute_activations( # type: ignore[call-overload]
55 prompt=prompt,
56 model=model,
57 save_path=temp_dir,
58 return_cache="torch",
59 stack_heads=True,
60 )
62 # Check return values
63 assert isinstance(result, torch.Tensor)
64 assert result.shape == (
65 1,
66 model.cfg.n_layers,
67 model.cfg.n_heads,
68 len(prompt["text"]),
69 len(prompt["text"]),
70 )
72 # Test with stack_heads=False
73 _path, result = compute_activations( # type: ignore[call-overload]
74 prompt=prompt,
75 model=model,
76 save_path=temp_dir,
77 return_cache="torch",
78 stack_heads=False,
79 )
81 # Check return values
82 assert isinstance(result, dict)
83 for i in range(model.cfg.n_layers):
84 key = f"blocks.{i}.attn.hook_pattern"
85 assert key in result
86 assert isinstance(result[key], torch.Tensor)
89def test_compute_activations_invalid_return():
90 """Test compute_activations with an invalid return_cache value."""
91 # Setup
92 temp_dir = TEMP_DIR / "test_compute_activations_invalid_return"
93 model = HookedTransformer.from_pretrained("pythia-14m")
94 prompt = {"text": "test prompt", "hash": "testhash123"}
96 # Test with an invalid return_cache value
97 with pytest.raises(ValueError, match="invalid return_cache"):
98 compute_activations( # type: ignore[call-overload]
99 prompt=prompt,
100 model=model,
101 save_path=temp_dir,
102 # intentionally invalid
103 return_cache="invalid",
104 stack_heads=True,
105 )
108def test_get_activations_torch_return():
109 """Test get_activations with return_cache="torch" and mocked load_activations."""
110 temp_dir = TEMP_DIR / "test_get_activations_torch_return"
111 prompt = {"text": "test prompt", "hash": "testhash123"}
112 model = MockHookedTransformer(model_name="test-model")
114 # Create a mock for load_activations that returns torch tensors
115 with mock.patch("pattern_lens.activations.load_activations") as mock_load:
116 mock_cache = {
117 "blocks.0.attn.hook_pattern": torch.rand(
118 1,
119 2,
120 len(prompt["text"]),
121 len(prompt["text"]),
122 ),
123 "blocks.1.attn.hook_pattern": torch.rand(
124 1,
125 2,
126 len(prompt["text"]),
127 len(prompt["text"]),
128 ),
129 }
130 mock_load.return_value = (Path("mock/path"), mock_cache)
132 # Call get_activations with torch return format
133 _path, cache = get_activations( # type: ignore[call-overload]
134 prompt=prompt,
135 model=model,
136 save_path=temp_dir,
137 return_cache="torch",
138 )
140 # Check that we got torch tensors back
141 assert isinstance(cache, dict)
142 for key, value in cache.items():
143 assert isinstance(key, str)
144 assert isinstance(value, torch.Tensor)
147def test_get_activations_none_return():
148 """Test get_activations with return_cache=None uses the fast existence-check path.
150 When return_cache=None and the activations already exist on disk,
151 get_activations should use activations_exist (cheap file-existence check)
152 instead of load_activations (which would decompress the full .npz).
153 """
154 temp_dir = TEMP_DIR / "test_get_activations_none_return"
155 prompt = {"text": "test prompt", "hash": "testhash123"}
157 # pass model as a string -- get_activations accepts HookedTransformer | str,
158 # and the fast path (return_cache=None) only needs the model name for path construction
159 with mock.patch("pattern_lens.activations.activations_exist", return_value=True):
160 path, cache = get_activations( # type: ignore[call-overload]
161 prompt=prompt,
162 model="test-model",
163 save_path=temp_dir,
164 return_cache=None,
165 )
167 # Check that we got the reconstructed path but no cache
168 expected_path = (
169 temp_dir / "test-model" / "prompts" / prompt["hash"] / "activations.npz"
170 )
171 assert path == expected_path
172 assert cache is None