Coverage for pattern_lens\load_activations.py: 84%

51 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-01-16 20:39 -0700

1"loading activations from .npz on disk. implements some custom Exception classes" 

2 

3import base64 

4import hashlib 

5import json 

6from pathlib import Path 

7from typing import Literal, overload 

8 

9import numpy as np 

10 

11 

12class GetActivationsError(ValueError): 

13 """base class for errors in getting activations""" 

14 

15 pass 

16 

17 

18class ActivationsMissingError(GetActivationsError, FileNotFoundError): 

19 """error for missing activations -- can't find the activations file""" 

20 

21 pass 

22 

23 

24class ActivationsMismatchError(GetActivationsError): 

25 """error for mismatched activations -- the prompt text or hash do not match 

26 

27 raised by `compare_prompt_to_loaded` 

28 """ 

29 

30 pass 

31 

32 

33class InvalidPromptError(GetActivationsError): 

34 """error for invalid prompt -- the prompt does not have fields "hash" or "text" 

35 

36 raised by `augment_prompt_with_hash` 

37 """ 

38 

39 pass 

40 

41 

42def compare_prompt_to_loaded(prompt: dict, prompt_loaded: dict) -> None: 

43 """compare a prompt to a loaded prompt, raise an error if they do not match 

44 

45 # Parameters: 

46 - `prompt : dict` 

47 - `prompt_loaded : dict` 

48 

49 # Returns: 

50 - `None` 

51 

52 # Raises: 

53 - `ActivationsMismatchError` : if the prompt text or hash do not match 

54 """ 

55 for key in ("text", "hash"): 

56 if prompt[key] != prompt_loaded[key]: 

57 raise ActivationsMismatchError( 

58 f"Prompt file does not match prompt at key {key}:\n{prompt}\n{prompt_loaded}" 

59 ) 

60 

61 

62def augment_prompt_with_hash(prompt: dict) -> dict: 

63 """if a prompt does not have a hash, add one 

64 

65 not having a "text" field is allowed, but only if "hash" is present 

66 

67 # Parameters: 

68 - `prompt : dict` 

69 

70 # Returns: 

71 - `dict` 

72 

73 # Modifies: 

74 the input `prompt` dictionary, if it does not have a `"hash"` key 

75 """ 

76 if "hash" not in prompt: 

77 if "text" not in prompt: 

78 raise InvalidPromptError( 

79 f"Prompt does not have 'text' field or 'hash' field: {prompt}" 

80 ) 

81 prompt_str: str = prompt["text"] 

82 prompt_hash: str = ( 

83 base64.urlsafe_b64encode(hashlib.md5(prompt_str.encode()).digest()) 

84 .decode() 

85 .rstrip("=") 

86 ) 

87 prompt.update(hash=prompt_hash) 

88 return prompt 

89 

90 

91@overload 

92def load_activations( 

93 model_name: str, 

94 prompt: dict, 

95 save_path: Path, 

96 return_fmt: Literal["torch"] = "torch", 

97) -> "tuple[Path, dict[str, torch.Tensor]]": # type: ignore[name-defined] # noqa: F821 

98 ... 

99@overload 

100def load_activations( 

101 model_name: str, 

102 prompt: dict, 

103 save_path: Path, 

104 return_fmt: Literal["numpy"] = "numpy", 

105) -> "tuple[Path, dict[str, np.ndarray]]": ... 

106def load_activations( 

107 model_name: str, 

108 prompt: dict, 

109 save_path: Path, 

110 return_fmt: Literal["torch", "numpy"] = "torch", 

111) -> "tuple[Path, dict[str, torch.Tensor]|dict[str, np.ndarray]]": # type: ignore[name-defined] # noqa: F821 

112 """load activations for a prompt and model, from an npz file 

113 

114 # Parameters: 

115 - `model_name : str` 

116 - `prompt : dict` 

117 - `save_path : Path` 

118 - `return_fmt : Literal["torch", "numpy"]` 

119 (defaults to `"torch"`) 

120 

121 # Returns: 

122 - `tuple[Path, dict[str, torch.Tensor]|dict[str, np.ndarray]]` 

123 the path to the activations file and the activations as a dictionary of numpy arrays or torch tensors, depending on `return_fmt` 

124 

125 # Raises: 

126 - `ActivationsMissingError` : if the activations file is missing 

127 - `ValueError` : if `return_fmt` is not `"torch"` or `"numpy"` 

128 """ 

129 

130 if return_fmt not in ("torch", "numpy"): 

131 raise ValueError( 

132 f"Invalid return_fmt: {return_fmt}, expected 'torch' or 'numpy'" 

133 ) 

134 if return_fmt == "torch": 

135 import torch 

136 

137 augment_prompt_with_hash(prompt) 

138 

139 prompt_dir: Path = save_path / model_name / "prompts" / prompt["hash"] 

140 prompt_file: Path = prompt_dir / "prompt.json" 

141 if not prompt_file.exists(): 

142 raise ActivationsMissingError(f"Prompt file {prompt_file} does not exist") 

143 with open(prompt_dir / "prompt.json", "r") as f: 

144 prompt_loaded: dict = json.load(f) 

145 compare_prompt_to_loaded(prompt, prompt_loaded) 

146 

147 activations_path: Path = prompt_dir / "activations.npz" 

148 

149 cache: dict 

150 

151 with np.load(activations_path) as npz_data: 

152 if return_fmt == "numpy": 

153 cache = {k: v for k, v in npz_data.items()} 

154 elif return_fmt == "torch": 

155 cache = {k: torch.from_numpy(v) for k, v in npz_data.items()} 

156 

157 return activations_path, cache