pattern_lens.load_activations
loading activations from .npz on disk. implements some custom Exception classes
1"loading activations from .npz on disk. implements some custom Exception classes" 2 3import base64 4import hashlib 5import json 6from pathlib import Path 7from typing import Literal, overload 8 9import numpy as np 10 11 12class GetActivationsError(ValueError): 13 """base class for errors in getting activations""" 14 15 pass 16 17 18class ActivationsMissingError(GetActivationsError, FileNotFoundError): 19 """error for missing activations -- can't find the activations file""" 20 21 pass 22 23 24class ActivationsMismatchError(GetActivationsError): 25 """error for mismatched activations -- the prompt text or hash do not match 26 27 raised by `compare_prompt_to_loaded` 28 """ 29 30 pass 31 32 33class InvalidPromptError(GetActivationsError): 34 """error for invalid prompt -- the prompt does not have fields "hash" or "text" 35 36 raised by `augment_prompt_with_hash` 37 """ 38 39 pass 40 41 42def compare_prompt_to_loaded(prompt: dict, prompt_loaded: dict) -> None: 43 """compare a prompt to a loaded prompt, raise an error if they do not match 44 45 # Parameters: 46 - `prompt : dict` 47 - `prompt_loaded : dict` 48 49 # Returns: 50 - `None` 51 52 # Raises: 53 - `ActivationsMismatchError` : if the prompt text or hash do not match 54 """ 55 for key in ("text", "hash"): 56 if prompt[key] != prompt_loaded[key]: 57 raise ActivationsMismatchError( 58 f"Prompt file does not match prompt at key {key}:\n{prompt}\n{prompt_loaded}" 59 ) 60 61 62def augment_prompt_with_hash(prompt: dict) -> dict: 63 """if a prompt does not have a hash, add one 64 65 not having a "text" field is allowed, but only if "hash" is present 66 67 # Parameters: 68 - `prompt : dict` 69 70 # Returns: 71 - `dict` 72 73 # Modifies: 74 the input `prompt` dictionary, if it does not have a `"hash"` key 75 """ 76 if "hash" not in prompt: 77 if "text" not in prompt: 78 raise InvalidPromptError( 79 f"Prompt does not have 'text' field or 'hash' field: {prompt}" 80 ) 81 prompt_str: str = prompt["text"] 82 prompt_hash: str = ( 83 base64.urlsafe_b64encode(hashlib.md5(prompt_str.encode()).digest()) 84 .decode() 85 .rstrip("=") 86 ) 87 prompt.update(hash=prompt_hash) 88 return prompt 89 90 91@overload 92def load_activations( 93 model_name: str, 94 prompt: dict, 95 save_path: Path, 96 return_fmt: Literal["torch"] = "torch", 97) -> "tuple[Path, dict[str, torch.Tensor]]": # type: ignore[name-defined] # noqa: F821 98 ... 99@overload 100def load_activations( 101 model_name: str, 102 prompt: dict, 103 save_path: Path, 104 return_fmt: Literal["numpy"] = "numpy", 105) -> "tuple[Path, dict[str, np.ndarray]]": ... 106def load_activations( 107 model_name: str, 108 prompt: dict, 109 save_path: Path, 110 return_fmt: Literal["torch", "numpy"] = "torch", 111) -> "tuple[Path, dict[str, torch.Tensor]|dict[str, np.ndarray]]": # type: ignore[name-defined] # noqa: F821 112 """load activations for a prompt and model, from an npz file 113 114 # Parameters: 115 - `model_name : str` 116 - `prompt : dict` 117 - `save_path : Path` 118 - `return_fmt : Literal["torch", "numpy"]` 119 (defaults to `"torch"`) 120 121 # Returns: 122 - `tuple[Path, dict[str, torch.Tensor]|dict[str, np.ndarray]]` 123 the path to the activations file and the activations as a dictionary of numpy arrays or torch tensors, depending on `return_fmt` 124 125 # Raises: 126 - `ActivationsMissingError` : if the activations file is missing 127 - `ValueError` : if `return_fmt` is not `"torch"` or `"numpy"` 128 """ 129 130 if return_fmt not in ("torch", "numpy"): 131 raise ValueError( 132 f"Invalid return_fmt: {return_fmt}, expected 'torch' or 'numpy'" 133 ) 134 if return_fmt == "torch": 135 import torch 136 137 augment_prompt_with_hash(prompt) 138 139 prompt_dir: Path = save_path / model_name / "prompts" / prompt["hash"] 140 prompt_file: Path = prompt_dir / "prompt.json" 141 if not prompt_file.exists(): 142 raise ActivationsMissingError(f"Prompt file {prompt_file} does not exist") 143 with open(prompt_dir / "prompt.json", "r") as f: 144 prompt_loaded: dict = json.load(f) 145 compare_prompt_to_loaded(prompt, prompt_loaded) 146 147 activations_path: Path = prompt_dir / "activations.npz" 148 149 cache: dict 150 151 with np.load(activations_path) as npz_data: 152 if return_fmt == "numpy": 153 cache = {k: v for k, v in npz_data.items()} 154 elif return_fmt == "torch": 155 cache = {k: torch.from_numpy(v) for k, v in npz_data.items()} 156 157 return activations_path, cache
class
GetActivationsError(builtins.ValueError):
13class GetActivationsError(ValueError): 14 """base class for errors in getting activations""" 15 16 pass
base class for errors in getting activations
Inherited Members
- builtins.ValueError
- ValueError
- builtins.BaseException
- with_traceback
- add_note
- args
19class ActivationsMissingError(GetActivationsError, FileNotFoundError): 20 """error for missing activations -- can't find the activations file""" 21 22 pass
error for missing activations -- can't find the activations file
Inherited Members
- builtins.ValueError
- ValueError
- builtins.OSError
- errno
- strerror
- filename
- filename2
- winerror
- characters_written
- builtins.BaseException
- with_traceback
- add_note
- args
25class ActivationsMismatchError(GetActivationsError): 26 """error for mismatched activations -- the prompt text or hash do not match 27 28 raised by `compare_prompt_to_loaded` 29 """ 30 31 pass
error for mismatched activations -- the prompt text or hash do not match
raised by compare_prompt_to_loaded
Inherited Members
- builtins.ValueError
- ValueError
- builtins.BaseException
- with_traceback
- add_note
- args
34class InvalidPromptError(GetActivationsError): 35 """error for invalid prompt -- the prompt does not have fields "hash" or "text" 36 37 raised by `augment_prompt_with_hash` 38 """ 39 40 pass
error for invalid prompt -- the prompt does not have fields "hash" or "text"
raised by augment_prompt_with_hash
Inherited Members
- builtins.ValueError
- ValueError
- builtins.BaseException
- with_traceback
- add_note
- args
def
compare_prompt_to_loaded(prompt: dict, prompt_loaded: dict) -> None:
43def compare_prompt_to_loaded(prompt: dict, prompt_loaded: dict) -> None: 44 """compare a prompt to a loaded prompt, raise an error if they do not match 45 46 # Parameters: 47 - `prompt : dict` 48 - `prompt_loaded : dict` 49 50 # Returns: 51 - `None` 52 53 # Raises: 54 - `ActivationsMismatchError` : if the prompt text or hash do not match 55 """ 56 for key in ("text", "hash"): 57 if prompt[key] != prompt_loaded[key]: 58 raise ActivationsMismatchError( 59 f"Prompt file does not match prompt at key {key}:\n{prompt}\n{prompt_loaded}" 60 )
compare a prompt to a loaded prompt, raise an error if they do not match
Parameters:
prompt : dictprompt_loaded : dict
Returns:
None
Raises:
ActivationsMismatchError: if the prompt text or hash do not match
def
augment_prompt_with_hash(prompt: dict) -> dict:
63def augment_prompt_with_hash(prompt: dict) -> dict: 64 """if a prompt does not have a hash, add one 65 66 not having a "text" field is allowed, but only if "hash" is present 67 68 # Parameters: 69 - `prompt : dict` 70 71 # Returns: 72 - `dict` 73 74 # Modifies: 75 the input `prompt` dictionary, if it does not have a `"hash"` key 76 """ 77 if "hash" not in prompt: 78 if "text" not in prompt: 79 raise InvalidPromptError( 80 f"Prompt does not have 'text' field or 'hash' field: {prompt}" 81 ) 82 prompt_str: str = prompt["text"] 83 prompt_hash: str = ( 84 base64.urlsafe_b64encode(hashlib.md5(prompt_str.encode()).digest()) 85 .decode() 86 .rstrip("=") 87 ) 88 prompt.update(hash=prompt_hash) 89 return prompt
if a prompt does not have a hash, add one
not having a "text" field is allowed, but only if "hash" is present
Parameters:
prompt : dict
Returns:
dict
Modifies:
the input prompt dictionary, if it does not have a "hash" key
def
load_activations( model_name: str, prompt: dict, save_path: pathlib.Path, return_fmt: Literal['torch', 'numpy'] = 'torch') -> tuple[pathlib.Path, dict[str, torch.Tensor] | dict[str, numpy.ndarray]]:
107def load_activations( 108 model_name: str, 109 prompt: dict, 110 save_path: Path, 111 return_fmt: Literal["torch", "numpy"] = "torch", 112) -> "tuple[Path, dict[str, torch.Tensor]|dict[str, np.ndarray]]": # type: ignore[name-defined] # noqa: F821 113 """load activations for a prompt and model, from an npz file 114 115 # Parameters: 116 - `model_name : str` 117 - `prompt : dict` 118 - `save_path : Path` 119 - `return_fmt : Literal["torch", "numpy"]` 120 (defaults to `"torch"`) 121 122 # Returns: 123 - `tuple[Path, dict[str, torch.Tensor]|dict[str, np.ndarray]]` 124 the path to the activations file and the activations as a dictionary of numpy arrays or torch tensors, depending on `return_fmt` 125 126 # Raises: 127 - `ActivationsMissingError` : if the activations file is missing 128 - `ValueError` : if `return_fmt` is not `"torch"` or `"numpy"` 129 """ 130 131 if return_fmt not in ("torch", "numpy"): 132 raise ValueError( 133 f"Invalid return_fmt: {return_fmt}, expected 'torch' or 'numpy'" 134 ) 135 if return_fmt == "torch": 136 import torch 137 138 augment_prompt_with_hash(prompt) 139 140 prompt_dir: Path = save_path / model_name / "prompts" / prompt["hash"] 141 prompt_file: Path = prompt_dir / "prompt.json" 142 if not prompt_file.exists(): 143 raise ActivationsMissingError(f"Prompt file {prompt_file} does not exist") 144 with open(prompt_dir / "prompt.json", "r") as f: 145 prompt_loaded: dict = json.load(f) 146 compare_prompt_to_loaded(prompt, prompt_loaded) 147 148 activations_path: Path = prompt_dir / "activations.npz" 149 150 cache: dict 151 152 with np.load(activations_path) as npz_data: 153 if return_fmt == "numpy": 154 cache = {k: v for k, v in npz_data.items()} 155 elif return_fmt == "torch": 156 cache = {k: torch.from_numpy(v) for k, v in npz_data.items()} 157 158 return activations_path, cache
load activations for a prompt and model, from an npz file
Parameters:
model_name : strprompt : dictsave_path : Pathreturn_fmt : Literal["torch", "numpy"](defaults to"torch")
Returns:
tuple[Path, dict[str, torch.Tensor]|dict[str, np.ndarray]]the path to the activations file and the activations as a dictionary of numpy arrays or torch tensors, depending onreturn_fmt
Raises:
ActivationsMissingError: if the activations file is missingValueError: ifreturn_fmtis not"torch"or"numpy"