Coverage for tests / unit / test_activations.py: 80%

65 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-22 18:15 -0700

1from pathlib import Path 

2from unittest import mock 

3 

4import numpy as np 

5import torch 

6from transformer_lens import HookedTransformer # type: ignore[import-untyped] 

7 

8from pattern_lens.activations import compute_activations, get_activations 

9from pattern_lens.load_activations import ActivationsMissingError 

10 

11TEMP_DIR: Path = Path("tests/.temp") 

12 

13 

14class MockHookedTransformer: 

15 """Mock of HookedTransformer for testing compute_activations and get_activations.""" 

16 

17 def __init__(self, model_name="test-model", n_layers=2, n_heads=2): 

18 self.model_name = model_name 

19 self.cfg = mock.MagicMock() 

20 self.cfg.n_layers = n_layers 

21 self.cfg.n_heads = n_heads 

22 self.tokenizer = mock.MagicMock() 

23 self.tokenizer.tokenize.return_value = ["test", "tokens"] 

24 

25 def eval(self): 

26 return self 

27 

28 def run_with_cache(self, prompt_str, names_filter=None, return_type=None): # noqa: ARG002 

29 """Mock run_with_cache to return fake attention patterns.""" 

30 # Create a mock activation cache with appropriately shaped attention patterns 

31 cache: dict[str, torch.Tensor] = {} 

32 for i in range(self.cfg.n_layers): 

33 # [1, n_heads, n_ctx, n_ctx] tensor, where n_ctx is len(prompt_str) 

34 n_ctx: int = len(prompt_str) 

35 attn_pattern: torch.Tensor = torch.rand( 

36 1, 

37 self.cfg.n_heads, 

38 n_ctx, 

39 n_ctx, 

40 ).float() 

41 cache[f"blocks.{i}.attn.hook_pattern"] = attn_pattern 

42 

43 return None, cache 

44 

45 

46def test_compute_activations_stack_heads(): 

47 """Test compute_activations with stack_heads=True.""" 

48 # Setup 

49 temp_dir: Path = TEMP_DIR / "test_compute_activations_stack_heads" 

50 model: HookedTransformer = HookedTransformer.from_pretrained("pythia-14m") 

51 prompt: dict[str, str] = {"text": "test prompt", "hash": "testhash123"} 

52 

53 # Test with return_cache=None 

54 path, result = compute_activations( 

55 prompt=prompt, 

56 model=model, 

57 save_path=temp_dir, 

58 return_cache=None, 

59 stack_heads=True, 

60 ) 

61 

62 # Check return values 

63 assert ( 

64 path 

65 == temp_dir 

66 / model.cfg.model_name 

67 / "prompts" 

68 / prompt["hash"] 

69 / "activations-blocks.-.attn.hook_pattern.npy" 

70 ) 

71 assert result is None 

72 

73 # Check the file was created and has correct shape 

74 assert path.exists() 

75 loaded = np.load(path) 

76 assert loaded.shape[:3] == ( 

77 1, 

78 model.cfg.n_layers, 

79 model.cfg.n_heads, 

80 ) 

81 assert loaded.shape[3] == loaded.shape[4] 

82 

83 # Test with return_cache="numpy" 

84 path, result = compute_activations( 

85 prompt=prompt, 

86 model=model, 

87 save_path=temp_dir, 

88 return_cache="numpy", 

89 stack_heads=True, 

90 ) 

91 

92 # Check return values 

93 assert isinstance(result, np.ndarray) 

94 assert result.shape[:3] == ( 

95 1, 

96 model.cfg.n_layers, 

97 model.cfg.n_heads, 

98 ) 

99 assert result.shape[3] == result.shape[4] 

100 

101 

102def test_compute_activations_no_stack(): 

103 """Test compute_activations with stack_heads=False.""" 

104 # Setup 

105 temp_dir = TEMP_DIR / "test_compute_activations_no_stack" 

106 model = HookedTransformer.from_pretrained("pythia-14m") 

107 prompt = {"text": "test prompt", "hash": "testhash123"} 

108 

109 # Test with return_cache="numpy" 

110 path, result = compute_activations( 

111 prompt=prompt, 

112 model=model, 

113 save_path=temp_dir, 

114 return_cache="numpy", 

115 stack_heads=False, 

116 ) 

117 

118 # Check return values 

119 assert ( 

120 path 

121 == temp_dir 

122 / model.cfg.model_name 

123 / "prompts" 

124 / prompt["hash"] 

125 / "activations.npz" 

126 ) 

127 assert isinstance(result, dict) 

128 

129 # Check that the keys have the expected form and values have the right shape 

130 for i in range(model.cfg.n_layers): 

131 key = f"blocks.{i}.attn.hook_pattern" 

132 assert key in result 

133 assert result[key].shape[:2] == ( 

134 1, 

135 model.cfg.n_heads, 

136 ) 

137 assert result[key].shape[2] == result[key].shape[3] 

138 

139 

140def test_get_activations_missing(): 

141 """Test get_activations when activations don't exist.""" 

142 temp_dir = TEMP_DIR / "test_get_activations_missing" 

143 prompt = {"text": "test prompt", "hash": "testhash123"} 

144 

145 # Patch the load_activations and compute_activations functions 

146 with ( 

147 mock.patch("pattern_lens.activations.load_activations") as mock_load, 

148 mock.patch("pattern_lens.activations.compute_activations") as mock_compute, 

149 ): 

150 # Set up load_activations to fail 

151 mock_load.side_effect = ActivationsMissingError("Not found") 

152 

153 # Set up mock model and compute_activations 

154 model = HookedTransformer.from_pretrained("pythia-14m") 

155 mock_compute.return_value = (Path("mock/path"), {"mock": "cache"}) 

156 

157 # Call get_activations 

158 _path, _cache = get_activations( 

159 prompt=prompt, 

160 model=model, 

161 save_path=temp_dir, 

162 return_cache="numpy", 

163 ) 

164 

165 # Check that compute_activations was called with the right arguments 

166 mock_compute.assert_called_once() 

167 _args, kwargs = mock_compute.call_args 

168 assert kwargs["prompt"] == prompt 

169 assert kwargs["save_path"] == temp_dir 

170 assert kwargs["return_cache"] == "numpy"