pattern_lens.load_activations
loading activations from .npz on disk. implements some custom Exception classes
1"loading activations from .npz on disk. implements some custom Exception classes" 2 3import base64 4import hashlib 5import json 6from pathlib import Path 7from typing import Literal, overload 8 9import numpy as np 10 11from pattern_lens.consts import ReturnCache 12 13 14class GetActivationsError(ValueError): 15 """base class for errors in getting activations""" 16 17 pass 18 19 20class ActivationsMissingError(GetActivationsError, FileNotFoundError): 21 """error for missing activations -- can't find the activations file""" 22 23 pass 24 25 26class ActivationsMismatchError(GetActivationsError): 27 """error for mismatched activations -- the prompt text or hash do not match 28 29 raised by `compare_prompt_to_loaded` 30 """ 31 32 pass 33 34 35class InvalidPromptError(GetActivationsError): 36 """error for invalid prompt -- the prompt does not have fields "hash" or "text" 37 38 raised by `augment_prompt_with_hash` 39 """ 40 41 pass 42 43 44def compare_prompt_to_loaded(prompt: dict, prompt_loaded: dict) -> None: 45 """compare a prompt to a loaded prompt, raise an error if they do not match 46 47 # Parameters: 48 - `prompt : dict` 49 - `prompt_loaded : dict` 50 51 # Returns: 52 - `None` 53 54 # Raises: 55 - `ActivationsMismatchError` : if the prompt text or hash do not match 56 """ 57 for key in ("text", "hash"): 58 if prompt[key] != prompt_loaded[key]: 59 msg = f"Prompt file does not match prompt at key {key}:\n{prompt}\n{prompt_loaded}" 60 raise ActivationsMismatchError( 61 msg, 62 ) 63 64 65def augment_prompt_with_hash(prompt: dict) -> dict: 66 """if a prompt does not have a hash, add one 67 68 not having a "text" field is allowed, but only if "hash" is present 69 70 # Parameters: 71 - `prompt : dict` 72 73 # Returns: 74 - `dict` 75 76 # Modifies: 77 the input `prompt` dictionary, if it does not have a `"hash"` key 78 """ 79 if "hash" not in prompt: 80 if "text" not in prompt: 81 msg = f"Prompt does not have 'text' field or 'hash' field: {prompt}" 82 raise InvalidPromptError( 83 msg, 84 ) 85 prompt_str: str = prompt["text"] 86 prompt_hash: str = ( 87 # we don't need this to be a secure hash 88 base64.urlsafe_b64encode(hashlib.md5(prompt_str.encode()).digest()) # noqa: S324 89 .decode() 90 .rstrip("=") 91 ) 92 prompt.update(hash=prompt_hash) 93 return prompt 94 95 96@overload 97def load_activations( 98 model_name: str, 99 prompt: dict, 100 save_path: Path, 101 return_fmt: Literal["torch"] = "torch", 102) -> "tuple[Path, dict[str, torch.Tensor]]": # type: ignore[name-defined] # noqa: F821 103 ... 104@overload 105def load_activations( 106 model_name: str, 107 prompt: dict, 108 save_path: Path, 109 return_fmt: Literal["numpy"] = "numpy", 110) -> "tuple[Path, dict[str, np.ndarray]]": ... 111def load_activations( 112 model_name: str, 113 prompt: dict, 114 save_path: Path, 115 return_fmt: ReturnCache = "torch", 116) -> "tuple[Path, dict[str, torch.Tensor]|dict[str, np.ndarray]]": # type: ignore[name-defined] # noqa: F821 117 """load activations for a prompt and model, from an npz file 118 119 # Parameters: 120 - `model_name : str` 121 - `prompt : dict` 122 - `save_path : Path` 123 - `return_fmt : Literal["torch", "numpy"]` 124 (defaults to `"torch"`) 125 126 # Returns: 127 - `tuple[Path, dict[str, torch.Tensor]|dict[str, np.ndarray]]` 128 the path to the activations file and the activations as a dictionary 129 of numpy arrays or torch tensors, depending on `return_fmt` 130 131 # Raises: 132 - `ActivationsMissingError` : if the activations file is missing 133 - `ValueError` : if `return_fmt` is not `"torch"` or `"numpy"` 134 """ 135 if return_fmt not in ("torch", "numpy"): 136 msg = f"Invalid return_fmt: {return_fmt}, expected 'torch' or 'numpy'" 137 raise ValueError( 138 msg, 139 ) 140 if return_fmt == "torch": 141 import torch 142 143 augment_prompt_with_hash(prompt) 144 145 prompt_dir: Path = save_path / model_name / "prompts" / prompt["hash"] 146 prompt_file: Path = prompt_dir / "prompt.json" 147 if not prompt_file.exists(): 148 msg = f"Prompt file {prompt_file} does not exist" 149 raise ActivationsMissingError(msg) 150 with open(prompt_dir / "prompt.json", "r") as f: 151 prompt_loaded: dict = json.load(f) 152 compare_prompt_to_loaded(prompt, prompt_loaded) 153 154 activations_path: Path = prompt_dir / "activations.npz" 155 156 cache: dict 157 158 with np.load(activations_path) as npz_data: 159 if return_fmt == "numpy": 160 cache = dict(npz_data.items()) 161 elif return_fmt == "torch": 162 cache = {k: torch.from_numpy(v) for k, v in npz_data.items()} 163 164 return activations_path, cache 165 166 167# def load_activations_stacked()
class
GetActivationsError(builtins.ValueError):
15class GetActivationsError(ValueError): 16 """base class for errors in getting activations""" 17 18 pass
base class for errors in getting activations
Inherited Members
- builtins.ValueError
- ValueError
- builtins.BaseException
- with_traceback
- add_note
- args
21class ActivationsMissingError(GetActivationsError, FileNotFoundError): 22 """error for missing activations -- can't find the activations file""" 23 24 pass
error for missing activations -- can't find the activations file
Inherited Members
- builtins.ValueError
- ValueError
- builtins.OSError
- errno
- strerror
- filename
- filename2
- characters_written
- builtins.BaseException
- with_traceback
- add_note
- args
27class ActivationsMismatchError(GetActivationsError): 28 """error for mismatched activations -- the prompt text or hash do not match 29 30 raised by `compare_prompt_to_loaded` 31 """ 32 33 pass
error for mismatched activations -- the prompt text or hash do not match
raised by compare_prompt_to_loaded
Inherited Members
- builtins.ValueError
- ValueError
- builtins.BaseException
- with_traceback
- add_note
- args
36class InvalidPromptError(GetActivationsError): 37 """error for invalid prompt -- the prompt does not have fields "hash" or "text" 38 39 raised by `augment_prompt_with_hash` 40 """ 41 42 pass
error for invalid prompt -- the prompt does not have fields "hash" or "text"
raised by augment_prompt_with_hash
Inherited Members
- builtins.ValueError
- ValueError
- builtins.BaseException
- with_traceback
- add_note
- args
def
compare_prompt_to_loaded(prompt: dict, prompt_loaded: dict) -> None:
45def compare_prompt_to_loaded(prompt: dict, prompt_loaded: dict) -> None: 46 """compare a prompt to a loaded prompt, raise an error if they do not match 47 48 # Parameters: 49 - `prompt : dict` 50 - `prompt_loaded : dict` 51 52 # Returns: 53 - `None` 54 55 # Raises: 56 - `ActivationsMismatchError` : if the prompt text or hash do not match 57 """ 58 for key in ("text", "hash"): 59 if prompt[key] != prompt_loaded[key]: 60 msg = f"Prompt file does not match prompt at key {key}:\n{prompt}\n{prompt_loaded}" 61 raise ActivationsMismatchError( 62 msg, 63 )
compare a prompt to a loaded prompt, raise an error if they do not match
Parameters:
prompt : dictprompt_loaded : dict
Returns:
None
Raises:
ActivationsMismatchError: if the prompt text or hash do not match
def
augment_prompt_with_hash(prompt: dict) -> dict:
66def augment_prompt_with_hash(prompt: dict) -> dict: 67 """if a prompt does not have a hash, add one 68 69 not having a "text" field is allowed, but only if "hash" is present 70 71 # Parameters: 72 - `prompt : dict` 73 74 # Returns: 75 - `dict` 76 77 # Modifies: 78 the input `prompt` dictionary, if it does not have a `"hash"` key 79 """ 80 if "hash" not in prompt: 81 if "text" not in prompt: 82 msg = f"Prompt does not have 'text' field or 'hash' field: {prompt}" 83 raise InvalidPromptError( 84 msg, 85 ) 86 prompt_str: str = prompt["text"] 87 prompt_hash: str = ( 88 # we don't need this to be a secure hash 89 base64.urlsafe_b64encode(hashlib.md5(prompt_str.encode()).digest()) # noqa: S324 90 .decode() 91 .rstrip("=") 92 ) 93 prompt.update(hash=prompt_hash) 94 return prompt
if a prompt does not have a hash, add one
not having a "text" field is allowed, but only if "hash" is present
Parameters:
prompt : dict
Returns:
dict
Modifies:
the input prompt dictionary, if it does not have a "hash" key
def
load_activations( model_name: str, prompt: dict, save_path: pathlib.Path, return_fmt: Literal[None, 'numpy', 'torch'] = 'torch') -> tuple[pathlib.Path, dict[str, torch.Tensor] | dict[str, numpy.ndarray]]:
112def load_activations( 113 model_name: str, 114 prompt: dict, 115 save_path: Path, 116 return_fmt: ReturnCache = "torch", 117) -> "tuple[Path, dict[str, torch.Tensor]|dict[str, np.ndarray]]": # type: ignore[name-defined] # noqa: F821 118 """load activations for a prompt and model, from an npz file 119 120 # Parameters: 121 - `model_name : str` 122 - `prompt : dict` 123 - `save_path : Path` 124 - `return_fmt : Literal["torch", "numpy"]` 125 (defaults to `"torch"`) 126 127 # Returns: 128 - `tuple[Path, dict[str, torch.Tensor]|dict[str, np.ndarray]]` 129 the path to the activations file and the activations as a dictionary 130 of numpy arrays or torch tensors, depending on `return_fmt` 131 132 # Raises: 133 - `ActivationsMissingError` : if the activations file is missing 134 - `ValueError` : if `return_fmt` is not `"torch"` or `"numpy"` 135 """ 136 if return_fmt not in ("torch", "numpy"): 137 msg = f"Invalid return_fmt: {return_fmt}, expected 'torch' or 'numpy'" 138 raise ValueError( 139 msg, 140 ) 141 if return_fmt == "torch": 142 import torch 143 144 augment_prompt_with_hash(prompt) 145 146 prompt_dir: Path = save_path / model_name / "prompts" / prompt["hash"] 147 prompt_file: Path = prompt_dir / "prompt.json" 148 if not prompt_file.exists(): 149 msg = f"Prompt file {prompt_file} does not exist" 150 raise ActivationsMissingError(msg) 151 with open(prompt_dir / "prompt.json", "r") as f: 152 prompt_loaded: dict = json.load(f) 153 compare_prompt_to_loaded(prompt, prompt_loaded) 154 155 activations_path: Path = prompt_dir / "activations.npz" 156 157 cache: dict 158 159 with np.load(activations_path) as npz_data: 160 if return_fmt == "numpy": 161 cache = dict(npz_data.items()) 162 elif return_fmt == "torch": 163 cache = {k: torch.from_numpy(v) for k, v in npz_data.items()} 164 165 return activations_path, cache
load activations for a prompt and model, from an npz file
Parameters:
model_name : strprompt : dictsave_path : Pathreturn_fmt : Literal["torch", "numpy"](defaults to"torch")
Returns:
tuple[Path, dict[str, torch.Tensor]|dict[str, np.ndarray]]the path to the activations file and the activations as a dictionary of numpy arrays or torch tensors, depending onreturn_fmt
Raises:
ActivationsMissingError: if the activations file is missingValueError: ifreturn_fmtis not"torch"or"numpy"