Coverage for pattern_lens / load_activations.py: 95%
62 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-22 18:15 -0700
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-22 18:15 -0700
1"loading activations from .npz on disk. implements some custom Exception classes"
3import base64
4import hashlib
5import json
6from pathlib import Path
7from typing import Literal, overload
9import numpy as np
11from pattern_lens.consts import ReturnCache
14class GetActivationsError(ValueError):
15 """base class for errors in getting activations"""
17 pass
20class ActivationsMissingError(GetActivationsError, FileNotFoundError):
21 """error for missing activations -- can't find the activations file"""
23 pass
26class ActivationsMismatchError(GetActivationsError):
27 """error for mismatched activations -- the prompt text or hash do not match
29 raised by `compare_prompt_to_loaded`
30 """
32 pass
35class InvalidPromptError(GetActivationsError):
36 """error for invalid prompt -- the prompt does not have fields "hash" or "text"
38 raised by `augment_prompt_with_hash`
39 """
41 pass
44def compare_prompt_to_loaded(prompt: dict, prompt_loaded: dict) -> None:
45 """compare a prompt to a loaded prompt, raise an error if they do not match
47 # Parameters:
48 - `prompt : dict`
49 - `prompt_loaded : dict`
51 # Returns:
52 - `None`
54 # Raises:
55 - `ActivationsMismatchError` : if the prompt text or hash do not match
56 """
57 for key in ("text", "hash"):
58 if prompt[key] != prompt_loaded[key]:
59 msg = f"Prompt file does not match prompt at key {key}:\n{prompt}\n{prompt_loaded}"
60 raise ActivationsMismatchError(
61 msg,
62 )
65def augment_prompt_with_hash(prompt: dict) -> dict:
66 """if a prompt does not have a hash, add one
68 not having a "text" field is allowed, but only if "hash" is present
70 # Parameters:
71 - `prompt : dict`
73 # Returns:
74 - `dict`
76 # Modifies:
77 the input `prompt` dictionary, if it does not have a `"hash"` key
78 """
79 if "hash" not in prompt:
80 if "text" not in prompt:
81 msg = f"Prompt does not have 'text' field or 'hash' field: {prompt}"
82 raise InvalidPromptError(
83 msg,
84 )
85 prompt_str: str = prompt["text"]
86 prompt_hash: str = (
87 # we don't need this to be a secure hash
88 base64.urlsafe_b64encode(hashlib.md5(prompt_str.encode()).digest()) # noqa: S324
89 .decode()
90 .rstrip("=")
91 )
92 prompt.update(hash=prompt_hash)
93 return prompt
96@overload
97def load_activations(
98 model_name: str,
99 prompt: dict,
100 save_path: Path,
101 return_fmt: Literal["torch"] = "torch",
102) -> "tuple[Path, dict[str, torch.Tensor]]": # type: ignore[name-defined] # noqa: F821
103 ...
104@overload
105def load_activations(
106 model_name: str,
107 prompt: dict,
108 save_path: Path,
109 return_fmt: Literal["numpy"] = "numpy",
110) -> "tuple[Path, dict[str, np.ndarray]]": ...
111def load_activations(
112 model_name: str,
113 prompt: dict,
114 save_path: Path,
115 return_fmt: ReturnCache = "torch",
116) -> "tuple[Path, dict[str, torch.Tensor]|dict[str, np.ndarray]]": # type: ignore[name-defined] # noqa: F821
117 """load activations for a prompt and model, from an npz file
119 # Parameters:
120 - `model_name : str`
121 - `prompt : dict`
122 - `save_path : Path`
123 - `return_fmt : Literal["torch", "numpy"]`
124 (defaults to `"torch"`)
126 # Returns:
127 - `tuple[Path, dict[str, torch.Tensor]|dict[str, np.ndarray]]`
128 the path to the activations file and the activations as a dictionary
129 of numpy arrays or torch tensors, depending on `return_fmt`
131 # Raises:
132 - `ActivationsMissingError` : if the activations file is missing
133 - `ValueError` : if `return_fmt` is not `"torch"` or `"numpy"`
134 """
135 if return_fmt not in ("torch", "numpy"):
136 msg = f"Invalid return_fmt: {return_fmt}, expected 'torch' or 'numpy'"
137 raise ValueError(
138 msg,
139 )
140 if return_fmt == "torch":
141 import torch # noqa: PLC0415
143 augment_prompt_with_hash(prompt)
145 prompt_dir: Path = save_path / model_name / "prompts" / prompt["hash"]
146 prompt_file: Path = prompt_dir / "prompt.json"
147 if not prompt_file.exists():
148 msg = f"Prompt file {prompt_file} does not exist"
149 raise ActivationsMissingError(msg)
150 with open(prompt_dir / "prompt.json", "r") as f:
151 prompt_loaded: dict = json.load(f)
152 compare_prompt_to_loaded(prompt, prompt_loaded)
154 activations_path: Path = prompt_dir / "activations.npz"
156 if not activations_path.exists():
157 msg = f"Activations file {activations_path} does not exist"
158 raise ActivationsMissingError(msg)
160 cache: dict
162 with np.load(activations_path) as npz_data:
163 if return_fmt == "numpy":
164 cache = dict(npz_data.items())
165 elif return_fmt == "torch":
166 cache = {k: torch.from_numpy(v) for k, v in npz_data.items()}
168 return activations_path, cache
171def activations_exist(model_name: str, prompt: dict, save_path: Path) -> bool:
172 """check if activations exist on disk without loading them
174 cheap alternative to calling `load_activations` when you only need to know
175 whether a prompt has been processed. `load_activations` decompresses the full
176 `.npz` into numpy arrays, which is wasteful when the data is immediately
177 discarded. this function just checks `.exists()` on the two expected files.
179 # Parameters:
180 - `model_name : str`
181 - `prompt : dict`
182 must contain a 'hash' key (call `augment_prompt_with_hash` first)
183 - `save_path : Path`
185 # Returns:
186 - `bool`
187 True if both prompt.json and activations.npz exist for this prompt
189 # Raises:
190 - `InvalidPromptError` : if the prompt does not have a 'hash' key
191 """
192 if "hash" not in prompt:
193 msg = f"Prompt must have 'hash' key (call augment_prompt_with_hash first): {prompt}"
194 raise InvalidPromptError(msg)
195 prompt_dir: Path = save_path / model_name / "prompts" / prompt["hash"]
196 return (prompt_dir / "prompt.json").exists() and (
197 prompt_dir / "activations.npz"
198 ).exists()
201# def load_activations_stacked()