docs for pattern_lens v0.6.0
View Source on GitHub

pattern_lens.load_activations

loading activations from .npz on disk. implements some custom Exception classes


  1"loading activations from .npz on disk. implements some custom Exception classes"
  2
  3import base64
  4import hashlib
  5import json
  6from pathlib import Path
  7from typing import Literal, overload
  8
  9import numpy as np
 10
 11from pattern_lens.consts import ReturnCache
 12
 13
 14class GetActivationsError(ValueError):
 15	"""base class for errors in getting activations"""
 16
 17	pass
 18
 19
 20class ActivationsMissingError(GetActivationsError, FileNotFoundError):
 21	"""error for missing activations -- can't find the activations file"""
 22
 23	pass
 24
 25
 26class ActivationsMismatchError(GetActivationsError):
 27	"""error for mismatched activations -- the prompt text or hash do not match
 28
 29	raised by `compare_prompt_to_loaded`
 30	"""
 31
 32	pass
 33
 34
 35class InvalidPromptError(GetActivationsError):
 36	"""error for invalid prompt -- the prompt does not have fields "hash" or "text"
 37
 38	raised by `augment_prompt_with_hash`
 39	"""
 40
 41	pass
 42
 43
 44def compare_prompt_to_loaded(prompt: dict, prompt_loaded: dict) -> None:
 45	"""compare a prompt to a loaded prompt, raise an error if they do not match
 46
 47	# Parameters:
 48	- `prompt : dict`
 49	- `prompt_loaded : dict`
 50
 51	# Returns:
 52	- `None`
 53
 54	# Raises:
 55	- `ActivationsMismatchError` : if the prompt text or hash do not match
 56	"""
 57	for key in ("text", "hash"):
 58		if prompt[key] != prompt_loaded[key]:
 59			msg = f"Prompt file does not match prompt at key {key}:\n{prompt}\n{prompt_loaded}"
 60			raise ActivationsMismatchError(
 61				msg,
 62			)
 63
 64
 65def augment_prompt_with_hash(prompt: dict) -> dict:
 66	"""if a prompt does not have a hash, add one
 67
 68	not having a "text" field is allowed, but only if "hash" is present
 69
 70	# Parameters:
 71	- `prompt : dict`
 72
 73	# Returns:
 74	- `dict`
 75
 76	# Modifies:
 77	the input `prompt` dictionary, if it does not have a `"hash"` key
 78	"""
 79	if "hash" not in prompt:
 80		if "text" not in prompt:
 81			msg = f"Prompt does not have 'text' field or 'hash' field: {prompt}"
 82			raise InvalidPromptError(
 83				msg,
 84			)
 85		prompt_str: str = prompt["text"]
 86		prompt_hash: str = (
 87			# we don't need this to be a secure hash
 88			base64.urlsafe_b64encode(hashlib.md5(prompt_str.encode()).digest())  # noqa: S324
 89			.decode()
 90			.rstrip("=")
 91		)
 92		prompt.update(hash=prompt_hash)
 93	return prompt
 94
 95
 96@overload
 97def load_activations(
 98	model_name: str,
 99	prompt: dict,
100	save_path: Path,
101	return_fmt: Literal["torch"] = "torch",
102) -> "tuple[Path, dict[str, torch.Tensor]]":  # type: ignore[name-defined] # noqa: F821
103	...
104@overload
105def load_activations(
106	model_name: str,
107	prompt: dict,
108	save_path: Path,
109	return_fmt: Literal["numpy"] = "numpy",
110) -> "tuple[Path, dict[str, np.ndarray]]": ...
111def load_activations(
112	model_name: str,
113	prompt: dict,
114	save_path: Path,
115	return_fmt: ReturnCache = "torch",
116) -> "tuple[Path, dict[str, torch.Tensor]|dict[str, np.ndarray]]":  # type: ignore[name-defined] # noqa: F821
117	"""load activations for a prompt and model, from an npz file
118
119	# Parameters:
120	- `model_name : str`
121	- `prompt : dict`
122	- `save_path : Path`
123	- `return_fmt : Literal["torch", "numpy"]`
124		(defaults to `"torch"`)
125
126	# Returns:
127	- `tuple[Path, dict[str, torch.Tensor]|dict[str, np.ndarray]]`
128		the path to the activations file and the activations as a dictionary
129		of numpy arrays or torch tensors, depending on `return_fmt`
130
131	# Raises:
132	- `ActivationsMissingError` : if the activations file is missing
133	- `ValueError` : if `return_fmt` is not `"torch"` or `"numpy"`
134	"""
135	if return_fmt not in ("torch", "numpy"):
136		msg = f"Invalid return_fmt: {return_fmt}, expected 'torch' or 'numpy'"
137		raise ValueError(
138			msg,
139		)
140	if return_fmt == "torch":
141		import torch  # noqa: PLC0415
142
143	augment_prompt_with_hash(prompt)
144
145	prompt_dir: Path = save_path / model_name / "prompts" / prompt["hash"]
146	prompt_file: Path = prompt_dir / "prompt.json"
147	if not prompt_file.exists():
148		msg = f"Prompt file {prompt_file} does not exist"
149		raise ActivationsMissingError(msg)
150	with open(prompt_dir / "prompt.json", "r") as f:
151		prompt_loaded: dict = json.load(f)
152		compare_prompt_to_loaded(prompt, prompt_loaded)
153
154	activations_path: Path = prompt_dir / "activations.npz"
155
156	if not activations_path.exists():
157		msg = f"Activations file {activations_path} does not exist"
158		raise ActivationsMissingError(msg)
159
160	cache: dict
161
162	with np.load(activations_path) as npz_data:
163		if return_fmt == "numpy":
164			cache = dict(npz_data.items())
165		elif return_fmt == "torch":
166			cache = {k: torch.from_numpy(v) for k, v in npz_data.items()}
167
168	return activations_path, cache
169
170
171def activations_exist(model_name: str, prompt: dict, save_path: Path) -> bool:
172	"""check if activations exist on disk without loading them
173
174	cheap alternative to calling `load_activations` when you only need to know
175	whether a prompt has been processed. `load_activations` decompresses the full
176	`.npz` into numpy arrays, which is wasteful when the data is immediately
177	discarded. this function just checks `.exists()` on the two expected files.
178
179	# Parameters:
180	- `model_name : str`
181	- `prompt : dict`
182		must contain a 'hash' key (call `augment_prompt_with_hash` first)
183	- `save_path : Path`
184
185	# Returns:
186	- `bool`
187		True if both prompt.json and activations.npz exist for this prompt
188
189	# Raises:
190	- `InvalidPromptError` : if the prompt does not have a 'hash' key
191	"""
192	if "hash" not in prompt:
193		msg = f"Prompt must have 'hash' key (call augment_prompt_with_hash first): {prompt}"
194		raise InvalidPromptError(msg)
195	prompt_dir: Path = save_path / model_name / "prompts" / prompt["hash"]
196	return (prompt_dir / "prompt.json").exists() and (
197		prompt_dir / "activations.npz"
198	).exists()
199
200
201# def load_activations_stacked()

class GetActivationsError(builtins.ValueError):
15class GetActivationsError(ValueError):
16	"""base class for errors in getting activations"""
17
18	pass

base class for errors in getting activations

Inherited Members
builtins.ValueError
ValueError
builtins.BaseException
with_traceback
add_note
args
class ActivationsMissingError(GetActivationsError, builtins.FileNotFoundError):
21class ActivationsMissingError(GetActivationsError, FileNotFoundError):
22	"""error for missing activations -- can't find the activations file"""
23
24	pass

error for missing activations -- can't find the activations file

Inherited Members
builtins.ValueError
ValueError
builtins.OSError
errno
strerror
filename
filename2
characters_written
builtins.BaseException
with_traceback
add_note
args
class ActivationsMismatchError(GetActivationsError):
27class ActivationsMismatchError(GetActivationsError):
28	"""error for mismatched activations -- the prompt text or hash do not match
29
30	raised by `compare_prompt_to_loaded`
31	"""
32
33	pass

error for mismatched activations -- the prompt text or hash do not match

raised by compare_prompt_to_loaded

Inherited Members
builtins.ValueError
ValueError
builtins.BaseException
with_traceback
add_note
args
class InvalidPromptError(GetActivationsError):
36class InvalidPromptError(GetActivationsError):
37	"""error for invalid prompt -- the prompt does not have fields "hash" or "text"
38
39	raised by `augment_prompt_with_hash`
40	"""
41
42	pass

error for invalid prompt -- the prompt does not have fields "hash" or "text"

raised by augment_prompt_with_hash

Inherited Members
builtins.ValueError
ValueError
builtins.BaseException
with_traceback
add_note
args
def compare_prompt_to_loaded(prompt: dict, prompt_loaded: dict) -> None:
45def compare_prompt_to_loaded(prompt: dict, prompt_loaded: dict) -> None:
46	"""compare a prompt to a loaded prompt, raise an error if they do not match
47
48	# Parameters:
49	- `prompt : dict`
50	- `prompt_loaded : dict`
51
52	# Returns:
53	- `None`
54
55	# Raises:
56	- `ActivationsMismatchError` : if the prompt text or hash do not match
57	"""
58	for key in ("text", "hash"):
59		if prompt[key] != prompt_loaded[key]:
60			msg = f"Prompt file does not match prompt at key {key}:\n{prompt}\n{prompt_loaded}"
61			raise ActivationsMismatchError(
62				msg,
63			)

compare a prompt to a loaded prompt, raise an error if they do not match

Parameters:

  • prompt : dict
  • prompt_loaded : dict

Returns:

  • None

Raises:

def augment_prompt_with_hash(prompt: dict) -> dict:
66def augment_prompt_with_hash(prompt: dict) -> dict:
67	"""if a prompt does not have a hash, add one
68
69	not having a "text" field is allowed, but only if "hash" is present
70
71	# Parameters:
72	- `prompt : dict`
73
74	# Returns:
75	- `dict`
76
77	# Modifies:
78	the input `prompt` dictionary, if it does not have a `"hash"` key
79	"""
80	if "hash" not in prompt:
81		if "text" not in prompt:
82			msg = f"Prompt does not have 'text' field or 'hash' field: {prompt}"
83			raise InvalidPromptError(
84				msg,
85			)
86		prompt_str: str = prompt["text"]
87		prompt_hash: str = (
88			# we don't need this to be a secure hash
89			base64.urlsafe_b64encode(hashlib.md5(prompt_str.encode()).digest())  # noqa: S324
90			.decode()
91			.rstrip("=")
92		)
93		prompt.update(hash=prompt_hash)
94	return prompt

if a prompt does not have a hash, add one

not having a "text" field is allowed, but only if "hash" is present

Parameters:

  • prompt : dict

Returns:

  • dict

Modifies:

the input prompt dictionary, if it does not have a "hash" key

def load_activations( model_name: str, prompt: dict, save_path: pathlib._local.Path, return_fmt: Optional[Literal['numpy', 'torch']] = 'torch') -> tuple[pathlib._local.Path, dict[str, torch.Tensor] | dict[str, numpy.ndarray]]:
112def load_activations(
113	model_name: str,
114	prompt: dict,
115	save_path: Path,
116	return_fmt: ReturnCache = "torch",
117) -> "tuple[Path, dict[str, torch.Tensor]|dict[str, np.ndarray]]":  # type: ignore[name-defined] # noqa: F821
118	"""load activations for a prompt and model, from an npz file
119
120	# Parameters:
121	- `model_name : str`
122	- `prompt : dict`
123	- `save_path : Path`
124	- `return_fmt : Literal["torch", "numpy"]`
125		(defaults to `"torch"`)
126
127	# Returns:
128	- `tuple[Path, dict[str, torch.Tensor]|dict[str, np.ndarray]]`
129		the path to the activations file and the activations as a dictionary
130		of numpy arrays or torch tensors, depending on `return_fmt`
131
132	# Raises:
133	- `ActivationsMissingError` : if the activations file is missing
134	- `ValueError` : if `return_fmt` is not `"torch"` or `"numpy"`
135	"""
136	if return_fmt not in ("torch", "numpy"):
137		msg = f"Invalid return_fmt: {return_fmt}, expected 'torch' or 'numpy'"
138		raise ValueError(
139			msg,
140		)
141	if return_fmt == "torch":
142		import torch  # noqa: PLC0415
143
144	augment_prompt_with_hash(prompt)
145
146	prompt_dir: Path = save_path / model_name / "prompts" / prompt["hash"]
147	prompt_file: Path = prompt_dir / "prompt.json"
148	if not prompt_file.exists():
149		msg = f"Prompt file {prompt_file} does not exist"
150		raise ActivationsMissingError(msg)
151	with open(prompt_dir / "prompt.json", "r") as f:
152		prompt_loaded: dict = json.load(f)
153		compare_prompt_to_loaded(prompt, prompt_loaded)
154
155	activations_path: Path = prompt_dir / "activations.npz"
156
157	if not activations_path.exists():
158		msg = f"Activations file {activations_path} does not exist"
159		raise ActivationsMissingError(msg)
160
161	cache: dict
162
163	with np.load(activations_path) as npz_data:
164		if return_fmt == "numpy":
165			cache = dict(npz_data.items())
166		elif return_fmt == "torch":
167			cache = {k: torch.from_numpy(v) for k, v in npz_data.items()}
168
169	return activations_path, cache

load activations for a prompt and model, from an npz file

Parameters:

  • model_name : str
  • prompt : dict
  • save_path : Path
  • return_fmt : Literal["torch", "numpy"] (defaults to "torch")

Returns:

  • tuple[Path, dict[str, torch.Tensor]|dict[str, np.ndarray]] the path to the activations file and the activations as a dictionary of numpy arrays or torch tensors, depending on return_fmt

Raises:

  • ActivationsMissingError : if the activations file is missing
  • ValueError : if return_fmt is not "torch" or "numpy"
def activations_exist(model_name: str, prompt: dict, save_path: pathlib._local.Path) -> bool:
172def activations_exist(model_name: str, prompt: dict, save_path: Path) -> bool:
173	"""check if activations exist on disk without loading them
174
175	cheap alternative to calling `load_activations` when you only need to know
176	whether a prompt has been processed. `load_activations` decompresses the full
177	`.npz` into numpy arrays, which is wasteful when the data is immediately
178	discarded. this function just checks `.exists()` on the two expected files.
179
180	# Parameters:
181	- `model_name : str`
182	- `prompt : dict`
183		must contain a 'hash' key (call `augment_prompt_with_hash` first)
184	- `save_path : Path`
185
186	# Returns:
187	- `bool`
188		True if both prompt.json and activations.npz exist for this prompt
189
190	# Raises:
191	- `InvalidPromptError` : if the prompt does not have a 'hash' key
192	"""
193	if "hash" not in prompt:
194		msg = f"Prompt must have 'hash' key (call augment_prompt_with_hash first): {prompt}"
195		raise InvalidPromptError(msg)
196	prompt_dir: Path = save_path / model_name / "prompts" / prompt["hash"]
197	return (prompt_dir / "prompt.json").exists() and (
198		prompt_dir / "activations.npz"
199	).exists()

check if activations exist on disk without loading them

cheap alternative to calling load_activations when you only need to know whether a prompt has been processed. load_activations decompresses the full .npz into numpy arrays, which is wasteful when the data is immediately discarded. this function just checks .exists() on the two expected files.

Parameters:

Returns:

  • bool True if both prompt.json and activations.npz exist for this prompt

Raises: