Coverage for pattern_lens / load_activations.py: 95%

62 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-22 18:15 -0700

1"loading activations from .npz on disk. implements some custom Exception classes" 

2 

3import base64 

4import hashlib 

5import json 

6from pathlib import Path 

7from typing import Literal, overload 

8 

9import numpy as np 

10 

11from pattern_lens.consts import ReturnCache 

12 

13 

14class GetActivationsError(ValueError): 

15 """base class for errors in getting activations""" 

16 

17 pass 

18 

19 

20class ActivationsMissingError(GetActivationsError, FileNotFoundError): 

21 """error for missing activations -- can't find the activations file""" 

22 

23 pass 

24 

25 

26class ActivationsMismatchError(GetActivationsError): 

27 """error for mismatched activations -- the prompt text or hash do not match 

28 

29 raised by `compare_prompt_to_loaded` 

30 """ 

31 

32 pass 

33 

34 

35class InvalidPromptError(GetActivationsError): 

36 """error for invalid prompt -- the prompt does not have fields "hash" or "text" 

37 

38 raised by `augment_prompt_with_hash` 

39 """ 

40 

41 pass 

42 

43 

44def compare_prompt_to_loaded(prompt: dict, prompt_loaded: dict) -> None: 

45 """compare a prompt to a loaded prompt, raise an error if they do not match 

46 

47 # Parameters: 

48 - `prompt : dict` 

49 - `prompt_loaded : dict` 

50 

51 # Returns: 

52 - `None` 

53 

54 # Raises: 

55 - `ActivationsMismatchError` : if the prompt text or hash do not match 

56 """ 

57 for key in ("text", "hash"): 

58 if prompt[key] != prompt_loaded[key]: 

59 msg = f"Prompt file does not match prompt at key {key}:\n{prompt}\n{prompt_loaded}" 

60 raise ActivationsMismatchError( 

61 msg, 

62 ) 

63 

64 

65def augment_prompt_with_hash(prompt: dict) -> dict: 

66 """if a prompt does not have a hash, add one 

67 

68 not having a "text" field is allowed, but only if "hash" is present 

69 

70 # Parameters: 

71 - `prompt : dict` 

72 

73 # Returns: 

74 - `dict` 

75 

76 # Modifies: 

77 the input `prompt` dictionary, if it does not have a `"hash"` key 

78 """ 

79 if "hash" not in prompt: 

80 if "text" not in prompt: 

81 msg = f"Prompt does not have 'text' field or 'hash' field: {prompt}" 

82 raise InvalidPromptError( 

83 msg, 

84 ) 

85 prompt_str: str = prompt["text"] 

86 prompt_hash: str = ( 

87 # we don't need this to be a secure hash 

88 base64.urlsafe_b64encode(hashlib.md5(prompt_str.encode()).digest()) # noqa: S324 

89 .decode() 

90 .rstrip("=") 

91 ) 

92 prompt.update(hash=prompt_hash) 

93 return prompt 

94 

95 

96@overload 

97def load_activations( 

98 model_name: str, 

99 prompt: dict, 

100 save_path: Path, 

101 return_fmt: Literal["torch"] = "torch", 

102) -> "tuple[Path, dict[str, torch.Tensor]]": # type: ignore[name-defined] # noqa: F821 

103 ... 

104@overload 

105def load_activations( 

106 model_name: str, 

107 prompt: dict, 

108 save_path: Path, 

109 return_fmt: Literal["numpy"] = "numpy", 

110) -> "tuple[Path, dict[str, np.ndarray]]": ... 

111def load_activations( 

112 model_name: str, 

113 prompt: dict, 

114 save_path: Path, 

115 return_fmt: ReturnCache = "torch", 

116) -> "tuple[Path, dict[str, torch.Tensor]|dict[str, np.ndarray]]": # type: ignore[name-defined] # noqa: F821 

117 """load activations for a prompt and model, from an npz file 

118 

119 # Parameters: 

120 - `model_name : str` 

121 - `prompt : dict` 

122 - `save_path : Path` 

123 - `return_fmt : Literal["torch", "numpy"]` 

124 (defaults to `"torch"`) 

125 

126 # Returns: 

127 - `tuple[Path, dict[str, torch.Tensor]|dict[str, np.ndarray]]` 

128 the path to the activations file and the activations as a dictionary 

129 of numpy arrays or torch tensors, depending on `return_fmt` 

130 

131 # Raises: 

132 - `ActivationsMissingError` : if the activations file is missing 

133 - `ValueError` : if `return_fmt` is not `"torch"` or `"numpy"` 

134 """ 

135 if return_fmt not in ("torch", "numpy"): 

136 msg = f"Invalid return_fmt: {return_fmt}, expected 'torch' or 'numpy'" 

137 raise ValueError( 

138 msg, 

139 ) 

140 if return_fmt == "torch": 

141 import torch # noqa: PLC0415 

142 

143 augment_prompt_with_hash(prompt) 

144 

145 prompt_dir: Path = save_path / model_name / "prompts" / prompt["hash"] 

146 prompt_file: Path = prompt_dir / "prompt.json" 

147 if not prompt_file.exists(): 

148 msg = f"Prompt file {prompt_file} does not exist" 

149 raise ActivationsMissingError(msg) 

150 with open(prompt_dir / "prompt.json", "r") as f: 

151 prompt_loaded: dict = json.load(f) 

152 compare_prompt_to_loaded(prompt, prompt_loaded) 

153 

154 activations_path: Path = prompt_dir / "activations.npz" 

155 

156 if not activations_path.exists(): 

157 msg = f"Activations file {activations_path} does not exist" 

158 raise ActivationsMissingError(msg) 

159 

160 cache: dict 

161 

162 with np.load(activations_path) as npz_data: 

163 if return_fmt == "numpy": 

164 cache = dict(npz_data.items()) 

165 elif return_fmt == "torch": 

166 cache = {k: torch.from_numpy(v) for k, v in npz_data.items()} 

167 

168 return activations_path, cache 

169 

170 

171def activations_exist(model_name: str, prompt: dict, save_path: Path) -> bool: 

172 """check if activations exist on disk without loading them 

173 

174 cheap alternative to calling `load_activations` when you only need to know 

175 whether a prompt has been processed. `load_activations` decompresses the full 

176 `.npz` into numpy arrays, which is wasteful when the data is immediately 

177 discarded. this function just checks `.exists()` on the two expected files. 

178 

179 # Parameters: 

180 - `model_name : str` 

181 - `prompt : dict` 

182 must contain a 'hash' key (call `augment_prompt_with_hash` first) 

183 - `save_path : Path` 

184 

185 # Returns: 

186 - `bool` 

187 True if both prompt.json and activations.npz exist for this prompt 

188 

189 # Raises: 

190 - `InvalidPromptError` : if the prompt does not have a 'hash' key 

191 """ 

192 if "hash" not in prompt: 

193 msg = f"Prompt must have 'hash' key (call augment_prompt_with_hash first): {prompt}" 

194 raise InvalidPromptError(msg) 

195 prompt_dir: Path = save_path / model_name / "prompts" / prompt["hash"] 

196 return (prompt_dir / "prompt.json").exists() and ( 

197 prompt_dir / "activations.npz" 

198 ).exists() 

199 

200 

201# def load_activations_stacked()