Coverage for src / tracekit / inference / alignment.py: 99%

286 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-11 23:04 +0000

1"""Sequence alignment algorithms for binary message comparison. 

2 

3Requirements addressed: PSI-003 

4 

5This module applies sequence alignment algorithms to compare binary messages 

6for identifying common structures and variations. 

7 

8Key capabilities: 

9- Needleman-Wunsch for global alignment 

10- Smith-Waterman for local alignment 

11- Multiple sequence alignment 

12- Conserved/variable region detection 

13""" 

14 

15from dataclasses import dataclass 

16from typing import Any, Literal 

17 

18import numpy as np 

19from numpy.typing import NDArray 

20 

21 

22@dataclass 

23class AlignmentResult: 

24 """Result of sequence alignment. 

25 

26 : Alignment result representation. 

27 

28 Attributes: 

29 aligned_a: Aligned sequence A (with gaps as -1) 

30 aligned_b: Aligned sequence B (with gaps as -1) 

31 score: Alignment score 

32 similarity: Similarity ratio (0-1) 

33 identity: Fraction of identical positions 

34 gaps: Number of gap positions 

35 conserved_regions: List of (start, end) tuples for conserved regions 

36 variable_regions: List of (start, end) tuples for variable regions 

37 """ 

38 

39 aligned_a: bytes | list[int] # Aligned sequence A (with gaps as -1) 

40 aligned_b: bytes | list[int] # Aligned sequence B (with gaps as -1) 

41 score: float 

42 similarity: float # 0-1 

43 identity: float # Fraction of identical positions 

44 gaps: int # Number of gap positions 

45 conserved_regions: list[tuple[int, int]] # (start, end) of conserved regions 

46 variable_regions: list[tuple[int, int]] # (start, end) of variable regions 

47 

48 

49def align_global( 

50 seq_a: bytes | NDArray[Any], 

51 seq_b: bytes | NDArray[Any], 

52 gap_penalty: float = -1.0, 

53 match_score: float = 1.0, 

54 mismatch_penalty: float = -1.0, 

55) -> AlignmentResult: 

56 """Global alignment using Needleman-Wunsch algorithm. 

57 

58 : Needleman-Wunsch global alignment (O(mn) complexity). 

59 

60 Args: 

61 seq_a: First sequence (bytes or array) 

62 seq_b: Second sequence (bytes or array) 

63 gap_penalty: Penalty for gaps 

64 match_score: Score for matching positions 

65 mismatch_penalty: Penalty for mismatches 

66 

67 Returns: 

68 AlignmentResult with aligned sequences and statistics 

69 """ 

70 # Convert to arrays 

71 if isinstance(seq_a, bytes): 

72 arr_a = np.frombuffer(seq_a, dtype=np.uint8) 

73 else: 

74 arr_a = np.array(seq_a, dtype=np.uint8) 

75 

76 if isinstance(seq_b, bytes): 

77 arr_b = np.frombuffer(seq_b, dtype=np.uint8) 

78 else: 

79 arr_b = np.array(seq_b, dtype=np.uint8) 

80 

81 n, m = len(arr_a), len(arr_b) 

82 

83 # Initialize scoring matrix and traceback matrix 

84 score_matrix = np.zeros((n + 1, m + 1), dtype=np.float32) 

85 traceback = np.zeros((n + 1, m + 1), dtype=np.int8) 

86 

87 # Initialize first row and column with gap penalties 

88 for i in range(1, n + 1): 

89 score_matrix[i, 0] = i * gap_penalty 

90 traceback[i, 0] = 1 # Up (gap in seq_b) 

91 

92 for j in range(1, m + 1): 

93 score_matrix[0, j] = j * gap_penalty 

94 traceback[0, j] = 2 # Left (gap in seq_a) 

95 

96 # Fill the matrices 

97 for i in range(1, n + 1): 

98 for j in range(1, m + 1): 

99 # Match/mismatch 

100 if arr_a[i - 1] == arr_b[j - 1]: 

101 diag_score = score_matrix[i - 1, j - 1] + match_score 

102 else: 

103 diag_score = score_matrix[i - 1, j - 1] + mismatch_penalty 

104 

105 # Gap in seq_b (up) 

106 up_score = score_matrix[i - 1, j] + gap_penalty 

107 

108 # Gap in seq_a (left) 

109 left_score = score_matrix[i, j - 1] + gap_penalty 

110 

111 # Choose best 

112 max_score = max(diag_score, up_score, left_score) 

113 score_matrix[i, j] = max_score 

114 

115 if max_score == diag_score: 

116 traceback[i, j] = 0 # Diagonal 

117 elif max_score == up_score: 

118 traceback[i, j] = 1 # Up 

119 else: 

120 traceback[i, j] = 2 # Left 

121 

122 # Traceback to get alignment 

123 aligned_a = [] 

124 aligned_b = [] 

125 

126 i, j = n, m 

127 while i > 0 or j > 0: 

128 if traceback[i, j] == 0: # Diagonal 

129 aligned_a.append(int(arr_a[i - 1])) 

130 aligned_b.append(int(arr_b[j - 1])) 

131 i -= 1 

132 j -= 1 

133 elif traceback[i, j] == 1: # Up 

134 aligned_a.append(int(arr_a[i - 1])) 

135 aligned_b.append(-1) # Gap 

136 i -= 1 

137 else: # Left 

138 aligned_a.append(-1) # Gap 

139 aligned_b.append(int(arr_b[j - 1])) 

140 j -= 1 

141 

142 # Reverse (we traced backwards) 

143 aligned_a = list(reversed(aligned_a)) 

144 aligned_b = list(reversed(aligned_b)) 

145 

146 # Calculate statistics 

147 final_score = float(score_matrix[n, m]) 

148 similarity = compute_similarity(aligned_a, aligned_b) 

149 

150 # Handle empty alignments 

151 if len(aligned_a) == 0: 

152 identity = 0.0 

153 gaps = 0 

154 else: 

155 identity = sum( 

156 1 for a, b in zip(aligned_a, aligned_b, strict=True) if a == b and a != -1 

157 ) / len(aligned_a) 

158 gaps = sum(1 for a, b in zip(aligned_a, aligned_b, strict=True) if a == -1 or b == -1) 

159 

160 # Find conserved and variable regions 

161 conserved = _find_conserved_simple(aligned_a, aligned_b) 

162 variable = _find_variable_simple(aligned_a, aligned_b) 

163 

164 return AlignmentResult( 

165 aligned_a=aligned_a, 

166 aligned_b=aligned_b, 

167 score=final_score, 

168 similarity=similarity, 

169 identity=identity, 

170 gaps=gaps, 

171 conserved_regions=conserved, 

172 variable_regions=variable, 

173 ) 

174 

175 

176def align_local( 

177 seq_a: bytes | NDArray[Any], 

178 seq_b: bytes | NDArray[Any], 

179 gap_penalty: float = -1.0, 

180 match_score: float = 2.0, 

181 mismatch_penalty: float = -1.0, 

182) -> AlignmentResult: 

183 """Local alignment using Smith-Waterman algorithm. 

184 

185 : Smith-Waterman local alignment (O(mn) complexity). 

186 

187 Args: 

188 seq_a: First sequence 

189 seq_b: Second sequence 

190 gap_penalty: Penalty for gaps 

191 match_score: Score for matches 

192 mismatch_penalty: Penalty for mismatches 

193 

194 Returns: 

195 AlignmentResult with best local alignment 

196 """ 

197 # Convert to arrays 

198 if isinstance(seq_a, bytes): 

199 arr_a = np.frombuffer(seq_a, dtype=np.uint8) 

200 else: 

201 arr_a = np.array(seq_a, dtype=np.uint8) 

202 

203 if isinstance(seq_b, bytes): 

204 arr_b = np.frombuffer(seq_b, dtype=np.uint8) 

205 else: 

206 arr_b = np.array(seq_b, dtype=np.uint8) 

207 

208 n, m = len(arr_a), len(arr_b) 

209 

210 # Initialize scoring matrix and traceback matrix 

211 score_matrix = np.zeros((n + 1, m + 1), dtype=np.float32) 

212 traceback = np.zeros((n + 1, m + 1), dtype=np.int8) 

213 

214 # Track maximum score position 

215 max_score = 0.0 

216 max_i, max_j = 0, 0 

217 

218 # Fill the matrices (Smith-Waterman: no negative scores) 

219 for i in range(1, n + 1): 

220 for j in range(1, m + 1): 

221 # Match/mismatch 

222 if arr_a[i - 1] == arr_b[j - 1]: 

223 diag_score = score_matrix[i - 1, j - 1] + match_score 

224 else: 

225 diag_score = score_matrix[i - 1, j - 1] + mismatch_penalty 

226 

227 # Gap in seq_b (up) 

228 up_score = score_matrix[i - 1, j] + gap_penalty 

229 

230 # Gap in seq_a (left) 

231 left_score = score_matrix[i, j - 1] + gap_penalty 

232 

233 # Smith-Waterman: can start fresh (score = 0) 

234 cell_score = max(0.0, diag_score, up_score, left_score) 

235 score_matrix[i, j] = cell_score 

236 

237 if cell_score == 0: 

238 traceback[i, j] = -1 # Stop 

239 elif cell_score == diag_score: 

240 traceback[i, j] = 0 # Diagonal 

241 elif cell_score == up_score: 

242 traceback[i, j] = 1 # Up 

243 else: 

244 traceback[i, j] = 2 # Left 

245 

246 # Track maximum 

247 if cell_score > max_score: 

248 max_score = cell_score 

249 max_i, max_j = i, j 

250 

251 # Traceback from max position 

252 aligned_a = [] 

253 aligned_b = [] 

254 

255 i, j = max_i, max_j 

256 while i > 0 and j > 0 and traceback[i, j] != -1: 

257 if traceback[i, j] == 0: # Diagonal 

258 aligned_a.append(int(arr_a[i - 1])) 

259 aligned_b.append(int(arr_b[j - 1])) 

260 i -= 1 

261 j -= 1 

262 elif traceback[i, j] == 1: # Up 262 ↛ 267line 262 didn't jump to line 267 because the condition on line 262 was always true

263 aligned_a.append(int(arr_a[i - 1])) 

264 aligned_b.append(-1) # Gap 

265 i -= 1 

266 else: # Left 

267 aligned_a.append(-1) # Gap 

268 aligned_b.append(int(arr_b[j - 1])) 

269 j -= 1 

270 

271 # Reverse 

272 aligned_a = list(reversed(aligned_a)) 

273 aligned_b = list(reversed(aligned_b)) 

274 

275 # Calculate statistics 

276 if len(aligned_a) > 0: 

277 similarity = compute_similarity(aligned_a, aligned_b) 

278 identity = sum( 

279 1 for a, b in zip(aligned_a, aligned_b, strict=True) if a == b and a != -1 

280 ) / len(aligned_a) 

281 gaps = sum(1 for a, b in zip(aligned_a, aligned_b, strict=True) if a == -1 or b == -1) 

282 else: 

283 similarity = 0.0 

284 identity = 0.0 

285 gaps = 0 

286 

287 # Find conserved and variable regions 

288 conserved = _find_conserved_simple(aligned_a, aligned_b) 

289 variable = _find_variable_simple(aligned_a, aligned_b) 

290 

291 return AlignmentResult( 

292 aligned_a=aligned_a, 

293 aligned_b=aligned_b, 

294 score=float(max_score), 

295 similarity=similarity, 

296 identity=identity, 

297 gaps=gaps, 

298 conserved_regions=conserved, 

299 variable_regions=variable, 

300 ) 

301 

302 

303def align_multiple( 

304 sequences: list[bytes | NDArray[Any]], 

305 method: Literal["progressive", "iterative"] = "progressive", 

306) -> list[list[int]]: 

307 """Multiple sequence alignment. 

308 

309 : Progressive MSA using guide tree and pairwise alignment. 

310 

311 Args: 

312 sequences: List of sequences (bytes or arrays) 

313 method: Alignment method ('progressive' or 'iterative') 

314 

315 Returns: 

316 List of aligned sequences (as lists with -1 for gaps) 

317 """ 

318 if len(sequences) == 0: 

319 return [] 

320 if len(sequences) == 1: 

321 # Convert to list 

322 if isinstance(sequences[0], bytes): 322 ↛ 325line 322 didn't jump to line 325 because the condition on line 322 was always true

323 return [list(np.frombuffer(sequences[0], dtype=np.uint8))] 

324 else: 

325 return [list(sequences[0])] 

326 

327 # Progressive alignment 

328 if method == "progressive": 

329 # Start with first two sequences 

330 result = align_global(sequences[0], sequences[1]) 

331 # Convert to list[int] if needed 

332 aligned_a_list = ( 

333 list(result.aligned_a) if isinstance(result.aligned_a, bytes) else result.aligned_a 

334 ) 

335 aligned_b_list = ( 

336 list(result.aligned_b) if isinstance(result.aligned_b, bytes) else result.aligned_b 

337 ) 

338 aligned: list[list[int]] = [aligned_a_list, aligned_b_list] 

339 

340 # Add remaining sequences one by one 

341 for seq in sequences[2:]: 

342 # Align seq to consensus of current alignment 

343 consensus_seq = _compute_consensus(aligned) 

344 consensus_bytes = bytes([v if v != -1 else 0 for v in consensus_seq]) 

345 result = align_global(consensus_bytes, seq) 

346 

347 # Insert gaps in existing alignments 

348 new_aligned: list[list[int]] = [] 

349 result_a_list = ( 

350 list(result.aligned_a) if isinstance(result.aligned_a, bytes) else result.aligned_a 

351 ) 

352 for existing in aligned: 

353 new_seq = _insert_gaps_from_alignment(existing, result_a_list) 

354 new_aligned.append(new_seq) 

355 

356 # Add new sequence 

357 result_b_list = ( 

358 list(result.aligned_b) if isinstance(result.aligned_b, bytes) else result.aligned_b 

359 ) 

360 new_aligned.append(result_b_list) 

361 aligned = new_aligned 

362 

363 return aligned 

364 else: 

365 # Iterative not implemented, fall back to progressive 

366 return align_multiple(sequences, method="progressive") 

367 

368 

369def compute_similarity(aligned_a: bytes | list[int], aligned_b: bytes | list[int]) -> float: 

370 """Compute similarity between aligned sequences. 

371 

372 : Similarity calculation. 

373 

374 Args: 

375 aligned_a: First aligned sequence 

376 aligned_b: Second aligned sequence 

377 

378 Returns: 

379 Similarity ratio (0-1) 

380 

381 Raises: 

382 ValueError: If aligned sequences have different lengths. 

383 """ 

384 if len(aligned_a) != len(aligned_b): 

385 raise ValueError("Aligned sequences must have same length") 

386 

387 if len(aligned_a) == 0: 

388 return 0.0 

389 

390 matches = 0 

391 total = 0 

392 

393 for a, b in zip(aligned_a, aligned_b, strict=True): 

394 # Skip double gaps 

395 if a == -1 and b == -1: 

396 continue 

397 

398 total += 1 

399 if a == b and a != -1: 

400 matches += 1 

401 

402 if total == 0: 

403 return 0.0 

404 

405 return matches / total 

406 

407 

408def find_conserved_regions( 

409 aligned_sequences: list[list[int]], min_conservation: float = 0.9, min_length: int = 4 

410) -> list[tuple[int, int]]: 

411 """Find highly conserved regions in aligned sequences. 

412 

413 : Conserved region detection. 

414 

415 Args: 

416 aligned_sequences: List of aligned sequences 

417 min_conservation: Minimum conservation ratio (0-1) 

418 min_length: Minimum region length 

419 

420 Returns: 

421 List of (start, end) tuples for conserved regions 

422 """ 

423 if not aligned_sequences: 

424 return [] 

425 

426 length = len(aligned_sequences[0]) 

427 _num_seqs = len(aligned_sequences) 

428 

429 # Calculate conservation at each position 

430 conservation = [] 

431 for pos in range(length): 

432 values = [seq[pos] for seq in aligned_sequences if pos < len(seq)] 

433 

434 # Skip gaps 

435 non_gap_values = [v for v in values if v != -1] 

436 

437 if len(non_gap_values) == 0: 

438 conservation.append(0.0) 

439 continue 

440 

441 # Count most common value 

442 from collections import Counter 

443 

444 counts = Counter(non_gap_values) 

445 most_common_count = counts.most_common(1)[0][1] 

446 

447 cons = most_common_count / len(non_gap_values) 

448 conservation.append(cons) 

449 

450 # Find regions above threshold 

451 regions = [] 

452 start = None 

453 

454 for i, cons in enumerate(conservation): 

455 if cons >= min_conservation: 

456 if start is None: 

457 start = i 

458 else: 

459 if start is not None: 

460 if i - start >= min_length: 

461 regions.append((start, i)) 

462 start = None 

463 

464 # Handle region at end 

465 if start is not None and length - start >= min_length: 

466 regions.append((start, length)) 

467 

468 return regions 

469 

470 

471def find_variable_regions( 

472 aligned_sequences: list[list[int]], max_conservation: float = 0.5, min_length: int = 2 

473) -> list[tuple[int, int]]: 

474 """Find highly variable regions in aligned sequences. 

475 

476 : Variable region detection. 

477 

478 Args: 

479 aligned_sequences: List of aligned sequences 

480 max_conservation: Maximum conservation ratio (0-1) 

481 min_length: Minimum region length 

482 

483 Returns: 

484 List of (start, end) tuples for variable regions 

485 """ 

486 if not aligned_sequences: 

487 return [] 

488 

489 length = len(aligned_sequences[0]) 

490 

491 # Calculate conservation at each position 

492 conservation = [] 

493 for pos in range(length): 

494 values = [seq[pos] for seq in aligned_sequences if pos < len(seq)] 

495 

496 # Skip gaps 

497 non_gap_values = [v for v in values if v != -1] 

498 

499 if len(non_gap_values) == 0: 

500 conservation.append(1.0) # All gaps = conserved 

501 continue 

502 

503 # Count most common value 

504 from collections import Counter 

505 

506 counts = Counter(non_gap_values) 

507 most_common_count = counts.most_common(1)[0][1] 

508 

509 cons = most_common_count / len(non_gap_values) 

510 conservation.append(cons) 

511 

512 # Find regions below threshold 

513 regions = [] 

514 start = None 

515 

516 for i, cons in enumerate(conservation): 

517 if cons <= max_conservation: 

518 if start is None: 

519 start = i 

520 else: 

521 if start is not None: 

522 if i - start >= min_length: 

523 regions.append((start, i)) 

524 start = None 

525 

526 # Handle region at end 

527 if start is not None and length - start >= min_length: 

528 regions.append((start, length)) 

529 

530 return regions 

531 

532 

533def _find_conserved_simple(aligned_a: list[int], aligned_b: list[int]) -> list[tuple[int, int]]: 

534 """Find conserved regions in pairwise alignment. 

535 

536 Args: 

537 aligned_a: First aligned sequence 

538 aligned_b: Second aligned sequence 

539 

540 Returns: 

541 List of (start, end) tuples 

542 """ 

543 regions = [] 

544 start = None 

545 

546 for i, (a, b) in enumerate(zip(aligned_a, aligned_b, strict=True)): 

547 if a == b and a != -1: 

548 if start is None: 

549 start = i 

550 else: 

551 if start is not None: 

552 if i - start >= 4: # Min length 4 

553 regions.append((start, i)) 

554 start = None 

555 

556 # Handle region at end 

557 if start is not None and len(aligned_a) - start >= 4: 

558 regions.append((start, len(aligned_a))) 

559 

560 return regions 

561 

562 

563def _find_variable_simple(aligned_a: list[int], aligned_b: list[int]) -> list[tuple[int, int]]: 

564 """Find variable regions in pairwise alignment. 

565 

566 Args: 

567 aligned_a: First aligned sequence 

568 aligned_b: Second aligned sequence 

569 

570 Returns: 

571 List of (start, end) tuples 

572 """ 

573 regions = [] 

574 start = None 

575 

576 for i, (a, b) in enumerate(zip(aligned_a, aligned_b, strict=True)): 

577 if a != b: 

578 if start is None: 

579 start = i 

580 else: 

581 if start is not None: 

582 if i - start >= 2: # Min length 2 

583 regions.append((start, i)) 

584 start = None 

585 

586 # Handle region at end 

587 if start is not None and len(aligned_a) - start >= 2: 

588 regions.append((start, len(aligned_a))) 

589 

590 return regions 

591 

592 

593def _compute_consensus(aligned_sequences: list[list[int]]) -> list[int]: 

594 """Compute consensus sequence from multiple aligned sequences. 

595 

596 Args: 

597 aligned_sequences: List of aligned sequences 

598 

599 Returns: 

600 Consensus sequence 

601 """ 

602 if not aligned_sequences: 

603 return [] 

604 

605 length = max(len(seq) for seq in aligned_sequences) 

606 consensus = [] 

607 

608 for pos in range(length): 

609 values = [seq[pos] for seq in aligned_sequences if pos < len(seq)] 

610 

611 # Skip gaps when computing consensus 

612 non_gap_values = [v for v in values if v != -1] 

613 

614 if non_gap_values: 

615 # Most common value 

616 from collections import Counter 

617 

618 counts = Counter(non_gap_values) 

619 consensus_val = counts.most_common(1)[0][0] 

620 consensus.append(consensus_val) 

621 else: 

622 # All gaps 

623 consensus.append(-1) 

624 

625 return consensus 

626 

627 

628def _insert_gaps_from_alignment(sequence: list[int], alignment_template: list[int]) -> list[int]: 

629 """Insert gaps into sequence based on alignment template. 

630 

631 Args: 

632 sequence: Original sequence 

633 alignment_template: Template showing where gaps should be 

634 

635 Returns: 

636 Sequence with gaps inserted 

637 """ 

638 result = [] 

639 seq_idx = 0 

640 

641 for template_val in alignment_template: 

642 if template_val == -1: 

643 # Gap in template, insert gap 

644 result.append(-1) 

645 else: 

646 # Non-gap, copy from sequence 

647 if seq_idx < len(sequence): 

648 result.append(sequence[seq_idx]) 

649 seq_idx += 1 

650 else: 

651 result.append(-1) 

652 

653 return result