Coverage for tests / unit / test_load_activations.py: 100%

68 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-22 18:15 -0700

1# tests/unit/test_load_activations.py 

2import json 

3from pathlib import Path 

4from unittest import mock 

5 

6import numpy as np 

7import pytest 

8 

9from pattern_lens.load_activations import ( 

10 ActivationsMismatchError, 

11 ActivationsMissingError, 

12 InvalidPromptError, 

13 augment_prompt_with_hash, 

14 load_activations, 

15) 

16 

17TEMP_DIR: Path = Path("tests/.temp") 

18 

19 

20def test_augment_prompt_with_hash(): 

21 """Test adding hash to prompt.""" 

22 # Test with a prompt that doesn't have a hash 

23 prompt_no_hash = {"text": "test prompt"} 

24 result = augment_prompt_with_hash(prompt_no_hash) 

25 

26 # Check that the hash was added and is deterministic 

27 assert "hash" in result 

28 assert isinstance(result["hash"], str) 

29 

30 # Save the hash for comparison 

31 first_hash = result["hash"] 

32 

33 # Test that calling it again doesn't change the hash 

34 result = augment_prompt_with_hash(prompt_no_hash) 

35 assert result["hash"] == first_hash 

36 

37 # Test with a prompt that already has a hash 

38 prompt_with_hash = {"text": "test prompt", "hash": "existing-hash"} 

39 result = augment_prompt_with_hash(prompt_with_hash) 

40 

41 # Check that the hash wasn't changed 

42 assert result["hash"] == "existing-hash" 

43 

44 # Test with an invalid prompt (no text or hash) 

45 with pytest.raises(InvalidPromptError): 

46 augment_prompt_with_hash({"other_field": "value"}) 

47 

48 

49def test_load_activations_success(): 

50 """Test successful loading of activations.""" 

51 # Setup 

52 temp_dir = TEMP_DIR / "test_load_activations_success" 

53 model_name = "test-model" 

54 prompt = {"text": "test prompt", "hash": "testhash123"} 

55 

56 # Create the necessary directory structure 

57 prompt_dir = temp_dir / model_name / "prompts" / prompt["hash"] 

58 prompt_dir.mkdir(parents=True, exist_ok=True) 

59 

60 # Create a dummy prompt.json file 

61 with open(prompt_dir / "prompt.json", "w") as f: 

62 json.dump(prompt, f) 

63 

64 # Create a dummy activations.npz file 

65 fake_activations = { 

66 "blocks.0.attn.hook_pattern": np.random.rand(1, 2, 10, 10).astype(np.float32), 

67 } 

68 np.savez(prompt_dir / "activations.npz", **fake_activations) # type: ignore[arg-type] 

69 

70 # Test loading with numpy format 

71 with mock.patch( 

72 "pattern_lens.load_activations.compare_prompt_to_loaded", 

73 ) as mock_compare: 

74 path, cache = load_activations( 

75 model_name=model_name, 

76 prompt=prompt, 

77 save_path=temp_dir, 

78 return_fmt="numpy", 

79 ) 

80 

81 # Check that the path is correct 

82 assert path == prompt_dir / "activations.npz" 

83 

84 # Check that the cache has the right structure 

85 assert isinstance(cache, dict) 

86 assert "blocks.0.attn.hook_pattern" in cache 

87 assert cache["blocks.0.attn.hook_pattern"].shape == (1, 2, 10, 10) 

88 

89 # Check that the prompt was compared 

90 mock_compare.assert_called_once_with(prompt, prompt) 

91 

92 

93def test_load_activations_errors(): 

94 """Test error handling in load_activations.""" 

95 # Setup 

96 temp_dir = TEMP_DIR / "test_load_activations_errors" 

97 model_name = "test-model" 

98 prompt = {"text": "test prompt", "hash": "testhash123"} 

99 

100 # Test with missing prompt file 

101 with pytest.raises(ActivationsMissingError): 

102 load_activations( 

103 model_name=model_name, 

104 prompt=prompt, 

105 save_path=temp_dir, 

106 return_fmt="numpy", 

107 ) 

108 

109 # Create the necessary directory structure 

110 prompt_dir = temp_dir / model_name / "prompts" / prompt["hash"] 

111 prompt_dir.mkdir(parents=True, exist_ok=True) 

112 

113 # Create a prompt.json file with different content 

114 different_prompt = {"text": "different prompt", "hash": prompt["hash"]} 

115 with open(prompt_dir / "prompt.json", "w") as f: 

116 json.dump(different_prompt, f) 

117 

118 # Test with mismatched prompt 

119 with pytest.raises(ActivationsMismatchError): 

120 load_activations( 

121 model_name=model_name, 

122 prompt=prompt, 

123 save_path=temp_dir, 

124 return_fmt="numpy", 

125 ) 

126 

127 # Fix the prompt file 

128 with open(prompt_dir / "prompt.json", "w") as f: 

129 json.dump(prompt, f) 

130 

131 # Test with missing activations file (prompt.json exists but activations.npz does not) 

132 with pytest.raises(ActivationsMissingError, match="Activations file"): 

133 load_activations( 

134 model_name=model_name, 

135 prompt=prompt, 

136 save_path=temp_dir, 

137 return_fmt="numpy", 

138 ) 

139 

140 

141def test_load_activations_missing_npz_is_subclass_of_file_not_found(): 

142 """ActivationsMissingError for missing activations.npz is also a FileNotFoundError.""" 

143 temp_dir = TEMP_DIR / "test_load_activations_missing_npz_subclass" 

144 model_name = "test-model" 

145 prompt = {"text": "test prompt", "hash": "testhash_sub"} 

146 

147 prompt_dir = temp_dir / model_name / "prompts" / prompt["hash"] 

148 prompt_dir.mkdir(parents=True, exist_ok=True) 

149 

150 with open(prompt_dir / "prompt.json", "w") as f: 

151 json.dump(prompt, f) 

152 

153 # Should raise ActivationsMissingError, which is also a FileNotFoundError 

154 with pytest.raises(ActivationsMissingError) as exc_info: 

155 load_activations( 

156 model_name=model_name, 

157 prompt=prompt, 

158 save_path=temp_dir, 

159 return_fmt="numpy", 

160 ) 

161 assert isinstance(exc_info.value, FileNotFoundError) 

162 

163 

164def test_load_activations_invalid_return_fmt(): 

165 """Test that invalid return_fmt raises ValueError.""" 

166 with pytest.raises(ValueError, match="Invalid return_fmt"): 

167 load_activations( # type: ignore[no-matching-overload] 

168 model_name="test-model", 

169 prompt={"text": "test", "hash": "test"}, 

170 save_path=Path("/nonexistent"), 

171 return_fmt="invalid", 

172 )