pattern_lens.load_activations
loading activations from .npz on disk. implements some custom Exception classes
1"loading activations from .npz on disk. implements some custom Exception classes" 2 3import base64 4import hashlib 5import json 6from pathlib import Path 7from typing import Literal, overload 8 9import numpy as np 10 11from pattern_lens.consts import ReturnCache 12 13 14class GetActivationsError(ValueError): 15 """base class for errors in getting activations""" 16 17 pass 18 19 20class ActivationsMissingError(GetActivationsError, FileNotFoundError): 21 """error for missing activations -- can't find the activations file""" 22 23 pass 24 25 26class ActivationsMismatchError(GetActivationsError): 27 """error for mismatched activations -- the prompt text or hash do not match 28 29 raised by `compare_prompt_to_loaded` 30 """ 31 32 pass 33 34 35class InvalidPromptError(GetActivationsError): 36 """error for invalid prompt -- the prompt does not have fields "hash" or "text" 37 38 raised by `augment_prompt_with_hash` 39 """ 40 41 pass 42 43 44def compare_prompt_to_loaded(prompt: dict, prompt_loaded: dict) -> None: 45 """compare a prompt to a loaded prompt, raise an error if they do not match 46 47 # Parameters: 48 - `prompt : dict` 49 - `prompt_loaded : dict` 50 51 # Returns: 52 - `None` 53 54 # Raises: 55 - `ActivationsMismatchError` : if the prompt text or hash do not match 56 """ 57 for key in ("text", "hash"): 58 if prompt[key] != prompt_loaded[key]: 59 msg = f"Prompt file does not match prompt at key {key}:\n{prompt}\n{prompt_loaded}" 60 raise ActivationsMismatchError( 61 msg, 62 ) 63 64 65def augment_prompt_with_hash(prompt: dict) -> dict: 66 """if a prompt does not have a hash, add one 67 68 not having a "text" field is allowed, but only if "hash" is present 69 70 # Parameters: 71 - `prompt : dict` 72 73 # Returns: 74 - `dict` 75 76 # Modifies: 77 the input `prompt` dictionary, if it does not have a `"hash"` key 78 """ 79 if "hash" not in prompt: 80 if "text" not in prompt: 81 msg = f"Prompt does not have 'text' field or 'hash' field: {prompt}" 82 raise InvalidPromptError( 83 msg, 84 ) 85 prompt_str: str = prompt["text"] 86 prompt_hash: str = ( 87 # we don't need this to be a secure hash 88 base64.urlsafe_b64encode(hashlib.md5(prompt_str.encode()).digest()) # noqa: S324 89 .decode() 90 .rstrip("=") 91 ) 92 prompt.update(hash=prompt_hash) 93 return prompt 94 95 96@overload 97def load_activations( 98 model_name: str, 99 prompt: dict, 100 save_path: Path, 101 return_fmt: Literal["torch"] = "torch", 102) -> "tuple[Path, dict[str, torch.Tensor]]": # type: ignore[name-defined] # noqa: F821 103 ... 104@overload 105def load_activations( 106 model_name: str, 107 prompt: dict, 108 save_path: Path, 109 return_fmt: Literal["numpy"] = "numpy", 110) -> "tuple[Path, dict[str, np.ndarray]]": ... 111def load_activations( 112 model_name: str, 113 prompt: dict, 114 save_path: Path, 115 return_fmt: ReturnCache = "torch", 116) -> "tuple[Path, dict[str, torch.Tensor]|dict[str, np.ndarray]]": # type: ignore[name-defined] # noqa: F821 117 """load activations for a prompt and model, from an npz file 118 119 # Parameters: 120 - `model_name : str` 121 - `prompt : dict` 122 - `save_path : Path` 123 - `return_fmt : Literal["torch", "numpy"]` 124 (defaults to `"torch"`) 125 126 # Returns: 127 - `tuple[Path, dict[str, torch.Tensor]|dict[str, np.ndarray]]` 128 the path to the activations file and the activations as a dictionary 129 of numpy arrays or torch tensors, depending on `return_fmt` 130 131 # Raises: 132 - `ActivationsMissingError` : if the activations file is missing 133 - `ValueError` : if `return_fmt` is not `"torch"` or `"numpy"` 134 """ 135 if return_fmt not in ("torch", "numpy"): 136 msg = f"Invalid return_fmt: {return_fmt}, expected 'torch' or 'numpy'" 137 raise ValueError( 138 msg, 139 ) 140 if return_fmt == "torch": 141 import torch # noqa: PLC0415 142 143 augment_prompt_with_hash(prompt) 144 145 prompt_dir: Path = save_path / model_name / "prompts" / prompt["hash"] 146 prompt_file: Path = prompt_dir / "prompt.json" 147 if not prompt_file.exists(): 148 msg = f"Prompt file {prompt_file} does not exist" 149 raise ActivationsMissingError(msg) 150 with open(prompt_dir / "prompt.json", "r") as f: 151 prompt_loaded: dict = json.load(f) 152 compare_prompt_to_loaded(prompt, prompt_loaded) 153 154 activations_path: Path = prompt_dir / "activations.npz" 155 156 if not activations_path.exists(): 157 msg = f"Activations file {activations_path} does not exist" 158 raise ActivationsMissingError(msg) 159 160 cache: dict 161 162 with np.load(activations_path) as npz_data: 163 if return_fmt == "numpy": 164 cache = dict(npz_data.items()) 165 elif return_fmt == "torch": 166 cache = {k: torch.from_numpy(v) for k, v in npz_data.items()} 167 168 return activations_path, cache 169 170 171def activations_exist(model_name: str, prompt: dict, save_path: Path) -> bool: 172 """check if activations exist on disk without loading them 173 174 cheap alternative to calling `load_activations` when you only need to know 175 whether a prompt has been processed. `load_activations` decompresses the full 176 `.npz` into numpy arrays, which is wasteful when the data is immediately 177 discarded. this function just checks `.exists()` on the two expected files. 178 179 # Parameters: 180 - `model_name : str` 181 - `prompt : dict` 182 must contain a 'hash' key (call `augment_prompt_with_hash` first) 183 - `save_path : Path` 184 185 # Returns: 186 - `bool` 187 True if both prompt.json and activations.npz exist for this prompt 188 189 # Raises: 190 - `InvalidPromptError` : if the prompt does not have a 'hash' key 191 """ 192 if "hash" not in prompt: 193 msg = f"Prompt must have 'hash' key (call augment_prompt_with_hash first): {prompt}" 194 raise InvalidPromptError(msg) 195 prompt_dir: Path = save_path / model_name / "prompts" / prompt["hash"] 196 return (prompt_dir / "prompt.json").exists() and ( 197 prompt_dir / "activations.npz" 198 ).exists() 199 200 201# def load_activations_stacked()
15class GetActivationsError(ValueError): 16 """base class for errors in getting activations""" 17 18 pass
base class for errors in getting activations
Inherited Members
- builtins.ValueError
- ValueError
- builtins.BaseException
- with_traceback
- add_note
- args
21class ActivationsMissingError(GetActivationsError, FileNotFoundError): 22 """error for missing activations -- can't find the activations file""" 23 24 pass
error for missing activations -- can't find the activations file
Inherited Members
- builtins.ValueError
- ValueError
- builtins.OSError
- errno
- strerror
- filename
- filename2
- characters_written
- builtins.BaseException
- with_traceback
- add_note
- args
27class ActivationsMismatchError(GetActivationsError): 28 """error for mismatched activations -- the prompt text or hash do not match 29 30 raised by `compare_prompt_to_loaded` 31 """ 32 33 pass
error for mismatched activations -- the prompt text or hash do not match
raised by compare_prompt_to_loaded
Inherited Members
- builtins.ValueError
- ValueError
- builtins.BaseException
- with_traceback
- add_note
- args
36class InvalidPromptError(GetActivationsError): 37 """error for invalid prompt -- the prompt does not have fields "hash" or "text" 38 39 raised by `augment_prompt_with_hash` 40 """ 41 42 pass
error for invalid prompt -- the prompt does not have fields "hash" or "text"
raised by augment_prompt_with_hash
Inherited Members
- builtins.ValueError
- ValueError
- builtins.BaseException
- with_traceback
- add_note
- args
45def compare_prompt_to_loaded(prompt: dict, prompt_loaded: dict) -> None: 46 """compare a prompt to a loaded prompt, raise an error if they do not match 47 48 # Parameters: 49 - `prompt : dict` 50 - `prompt_loaded : dict` 51 52 # Returns: 53 - `None` 54 55 # Raises: 56 - `ActivationsMismatchError` : if the prompt text or hash do not match 57 """ 58 for key in ("text", "hash"): 59 if prompt[key] != prompt_loaded[key]: 60 msg = f"Prompt file does not match prompt at key {key}:\n{prompt}\n{prompt_loaded}" 61 raise ActivationsMismatchError( 62 msg, 63 )
compare a prompt to a loaded prompt, raise an error if they do not match
Parameters:
prompt : dictprompt_loaded : dict
Returns:
None
Raises:
ActivationsMismatchError: if the prompt text or hash do not match
66def augment_prompt_with_hash(prompt: dict) -> dict: 67 """if a prompt does not have a hash, add one 68 69 not having a "text" field is allowed, but only if "hash" is present 70 71 # Parameters: 72 - `prompt : dict` 73 74 # Returns: 75 - `dict` 76 77 # Modifies: 78 the input `prompt` dictionary, if it does not have a `"hash"` key 79 """ 80 if "hash" not in prompt: 81 if "text" not in prompt: 82 msg = f"Prompt does not have 'text' field or 'hash' field: {prompt}" 83 raise InvalidPromptError( 84 msg, 85 ) 86 prompt_str: str = prompt["text"] 87 prompt_hash: str = ( 88 # we don't need this to be a secure hash 89 base64.urlsafe_b64encode(hashlib.md5(prompt_str.encode()).digest()) # noqa: S324 90 .decode() 91 .rstrip("=") 92 ) 93 prompt.update(hash=prompt_hash) 94 return prompt
if a prompt does not have a hash, add one
not having a "text" field is allowed, but only if "hash" is present
Parameters:
prompt : dict
Returns:
dict
Modifies:
the input prompt dictionary, if it does not have a "hash" key
112def load_activations( 113 model_name: str, 114 prompt: dict, 115 save_path: Path, 116 return_fmt: ReturnCache = "torch", 117) -> "tuple[Path, dict[str, torch.Tensor]|dict[str, np.ndarray]]": # type: ignore[name-defined] # noqa: F821 118 """load activations for a prompt and model, from an npz file 119 120 # Parameters: 121 - `model_name : str` 122 - `prompt : dict` 123 - `save_path : Path` 124 - `return_fmt : Literal["torch", "numpy"]` 125 (defaults to `"torch"`) 126 127 # Returns: 128 - `tuple[Path, dict[str, torch.Tensor]|dict[str, np.ndarray]]` 129 the path to the activations file and the activations as a dictionary 130 of numpy arrays or torch tensors, depending on `return_fmt` 131 132 # Raises: 133 - `ActivationsMissingError` : if the activations file is missing 134 - `ValueError` : if `return_fmt` is not `"torch"` or `"numpy"` 135 """ 136 if return_fmt not in ("torch", "numpy"): 137 msg = f"Invalid return_fmt: {return_fmt}, expected 'torch' or 'numpy'" 138 raise ValueError( 139 msg, 140 ) 141 if return_fmt == "torch": 142 import torch # noqa: PLC0415 143 144 augment_prompt_with_hash(prompt) 145 146 prompt_dir: Path = save_path / model_name / "prompts" / prompt["hash"] 147 prompt_file: Path = prompt_dir / "prompt.json" 148 if not prompt_file.exists(): 149 msg = f"Prompt file {prompt_file} does not exist" 150 raise ActivationsMissingError(msg) 151 with open(prompt_dir / "prompt.json", "r") as f: 152 prompt_loaded: dict = json.load(f) 153 compare_prompt_to_loaded(prompt, prompt_loaded) 154 155 activations_path: Path = prompt_dir / "activations.npz" 156 157 if not activations_path.exists(): 158 msg = f"Activations file {activations_path} does not exist" 159 raise ActivationsMissingError(msg) 160 161 cache: dict 162 163 with np.load(activations_path) as npz_data: 164 if return_fmt == "numpy": 165 cache = dict(npz_data.items()) 166 elif return_fmt == "torch": 167 cache = {k: torch.from_numpy(v) for k, v in npz_data.items()} 168 169 return activations_path, cache
load activations for a prompt and model, from an npz file
Parameters:
model_name : strprompt : dictsave_path : Pathreturn_fmt : Literal["torch", "numpy"](defaults to"torch")
Returns:
tuple[Path, dict[str, torch.Tensor]|dict[str, np.ndarray]]the path to the activations file and the activations as a dictionary of numpy arrays or torch tensors, depending onreturn_fmt
Raises:
ActivationsMissingError: if the activations file is missingValueError: ifreturn_fmtis not"torch"or"numpy"
172def activations_exist(model_name: str, prompt: dict, save_path: Path) -> bool: 173 """check if activations exist on disk without loading them 174 175 cheap alternative to calling `load_activations` when you only need to know 176 whether a prompt has been processed. `load_activations` decompresses the full 177 `.npz` into numpy arrays, which is wasteful when the data is immediately 178 discarded. this function just checks `.exists()` on the two expected files. 179 180 # Parameters: 181 - `model_name : str` 182 - `prompt : dict` 183 must contain a 'hash' key (call `augment_prompt_with_hash` first) 184 - `save_path : Path` 185 186 # Returns: 187 - `bool` 188 True if both prompt.json and activations.npz exist for this prompt 189 190 # Raises: 191 - `InvalidPromptError` : if the prompt does not have a 'hash' key 192 """ 193 if "hash" not in prompt: 194 msg = f"Prompt must have 'hash' key (call augment_prompt_with_hash first): {prompt}" 195 raise InvalidPromptError(msg) 196 prompt_dir: Path = save_path / model_name / "prompts" / prompt["hash"] 197 return (prompt_dir / "prompt.json").exists() and ( 198 prompt_dir / "activations.npz" 199 ).exists()
check if activations exist on disk without loading them
cheap alternative to calling load_activations when you only need to know
whether a prompt has been processed. load_activations decompresses the full
.npz into numpy arrays, which is wasteful when the data is immediately
discarded. this function just checks .exists() on the two expected files.
Parameters:
model_name : strprompt : dictmust contain a 'hash' key (callaugment_prompt_with_hashfirst)save_path : Path
Returns:
boolTrue if both prompt.json and activations.npz exist for this prompt
Raises:
InvalidPromptError: if the prompt does not have a 'hash' key