Coverage for tests / unit / test_activations_batched.py: 97%
347 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-22 18:15 -0700
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-22 18:15 -0700
1"""Tests for batched activation computation.
3Tests verify:
4- compute_activations_batched produces correct shapes per-prompt
5- Trimming correctly removes padding (variable-length prompts)
6- Batched results are equivalent to single-prompt results
7- Prompt metadata (prompt.json) is correctly saved
8- activations_exist helper function
9- Cache skipping in activations_main
10"""
12import json
13import shutil
14from pathlib import Path
15from unittest import mock
17import numpy as np
18import pytest
19import torch
21from pattern_lens.activations import (
22 activations_main,
23 compute_activations,
24 compute_activations_batched,
25)
26from pattern_lens.load_activations import (
27 InvalidPromptError,
28 activations_exist,
29 augment_prompt_with_hash,
30)
32TEMP_DIR: Path = Path("tests/.temp")
35class MockHookedTransformerBatched:
36 """Mock of HookedTransformer that supports both single and batched input.
38 Tokenization rule: each character in the text becomes one token, plus a BOS token.
39 So "hello" -> seq_len=6 (1 BOS + 5 chars).
41 For batched input, shorter sequences are right-padded with zeros.
42 Real (non-padding) positions are filled with deterministic random values
43 seeded by the text content, so the same text produces the same attention
44 patterns regardless of batch composition.
45 """
47 def __init__(
48 self,
49 model_name: str = "test-model",
50 n_layers: int = 2,
51 n_heads: int = 2,
52 ):
53 self.model_name = model_name
54 self.cfg = mock.MagicMock()
55 self.cfg.n_layers = n_layers
56 self.cfg.n_heads = n_heads
57 self.cfg.model_name = model_name
58 self.tokenizer = mock.MagicMock()
59 # tokenize returns subword strings (no BOS) — matches HuggingFace tokenizer behavior
60 self.tokenizer.tokenize = lambda text: list(text) # noqa: PLW0108 -- split text into individual characters
62 def _seq_len(self, text: str) -> int:
63 """Actual model sequence length for a text (includes BOS)."""
64 return len(text) + 1
66 def parameters(self):
67 """Return a dummy parameter for activations_main's numel/device checks."""
68 return [torch.zeros(1)]
70 def eval(self):
71 return self
73 def to_tokens(self, text):
74 """Return token IDs. Includes BOS, so seq_len = len(text) + 1."""
75 if isinstance(text, str):
76 return torch.zeros(1, self._seq_len(text), dtype=torch.long)
77 else:
78 seq_lens = [self._seq_len(t) for t in text]
79 max_len = max(seq_lens)
80 return torch.zeros(len(text), max_len, dtype=torch.long)
82 def _make_deterministic_attn(self, text: str, seq_len: int) -> torch.Tensor:
83 """Generate deterministic attention values for a text.
85 Returns shape [n_heads, seq_len, seq_len] with values seeded by text content.
86 """
87 gen = torch.Generator()
88 gen.manual_seed(hash(text) % (2**31))
89 return torch.rand(self.cfg.n_heads, seq_len, seq_len, generator=gen)
91 def run_with_cache(
92 self,
93 input, # noqa: A002 -- matches TransformerLens API signature
94 names_filter=None, # noqa: ARG002
95 return_type=None, # noqa: ARG002
96 padding_side="right", # noqa: ARG002
97 ):
98 """Mock run_with_cache supporting both single string and list of strings.
100 For batched input, pads to max length. Real positions get deterministic
101 values based on text content. Padding positions are 0.
102 """
103 if isinstance(input, list):
104 texts = input
105 batch_size = len(texts)
106 else:
107 texts = [input]
108 batch_size = 1
110 seq_lens = [self._seq_len(t) for t in texts]
111 max_len = max(seq_lens)
113 cache: dict[str, torch.Tensor] = {}
114 for layer in range(self.cfg.n_layers):
115 attn = torch.zeros(batch_size, self.cfg.n_heads, max_len, max_len)
116 for b, (text, seq_len) in enumerate(zip(texts, seq_lens, strict=True)):
117 attn[b, :, :seq_len, :seq_len] = self._make_deterministic_attn(
118 text,
119 seq_len,
120 )
121 cache[f"blocks.{layer}.attn.hook_pattern"] = attn
123 return None, cache
126def _make_prompts() -> list[dict]:
127 """Return fresh copies of test prompts (dicts get mutated during processing)."""
128 return [
129 {"text": "hi", "hash": "hash_short"},
130 {"text": "hello world", "hash": "hash_medium"},
131 {"text": "the quick brown fox jumps", "hash": "hash_long"},
132 ]
135def _expected_seq_len(text: str) -> int:
136 """Expected model sequence length: len(text) + 1 for BOS."""
137 return len(text) + 1
140# ============================================================================
141# Test: activations_exist
142# ============================================================================
145def test_activations_exist_both_present():
146 """activations_exist returns True when both prompt.json and activations.npz exist."""
147 temp_dir = TEMP_DIR / "test_activations_exist_both"
148 model_name = "test-model"
149 prompt = {"text": "test", "hash": "exist_hash"}
151 prompt_dir = temp_dir / model_name / "prompts" / prompt["hash"]
152 prompt_dir.mkdir(parents=True, exist_ok=True)
154 with open(prompt_dir / "prompt.json", "w") as f:
155 json.dump(prompt, f)
156 np.savez(prompt_dir / "activations.npz", dummy=np.zeros(1))
158 assert activations_exist(model_name, prompt, temp_dir) is True
161def test_activations_exist_missing_npz():
162 """activations_exist returns False when activations.npz is missing."""
163 temp_dir = TEMP_DIR / "test_activations_exist_no_npz"
164 model_name = "test-model"
165 prompt = {"text": "test", "hash": "exist_no_npz"}
167 prompt_dir = temp_dir / model_name / "prompts" / prompt["hash"]
168 prompt_dir.mkdir(parents=True, exist_ok=True)
170 with open(prompt_dir / "prompt.json", "w") as f:
171 json.dump(prompt, f)
173 assert activations_exist(model_name, prompt, temp_dir) is False
176def test_activations_exist_missing_json():
177 """activations_exist returns False when prompt.json is missing."""
178 temp_dir = TEMP_DIR / "test_activations_exist_no_json"
179 model_name = "test-model"
180 prompt = {"text": "test", "hash": "exist_no_json"}
182 prompt_dir = temp_dir / model_name / "prompts" / prompt["hash"]
183 prompt_dir.mkdir(parents=True, exist_ok=True)
185 np.savez(prompt_dir / "activations.npz", dummy=np.zeros(1))
187 assert activations_exist(model_name, prompt, temp_dir) is False
190def test_activations_exist_missing_dir():
191 """activations_exist returns False when the directory doesn't exist."""
192 temp_dir = TEMP_DIR / "test_activations_exist_no_dir"
193 prompt = {"text": "test", "hash": "exist_no_dir"}
195 assert activations_exist("test-model", prompt, temp_dir) is False
198# ============================================================================
199# Test: compute_activations_batched shapes
200# ============================================================================
203def test_compute_activations_batched_shapes():
204 """Each prompt's saved .npz has attention patterns with the correct unpadded shape."""
205 temp_dir = TEMP_DIR / "test_batched_shapes"
206 model = MockHookedTransformerBatched()
207 prompts = _make_prompts()
209 paths = compute_activations_batched(
210 prompts=prompts,
211 model=model, # type: ignore[arg-type]
212 save_path=temp_dir,
213 )
215 assert len(paths) == 3
217 for prompt, path in zip(prompts, paths, strict=True):
218 assert path.exists(), f"Missing file: {path}"
220 expected_seq_len = _expected_seq_len(prompt["text"])
221 with np.load(path) as data:
222 for layer in range(model.cfg.n_layers):
223 key = f"blocks.{layer}.attn.hook_pattern"
224 assert key in data, f"Missing key {key} in {path}"
225 arr = data[key]
226 expected_shape = (
227 1,
228 model.cfg.n_heads,
229 expected_seq_len,
230 expected_seq_len,
231 )
232 assert arr.shape == expected_shape, (
233 f"Wrong shape for {prompt['text']!r}: {arr.shape} != {expected_shape}"
234 )
237def test_compute_activations_batched_no_padding_leaks():
238 """Verify that padding values (zeros) don't appear in saved data for real positions.
240 The mock fills real positions with random values > 0 (with very high probability)
241 and padding positions with exactly 0. If trimming is wrong, we'd see zeros
242 in places that should have random values.
243 """
244 temp_dir = TEMP_DIR / "test_batched_no_padding_leaks"
245 model = MockHookedTransformerBatched()
246 # Use prompts with very different lengths to ensure padding exists
247 prompts = [
248 {"text": "ab", "hash": "hash_tiny"}, # seq_len=3
249 {"text": "a" * 50, "hash": "hash_big"}, # seq_len=51
250 ]
252 paths = compute_activations_batched(
253 prompts=prompts,
254 model=model, # type: ignore[arg-type]
255 save_path=temp_dir,
256 )
258 # Check the short prompt's data — if padding leaked, it would have
259 # shape (1, 2, 51, 51) instead of (1, 2, 3, 3)
260 with np.load(paths[0]) as data:
261 arr = data["blocks.0.attn.hook_pattern"]
262 assert arr.shape == (1, 2, 3, 3), (
263 f"Padding leaked into short prompt: shape={arr.shape}"
264 )
267# ============================================================================
268# Test: batched vs single equivalence
269# ============================================================================
272def test_batched_vs_single_equivalence():
273 """Batched results must be identical to single-prompt results.
275 Process the same prompts both individually and as a batch.
276 With right-padding + causal mask, real positions should have identical values.
277 """
278 temp_dir_single = TEMP_DIR / "test_equivalence_single"
279 temp_dir_batch = TEMP_DIR / "test_equivalence_batch"
280 model = MockHookedTransformerBatched()
282 prompts_single = _make_prompts()
283 prompts_batch = _make_prompts()
285 # Process individually
286 single_paths = []
287 for p in prompts_single:
288 augment_prompt_with_hash(p)
289 path, _ = compute_activations( # ty: ignore[no-matching-overload]
290 prompt=p,
291 model=model, # type: ignore[arg-type]
292 save_path=temp_dir_single,
293 return_cache=None,
294 )
295 single_paths.append(path)
297 # Process as batch
298 batch_paths = compute_activations_batched(
299 prompts=prompts_batch,
300 model=model, # type: ignore[arg-type]
301 save_path=temp_dir_batch,
302 )
304 # Compare
305 assert len(single_paths) == len(batch_paths)
306 for i, (single_path, batch_path) in enumerate(
307 zip(single_paths, batch_paths, strict=True),
308 ):
309 with np.load(single_path) as single_data, np.load(batch_path) as batch_data:
310 single_keys = set(single_data.keys())
311 batch_keys = set(batch_data.keys())
312 assert single_keys == batch_keys, (
313 f"Prompt {i}: key mismatch: {single_keys} != {batch_keys}"
314 )
316 for key in single_keys:
317 single_arr = single_data[key]
318 batch_arr = batch_data[key]
319 assert single_arr.shape == batch_arr.shape, (
320 f"Prompt {i}, key {key}: shape mismatch: "
321 f"{single_arr.shape} != {batch_arr.shape}"
322 )
323 np.testing.assert_allclose(
324 single_arr,
325 batch_arr,
326 rtol=0,
327 atol=0,
328 err_msg=f"Prompt {i}, key {key}: values differ",
329 )
332def test_batched_single_prompt_equivalence():
333 """Batch of size 1 must produce identical results to single-prompt compute."""
334 temp_dir_single = TEMP_DIR / "test_single_equiv_single"
335 temp_dir_batch = TEMP_DIR / "test_single_equiv_batch"
336 model = MockHookedTransformerBatched()
338 prompt_single = {"text": "hello world", "hash": "hash_1prompt"}
339 prompt_batch = {"text": "hello world", "hash": "hash_1prompt"}
341 augment_prompt_with_hash(prompt_single)
342 single_path, _ = compute_activations( # ty: ignore[no-matching-overload]
343 prompt=prompt_single,
344 model=model, # type: ignore[arg-type]
345 save_path=temp_dir_single,
346 return_cache=None,
347 )
349 batch_paths = compute_activations_batched(
350 prompts=[prompt_batch],
351 model=model, # type: ignore[arg-type]
352 save_path=temp_dir_batch,
353 )
355 with np.load(single_path) as single_data, np.load(batch_paths[0]) as batch_data:
356 for key in single_data:
357 np.testing.assert_allclose(
358 single_data[key],
359 batch_data[key],
360 rtol=0,
361 atol=0,
362 err_msg=f"Key {key}: single vs batch-of-1 differ",
363 )
366# ============================================================================
367# Test: prompt metadata
368# ============================================================================
371def test_compute_activations_batched_saves_prompt_metadata():
372 """Each prompt in the batch gets its own prompt.json with correct fields."""
373 temp_dir = TEMP_DIR / "test_batched_metadata"
374 model = MockHookedTransformerBatched()
375 prompts = _make_prompts()
377 compute_activations_batched(
378 prompts=prompts,
379 model=model, # type: ignore[arg-type]
380 save_path=temp_dir,
381 )
383 for prompt in prompts:
384 prompt_dir = temp_dir / model.cfg.model_name / "prompts" / prompt["hash"]
385 prompt_json_path = prompt_dir / "prompt.json"
386 assert prompt_json_path.exists(), f"Missing prompt.json for {prompt['hash']}"
388 with open(prompt_json_path) as f:
389 saved = json.load(f)
391 assert saved["text"] == prompt["text"]
392 assert saved["hash"] == prompt["hash"]
393 assert "tokens" in saved
394 assert "n_tokens" in saved
395 # n_tokens should be len(text) since tokenizer.tokenize returns list(text)
396 assert saved["n_tokens"] == len(prompt["text"])
397 # tokens should be the list of characters
398 assert saved["tokens"] == list(prompt["text"])
401# ============================================================================
402# Test: file structure and path correctness
403# ============================================================================
406def test_compute_activations_batched_file_paths():
407 """Saved files follow the expected directory structure."""
408 temp_dir = TEMP_DIR / "test_batched_paths"
409 model = MockHookedTransformerBatched(model_name="my-model")
410 prompts = _make_prompts()
412 paths = compute_activations_batched(
413 prompts=prompts,
414 model=model, # type: ignore[arg-type]
415 save_path=temp_dir,
416 )
418 for prompt, path in zip(prompts, paths, strict=True):
419 expected = (
420 temp_dir / "my-model" / "prompts" / prompt["hash"] / "activations.npz"
421 )
422 assert path == expected, f"Wrong path: {path} != {expected}"
425# ============================================================================
426# Test: empty batch assertion
427# ============================================================================
430def test_compute_activations_batched_empty_raises():
431 """Empty prompt list should raise AssertionError."""
432 model = MockHookedTransformerBatched()
433 with pytest.raises(AssertionError, match="prompts must not be empty"):
434 compute_activations_batched(
435 prompts=[],
436 model=model, # type: ignore[arg-type]
437 save_path=TEMP_DIR / "test_batched_empty",
438 )
441# ============================================================================
442# Test: variable-length trimming with extreme size differences
443# ============================================================================
446def test_batched_extreme_length_difference():
447 """Test with prompts whose lengths differ by 10x+ to stress-test trimming."""
448 temp_dir = TEMP_DIR / "test_batched_extreme_lengths"
449 model = MockHookedTransformerBatched()
451 prompts = [
452 {"text": "x", "hash": "hash_1char"}, # seq_len=2
453 {"text": "y" * 100, "hash": "hash_100char"}, # seq_len=101
454 ]
456 paths = compute_activations_batched(
457 prompts=prompts,
458 model=model, # type: ignore[arg-type]
459 save_path=temp_dir,
460 )
462 # Short prompt: shape should be [1, 2, 2, 2], NOT [1, 2, 101, 101]
463 with np.load(paths[0]) as data:
464 arr = data["blocks.0.attn.hook_pattern"]
465 assert arr.shape == (1, 2, 2, 2)
467 # Long prompt: shape should be [1, 2, 101, 101]
468 with np.load(paths[1]) as data:
469 arr = data["blocks.0.attn.hook_pattern"]
470 assert arr.shape == (1, 2, 101, 101)
473def test_batched_same_length_prompts():
474 """When all prompts have the same length, no trimming needed — should still work."""
475 temp_dir = TEMP_DIR / "test_batched_same_length"
476 model = MockHookedTransformerBatched()
478 prompts = [
479 {"text": "abc", "hash": "hash_abc"},
480 {"text": "def", "hash": "hash_def"},
481 {"text": "ghi", "hash": "hash_ghi"},
482 ]
484 paths = compute_activations_batched(
485 prompts=prompts,
486 model=model, # type: ignore[arg-type]
487 save_path=temp_dir,
488 )
490 # All should have seq_len=4 (3 chars + BOS)
491 for path in paths:
492 with np.load(path) as data:
493 arr = data["blocks.0.attn.hook_pattern"]
494 assert arr.shape == (1, 2, 4, 4)
497# ============================================================================
498# Test: activations_exist integration with compute_activations_batched
499# ============================================================================
502def test_activations_exist_after_batched_compute():
503 """activations_exist returns True for all prompts after batched computation."""
504 temp_dir = TEMP_DIR / "test_exist_after_batched"
505 model = MockHookedTransformerBatched()
506 prompts = _make_prompts()
508 compute_activations_batched(
509 prompts=prompts,
510 model=model, # type: ignore[arg-type]
511 save_path=temp_dir,
512 )
514 for prompt in prompts:
515 assert activations_exist(model.cfg.model_name, prompt, temp_dir), (
516 f"activations_exist returned False for {prompt['hash']}"
517 )
520# ============================================================================
521# Test: saved attention values are nonzero (trimmed correctly, not padding)
522# ============================================================================
525def test_batched_trimmed_values_are_nonzero():
526 """Verify saved attention values are all nonzero.
528 The mock fills real positions with torch.rand (uniform [0,1)) and padding
529 with exactly 0. If trimming is wrong and we include padding positions,
530 we'd see exact zeros in the saved data. Since rand() producing exactly 0.0
531 is astronomically unlikely for the small tensors in this test, any zero
532 indicates a trimming bug.
533 """
534 temp_dir = TEMP_DIR / "test_batched_nonzero_values"
535 model = MockHookedTransformerBatched()
536 # Deliberately different lengths so padding exists in the batch
537 prompts = [
538 {"text": "ab", "hash": "hash_nz_short"}, # seq_len=3
539 {"text": "a" * 20, "hash": "hash_nz_long"}, # seq_len=21
540 ]
542 paths = compute_activations_batched(
543 prompts=prompts,
544 model=model, # type: ignore[arg-type]
545 save_path=temp_dir,
546 )
548 for prompt, path in zip(prompts, paths, strict=True):
549 with np.load(path) as data:
550 for key in data:
551 arr = data[key]
552 assert np.all(arr != 0.0), (
553 f"Found zeros in {key} for prompt {prompt['text']!r} "
554 f"(shape {arr.shape}). This suggests padding was not trimmed."
555 )
558# ============================================================================
559# Test: pre-computed seq_lens parameter
560# ============================================================================
563def test_batched_with_precomputed_seq_lens():
564 """Passing seq_lens explicitly produces the same result as computing internally."""
565 temp_dir_auto = TEMP_DIR / "test_seq_lens_auto"
566 temp_dir_manual = TEMP_DIR / "test_seq_lens_manual"
567 model = MockHookedTransformerBatched()
569 prompts_auto = _make_prompts()
570 prompts_manual = _make_prompts()
572 # Auto-computed seq_lens
573 paths_auto = compute_activations_batched(
574 prompts=prompts_auto,
575 model=model, # type: ignore[arg-type]
576 save_path=temp_dir_auto,
577 )
579 # Manually pre-computed seq_lens
580 manual_seq_lens = [model.to_tokens(p["text"]).shape[1] for p in prompts_manual]
581 paths_manual = compute_activations_batched(
582 prompts=prompts_manual,
583 model=model, # type: ignore[arg-type]
584 save_path=temp_dir_manual,
585 seq_lens=manual_seq_lens,
586 )
588 for auto_path, manual_path in zip(paths_auto, paths_manual, strict=True):
589 with np.load(auto_path) as auto_data, np.load(manual_path) as manual_data:
590 for key in auto_data:
591 np.testing.assert_array_equal(
592 auto_data[key],
593 manual_data[key],
594 err_msg=f"Key {key}: auto vs manual seq_lens differ",
595 )
598def test_batched_seq_lens_length_mismatch_raises():
599 """Passing seq_lens with wrong length raises AssertionError."""
600 model = MockHookedTransformerBatched()
601 prompts = _make_prompts()
603 with pytest.raises(AssertionError, match="seq_lens length mismatch"):
604 compute_activations_batched(
605 prompts=prompts,
606 model=model, # type: ignore[arg-type]
607 save_path=TEMP_DIR / "test_seq_lens_mismatch",
608 seq_lens=[5, 10], # 2 lengths for 3 prompts
609 )
612# ============================================================================
613# Test: activations_main cache-skip path (force=False)
614# ============================================================================
616# Patches needed to run activations_main with a mock model
617_ACTIVATIONS_MAIN_PATCHES = [
618 "pattern_lens.activations.HookedTransformer",
619 "pattern_lens.activations.load_text_data",
620 "pattern_lens.activations.write_html_index",
621 "pattern_lens.activations.generate_models_jsonl",
622 "pattern_lens.activations.generate_prompts_jsonl",
623 "pattern_lens.activations.asdict",
624 "pattern_lens.activations.json_serialize",
625]
628def _make_5_prompts() -> list[dict]:
629 """5 prompts of varying lengths for cache-skip tests."""
630 return [
631 {"text": "aa"},
632 {"text": "bbbb"},
633 {"text": "cccccc"},
634 {"text": "dddddddd"},
635 {"text": "eeeeeeeeee"},
636 ]
639def _run_activations_main_mocked(
640 save_path: Path,
641 prompts: list[dict],
642 force: bool,
643 batch_size: int = 32,
644) -> mock.MagicMock:
645 """Run activations_main with all heavy dependencies mocked.
647 Returns the mock for compute_activations_batched so callers can inspect calls.
648 """
649 mock_model = MockHookedTransformerBatched()
651 with (
652 mock.patch("pattern_lens.activations.HookedTransformer") as mock_ht_cls,
653 mock.patch("pattern_lens.activations.load_text_data", return_value=prompts),
654 mock.patch("pattern_lens.activations.write_html_index"),
655 mock.patch("pattern_lens.activations.generate_models_jsonl"),
656 mock.patch("pattern_lens.activations.generate_prompts_jsonl"),
657 mock.patch("pattern_lens.activations.asdict", return_value={}),
658 mock.patch("pattern_lens.activations.json_serialize", return_value={}),
659 mock.patch(
660 "pattern_lens.activations.compute_activations_batched",
661 wraps=compute_activations_batched,
662 ) as spy_batched,
663 ):
664 mock_ht_cls.from_pretrained.return_value = mock_model
666 activations_main(
667 model_name="test-model",
668 save_path=str(save_path),
669 prompts_path="dummy.jsonl",
670 raw_prompts=True,
671 min_chars=0,
672 max_chars=9999,
673 force=force,
674 n_samples=len(prompts),
675 no_index_html=True,
676 device=torch.device("cpu"),
677 batch_size=batch_size,
678 )
680 return spy_batched
683def test_activations_main_partial_cache_skip():
684 """With force=False, only uncached prompts are computed."""
685 temp_dir = TEMP_DIR / "test_main_partial_cache"
686 if temp_dir.exists():
687 shutil.rmtree(temp_dir)
688 model = MockHookedTransformerBatched()
689 all_prompts = _make_5_prompts()
691 # Pre-compute activations for the first 2 prompts
692 pre_cached = [dict(p) for p in all_prompts[:2]]
693 for p in pre_cached:
694 augment_prompt_with_hash(p)
695 compute_activations_batched(
696 prompts=pre_cached,
697 model=model, # type: ignore[arg-type]
698 save_path=temp_dir,
699 )
701 # Verify they're cached
702 for p in pre_cached:
703 assert activations_exist("test-model", p, temp_dir)
705 # Run activations_main with force=False — should skip the 2 cached ones
706 spy = _run_activations_main_mocked(
707 save_path=temp_dir,
708 prompts=[dict(p) for p in all_prompts],
709 force=False,
710 batch_size=32,
711 )
713 # compute_activations_batched should have been called with only 3 uncached prompts
714 assert spy.call_count == 1
715 called_prompts = spy.call_args[1]["prompts"]
716 assert len(called_prompts) == 3
718 # All 5 should now be cached
719 for p in all_prompts:
720 augment_prompt_with_hash(p)
721 assert activations_exist("test-model", p, temp_dir), (
722 f"prompt {p['text']!r} not cached after activations_main"
723 )
726def test_activations_main_full_cache_skip():
727 """With force=False and all prompts cached, compute_activations_batched is never called."""
728 temp_dir = TEMP_DIR / "test_main_full_cache"
729 if temp_dir.exists():
730 shutil.rmtree(temp_dir)
731 model = MockHookedTransformerBatched()
732 all_prompts = _make_5_prompts()
734 # Pre-compute activations for ALL prompts
735 pre_cached = [dict(p) for p in all_prompts]
736 for p in pre_cached:
737 augment_prompt_with_hash(p)
738 compute_activations_batched(
739 prompts=pre_cached,
740 model=model, # type: ignore[arg-type]
741 save_path=temp_dir,
742 )
744 # Run activations_main with force=False — should skip everything
745 spy = _run_activations_main_mocked(
746 save_path=temp_dir,
747 prompts=[dict(p) for p in all_prompts],
748 force=False,
749 )
751 # compute_activations_batched should never have been called
752 assert spy.call_count == 0
755# ============================================================================
756# Test: names_filter as a callable (not regex)
757# ============================================================================
760def test_names_filter_callable():
761 """Passing a plain callable as names_filter exercises the non-regex branch."""
762 temp_dir = TEMP_DIR / "test_names_filter_callable"
763 model = MockHookedTransformerBatched()
764 prompts = _make_prompts()
766 def my_filter(key: str) -> bool:
767 return "hook_pattern" in key
769 paths = compute_activations_batched(
770 prompts=prompts,
771 model=model, # type: ignore[arg-type]
772 save_path=temp_dir,
773 names_filter=my_filter,
774 )
776 assert len(paths) == 3
777 for prompt, path in zip(prompts, paths, strict=True):
778 assert path.exists()
779 expected_seq_len: int = _expected_seq_len(prompt["text"])
780 with np.load(path) as data:
781 for layer in range(model.cfg.n_layers):
782 key = f"blocks.{layer}.attn.hook_pattern"
783 assert key in data
784 assert data[key].shape == (
785 1,
786 model.cfg.n_heads,
787 expected_seq_len,
788 expected_seq_len,
789 )
792# ============================================================================
793# Test: activations_exist raises when hash is missing
794# ============================================================================
797def test_activations_exist_requires_hash():
798 """activations_exist raises InvalidPromptError when prompt has no hash."""
799 with pytest.raises(InvalidPromptError, match="must have 'hash' key"):
800 activations_exist("test-model", {"text": "no hash here"}, TEMP_DIR)
803# ============================================================================
804# Test: compute_activations_batched input validation
805# ============================================================================
808def test_compute_activations_batched_missing_text_raises():
809 """Prompt without 'text' key raises AssertionError."""
810 model = MockHookedTransformerBatched()
811 with pytest.raises(AssertionError, match="text"):
812 compute_activations_batched(
813 prompts=[{"hash": "abc"}],
814 model=model, # type: ignore[arg-type]
815 save_path=TEMP_DIR / "test_missing_text",
816 )
819def test_compute_activations_batched_missing_hash_raises():
820 """Prompt without 'hash' key raises AssertionError."""
821 model = MockHookedTransformerBatched()
822 with pytest.raises(AssertionError, match="hash"):
823 compute_activations_batched(
824 prompts=[{"text": "hello"}],
825 model=model, # type: ignore[arg-type]
826 save_path=TEMP_DIR / "test_missing_hash",
827 )
830# ============================================================================
831# Test: activations_main with batch_size=1
832# ============================================================================
835def test_activations_main_batch_size_1():
836 """batch_size=1 processes each prompt individually (one call per prompt)."""
837 temp_dir = TEMP_DIR / "test_main_batch_size_1"
838 if temp_dir.exists():
839 shutil.rmtree(temp_dir)
840 all_prompts = _make_5_prompts()
842 spy = _run_activations_main_mocked(
843 save_path=temp_dir,
844 prompts=[dict(p) for p in all_prompts],
845 force=True,
846 batch_size=1,
847 )
849 # With batch_size=1 and 5 prompts, should be called 5 times (one prompt each)
850 assert spy.call_count == 5
851 for call in spy.call_args_list:
852 assert len(call[1]["prompts"]) == 1
854 # All 5 should now be cached
855 for p in all_prompts:
856 augment_prompt_with_hash(p)
857 assert activations_exist("test-model", p, temp_dir)
860# ============================================================================
861# Test: activations_main with force=True recomputes cached prompts
862# ============================================================================
865def test_activations_main_force_recomputes():
866 """With force=True, all prompts are recomputed even if already cached."""
867 temp_dir = TEMP_DIR / "test_main_force_recompute"
868 if temp_dir.exists():
869 shutil.rmtree(temp_dir)
870 model = MockHookedTransformerBatched()
871 all_prompts = _make_5_prompts()
873 # Pre-compute activations for ALL prompts
874 pre_cached = [dict(p) for p in all_prompts]
875 for p in pre_cached:
876 augment_prompt_with_hash(p)
877 compute_activations_batched(
878 prompts=pre_cached,
879 model=model, # type: ignore[arg-type]
880 save_path=temp_dir,
881 )
883 # Verify they're all cached
884 for p in pre_cached:
885 assert activations_exist("test-model", p, temp_dir)
887 # Run with force=True — should recompute all 5
888 spy = _run_activations_main_mocked(
889 save_path=temp_dir,
890 prompts=[dict(p) for p in all_prompts],
891 force=True,
892 )
894 # compute_activations_batched should have been called with all 5 prompts
895 total_computed = sum(len(call[1]["prompts"]) for call in spy.call_args_list)
896 assert total_computed == 5
899# ============================================================================
900# Test: activations_main sorts prompts by length (longest first)
901# ============================================================================
904def test_activations_main_sorts_by_length():
905 """Prompts are sorted longest-first within each batch for padding efficiency."""
906 temp_dir = TEMP_DIR / "test_main_sorts_by_length"
907 if temp_dir.exists():
908 shutil.rmtree(temp_dir)
910 # Deliberately pass prompts in SHORT-first order
911 prompts = [
912 {"text": "a"}, # shortest
913 {"text": "bb"},
914 {"text": "ccc"},
915 {"text": "dddd"},
916 {"text": "eeeee"}, # longest
917 ]
919 spy = _run_activations_main_mocked(
920 save_path=temp_dir,
921 prompts=[dict(p) for p in prompts],
922 force=True,
923 batch_size=100, # large enough to fit all in one batch
924 )
926 # All prompts in one call — check they're sorted longest-first
927 assert spy.call_count == 1
928 called_prompts = spy.call_args[1]["prompts"]
929 called_texts = [p["text"] for p in called_prompts]
930 called_lengths = [len(t) for t in called_texts]
931 assert called_lengths == sorted(called_lengths, reverse=True), (
932 f"Prompts not sorted longest-first: {called_texts}"
933 )
935 # seq_lens must be sorted in the same order as prompts
936 called_seq_lens = spy.call_args[1]["seq_lens"]
937 assert called_seq_lens == sorted(called_seq_lens, reverse=True), (
938 f"seq_lens not sorted longest-first: {called_seq_lens}"
939 )
942# ============================================================================
943# Test: activations_main splits into multiple batches correctly
944# ============================================================================
947def test_activations_main_multiple_batches():
948 """batch_size=2 with 5 prompts produces 3 batched calls (2+2+1)."""
949 temp_dir = TEMP_DIR / "test_main_multi_batch"
950 if temp_dir.exists():
951 shutil.rmtree(temp_dir)
952 all_prompts = _make_5_prompts()
954 spy = _run_activations_main_mocked(
955 save_path=temp_dir,
956 prompts=[dict(p) for p in all_prompts],
957 force=True,
958 batch_size=2,
959 )
961 # 5 prompts / batch_size=2 => 3 calls: 2, 2, 1
962 assert spy.call_count == 3
963 batch_sizes = [len(call[1]["prompts"]) for call in spy.call_args_list]
964 assert batch_sizes == [2, 2, 1], f"Unexpected batch sizes: {batch_sizes}"
966 # All 5 should be cached
967 for p in all_prompts:
968 augment_prompt_with_hash(p)
969 assert activations_exist("test-model", p, temp_dir)