Coverage for tests/unit/test_activations_return.py: 100%

65 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-04-06 15:09 -0600

1# tests/unit/test_activations_return.py 

2from pathlib import Path 

3from unittest import mock 

4 

5import pytest 

6import torch 

7from transformer_lens import HookedTransformer # type: ignore[import-untyped] 

8 

9from pattern_lens.activations import compute_activations, get_activations 

10 

11TEMP_DIR: Path = Path("tests/_temp") 

12 

13 

14class MockHookedTransformer: 

15 """Mock of HookedTransformer for testing compute_activations and get_activations.""" 

16 

17 def __init__(self, model_name="test-model", n_layers=2, n_heads=2): 

18 self.model_name = model_name 

19 self.cfg = mock.MagicMock() 

20 self.cfg.n_layers = n_layers 

21 self.cfg.n_heads = n_heads 

22 self.tokenizer = mock.MagicMock() 

23 self.tokenizer.tokenize.return_value = ["test", "tokens"] 

24 

25 def eval(self): 

26 return self 

27 

28 def run_with_cache(self, prompt_str, names_filter=None, return_type=None): # noqa: ARG002 

29 """Mock run_with_cache to return fake attention patterns.""" 

30 # Create a mock activation cache with appropriately shaped attention patterns 

31 cache = {} 

32 for i in range(self.cfg.n_layers): 

33 # [1, n_heads, n_ctx, n_ctx] tensor, where n_ctx is len(prompt_str) 

34 n_ctx = len(prompt_str) 

35 attn_pattern = torch.rand( 

36 1, 

37 self.cfg.n_heads, 

38 n_ctx, 

39 n_ctx, 

40 ).float() 

41 cache[f"blocks.{i}.attn.hook_pattern"] = attn_pattern 

42 

43 return None, cache 

44 

45 

46def test_compute_activations_torch_return(): 

47 """Test compute_activations with return_cache="torch".""" 

48 # Setup 

49 temp_dir = TEMP_DIR / "test_compute_activations_torch_return" 

50 model = MockHookedTransformer(n_layers=3, n_heads=4) 

51 prompt = {"text": "test prompt", "hash": "testhash123"} 

52 

53 # Test with stack_heads=True 

54 path, result = compute_activations( 

55 prompt=prompt, 

56 model=model, 

57 save_path=temp_dir, 

58 return_cache="torch", 

59 stack_heads=True, 

60 ) 

61 

62 # Check return values 

63 assert isinstance(result, torch.Tensor) 

64 assert result.shape == ( 

65 1, 

66 model.cfg.n_layers, 

67 model.cfg.n_heads, 

68 len(prompt["text"]), 

69 len(prompt["text"]), 

70 ) 

71 

72 # Test with stack_heads=False 

73 path, result = compute_activations( 

74 prompt=prompt, 

75 model=model, 

76 save_path=temp_dir, 

77 return_cache="torch", 

78 stack_heads=False, 

79 ) 

80 

81 # Check return values 

82 assert isinstance(result, dict) 

83 for i in range(model.cfg.n_layers): 

84 key = f"blocks.{i}.attn.hook_pattern" 

85 assert key in result 

86 assert isinstance(result[key], torch.Tensor) 

87 

88 

89def test_compute_activations_invalid_return(): 

90 """Test compute_activations with an invalid return_cache value.""" 

91 # Setup 

92 temp_dir = TEMP_DIR / "test_compute_activations_invalid_return" 

93 model = HookedTransformer.from_pretrained("pythia-14m") 

94 prompt = {"text": "test prompt", "hash": "testhash123"} 

95 

96 # Test with an invalid return_cache value 

97 with pytest.raises(ValueError, match="invalid return_cache"): 

98 compute_activations( 

99 prompt=prompt, 

100 model=model, 

101 save_path=temp_dir, 

102 # intentionally invalid 

103 return_cache="invalid", # type: ignore[call-overload] 

104 stack_heads=True, 

105 ) 

106 

107 

108def test_get_activations_torch_return(): 

109 """Test get_activations with return_cache="torch" and mocked load_activations.""" 

110 temp_dir = TEMP_DIR / "test_get_activations_torch_return" 

111 prompt = {"text": "test prompt", "hash": "testhash123"} 

112 model = MockHookedTransformer(model_name="test-model") 

113 

114 # Create a mock for load_activations that returns torch tensors 

115 with mock.patch("pattern_lens.activations.load_activations") as mock_load: 

116 mock_cache = { 

117 "blocks.0.attn.hook_pattern": torch.rand( 

118 1, 

119 2, 

120 len(prompt["text"]), 

121 len(prompt["text"]), 

122 ), 

123 "blocks.1.attn.hook_pattern": torch.rand( 

124 1, 

125 2, 

126 len(prompt["text"]), 

127 len(prompt["text"]), 

128 ), 

129 } 

130 mock_load.return_value = (Path("mock/path"), mock_cache) 

131 

132 # Call get_activations with torch return format 

133 path, cache = get_activations( 

134 prompt=prompt, 

135 model=model, 

136 save_path=temp_dir, 

137 return_cache="torch", 

138 ) 

139 

140 # Check that we got torch tensors back 

141 assert isinstance(cache, dict) 

142 for key, value in cache.items(): 

143 assert isinstance(key, str) 

144 assert isinstance(value, torch.Tensor) 

145 

146 

147def test_get_activations_none_return(): 

148 """Test get_activations with return_cache=None.""" 

149 temp_dir = TEMP_DIR / "test_get_activations_none_return" 

150 prompt = {"text": "test prompt", "hash": "testhash123"} 

151 model = MockHookedTransformer(model_name="test-model") 

152 

153 # Create a mock for load_activations that returns a path but no cache 

154 with mock.patch("pattern_lens.activations.load_activations") as mock_load: 

155 mock_path = Path("mock/path") 

156 mock_load.return_value = (mock_path, {}) # Cache will be ignored 

157 

158 # Call get_activations with None return format 

159 path, cache = get_activations( 

160 prompt=prompt, 

161 model=model, 

162 save_path=temp_dir, 

163 return_cache=None, 

164 ) 

165 

166 # Check that we got the path but no cache 

167 assert path == mock_path 

168 assert cache is None