Coverage for tests / unit / test_load_activations.py: 100%
68 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-22 18:15 -0700
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-22 18:15 -0700
1# tests/unit/test_load_activations.py
2import json
3from pathlib import Path
4from unittest import mock
6import numpy as np
7import pytest
9from pattern_lens.load_activations import (
10 ActivationsMismatchError,
11 ActivationsMissingError,
12 InvalidPromptError,
13 augment_prompt_with_hash,
14 load_activations,
15)
17TEMP_DIR: Path = Path("tests/.temp")
20def test_augment_prompt_with_hash():
21 """Test adding hash to prompt."""
22 # Test with a prompt that doesn't have a hash
23 prompt_no_hash = {"text": "test prompt"}
24 result = augment_prompt_with_hash(prompt_no_hash)
26 # Check that the hash was added and is deterministic
27 assert "hash" in result
28 assert isinstance(result["hash"], str)
30 # Save the hash for comparison
31 first_hash = result["hash"]
33 # Test that calling it again doesn't change the hash
34 result = augment_prompt_with_hash(prompt_no_hash)
35 assert result["hash"] == first_hash
37 # Test with a prompt that already has a hash
38 prompt_with_hash = {"text": "test prompt", "hash": "existing-hash"}
39 result = augment_prompt_with_hash(prompt_with_hash)
41 # Check that the hash wasn't changed
42 assert result["hash"] == "existing-hash"
44 # Test with an invalid prompt (no text or hash)
45 with pytest.raises(InvalidPromptError):
46 augment_prompt_with_hash({"other_field": "value"})
49def test_load_activations_success():
50 """Test successful loading of activations."""
51 # Setup
52 temp_dir = TEMP_DIR / "test_load_activations_success"
53 model_name = "test-model"
54 prompt = {"text": "test prompt", "hash": "testhash123"}
56 # Create the necessary directory structure
57 prompt_dir = temp_dir / model_name / "prompts" / prompt["hash"]
58 prompt_dir.mkdir(parents=True, exist_ok=True)
60 # Create a dummy prompt.json file
61 with open(prompt_dir / "prompt.json", "w") as f:
62 json.dump(prompt, f)
64 # Create a dummy activations.npz file
65 fake_activations = {
66 "blocks.0.attn.hook_pattern": np.random.rand(1, 2, 10, 10).astype(np.float32),
67 }
68 np.savez(prompt_dir / "activations.npz", **fake_activations) # type: ignore[arg-type]
70 # Test loading with numpy format
71 with mock.patch(
72 "pattern_lens.load_activations.compare_prompt_to_loaded",
73 ) as mock_compare:
74 path, cache = load_activations(
75 model_name=model_name,
76 prompt=prompt,
77 save_path=temp_dir,
78 return_fmt="numpy",
79 )
81 # Check that the path is correct
82 assert path == prompt_dir / "activations.npz"
84 # Check that the cache has the right structure
85 assert isinstance(cache, dict)
86 assert "blocks.0.attn.hook_pattern" in cache
87 assert cache["blocks.0.attn.hook_pattern"].shape == (1, 2, 10, 10)
89 # Check that the prompt was compared
90 mock_compare.assert_called_once_with(prompt, prompt)
93def test_load_activations_errors():
94 """Test error handling in load_activations."""
95 # Setup
96 temp_dir = TEMP_DIR / "test_load_activations_errors"
97 model_name = "test-model"
98 prompt = {"text": "test prompt", "hash": "testhash123"}
100 # Test with missing prompt file
101 with pytest.raises(ActivationsMissingError):
102 load_activations(
103 model_name=model_name,
104 prompt=prompt,
105 save_path=temp_dir,
106 return_fmt="numpy",
107 )
109 # Create the necessary directory structure
110 prompt_dir = temp_dir / model_name / "prompts" / prompt["hash"]
111 prompt_dir.mkdir(parents=True, exist_ok=True)
113 # Create a prompt.json file with different content
114 different_prompt = {"text": "different prompt", "hash": prompt["hash"]}
115 with open(prompt_dir / "prompt.json", "w") as f:
116 json.dump(different_prompt, f)
118 # Test with mismatched prompt
119 with pytest.raises(ActivationsMismatchError):
120 load_activations(
121 model_name=model_name,
122 prompt=prompt,
123 save_path=temp_dir,
124 return_fmt="numpy",
125 )
127 # Fix the prompt file
128 with open(prompt_dir / "prompt.json", "w") as f:
129 json.dump(prompt, f)
131 # Test with missing activations file (prompt.json exists but activations.npz does not)
132 with pytest.raises(ActivationsMissingError, match="Activations file"):
133 load_activations(
134 model_name=model_name,
135 prompt=prompt,
136 save_path=temp_dir,
137 return_fmt="numpy",
138 )
141def test_load_activations_missing_npz_is_subclass_of_file_not_found():
142 """ActivationsMissingError for missing activations.npz is also a FileNotFoundError."""
143 temp_dir = TEMP_DIR / "test_load_activations_missing_npz_subclass"
144 model_name = "test-model"
145 prompt = {"text": "test prompt", "hash": "testhash_sub"}
147 prompt_dir = temp_dir / model_name / "prompts" / prompt["hash"]
148 prompt_dir.mkdir(parents=True, exist_ok=True)
150 with open(prompt_dir / "prompt.json", "w") as f:
151 json.dump(prompt, f)
153 # Should raise ActivationsMissingError, which is also a FileNotFoundError
154 with pytest.raises(ActivationsMissingError) as exc_info:
155 load_activations(
156 model_name=model_name,
157 prompt=prompt,
158 save_path=temp_dir,
159 return_fmt="numpy",
160 )
161 assert isinstance(exc_info.value, FileNotFoundError)
164def test_load_activations_invalid_return_fmt():
165 """Test that invalid return_fmt raises ValueError."""
166 with pytest.raises(ValueError, match="Invalid return_fmt"):
167 load_activations( # type: ignore[no-matching-overload]
168 model_name="test-model",
169 prompt={"text": "test", "hash": "test"},
170 save_path=Path("/nonexistent"),
171 return_fmt="invalid",
172 )