Coverage for pattern_lens/load_activations.py: 89%

56 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-04-06 15:09 -0600

1"loading activations from .npz on disk. implements some custom Exception classes" 

2 

3import base64 

4import hashlib 

5import json 

6from pathlib import Path 

7from typing import Literal, overload 

8 

9import numpy as np 

10 

11from pattern_lens.consts import ReturnCache 

12 

13 

14class GetActivationsError(ValueError): 

15 """base class for errors in getting activations""" 

16 

17 pass 

18 

19 

20class ActivationsMissingError(GetActivationsError, FileNotFoundError): 

21 """error for missing activations -- can't find the activations file""" 

22 

23 pass 

24 

25 

26class ActivationsMismatchError(GetActivationsError): 

27 """error for mismatched activations -- the prompt text or hash do not match 

28 

29 raised by `compare_prompt_to_loaded` 

30 """ 

31 

32 pass 

33 

34 

35class InvalidPromptError(GetActivationsError): 

36 """error for invalid prompt -- the prompt does not have fields "hash" or "text" 

37 

38 raised by `augment_prompt_with_hash` 

39 """ 

40 

41 pass 

42 

43 

44def compare_prompt_to_loaded(prompt: dict, prompt_loaded: dict) -> None: 

45 """compare a prompt to a loaded prompt, raise an error if they do not match 

46 

47 # Parameters: 

48 - `prompt : dict` 

49 - `prompt_loaded : dict` 

50 

51 # Returns: 

52 - `None` 

53 

54 # Raises: 

55 - `ActivationsMismatchError` : if the prompt text or hash do not match 

56 """ 

57 for key in ("text", "hash"): 

58 if prompt[key] != prompt_loaded[key]: 

59 msg = f"Prompt file does not match prompt at key {key}:\n{prompt}\n{prompt_loaded}" 

60 raise ActivationsMismatchError( 

61 msg, 

62 ) 

63 

64 

65def augment_prompt_with_hash(prompt: dict) -> dict: 

66 """if a prompt does not have a hash, add one 

67 

68 not having a "text" field is allowed, but only if "hash" is present 

69 

70 # Parameters: 

71 - `prompt : dict` 

72 

73 # Returns: 

74 - `dict` 

75 

76 # Modifies: 

77 the input `prompt` dictionary, if it does not have a `"hash"` key 

78 """ 

79 if "hash" not in prompt: 

80 if "text" not in prompt: 

81 msg = f"Prompt does not have 'text' field or 'hash' field: {prompt}" 

82 raise InvalidPromptError( 

83 msg, 

84 ) 

85 prompt_str: str = prompt["text"] 

86 prompt_hash: str = ( 

87 # we don't need this to be a secure hash 

88 base64.urlsafe_b64encode(hashlib.md5(prompt_str.encode()).digest()) # noqa: S324 

89 .decode() 

90 .rstrip("=") 

91 ) 

92 prompt.update(hash=prompt_hash) 

93 return prompt 

94 

95 

96@overload 

97def load_activations( 

98 model_name: str, 

99 prompt: dict, 

100 save_path: Path, 

101 return_fmt: Literal["torch"] = "torch", 

102) -> "tuple[Path, dict[str, torch.Tensor]]": # type: ignore[name-defined] # noqa: F821 

103 ... 

104@overload 

105def load_activations( 

106 model_name: str, 

107 prompt: dict, 

108 save_path: Path, 

109 return_fmt: Literal["numpy"] = "numpy", 

110) -> "tuple[Path, dict[str, np.ndarray]]": ... 

111def load_activations( 

112 model_name: str, 

113 prompt: dict, 

114 save_path: Path, 

115 return_fmt: ReturnCache = "torch", 

116) -> "tuple[Path, dict[str, torch.Tensor]|dict[str, np.ndarray]]": # type: ignore[name-defined] # noqa: F821 

117 """load activations for a prompt and model, from an npz file 

118 

119 # Parameters: 

120 - `model_name : str` 

121 - `prompt : dict` 

122 - `save_path : Path` 

123 - `return_fmt : Literal["torch", "numpy"]` 

124 (defaults to `"torch"`) 

125 

126 # Returns: 

127 - `tuple[Path, dict[str, torch.Tensor]|dict[str, np.ndarray]]` 

128 the path to the activations file and the activations as a dictionary 

129 of numpy arrays or torch tensors, depending on `return_fmt` 

130 

131 # Raises: 

132 - `ActivationsMissingError` : if the activations file is missing 

133 - `ValueError` : if `return_fmt` is not `"torch"` or `"numpy"` 

134 """ 

135 if return_fmt not in ("torch", "numpy"): 

136 msg = f"Invalid return_fmt: {return_fmt}, expected 'torch' or 'numpy'" 

137 raise ValueError( 

138 msg, 

139 ) 

140 if return_fmt == "torch": 

141 import torch 

142 

143 augment_prompt_with_hash(prompt) 

144 

145 prompt_dir: Path = save_path / model_name / "prompts" / prompt["hash"] 

146 prompt_file: Path = prompt_dir / "prompt.json" 

147 if not prompt_file.exists(): 

148 msg = f"Prompt file {prompt_file} does not exist" 

149 raise ActivationsMissingError(msg) 

150 with open(prompt_dir / "prompt.json", "r") as f: 

151 prompt_loaded: dict = json.load(f) 

152 compare_prompt_to_loaded(prompt, prompt_loaded) 

153 

154 activations_path: Path = prompt_dir / "activations.npz" 

155 

156 cache: dict 

157 

158 with np.load(activations_path) as npz_data: 

159 if return_fmt == "numpy": 

160 cache = dict(npz_data.items()) 

161 elif return_fmt == "torch": 

162 cache = {k: torch.from_numpy(v) for k, v in npz_data.items()} 

163 

164 return activations_path, cache 

165 

166 

167# def load_activations_stacked()