Coverage for tests / unit / test_activations_return.py: 100%

63 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-22 18:15 -0700

1# tests/unit/test_activations_return.py 

2from pathlib import Path 

3from unittest import mock 

4 

5import pytest 

6import torch 

7from transformer_lens import HookedTransformer # type: ignore[import-untyped] 

8 

9from pattern_lens.activations import compute_activations, get_activations 

10 

11TEMP_DIR: Path = Path("tests/.temp") 

12 

13 

14class MockHookedTransformer: 

15 """Mock of HookedTransformer for testing compute_activations and get_activations.""" 

16 

17 def __init__(self, model_name="test-model", n_layers=2, n_heads=2): 

18 self.model_name = model_name 

19 self.cfg = mock.MagicMock() 

20 self.cfg.n_layers = n_layers 

21 self.cfg.n_heads = n_heads 

22 self.tokenizer = mock.MagicMock() 

23 self.tokenizer.tokenize.return_value = ["test", "tokens"] 

24 

25 def eval(self): 

26 return self 

27 

28 def run_with_cache(self, prompt_str, names_filter=None, return_type=None): # noqa: ARG002 

29 """Mock run_with_cache to return fake attention patterns.""" 

30 # Create a mock activation cache with appropriately shaped attention patterns 

31 cache = {} 

32 for i in range(self.cfg.n_layers): 

33 # [1, n_heads, n_ctx, n_ctx] tensor, where n_ctx is len(prompt_str) 

34 n_ctx = len(prompt_str) 

35 attn_pattern = torch.rand( 

36 1, 

37 self.cfg.n_heads, 

38 n_ctx, 

39 n_ctx, 

40 ).float() 

41 cache[f"blocks.{i}.attn.hook_pattern"] = attn_pattern 

42 

43 return None, cache 

44 

45 

46def test_compute_activations_torch_return(): 

47 """Test compute_activations with return_cache="torch".""" 

48 # Setup 

49 temp_dir = TEMP_DIR / "test_compute_activations_torch_return" 

50 model = MockHookedTransformer(n_layers=3, n_heads=4) 

51 prompt = {"text": "test prompt", "hash": "testhash123"} 

52 

53 # Test with stack_heads=True 

54 _path, result = compute_activations( # type: ignore[call-overload] 

55 prompt=prompt, 

56 model=model, 

57 save_path=temp_dir, 

58 return_cache="torch", 

59 stack_heads=True, 

60 ) 

61 

62 # Check return values 

63 assert isinstance(result, torch.Tensor) 

64 assert result.shape == ( 

65 1, 

66 model.cfg.n_layers, 

67 model.cfg.n_heads, 

68 len(prompt["text"]), 

69 len(prompt["text"]), 

70 ) 

71 

72 # Test with stack_heads=False 

73 _path, result = compute_activations( # type: ignore[call-overload] 

74 prompt=prompt, 

75 model=model, 

76 save_path=temp_dir, 

77 return_cache="torch", 

78 stack_heads=False, 

79 ) 

80 

81 # Check return values 

82 assert isinstance(result, dict) 

83 for i in range(model.cfg.n_layers): 

84 key = f"blocks.{i}.attn.hook_pattern" 

85 assert key in result 

86 assert isinstance(result[key], torch.Tensor) 

87 

88 

89def test_compute_activations_invalid_return(): 

90 """Test compute_activations with an invalid return_cache value.""" 

91 # Setup 

92 temp_dir = TEMP_DIR / "test_compute_activations_invalid_return" 

93 model = HookedTransformer.from_pretrained("pythia-14m") 

94 prompt = {"text": "test prompt", "hash": "testhash123"} 

95 

96 # Test with an invalid return_cache value 

97 with pytest.raises(ValueError, match="invalid return_cache"): 

98 compute_activations( # type: ignore[call-overload] 

99 prompt=prompt, 

100 model=model, 

101 save_path=temp_dir, 

102 # intentionally invalid 

103 return_cache="invalid", 

104 stack_heads=True, 

105 ) 

106 

107 

108def test_get_activations_torch_return(): 

109 """Test get_activations with return_cache="torch" and mocked load_activations.""" 

110 temp_dir = TEMP_DIR / "test_get_activations_torch_return" 

111 prompt = {"text": "test prompt", "hash": "testhash123"} 

112 model = MockHookedTransformer(model_name="test-model") 

113 

114 # Create a mock for load_activations that returns torch tensors 

115 with mock.patch("pattern_lens.activations.load_activations") as mock_load: 

116 mock_cache = { 

117 "blocks.0.attn.hook_pattern": torch.rand( 

118 1, 

119 2, 

120 len(prompt["text"]), 

121 len(prompt["text"]), 

122 ), 

123 "blocks.1.attn.hook_pattern": torch.rand( 

124 1, 

125 2, 

126 len(prompt["text"]), 

127 len(prompt["text"]), 

128 ), 

129 } 

130 mock_load.return_value = (Path("mock/path"), mock_cache) 

131 

132 # Call get_activations with torch return format 

133 _path, cache = get_activations( # type: ignore[call-overload] 

134 prompt=prompt, 

135 model=model, 

136 save_path=temp_dir, 

137 return_cache="torch", 

138 ) 

139 

140 # Check that we got torch tensors back 

141 assert isinstance(cache, dict) 

142 for key, value in cache.items(): 

143 assert isinstance(key, str) 

144 assert isinstance(value, torch.Tensor) 

145 

146 

147def test_get_activations_none_return(): 

148 """Test get_activations with return_cache=None uses the fast existence-check path. 

149 

150 When return_cache=None and the activations already exist on disk, 

151 get_activations should use activations_exist (cheap file-existence check) 

152 instead of load_activations (which would decompress the full .npz). 

153 """ 

154 temp_dir = TEMP_DIR / "test_get_activations_none_return" 

155 prompt = {"text": "test prompt", "hash": "testhash123"} 

156 

157 # pass model as a string -- get_activations accepts HookedTransformer | str, 

158 # and the fast path (return_cache=None) only needs the model name for path construction 

159 with mock.patch("pattern_lens.activations.activations_exist", return_value=True): 

160 path, cache = get_activations( # type: ignore[call-overload] 

161 prompt=prompt, 

162 model="test-model", 

163 save_path=temp_dir, 

164 return_cache=None, 

165 ) 

166 

167 # Check that we got the reconstructed path but no cache 

168 expected_path = ( 

169 temp_dir / "test-model" / "prompts" / prompt["hash"] / "activations.npz" 

170 ) 

171 assert path == expected_path 

172 assert cache is None