Coverage for pattern_lens\load_activations.py: 84%
51 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-16 20:39 -0700
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-16 20:39 -0700
1"loading activations from .npz on disk. implements some custom Exception classes"
3import base64
4import hashlib
5import json
6from pathlib import Path
7from typing import Literal, overload
9import numpy as np
12class GetActivationsError(ValueError):
13 """base class for errors in getting activations"""
15 pass
18class ActivationsMissingError(GetActivationsError, FileNotFoundError):
19 """error for missing activations -- can't find the activations file"""
21 pass
24class ActivationsMismatchError(GetActivationsError):
25 """error for mismatched activations -- the prompt text or hash do not match
27 raised by `compare_prompt_to_loaded`
28 """
30 pass
33class InvalidPromptError(GetActivationsError):
34 """error for invalid prompt -- the prompt does not have fields "hash" or "text"
36 raised by `augment_prompt_with_hash`
37 """
39 pass
42def compare_prompt_to_loaded(prompt: dict, prompt_loaded: dict) -> None:
43 """compare a prompt to a loaded prompt, raise an error if they do not match
45 # Parameters:
46 - `prompt : dict`
47 - `prompt_loaded : dict`
49 # Returns:
50 - `None`
52 # Raises:
53 - `ActivationsMismatchError` : if the prompt text or hash do not match
54 """
55 for key in ("text", "hash"):
56 if prompt[key] != prompt_loaded[key]:
57 raise ActivationsMismatchError(
58 f"Prompt file does not match prompt at key {key}:\n{prompt}\n{prompt_loaded}"
59 )
62def augment_prompt_with_hash(prompt: dict) -> dict:
63 """if a prompt does not have a hash, add one
65 not having a "text" field is allowed, but only if "hash" is present
67 # Parameters:
68 - `prompt : dict`
70 # Returns:
71 - `dict`
73 # Modifies:
74 the input `prompt` dictionary, if it does not have a `"hash"` key
75 """
76 if "hash" not in prompt:
77 if "text" not in prompt:
78 raise InvalidPromptError(
79 f"Prompt does not have 'text' field or 'hash' field: {prompt}"
80 )
81 prompt_str: str = prompt["text"]
82 prompt_hash: str = (
83 base64.urlsafe_b64encode(hashlib.md5(prompt_str.encode()).digest())
84 .decode()
85 .rstrip("=")
86 )
87 prompt.update(hash=prompt_hash)
88 return prompt
91@overload
92def load_activations(
93 model_name: str,
94 prompt: dict,
95 save_path: Path,
96 return_fmt: Literal["torch"] = "torch",
97) -> "tuple[Path, dict[str, torch.Tensor]]": # type: ignore[name-defined] # noqa: F821
98 ...
99@overload
100def load_activations(
101 model_name: str,
102 prompt: dict,
103 save_path: Path,
104 return_fmt: Literal["numpy"] = "numpy",
105) -> "tuple[Path, dict[str, np.ndarray]]": ...
106def load_activations(
107 model_name: str,
108 prompt: dict,
109 save_path: Path,
110 return_fmt: Literal["torch", "numpy"] = "torch",
111) -> "tuple[Path, dict[str, torch.Tensor]|dict[str, np.ndarray]]": # type: ignore[name-defined] # noqa: F821
112 """load activations for a prompt and model, from an npz file
114 # Parameters:
115 - `model_name : str`
116 - `prompt : dict`
117 - `save_path : Path`
118 - `return_fmt : Literal["torch", "numpy"]`
119 (defaults to `"torch"`)
121 # Returns:
122 - `tuple[Path, dict[str, torch.Tensor]|dict[str, np.ndarray]]`
123 the path to the activations file and the activations as a dictionary of numpy arrays or torch tensors, depending on `return_fmt`
125 # Raises:
126 - `ActivationsMissingError` : if the activations file is missing
127 - `ValueError` : if `return_fmt` is not `"torch"` or `"numpy"`
128 """
130 if return_fmt not in ("torch", "numpy"):
131 raise ValueError(
132 f"Invalid return_fmt: {return_fmt}, expected 'torch' or 'numpy'"
133 )
134 if return_fmt == "torch":
135 import torch
137 augment_prompt_with_hash(prompt)
139 prompt_dir: Path = save_path / model_name / "prompts" / prompt["hash"]
140 prompt_file: Path = prompt_dir / "prompt.json"
141 if not prompt_file.exists():
142 raise ActivationsMissingError(f"Prompt file {prompt_file} does not exist")
143 with open(prompt_dir / "prompt.json", "r") as f:
144 prompt_loaded: dict = json.load(f)
145 compare_prompt_to_loaded(prompt, prompt_loaded)
147 activations_path: Path = prompt_dir / "activations.npz"
149 cache: dict
151 with np.load(activations_path) as npz_data:
152 if return_fmt == "numpy":
153 cache = {k: v for k, v in npz_data.items()}
154 elif return_fmt == "torch":
155 cache = {k: torch.from_numpy(v) for k, v in npz_data.items()}
157 return activations_path, cache