docs for pattern_lens v0.2.0
View Source on GitHub

pattern_lens.load_activations

loading activations from .npz on disk. implements some custom Exception classes


  1"loading activations from .npz on disk. implements some custom Exception classes"
  2
  3import base64
  4import hashlib
  5import json
  6from pathlib import Path
  7from typing import Literal, overload
  8
  9import numpy as np
 10
 11
 12class GetActivationsError(ValueError):
 13    """base class for errors in getting activations"""
 14
 15    pass
 16
 17
 18class ActivationsMissingError(GetActivationsError, FileNotFoundError):
 19    """error for missing activations -- can't find the activations file"""
 20
 21    pass
 22
 23
 24class ActivationsMismatchError(GetActivationsError):
 25    """error for mismatched activations -- the prompt text or hash do not match
 26
 27    raised by `compare_prompt_to_loaded`
 28    """
 29
 30    pass
 31
 32
 33class InvalidPromptError(GetActivationsError):
 34    """error for invalid prompt -- the prompt does not have fields "hash" or "text"
 35
 36    raised by `augment_prompt_with_hash`
 37    """
 38
 39    pass
 40
 41
 42def compare_prompt_to_loaded(prompt: dict, prompt_loaded: dict) -> None:
 43    """compare a prompt to a loaded prompt, raise an error if they do not match
 44
 45    # Parameters:
 46     - `prompt : dict`
 47     - `prompt_loaded : dict`
 48
 49    # Returns:
 50     - `None`
 51
 52    # Raises:
 53     - `ActivationsMismatchError` : if the prompt text or hash do not match
 54    """
 55    for key in ("text", "hash"):
 56        if prompt[key] != prompt_loaded[key]:
 57            raise ActivationsMismatchError(
 58                f"Prompt file does not match prompt at key {key}:\n{prompt}\n{prompt_loaded}"
 59            )
 60
 61
 62def augment_prompt_with_hash(prompt: dict) -> dict:
 63    """if a prompt does not have a hash, add one
 64
 65    not having a "text" field is allowed, but only if "hash" is present
 66
 67    # Parameters:
 68     - `prompt : dict`
 69
 70    # Returns:
 71     - `dict`
 72
 73    # Modifies:
 74    the input `prompt` dictionary, if it does not have a `"hash"` key
 75    """
 76    if "hash" not in prompt:
 77        if "text" not in prompt:
 78            raise InvalidPromptError(
 79                f"Prompt does not have 'text' field or 'hash' field: {prompt}"
 80            )
 81        prompt_str: str = prompt["text"]
 82        prompt_hash: str = (
 83            base64.urlsafe_b64encode(hashlib.md5(prompt_str.encode()).digest())
 84            .decode()
 85            .rstrip("=")
 86        )
 87        prompt.update(hash=prompt_hash)
 88    return prompt
 89
 90
 91@overload
 92def load_activations(
 93    model_name: str,
 94    prompt: dict,
 95    save_path: Path,
 96    return_fmt: Literal["torch"] = "torch",
 97) -> "tuple[Path, dict[str, torch.Tensor]]":  # type: ignore[name-defined] # noqa: F821
 98    ...
 99@overload
100def load_activations(
101    model_name: str,
102    prompt: dict,
103    save_path: Path,
104    return_fmt: Literal["numpy"] = "numpy",
105) -> "tuple[Path, dict[str, np.ndarray]]": ...
106def load_activations(
107    model_name: str,
108    prompt: dict,
109    save_path: Path,
110    return_fmt: Literal["torch", "numpy"] = "torch",
111) -> "tuple[Path, dict[str, torch.Tensor]|dict[str, np.ndarray]]":  # type: ignore[name-defined] # noqa: F821
112    """load activations for a prompt and model, from an npz file
113
114    # Parameters:
115     - `model_name : str`
116     - `prompt : dict`
117     - `save_path : Path`
118     - `return_fmt : Literal["torch", "numpy"]`
119       (defaults to `"torch"`)
120
121    # Returns:
122     - `tuple[Path, dict[str, torch.Tensor]|dict[str, np.ndarray]]`
123         the path to the activations file and the activations as a dictionary of numpy arrays or torch tensors, depending on `return_fmt`
124
125    # Raises:
126     - `ActivationsMissingError` : if the activations file is missing
127     - `ValueError` : if `return_fmt` is not `"torch"` or `"numpy"`
128    """
129
130    if return_fmt not in ("torch", "numpy"):
131        raise ValueError(
132            f"Invalid return_fmt: {return_fmt}, expected 'torch' or 'numpy'"
133        )
134    if return_fmt == "torch":
135        import torch
136
137    augment_prompt_with_hash(prompt)
138
139    prompt_dir: Path = save_path / model_name / "prompts" / prompt["hash"]
140    prompt_file: Path = prompt_dir / "prompt.json"
141    if not prompt_file.exists():
142        raise ActivationsMissingError(f"Prompt file {prompt_file} does not exist")
143    with open(prompt_dir / "prompt.json", "r") as f:
144        prompt_loaded: dict = json.load(f)
145        compare_prompt_to_loaded(prompt, prompt_loaded)
146
147    activations_path: Path = prompt_dir / "activations.npz"
148
149    cache: dict
150
151    with np.load(activations_path) as npz_data:
152        if return_fmt == "numpy":
153            cache = {k: v for k, v in npz_data.items()}
154        elif return_fmt == "torch":
155            cache = {k: torch.from_numpy(v) for k, v in npz_data.items()}
156
157    return activations_path, cache

class GetActivationsError(builtins.ValueError):
13class GetActivationsError(ValueError):
14    """base class for errors in getting activations"""
15
16    pass

base class for errors in getting activations

Inherited Members
builtins.ValueError
ValueError
builtins.BaseException
with_traceback
add_note
args
class ActivationsMissingError(GetActivationsError, builtins.FileNotFoundError):
19class ActivationsMissingError(GetActivationsError, FileNotFoundError):
20    """error for missing activations -- can't find the activations file"""
21
22    pass

error for missing activations -- can't find the activations file

Inherited Members
builtins.ValueError
ValueError
builtins.OSError
errno
strerror
filename
filename2
winerror
characters_written
builtins.BaseException
with_traceback
add_note
args
class ActivationsMismatchError(GetActivationsError):
25class ActivationsMismatchError(GetActivationsError):
26    """error for mismatched activations -- the prompt text or hash do not match
27
28    raised by `compare_prompt_to_loaded`
29    """
30
31    pass

error for mismatched activations -- the prompt text or hash do not match

raised by compare_prompt_to_loaded

Inherited Members
builtins.ValueError
ValueError
builtins.BaseException
with_traceback
add_note
args
class InvalidPromptError(GetActivationsError):
34class InvalidPromptError(GetActivationsError):
35    """error for invalid prompt -- the prompt does not have fields "hash" or "text"
36
37    raised by `augment_prompt_with_hash`
38    """
39
40    pass

error for invalid prompt -- the prompt does not have fields "hash" or "text"

raised by augment_prompt_with_hash

Inherited Members
builtins.ValueError
ValueError
builtins.BaseException
with_traceback
add_note
args
def compare_prompt_to_loaded(prompt: dict, prompt_loaded: dict) -> None:
43def compare_prompt_to_loaded(prompt: dict, prompt_loaded: dict) -> None:
44    """compare a prompt to a loaded prompt, raise an error if they do not match
45
46    # Parameters:
47     - `prompt : dict`
48     - `prompt_loaded : dict`
49
50    # Returns:
51     - `None`
52
53    # Raises:
54     - `ActivationsMismatchError` : if the prompt text or hash do not match
55    """
56    for key in ("text", "hash"):
57        if prompt[key] != prompt_loaded[key]:
58            raise ActivationsMismatchError(
59                f"Prompt file does not match prompt at key {key}:\n{prompt}\n{prompt_loaded}"
60            )

compare a prompt to a loaded prompt, raise an error if they do not match

Parameters:

  • prompt : dict
  • prompt_loaded : dict

Returns:

  • None

Raises:

def augment_prompt_with_hash(prompt: dict) -> dict:
63def augment_prompt_with_hash(prompt: dict) -> dict:
64    """if a prompt does not have a hash, add one
65
66    not having a "text" field is allowed, but only if "hash" is present
67
68    # Parameters:
69     - `prompt : dict`
70
71    # Returns:
72     - `dict`
73
74    # Modifies:
75    the input `prompt` dictionary, if it does not have a `"hash"` key
76    """
77    if "hash" not in prompt:
78        if "text" not in prompt:
79            raise InvalidPromptError(
80                f"Prompt does not have 'text' field or 'hash' field: {prompt}"
81            )
82        prompt_str: str = prompt["text"]
83        prompt_hash: str = (
84            base64.urlsafe_b64encode(hashlib.md5(prompt_str.encode()).digest())
85            .decode()
86            .rstrip("=")
87        )
88        prompt.update(hash=prompt_hash)
89    return prompt

if a prompt does not have a hash, add one

not having a "text" field is allowed, but only if "hash" is present

Parameters:

  • prompt : dict

Returns:

  • dict

Modifies:

the input prompt dictionary, if it does not have a "hash" key

def load_activations( model_name: str, prompt: dict, save_path: pathlib.Path, return_fmt: Literal['torch', 'numpy'] = 'torch') -> tuple[pathlib.Path, dict[str, torch.Tensor] | dict[str, numpy.ndarray]]:
107def load_activations(
108    model_name: str,
109    prompt: dict,
110    save_path: Path,
111    return_fmt: Literal["torch", "numpy"] = "torch",
112) -> "tuple[Path, dict[str, torch.Tensor]|dict[str, np.ndarray]]":  # type: ignore[name-defined] # noqa: F821
113    """load activations for a prompt and model, from an npz file
114
115    # Parameters:
116     - `model_name : str`
117     - `prompt : dict`
118     - `save_path : Path`
119     - `return_fmt : Literal["torch", "numpy"]`
120       (defaults to `"torch"`)
121
122    # Returns:
123     - `tuple[Path, dict[str, torch.Tensor]|dict[str, np.ndarray]]`
124         the path to the activations file and the activations as a dictionary of numpy arrays or torch tensors, depending on `return_fmt`
125
126    # Raises:
127     - `ActivationsMissingError` : if the activations file is missing
128     - `ValueError` : if `return_fmt` is not `"torch"` or `"numpy"`
129    """
130
131    if return_fmt not in ("torch", "numpy"):
132        raise ValueError(
133            f"Invalid return_fmt: {return_fmt}, expected 'torch' or 'numpy'"
134        )
135    if return_fmt == "torch":
136        import torch
137
138    augment_prompt_with_hash(prompt)
139
140    prompt_dir: Path = save_path / model_name / "prompts" / prompt["hash"]
141    prompt_file: Path = prompt_dir / "prompt.json"
142    if not prompt_file.exists():
143        raise ActivationsMissingError(f"Prompt file {prompt_file} does not exist")
144    with open(prompt_dir / "prompt.json", "r") as f:
145        prompt_loaded: dict = json.load(f)
146        compare_prompt_to_loaded(prompt, prompt_loaded)
147
148    activations_path: Path = prompt_dir / "activations.npz"
149
150    cache: dict
151
152    with np.load(activations_path) as npz_data:
153        if return_fmt == "numpy":
154            cache = {k: v for k, v in npz_data.items()}
155        elif return_fmt == "torch":
156            cache = {k: torch.from_numpy(v) for k, v in npz_data.items()}
157
158    return activations_path, cache

load activations for a prompt and model, from an npz file

Parameters:

  • model_name : str
  • prompt : dict
  • save_path : Path
  • return_fmt : Literal["torch", "numpy"] (defaults to "torch")

Returns:

  • tuple[Path, dict[str, torch.Tensor]|dict[str, np.ndarray]] the path to the activations file and the activations as a dictionary of numpy arrays or torch tensors, depending on return_fmt

Raises:

  • ActivationsMissingError : if the activations file is missing
  • ValueError : if return_fmt is not "torch" or "numpy"