Coverage for tests / unit / test_activations_batched.py: 97%

347 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-22 18:15 -0700

1"""Tests for batched activation computation. 

2 

3Tests verify: 

4- compute_activations_batched produces correct shapes per-prompt 

5- Trimming correctly removes padding (variable-length prompts) 

6- Batched results are equivalent to single-prompt results 

7- Prompt metadata (prompt.json) is correctly saved 

8- activations_exist helper function 

9- Cache skipping in activations_main 

10""" 

11 

12import json 

13import shutil 

14from pathlib import Path 

15from unittest import mock 

16 

17import numpy as np 

18import pytest 

19import torch 

20 

21from pattern_lens.activations import ( 

22 activations_main, 

23 compute_activations, 

24 compute_activations_batched, 

25) 

26from pattern_lens.load_activations import ( 

27 InvalidPromptError, 

28 activations_exist, 

29 augment_prompt_with_hash, 

30) 

31 

32TEMP_DIR: Path = Path("tests/.temp") 

33 

34 

35class MockHookedTransformerBatched: 

36 """Mock of HookedTransformer that supports both single and batched input. 

37 

38 Tokenization rule: each character in the text becomes one token, plus a BOS token. 

39 So "hello" -> seq_len=6 (1 BOS + 5 chars). 

40 

41 For batched input, shorter sequences are right-padded with zeros. 

42 Real (non-padding) positions are filled with deterministic random values 

43 seeded by the text content, so the same text produces the same attention 

44 patterns regardless of batch composition. 

45 """ 

46 

47 def __init__( 

48 self, 

49 model_name: str = "test-model", 

50 n_layers: int = 2, 

51 n_heads: int = 2, 

52 ): 

53 self.model_name = model_name 

54 self.cfg = mock.MagicMock() 

55 self.cfg.n_layers = n_layers 

56 self.cfg.n_heads = n_heads 

57 self.cfg.model_name = model_name 

58 self.tokenizer = mock.MagicMock() 

59 # tokenize returns subword strings (no BOS) — matches HuggingFace tokenizer behavior 

60 self.tokenizer.tokenize = lambda text: list(text) # noqa: PLW0108 -- split text into individual characters 

61 

62 def _seq_len(self, text: str) -> int: 

63 """Actual model sequence length for a text (includes BOS).""" 

64 return len(text) + 1 

65 

66 def parameters(self): 

67 """Return a dummy parameter for activations_main's numel/device checks.""" 

68 return [torch.zeros(1)] 

69 

70 def eval(self): 

71 return self 

72 

73 def to_tokens(self, text): 

74 """Return token IDs. Includes BOS, so seq_len = len(text) + 1.""" 

75 if isinstance(text, str): 

76 return torch.zeros(1, self._seq_len(text), dtype=torch.long) 

77 else: 

78 seq_lens = [self._seq_len(t) for t in text] 

79 max_len = max(seq_lens) 

80 return torch.zeros(len(text), max_len, dtype=torch.long) 

81 

82 def _make_deterministic_attn(self, text: str, seq_len: int) -> torch.Tensor: 

83 """Generate deterministic attention values for a text. 

84 

85 Returns shape [n_heads, seq_len, seq_len] with values seeded by text content. 

86 """ 

87 gen = torch.Generator() 

88 gen.manual_seed(hash(text) % (2**31)) 

89 return torch.rand(self.cfg.n_heads, seq_len, seq_len, generator=gen) 

90 

91 def run_with_cache( 

92 self, 

93 input, # noqa: A002 -- matches TransformerLens API signature 

94 names_filter=None, # noqa: ARG002 

95 return_type=None, # noqa: ARG002 

96 padding_side="right", # noqa: ARG002 

97 ): 

98 """Mock run_with_cache supporting both single string and list of strings. 

99 

100 For batched input, pads to max length. Real positions get deterministic 

101 values based on text content. Padding positions are 0. 

102 """ 

103 if isinstance(input, list): 

104 texts = input 

105 batch_size = len(texts) 

106 else: 

107 texts = [input] 

108 batch_size = 1 

109 

110 seq_lens = [self._seq_len(t) for t in texts] 

111 max_len = max(seq_lens) 

112 

113 cache: dict[str, torch.Tensor] = {} 

114 for layer in range(self.cfg.n_layers): 

115 attn = torch.zeros(batch_size, self.cfg.n_heads, max_len, max_len) 

116 for b, (text, seq_len) in enumerate(zip(texts, seq_lens, strict=True)): 

117 attn[b, :, :seq_len, :seq_len] = self._make_deterministic_attn( 

118 text, 

119 seq_len, 

120 ) 

121 cache[f"blocks.{layer}.attn.hook_pattern"] = attn 

122 

123 return None, cache 

124 

125 

126def _make_prompts() -> list[dict]: 

127 """Return fresh copies of test prompts (dicts get mutated during processing).""" 

128 return [ 

129 {"text": "hi", "hash": "hash_short"}, 

130 {"text": "hello world", "hash": "hash_medium"}, 

131 {"text": "the quick brown fox jumps", "hash": "hash_long"}, 

132 ] 

133 

134 

135def _expected_seq_len(text: str) -> int: 

136 """Expected model sequence length: len(text) + 1 for BOS.""" 

137 return len(text) + 1 

138 

139 

140# ============================================================================ 

141# Test: activations_exist 

142# ============================================================================ 

143 

144 

145def test_activations_exist_both_present(): 

146 """activations_exist returns True when both prompt.json and activations.npz exist.""" 

147 temp_dir = TEMP_DIR / "test_activations_exist_both" 

148 model_name = "test-model" 

149 prompt = {"text": "test", "hash": "exist_hash"} 

150 

151 prompt_dir = temp_dir / model_name / "prompts" / prompt["hash"] 

152 prompt_dir.mkdir(parents=True, exist_ok=True) 

153 

154 with open(prompt_dir / "prompt.json", "w") as f: 

155 json.dump(prompt, f) 

156 np.savez(prompt_dir / "activations.npz", dummy=np.zeros(1)) 

157 

158 assert activations_exist(model_name, prompt, temp_dir) is True 

159 

160 

161def test_activations_exist_missing_npz(): 

162 """activations_exist returns False when activations.npz is missing.""" 

163 temp_dir = TEMP_DIR / "test_activations_exist_no_npz" 

164 model_name = "test-model" 

165 prompt = {"text": "test", "hash": "exist_no_npz"} 

166 

167 prompt_dir = temp_dir / model_name / "prompts" / prompt["hash"] 

168 prompt_dir.mkdir(parents=True, exist_ok=True) 

169 

170 with open(prompt_dir / "prompt.json", "w") as f: 

171 json.dump(prompt, f) 

172 

173 assert activations_exist(model_name, prompt, temp_dir) is False 

174 

175 

176def test_activations_exist_missing_json(): 

177 """activations_exist returns False when prompt.json is missing.""" 

178 temp_dir = TEMP_DIR / "test_activations_exist_no_json" 

179 model_name = "test-model" 

180 prompt = {"text": "test", "hash": "exist_no_json"} 

181 

182 prompt_dir = temp_dir / model_name / "prompts" / prompt["hash"] 

183 prompt_dir.mkdir(parents=True, exist_ok=True) 

184 

185 np.savez(prompt_dir / "activations.npz", dummy=np.zeros(1)) 

186 

187 assert activations_exist(model_name, prompt, temp_dir) is False 

188 

189 

190def test_activations_exist_missing_dir(): 

191 """activations_exist returns False when the directory doesn't exist.""" 

192 temp_dir = TEMP_DIR / "test_activations_exist_no_dir" 

193 prompt = {"text": "test", "hash": "exist_no_dir"} 

194 

195 assert activations_exist("test-model", prompt, temp_dir) is False 

196 

197 

198# ============================================================================ 

199# Test: compute_activations_batched shapes 

200# ============================================================================ 

201 

202 

203def test_compute_activations_batched_shapes(): 

204 """Each prompt's saved .npz has attention patterns with the correct unpadded shape.""" 

205 temp_dir = TEMP_DIR / "test_batched_shapes" 

206 model = MockHookedTransformerBatched() 

207 prompts = _make_prompts() 

208 

209 paths = compute_activations_batched( 

210 prompts=prompts, 

211 model=model, # type: ignore[arg-type] 

212 save_path=temp_dir, 

213 ) 

214 

215 assert len(paths) == 3 

216 

217 for prompt, path in zip(prompts, paths, strict=True): 

218 assert path.exists(), f"Missing file: {path}" 

219 

220 expected_seq_len = _expected_seq_len(prompt["text"]) 

221 with np.load(path) as data: 

222 for layer in range(model.cfg.n_layers): 

223 key = f"blocks.{layer}.attn.hook_pattern" 

224 assert key in data, f"Missing key {key} in {path}" 

225 arr = data[key] 

226 expected_shape = ( 

227 1, 

228 model.cfg.n_heads, 

229 expected_seq_len, 

230 expected_seq_len, 

231 ) 

232 assert arr.shape == expected_shape, ( 

233 f"Wrong shape for {prompt['text']!r}: {arr.shape} != {expected_shape}" 

234 ) 

235 

236 

237def test_compute_activations_batched_no_padding_leaks(): 

238 """Verify that padding values (zeros) don't appear in saved data for real positions. 

239 

240 The mock fills real positions with random values > 0 (with very high probability) 

241 and padding positions with exactly 0. If trimming is wrong, we'd see zeros 

242 in places that should have random values. 

243 """ 

244 temp_dir = TEMP_DIR / "test_batched_no_padding_leaks" 

245 model = MockHookedTransformerBatched() 

246 # Use prompts with very different lengths to ensure padding exists 

247 prompts = [ 

248 {"text": "ab", "hash": "hash_tiny"}, # seq_len=3 

249 {"text": "a" * 50, "hash": "hash_big"}, # seq_len=51 

250 ] 

251 

252 paths = compute_activations_batched( 

253 prompts=prompts, 

254 model=model, # type: ignore[arg-type] 

255 save_path=temp_dir, 

256 ) 

257 

258 # Check the short prompt's data — if padding leaked, it would have 

259 # shape (1, 2, 51, 51) instead of (1, 2, 3, 3) 

260 with np.load(paths[0]) as data: 

261 arr = data["blocks.0.attn.hook_pattern"] 

262 assert arr.shape == (1, 2, 3, 3), ( 

263 f"Padding leaked into short prompt: shape={arr.shape}" 

264 ) 

265 

266 

267# ============================================================================ 

268# Test: batched vs single equivalence 

269# ============================================================================ 

270 

271 

272def test_batched_vs_single_equivalence(): 

273 """Batched results must be identical to single-prompt results. 

274 

275 Process the same prompts both individually and as a batch. 

276 With right-padding + causal mask, real positions should have identical values. 

277 """ 

278 temp_dir_single = TEMP_DIR / "test_equivalence_single" 

279 temp_dir_batch = TEMP_DIR / "test_equivalence_batch" 

280 model = MockHookedTransformerBatched() 

281 

282 prompts_single = _make_prompts() 

283 prompts_batch = _make_prompts() 

284 

285 # Process individually 

286 single_paths = [] 

287 for p in prompts_single: 

288 augment_prompt_with_hash(p) 

289 path, _ = compute_activations( # ty: ignore[no-matching-overload] 

290 prompt=p, 

291 model=model, # type: ignore[arg-type] 

292 save_path=temp_dir_single, 

293 return_cache=None, 

294 ) 

295 single_paths.append(path) 

296 

297 # Process as batch 

298 batch_paths = compute_activations_batched( 

299 prompts=prompts_batch, 

300 model=model, # type: ignore[arg-type] 

301 save_path=temp_dir_batch, 

302 ) 

303 

304 # Compare 

305 assert len(single_paths) == len(batch_paths) 

306 for i, (single_path, batch_path) in enumerate( 

307 zip(single_paths, batch_paths, strict=True), 

308 ): 

309 with np.load(single_path) as single_data, np.load(batch_path) as batch_data: 

310 single_keys = set(single_data.keys()) 

311 batch_keys = set(batch_data.keys()) 

312 assert single_keys == batch_keys, ( 

313 f"Prompt {i}: key mismatch: {single_keys} != {batch_keys}" 

314 ) 

315 

316 for key in single_keys: 

317 single_arr = single_data[key] 

318 batch_arr = batch_data[key] 

319 assert single_arr.shape == batch_arr.shape, ( 

320 f"Prompt {i}, key {key}: shape mismatch: " 

321 f"{single_arr.shape} != {batch_arr.shape}" 

322 ) 

323 np.testing.assert_allclose( 

324 single_arr, 

325 batch_arr, 

326 rtol=0, 

327 atol=0, 

328 err_msg=f"Prompt {i}, key {key}: values differ", 

329 ) 

330 

331 

332def test_batched_single_prompt_equivalence(): 

333 """Batch of size 1 must produce identical results to single-prompt compute.""" 

334 temp_dir_single = TEMP_DIR / "test_single_equiv_single" 

335 temp_dir_batch = TEMP_DIR / "test_single_equiv_batch" 

336 model = MockHookedTransformerBatched() 

337 

338 prompt_single = {"text": "hello world", "hash": "hash_1prompt"} 

339 prompt_batch = {"text": "hello world", "hash": "hash_1prompt"} 

340 

341 augment_prompt_with_hash(prompt_single) 

342 single_path, _ = compute_activations( # ty: ignore[no-matching-overload] 

343 prompt=prompt_single, 

344 model=model, # type: ignore[arg-type] 

345 save_path=temp_dir_single, 

346 return_cache=None, 

347 ) 

348 

349 batch_paths = compute_activations_batched( 

350 prompts=[prompt_batch], 

351 model=model, # type: ignore[arg-type] 

352 save_path=temp_dir_batch, 

353 ) 

354 

355 with np.load(single_path) as single_data, np.load(batch_paths[0]) as batch_data: 

356 for key in single_data: 

357 np.testing.assert_allclose( 

358 single_data[key], 

359 batch_data[key], 

360 rtol=0, 

361 atol=0, 

362 err_msg=f"Key {key}: single vs batch-of-1 differ", 

363 ) 

364 

365 

366# ============================================================================ 

367# Test: prompt metadata 

368# ============================================================================ 

369 

370 

371def test_compute_activations_batched_saves_prompt_metadata(): 

372 """Each prompt in the batch gets its own prompt.json with correct fields.""" 

373 temp_dir = TEMP_DIR / "test_batched_metadata" 

374 model = MockHookedTransformerBatched() 

375 prompts = _make_prompts() 

376 

377 compute_activations_batched( 

378 prompts=prompts, 

379 model=model, # type: ignore[arg-type] 

380 save_path=temp_dir, 

381 ) 

382 

383 for prompt in prompts: 

384 prompt_dir = temp_dir / model.cfg.model_name / "prompts" / prompt["hash"] 

385 prompt_json_path = prompt_dir / "prompt.json" 

386 assert prompt_json_path.exists(), f"Missing prompt.json for {prompt['hash']}" 

387 

388 with open(prompt_json_path) as f: 

389 saved = json.load(f) 

390 

391 assert saved["text"] == prompt["text"] 

392 assert saved["hash"] == prompt["hash"] 

393 assert "tokens" in saved 

394 assert "n_tokens" in saved 

395 # n_tokens should be len(text) since tokenizer.tokenize returns list(text) 

396 assert saved["n_tokens"] == len(prompt["text"]) 

397 # tokens should be the list of characters 

398 assert saved["tokens"] == list(prompt["text"]) 

399 

400 

401# ============================================================================ 

402# Test: file structure and path correctness 

403# ============================================================================ 

404 

405 

406def test_compute_activations_batched_file_paths(): 

407 """Saved files follow the expected directory structure.""" 

408 temp_dir = TEMP_DIR / "test_batched_paths" 

409 model = MockHookedTransformerBatched(model_name="my-model") 

410 prompts = _make_prompts() 

411 

412 paths = compute_activations_batched( 

413 prompts=prompts, 

414 model=model, # type: ignore[arg-type] 

415 save_path=temp_dir, 

416 ) 

417 

418 for prompt, path in zip(prompts, paths, strict=True): 

419 expected = ( 

420 temp_dir / "my-model" / "prompts" / prompt["hash"] / "activations.npz" 

421 ) 

422 assert path == expected, f"Wrong path: {path} != {expected}" 

423 

424 

425# ============================================================================ 

426# Test: empty batch assertion 

427# ============================================================================ 

428 

429 

430def test_compute_activations_batched_empty_raises(): 

431 """Empty prompt list should raise AssertionError.""" 

432 model = MockHookedTransformerBatched() 

433 with pytest.raises(AssertionError, match="prompts must not be empty"): 

434 compute_activations_batched( 

435 prompts=[], 

436 model=model, # type: ignore[arg-type] 

437 save_path=TEMP_DIR / "test_batched_empty", 

438 ) 

439 

440 

441# ============================================================================ 

442# Test: variable-length trimming with extreme size differences 

443# ============================================================================ 

444 

445 

446def test_batched_extreme_length_difference(): 

447 """Test with prompts whose lengths differ by 10x+ to stress-test trimming.""" 

448 temp_dir = TEMP_DIR / "test_batched_extreme_lengths" 

449 model = MockHookedTransformerBatched() 

450 

451 prompts = [ 

452 {"text": "x", "hash": "hash_1char"}, # seq_len=2 

453 {"text": "y" * 100, "hash": "hash_100char"}, # seq_len=101 

454 ] 

455 

456 paths = compute_activations_batched( 

457 prompts=prompts, 

458 model=model, # type: ignore[arg-type] 

459 save_path=temp_dir, 

460 ) 

461 

462 # Short prompt: shape should be [1, 2, 2, 2], NOT [1, 2, 101, 101] 

463 with np.load(paths[0]) as data: 

464 arr = data["blocks.0.attn.hook_pattern"] 

465 assert arr.shape == (1, 2, 2, 2) 

466 

467 # Long prompt: shape should be [1, 2, 101, 101] 

468 with np.load(paths[1]) as data: 

469 arr = data["blocks.0.attn.hook_pattern"] 

470 assert arr.shape == (1, 2, 101, 101) 

471 

472 

473def test_batched_same_length_prompts(): 

474 """When all prompts have the same length, no trimming needed — should still work.""" 

475 temp_dir = TEMP_DIR / "test_batched_same_length" 

476 model = MockHookedTransformerBatched() 

477 

478 prompts = [ 

479 {"text": "abc", "hash": "hash_abc"}, 

480 {"text": "def", "hash": "hash_def"}, 

481 {"text": "ghi", "hash": "hash_ghi"}, 

482 ] 

483 

484 paths = compute_activations_batched( 

485 prompts=prompts, 

486 model=model, # type: ignore[arg-type] 

487 save_path=temp_dir, 

488 ) 

489 

490 # All should have seq_len=4 (3 chars + BOS) 

491 for path in paths: 

492 with np.load(path) as data: 

493 arr = data["blocks.0.attn.hook_pattern"] 

494 assert arr.shape == (1, 2, 4, 4) 

495 

496 

497# ============================================================================ 

498# Test: activations_exist integration with compute_activations_batched 

499# ============================================================================ 

500 

501 

502def test_activations_exist_after_batched_compute(): 

503 """activations_exist returns True for all prompts after batched computation.""" 

504 temp_dir = TEMP_DIR / "test_exist_after_batched" 

505 model = MockHookedTransformerBatched() 

506 prompts = _make_prompts() 

507 

508 compute_activations_batched( 

509 prompts=prompts, 

510 model=model, # type: ignore[arg-type] 

511 save_path=temp_dir, 

512 ) 

513 

514 for prompt in prompts: 

515 assert activations_exist(model.cfg.model_name, prompt, temp_dir), ( 

516 f"activations_exist returned False for {prompt['hash']}" 

517 ) 

518 

519 

520# ============================================================================ 

521# Test: saved attention values are nonzero (trimmed correctly, not padding) 

522# ============================================================================ 

523 

524 

525def test_batched_trimmed_values_are_nonzero(): 

526 """Verify saved attention values are all nonzero. 

527 

528 The mock fills real positions with torch.rand (uniform [0,1)) and padding 

529 with exactly 0. If trimming is wrong and we include padding positions, 

530 we'd see exact zeros in the saved data. Since rand() producing exactly 0.0 

531 is astronomically unlikely for the small tensors in this test, any zero 

532 indicates a trimming bug. 

533 """ 

534 temp_dir = TEMP_DIR / "test_batched_nonzero_values" 

535 model = MockHookedTransformerBatched() 

536 # Deliberately different lengths so padding exists in the batch 

537 prompts = [ 

538 {"text": "ab", "hash": "hash_nz_short"}, # seq_len=3 

539 {"text": "a" * 20, "hash": "hash_nz_long"}, # seq_len=21 

540 ] 

541 

542 paths = compute_activations_batched( 

543 prompts=prompts, 

544 model=model, # type: ignore[arg-type] 

545 save_path=temp_dir, 

546 ) 

547 

548 for prompt, path in zip(prompts, paths, strict=True): 

549 with np.load(path) as data: 

550 for key in data: 

551 arr = data[key] 

552 assert np.all(arr != 0.0), ( 

553 f"Found zeros in {key} for prompt {prompt['text']!r} " 

554 f"(shape {arr.shape}). This suggests padding was not trimmed." 

555 ) 

556 

557 

558# ============================================================================ 

559# Test: pre-computed seq_lens parameter 

560# ============================================================================ 

561 

562 

563def test_batched_with_precomputed_seq_lens(): 

564 """Passing seq_lens explicitly produces the same result as computing internally.""" 

565 temp_dir_auto = TEMP_DIR / "test_seq_lens_auto" 

566 temp_dir_manual = TEMP_DIR / "test_seq_lens_manual" 

567 model = MockHookedTransformerBatched() 

568 

569 prompts_auto = _make_prompts() 

570 prompts_manual = _make_prompts() 

571 

572 # Auto-computed seq_lens 

573 paths_auto = compute_activations_batched( 

574 prompts=prompts_auto, 

575 model=model, # type: ignore[arg-type] 

576 save_path=temp_dir_auto, 

577 ) 

578 

579 # Manually pre-computed seq_lens 

580 manual_seq_lens = [model.to_tokens(p["text"]).shape[1] for p in prompts_manual] 

581 paths_manual = compute_activations_batched( 

582 prompts=prompts_manual, 

583 model=model, # type: ignore[arg-type] 

584 save_path=temp_dir_manual, 

585 seq_lens=manual_seq_lens, 

586 ) 

587 

588 for auto_path, manual_path in zip(paths_auto, paths_manual, strict=True): 

589 with np.load(auto_path) as auto_data, np.load(manual_path) as manual_data: 

590 for key in auto_data: 

591 np.testing.assert_array_equal( 

592 auto_data[key], 

593 manual_data[key], 

594 err_msg=f"Key {key}: auto vs manual seq_lens differ", 

595 ) 

596 

597 

598def test_batched_seq_lens_length_mismatch_raises(): 

599 """Passing seq_lens with wrong length raises AssertionError.""" 

600 model = MockHookedTransformerBatched() 

601 prompts = _make_prompts() 

602 

603 with pytest.raises(AssertionError, match="seq_lens length mismatch"): 

604 compute_activations_batched( 

605 prompts=prompts, 

606 model=model, # type: ignore[arg-type] 

607 save_path=TEMP_DIR / "test_seq_lens_mismatch", 

608 seq_lens=[5, 10], # 2 lengths for 3 prompts 

609 ) 

610 

611 

612# ============================================================================ 

613# Test: activations_main cache-skip path (force=False) 

614# ============================================================================ 

615 

616# Patches needed to run activations_main with a mock model 

617_ACTIVATIONS_MAIN_PATCHES = [ 

618 "pattern_lens.activations.HookedTransformer", 

619 "pattern_lens.activations.load_text_data", 

620 "pattern_lens.activations.write_html_index", 

621 "pattern_lens.activations.generate_models_jsonl", 

622 "pattern_lens.activations.generate_prompts_jsonl", 

623 "pattern_lens.activations.asdict", 

624 "pattern_lens.activations.json_serialize", 

625] 

626 

627 

628def _make_5_prompts() -> list[dict]: 

629 """5 prompts of varying lengths for cache-skip tests.""" 

630 return [ 

631 {"text": "aa"}, 

632 {"text": "bbbb"}, 

633 {"text": "cccccc"}, 

634 {"text": "dddddddd"}, 

635 {"text": "eeeeeeeeee"}, 

636 ] 

637 

638 

639def _run_activations_main_mocked( 

640 save_path: Path, 

641 prompts: list[dict], 

642 force: bool, 

643 batch_size: int = 32, 

644) -> mock.MagicMock: 

645 """Run activations_main with all heavy dependencies mocked. 

646 

647 Returns the mock for compute_activations_batched so callers can inspect calls. 

648 """ 

649 mock_model = MockHookedTransformerBatched() 

650 

651 with ( 

652 mock.patch("pattern_lens.activations.HookedTransformer") as mock_ht_cls, 

653 mock.patch("pattern_lens.activations.load_text_data", return_value=prompts), 

654 mock.patch("pattern_lens.activations.write_html_index"), 

655 mock.patch("pattern_lens.activations.generate_models_jsonl"), 

656 mock.patch("pattern_lens.activations.generate_prompts_jsonl"), 

657 mock.patch("pattern_lens.activations.asdict", return_value={}), 

658 mock.patch("pattern_lens.activations.json_serialize", return_value={}), 

659 mock.patch( 

660 "pattern_lens.activations.compute_activations_batched", 

661 wraps=compute_activations_batched, 

662 ) as spy_batched, 

663 ): 

664 mock_ht_cls.from_pretrained.return_value = mock_model 

665 

666 activations_main( 

667 model_name="test-model", 

668 save_path=str(save_path), 

669 prompts_path="dummy.jsonl", 

670 raw_prompts=True, 

671 min_chars=0, 

672 max_chars=9999, 

673 force=force, 

674 n_samples=len(prompts), 

675 no_index_html=True, 

676 device=torch.device("cpu"), 

677 batch_size=batch_size, 

678 ) 

679 

680 return spy_batched 

681 

682 

683def test_activations_main_partial_cache_skip(): 

684 """With force=False, only uncached prompts are computed.""" 

685 temp_dir = TEMP_DIR / "test_main_partial_cache" 

686 if temp_dir.exists(): 

687 shutil.rmtree(temp_dir) 

688 model = MockHookedTransformerBatched() 

689 all_prompts = _make_5_prompts() 

690 

691 # Pre-compute activations for the first 2 prompts 

692 pre_cached = [dict(p) for p in all_prompts[:2]] 

693 for p in pre_cached: 

694 augment_prompt_with_hash(p) 

695 compute_activations_batched( 

696 prompts=pre_cached, 

697 model=model, # type: ignore[arg-type] 

698 save_path=temp_dir, 

699 ) 

700 

701 # Verify they're cached 

702 for p in pre_cached: 

703 assert activations_exist("test-model", p, temp_dir) 

704 

705 # Run activations_main with force=False — should skip the 2 cached ones 

706 spy = _run_activations_main_mocked( 

707 save_path=temp_dir, 

708 prompts=[dict(p) for p in all_prompts], 

709 force=False, 

710 batch_size=32, 

711 ) 

712 

713 # compute_activations_batched should have been called with only 3 uncached prompts 

714 assert spy.call_count == 1 

715 called_prompts = spy.call_args[1]["prompts"] 

716 assert len(called_prompts) == 3 

717 

718 # All 5 should now be cached 

719 for p in all_prompts: 

720 augment_prompt_with_hash(p) 

721 assert activations_exist("test-model", p, temp_dir), ( 

722 f"prompt {p['text']!r} not cached after activations_main" 

723 ) 

724 

725 

726def test_activations_main_full_cache_skip(): 

727 """With force=False and all prompts cached, compute_activations_batched is never called.""" 

728 temp_dir = TEMP_DIR / "test_main_full_cache" 

729 if temp_dir.exists(): 

730 shutil.rmtree(temp_dir) 

731 model = MockHookedTransformerBatched() 

732 all_prompts = _make_5_prompts() 

733 

734 # Pre-compute activations for ALL prompts 

735 pre_cached = [dict(p) for p in all_prompts] 

736 for p in pre_cached: 

737 augment_prompt_with_hash(p) 

738 compute_activations_batched( 

739 prompts=pre_cached, 

740 model=model, # type: ignore[arg-type] 

741 save_path=temp_dir, 

742 ) 

743 

744 # Run activations_main with force=False — should skip everything 

745 spy = _run_activations_main_mocked( 

746 save_path=temp_dir, 

747 prompts=[dict(p) for p in all_prompts], 

748 force=False, 

749 ) 

750 

751 # compute_activations_batched should never have been called 

752 assert spy.call_count == 0 

753 

754 

755# ============================================================================ 

756# Test: names_filter as a callable (not regex) 

757# ============================================================================ 

758 

759 

760def test_names_filter_callable(): 

761 """Passing a plain callable as names_filter exercises the non-regex branch.""" 

762 temp_dir = TEMP_DIR / "test_names_filter_callable" 

763 model = MockHookedTransformerBatched() 

764 prompts = _make_prompts() 

765 

766 def my_filter(key: str) -> bool: 

767 return "hook_pattern" in key 

768 

769 paths = compute_activations_batched( 

770 prompts=prompts, 

771 model=model, # type: ignore[arg-type] 

772 save_path=temp_dir, 

773 names_filter=my_filter, 

774 ) 

775 

776 assert len(paths) == 3 

777 for prompt, path in zip(prompts, paths, strict=True): 

778 assert path.exists() 

779 expected_seq_len: int = _expected_seq_len(prompt["text"]) 

780 with np.load(path) as data: 

781 for layer in range(model.cfg.n_layers): 

782 key = f"blocks.{layer}.attn.hook_pattern" 

783 assert key in data 

784 assert data[key].shape == ( 

785 1, 

786 model.cfg.n_heads, 

787 expected_seq_len, 

788 expected_seq_len, 

789 ) 

790 

791 

792# ============================================================================ 

793# Test: activations_exist raises when hash is missing 

794# ============================================================================ 

795 

796 

797def test_activations_exist_requires_hash(): 

798 """activations_exist raises InvalidPromptError when prompt has no hash.""" 

799 with pytest.raises(InvalidPromptError, match="must have 'hash' key"): 

800 activations_exist("test-model", {"text": "no hash here"}, TEMP_DIR) 

801 

802 

803# ============================================================================ 

804# Test: compute_activations_batched input validation 

805# ============================================================================ 

806 

807 

808def test_compute_activations_batched_missing_text_raises(): 

809 """Prompt without 'text' key raises AssertionError.""" 

810 model = MockHookedTransformerBatched() 

811 with pytest.raises(AssertionError, match="text"): 

812 compute_activations_batched( 

813 prompts=[{"hash": "abc"}], 

814 model=model, # type: ignore[arg-type] 

815 save_path=TEMP_DIR / "test_missing_text", 

816 ) 

817 

818 

819def test_compute_activations_batched_missing_hash_raises(): 

820 """Prompt without 'hash' key raises AssertionError.""" 

821 model = MockHookedTransformerBatched() 

822 with pytest.raises(AssertionError, match="hash"): 

823 compute_activations_batched( 

824 prompts=[{"text": "hello"}], 

825 model=model, # type: ignore[arg-type] 

826 save_path=TEMP_DIR / "test_missing_hash", 

827 ) 

828 

829 

830# ============================================================================ 

831# Test: activations_main with batch_size=1 

832# ============================================================================ 

833 

834 

835def test_activations_main_batch_size_1(): 

836 """batch_size=1 processes each prompt individually (one call per prompt).""" 

837 temp_dir = TEMP_DIR / "test_main_batch_size_1" 

838 if temp_dir.exists(): 

839 shutil.rmtree(temp_dir) 

840 all_prompts = _make_5_prompts() 

841 

842 spy = _run_activations_main_mocked( 

843 save_path=temp_dir, 

844 prompts=[dict(p) for p in all_prompts], 

845 force=True, 

846 batch_size=1, 

847 ) 

848 

849 # With batch_size=1 and 5 prompts, should be called 5 times (one prompt each) 

850 assert spy.call_count == 5 

851 for call in spy.call_args_list: 

852 assert len(call[1]["prompts"]) == 1 

853 

854 # All 5 should now be cached 

855 for p in all_prompts: 

856 augment_prompt_with_hash(p) 

857 assert activations_exist("test-model", p, temp_dir) 

858 

859 

860# ============================================================================ 

861# Test: activations_main with force=True recomputes cached prompts 

862# ============================================================================ 

863 

864 

865def test_activations_main_force_recomputes(): 

866 """With force=True, all prompts are recomputed even if already cached.""" 

867 temp_dir = TEMP_DIR / "test_main_force_recompute" 

868 if temp_dir.exists(): 

869 shutil.rmtree(temp_dir) 

870 model = MockHookedTransformerBatched() 

871 all_prompts = _make_5_prompts() 

872 

873 # Pre-compute activations for ALL prompts 

874 pre_cached = [dict(p) for p in all_prompts] 

875 for p in pre_cached: 

876 augment_prompt_with_hash(p) 

877 compute_activations_batched( 

878 prompts=pre_cached, 

879 model=model, # type: ignore[arg-type] 

880 save_path=temp_dir, 

881 ) 

882 

883 # Verify they're all cached 

884 for p in pre_cached: 

885 assert activations_exist("test-model", p, temp_dir) 

886 

887 # Run with force=True — should recompute all 5 

888 spy = _run_activations_main_mocked( 

889 save_path=temp_dir, 

890 prompts=[dict(p) for p in all_prompts], 

891 force=True, 

892 ) 

893 

894 # compute_activations_batched should have been called with all 5 prompts 

895 total_computed = sum(len(call[1]["prompts"]) for call in spy.call_args_list) 

896 assert total_computed == 5 

897 

898 

899# ============================================================================ 

900# Test: activations_main sorts prompts by length (longest first) 

901# ============================================================================ 

902 

903 

904def test_activations_main_sorts_by_length(): 

905 """Prompts are sorted longest-first within each batch for padding efficiency.""" 

906 temp_dir = TEMP_DIR / "test_main_sorts_by_length" 

907 if temp_dir.exists(): 

908 shutil.rmtree(temp_dir) 

909 

910 # Deliberately pass prompts in SHORT-first order 

911 prompts = [ 

912 {"text": "a"}, # shortest 

913 {"text": "bb"}, 

914 {"text": "ccc"}, 

915 {"text": "dddd"}, 

916 {"text": "eeeee"}, # longest 

917 ] 

918 

919 spy = _run_activations_main_mocked( 

920 save_path=temp_dir, 

921 prompts=[dict(p) for p in prompts], 

922 force=True, 

923 batch_size=100, # large enough to fit all in one batch 

924 ) 

925 

926 # All prompts in one call — check they're sorted longest-first 

927 assert spy.call_count == 1 

928 called_prompts = spy.call_args[1]["prompts"] 

929 called_texts = [p["text"] for p in called_prompts] 

930 called_lengths = [len(t) for t in called_texts] 

931 assert called_lengths == sorted(called_lengths, reverse=True), ( 

932 f"Prompts not sorted longest-first: {called_texts}" 

933 ) 

934 

935 # seq_lens must be sorted in the same order as prompts 

936 called_seq_lens = spy.call_args[1]["seq_lens"] 

937 assert called_seq_lens == sorted(called_seq_lens, reverse=True), ( 

938 f"seq_lens not sorted longest-first: {called_seq_lens}" 

939 ) 

940 

941 

942# ============================================================================ 

943# Test: activations_main splits into multiple batches correctly 

944# ============================================================================ 

945 

946 

947def test_activations_main_multiple_batches(): 

948 """batch_size=2 with 5 prompts produces 3 batched calls (2+2+1).""" 

949 temp_dir = TEMP_DIR / "test_main_multi_batch" 

950 if temp_dir.exists(): 

951 shutil.rmtree(temp_dir) 

952 all_prompts = _make_5_prompts() 

953 

954 spy = _run_activations_main_mocked( 

955 save_path=temp_dir, 

956 prompts=[dict(p) for p in all_prompts], 

957 force=True, 

958 batch_size=2, 

959 ) 

960 

961 # 5 prompts / batch_size=2 => 3 calls: 2, 2, 1 

962 assert spy.call_count == 3 

963 batch_sizes = [len(call[1]["prompts"]) for call in spy.call_args_list] 

964 assert batch_sizes == [2, 2, 1], f"Unexpected batch sizes: {batch_sizes}" 

965 

966 # All 5 should be cached 

967 for p in all_prompts: 

968 augment_prompt_with_hash(p) 

969 assert activations_exist("test-model", p, temp_dir)