Coverage for tests\unit\test_activations_return.py: 100%
65 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-04 02:14 -0700
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-04 02:14 -0700
1# tests/unit/test_activations_return.py
2from pathlib import Path
3from unittest import mock
5import pytest
6import torch
7from transformer_lens import HookedTransformer # type: ignore[import-untyped]
9from pattern_lens.activations import compute_activations, get_activations
11TEMP_DIR: Path = Path("tests/_temp")
14class MockHookedTransformer:
15 """Mock of HookedTransformer for testing compute_activations and get_activations."""
17 def __init__(self, model_name="test-model", n_layers=2, n_heads=2):
18 self.model_name = model_name
19 self.cfg = mock.MagicMock()
20 self.cfg.n_layers = n_layers
21 self.cfg.n_heads = n_heads
22 self.tokenizer = mock.MagicMock()
23 self.tokenizer.tokenize.return_value = ["test", "tokens"]
25 def eval(self):
26 return self
28 def run_with_cache(self, prompt_str, names_filter=None, return_type=None): # noqa: ARG002
29 """Mock run_with_cache to return fake attention patterns."""
30 # Create a mock activation cache with appropriately shaped attention patterns
31 cache = {}
32 for i in range(self.cfg.n_layers):
33 # [1, n_heads, n_ctx, n_ctx] tensor, where n_ctx is len(prompt_str)
34 n_ctx = len(prompt_str)
35 attn_pattern = torch.rand(
36 1,
37 self.cfg.n_heads,
38 n_ctx,
39 n_ctx,
40 ).float()
41 cache[f"blocks.{i}.attn.hook_pattern"] = attn_pattern
43 return None, cache
46def test_compute_activations_torch_return():
47 """Test compute_activations with return_cache="torch"."""
48 # Setup
49 temp_dir = TEMP_DIR / "test_compute_activations_torch_return"
50 model = MockHookedTransformer(n_layers=3, n_heads=4)
51 prompt = {"text": "test prompt", "hash": "testhash123"}
53 # Test with stack_heads=True
54 path, result = compute_activations(
55 prompt=prompt,
56 model=model,
57 save_path=temp_dir,
58 return_cache="torch",
59 stack_heads=True,
60 )
62 # Check return values
63 assert isinstance(result, torch.Tensor)
64 assert result.shape == (
65 1,
66 model.cfg.n_layers,
67 model.cfg.n_heads,
68 len(prompt["text"]),
69 len(prompt["text"]),
70 )
72 # Test with stack_heads=False
73 path, result = compute_activations(
74 prompt=prompt,
75 model=model,
76 save_path=temp_dir,
77 return_cache="torch",
78 stack_heads=False,
79 )
81 # Check return values
82 assert isinstance(result, dict)
83 for i in range(model.cfg.n_layers):
84 key = f"blocks.{i}.attn.hook_pattern"
85 assert key in result
86 assert isinstance(result[key], torch.Tensor)
89def test_compute_activations_invalid_return():
90 """Test compute_activations with an invalid return_cache value."""
91 # Setup
92 temp_dir = TEMP_DIR / "test_compute_activations_invalid_return"
93 model = HookedTransformer.from_pretrained("pythia-14m")
94 prompt = {"text": "test prompt", "hash": "testhash123"}
96 # Test with an invalid return_cache value
97 with pytest.raises(ValueError, match="invalid return_cache"):
98 compute_activations(
99 prompt=prompt,
100 model=model,
101 save_path=temp_dir,
102 # intentionally invalid
103 return_cache="invalid", # type: ignore[call-overload]
104 stack_heads=True,
105 )
108def test_get_activations_torch_return():
109 """Test get_activations with return_cache="torch" and mocked load_activations."""
110 temp_dir = TEMP_DIR / "test_get_activations_torch_return"
111 prompt = {"text": "test prompt", "hash": "testhash123"}
112 model = MockHookedTransformer(model_name="test-model")
114 # Create a mock for load_activations that returns torch tensors
115 with mock.patch("pattern_lens.activations.load_activations") as mock_load:
116 mock_cache = {
117 "blocks.0.attn.hook_pattern": torch.rand(
118 1,
119 2,
120 len(prompt["text"]),
121 len(prompt["text"]),
122 ),
123 "blocks.1.attn.hook_pattern": torch.rand(
124 1,
125 2,
126 len(prompt["text"]),
127 len(prompt["text"]),
128 ),
129 }
130 mock_load.return_value = (Path("mock/path"), mock_cache)
132 # Call get_activations with torch return format
133 path, cache = get_activations(
134 prompt=prompt,
135 model=model,
136 save_path=temp_dir,
137 return_cache="torch",
138 )
140 # Check that we got torch tensors back
141 assert isinstance(cache, dict)
142 for key, value in cache.items():
143 assert isinstance(key, str)
144 assert isinstance(value, torch.Tensor)
147def test_get_activations_none_return():
148 """Test get_activations with return_cache=None."""
149 temp_dir = TEMP_DIR / "test_get_activations_none_return"
150 prompt = {"text": "test prompt", "hash": "testhash123"}
151 model = MockHookedTransformer(model_name="test-model")
153 # Create a mock for load_activations that returns a path but no cache
154 with mock.patch("pattern_lens.activations.load_activations") as mock_load:
155 mock_path = Path("mock/path")
156 mock_load.return_value = (mock_path, {}) # Cache will be ignored
158 # Call get_activations with None return format
159 path, cache = get_activations(
160 prompt=prompt,
161 model=model,
162 save_path=temp_dir,
163 return_cache=None,
164 )
166 # Check that we got the path but no cache
167 assert path == mock_path
168 assert cache is None