Coverage for src / tracekit / inference / alignment.py: 99%
286 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-11 23:04 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-11 23:04 +0000
1"""Sequence alignment algorithms for binary message comparison.
3Requirements addressed: PSI-003
5This module applies sequence alignment algorithms to compare binary messages
6for identifying common structures and variations.
8Key capabilities:
9- Needleman-Wunsch for global alignment
10- Smith-Waterman for local alignment
11- Multiple sequence alignment
12- Conserved/variable region detection
13"""
15from dataclasses import dataclass
16from typing import Any, Literal
18import numpy as np
19from numpy.typing import NDArray
22@dataclass
23class AlignmentResult:
24 """Result of sequence alignment.
26 : Alignment result representation.
28 Attributes:
29 aligned_a: Aligned sequence A (with gaps as -1)
30 aligned_b: Aligned sequence B (with gaps as -1)
31 score: Alignment score
32 similarity: Similarity ratio (0-1)
33 identity: Fraction of identical positions
34 gaps: Number of gap positions
35 conserved_regions: List of (start, end) tuples for conserved regions
36 variable_regions: List of (start, end) tuples for variable regions
37 """
39 aligned_a: bytes | list[int] # Aligned sequence A (with gaps as -1)
40 aligned_b: bytes | list[int] # Aligned sequence B (with gaps as -1)
41 score: float
42 similarity: float # 0-1
43 identity: float # Fraction of identical positions
44 gaps: int # Number of gap positions
45 conserved_regions: list[tuple[int, int]] # (start, end) of conserved regions
46 variable_regions: list[tuple[int, int]] # (start, end) of variable regions
49def align_global(
50 seq_a: bytes | NDArray[Any],
51 seq_b: bytes | NDArray[Any],
52 gap_penalty: float = -1.0,
53 match_score: float = 1.0,
54 mismatch_penalty: float = -1.0,
55) -> AlignmentResult:
56 """Global alignment using Needleman-Wunsch algorithm.
58 : Needleman-Wunsch global alignment (O(mn) complexity).
60 Args:
61 seq_a: First sequence (bytes or array)
62 seq_b: Second sequence (bytes or array)
63 gap_penalty: Penalty for gaps
64 match_score: Score for matching positions
65 mismatch_penalty: Penalty for mismatches
67 Returns:
68 AlignmentResult with aligned sequences and statistics
69 """
70 # Convert to arrays
71 if isinstance(seq_a, bytes):
72 arr_a = np.frombuffer(seq_a, dtype=np.uint8)
73 else:
74 arr_a = np.array(seq_a, dtype=np.uint8)
76 if isinstance(seq_b, bytes):
77 arr_b = np.frombuffer(seq_b, dtype=np.uint8)
78 else:
79 arr_b = np.array(seq_b, dtype=np.uint8)
81 n, m = len(arr_a), len(arr_b)
83 # Initialize scoring matrix and traceback matrix
84 score_matrix = np.zeros((n + 1, m + 1), dtype=np.float32)
85 traceback = np.zeros((n + 1, m + 1), dtype=np.int8)
87 # Initialize first row and column with gap penalties
88 for i in range(1, n + 1):
89 score_matrix[i, 0] = i * gap_penalty
90 traceback[i, 0] = 1 # Up (gap in seq_b)
92 for j in range(1, m + 1):
93 score_matrix[0, j] = j * gap_penalty
94 traceback[0, j] = 2 # Left (gap in seq_a)
96 # Fill the matrices
97 for i in range(1, n + 1):
98 for j in range(1, m + 1):
99 # Match/mismatch
100 if arr_a[i - 1] == arr_b[j - 1]:
101 diag_score = score_matrix[i - 1, j - 1] + match_score
102 else:
103 diag_score = score_matrix[i - 1, j - 1] + mismatch_penalty
105 # Gap in seq_b (up)
106 up_score = score_matrix[i - 1, j] + gap_penalty
108 # Gap in seq_a (left)
109 left_score = score_matrix[i, j - 1] + gap_penalty
111 # Choose best
112 max_score = max(diag_score, up_score, left_score)
113 score_matrix[i, j] = max_score
115 if max_score == diag_score:
116 traceback[i, j] = 0 # Diagonal
117 elif max_score == up_score:
118 traceback[i, j] = 1 # Up
119 else:
120 traceback[i, j] = 2 # Left
122 # Traceback to get alignment
123 aligned_a = []
124 aligned_b = []
126 i, j = n, m
127 while i > 0 or j > 0:
128 if traceback[i, j] == 0: # Diagonal
129 aligned_a.append(int(arr_a[i - 1]))
130 aligned_b.append(int(arr_b[j - 1]))
131 i -= 1
132 j -= 1
133 elif traceback[i, j] == 1: # Up
134 aligned_a.append(int(arr_a[i - 1]))
135 aligned_b.append(-1) # Gap
136 i -= 1
137 else: # Left
138 aligned_a.append(-1) # Gap
139 aligned_b.append(int(arr_b[j - 1]))
140 j -= 1
142 # Reverse (we traced backwards)
143 aligned_a = list(reversed(aligned_a))
144 aligned_b = list(reversed(aligned_b))
146 # Calculate statistics
147 final_score = float(score_matrix[n, m])
148 similarity = compute_similarity(aligned_a, aligned_b)
150 # Handle empty alignments
151 if len(aligned_a) == 0:
152 identity = 0.0
153 gaps = 0
154 else:
155 identity = sum(
156 1 for a, b in zip(aligned_a, aligned_b, strict=True) if a == b and a != -1
157 ) / len(aligned_a)
158 gaps = sum(1 for a, b in zip(aligned_a, aligned_b, strict=True) if a == -1 or b == -1)
160 # Find conserved and variable regions
161 conserved = _find_conserved_simple(aligned_a, aligned_b)
162 variable = _find_variable_simple(aligned_a, aligned_b)
164 return AlignmentResult(
165 aligned_a=aligned_a,
166 aligned_b=aligned_b,
167 score=final_score,
168 similarity=similarity,
169 identity=identity,
170 gaps=gaps,
171 conserved_regions=conserved,
172 variable_regions=variable,
173 )
176def align_local(
177 seq_a: bytes | NDArray[Any],
178 seq_b: bytes | NDArray[Any],
179 gap_penalty: float = -1.0,
180 match_score: float = 2.0,
181 mismatch_penalty: float = -1.0,
182) -> AlignmentResult:
183 """Local alignment using Smith-Waterman algorithm.
185 : Smith-Waterman local alignment (O(mn) complexity).
187 Args:
188 seq_a: First sequence
189 seq_b: Second sequence
190 gap_penalty: Penalty for gaps
191 match_score: Score for matches
192 mismatch_penalty: Penalty for mismatches
194 Returns:
195 AlignmentResult with best local alignment
196 """
197 # Convert to arrays
198 if isinstance(seq_a, bytes):
199 arr_a = np.frombuffer(seq_a, dtype=np.uint8)
200 else:
201 arr_a = np.array(seq_a, dtype=np.uint8)
203 if isinstance(seq_b, bytes):
204 arr_b = np.frombuffer(seq_b, dtype=np.uint8)
205 else:
206 arr_b = np.array(seq_b, dtype=np.uint8)
208 n, m = len(arr_a), len(arr_b)
210 # Initialize scoring matrix and traceback matrix
211 score_matrix = np.zeros((n + 1, m + 1), dtype=np.float32)
212 traceback = np.zeros((n + 1, m + 1), dtype=np.int8)
214 # Track maximum score position
215 max_score = 0.0
216 max_i, max_j = 0, 0
218 # Fill the matrices (Smith-Waterman: no negative scores)
219 for i in range(1, n + 1):
220 for j in range(1, m + 1):
221 # Match/mismatch
222 if arr_a[i - 1] == arr_b[j - 1]:
223 diag_score = score_matrix[i - 1, j - 1] + match_score
224 else:
225 diag_score = score_matrix[i - 1, j - 1] + mismatch_penalty
227 # Gap in seq_b (up)
228 up_score = score_matrix[i - 1, j] + gap_penalty
230 # Gap in seq_a (left)
231 left_score = score_matrix[i, j - 1] + gap_penalty
233 # Smith-Waterman: can start fresh (score = 0)
234 cell_score = max(0.0, diag_score, up_score, left_score)
235 score_matrix[i, j] = cell_score
237 if cell_score == 0:
238 traceback[i, j] = -1 # Stop
239 elif cell_score == diag_score:
240 traceback[i, j] = 0 # Diagonal
241 elif cell_score == up_score:
242 traceback[i, j] = 1 # Up
243 else:
244 traceback[i, j] = 2 # Left
246 # Track maximum
247 if cell_score > max_score:
248 max_score = cell_score
249 max_i, max_j = i, j
251 # Traceback from max position
252 aligned_a = []
253 aligned_b = []
255 i, j = max_i, max_j
256 while i > 0 and j > 0 and traceback[i, j] != -1:
257 if traceback[i, j] == 0: # Diagonal
258 aligned_a.append(int(arr_a[i - 1]))
259 aligned_b.append(int(arr_b[j - 1]))
260 i -= 1
261 j -= 1
262 elif traceback[i, j] == 1: # Up 262 ↛ 267line 262 didn't jump to line 267 because the condition on line 262 was always true
263 aligned_a.append(int(arr_a[i - 1]))
264 aligned_b.append(-1) # Gap
265 i -= 1
266 else: # Left
267 aligned_a.append(-1) # Gap
268 aligned_b.append(int(arr_b[j - 1]))
269 j -= 1
271 # Reverse
272 aligned_a = list(reversed(aligned_a))
273 aligned_b = list(reversed(aligned_b))
275 # Calculate statistics
276 if len(aligned_a) > 0:
277 similarity = compute_similarity(aligned_a, aligned_b)
278 identity = sum(
279 1 for a, b in zip(aligned_a, aligned_b, strict=True) if a == b and a != -1
280 ) / len(aligned_a)
281 gaps = sum(1 for a, b in zip(aligned_a, aligned_b, strict=True) if a == -1 or b == -1)
282 else:
283 similarity = 0.0
284 identity = 0.0
285 gaps = 0
287 # Find conserved and variable regions
288 conserved = _find_conserved_simple(aligned_a, aligned_b)
289 variable = _find_variable_simple(aligned_a, aligned_b)
291 return AlignmentResult(
292 aligned_a=aligned_a,
293 aligned_b=aligned_b,
294 score=float(max_score),
295 similarity=similarity,
296 identity=identity,
297 gaps=gaps,
298 conserved_regions=conserved,
299 variable_regions=variable,
300 )
303def align_multiple(
304 sequences: list[bytes | NDArray[Any]],
305 method: Literal["progressive", "iterative"] = "progressive",
306) -> list[list[int]]:
307 """Multiple sequence alignment.
309 : Progressive MSA using guide tree and pairwise alignment.
311 Args:
312 sequences: List of sequences (bytes or arrays)
313 method: Alignment method ('progressive' or 'iterative')
315 Returns:
316 List of aligned sequences (as lists with -1 for gaps)
317 """
318 if len(sequences) == 0:
319 return []
320 if len(sequences) == 1:
321 # Convert to list
322 if isinstance(sequences[0], bytes): 322 ↛ 325line 322 didn't jump to line 325 because the condition on line 322 was always true
323 return [list(np.frombuffer(sequences[0], dtype=np.uint8))]
324 else:
325 return [list(sequences[0])]
327 # Progressive alignment
328 if method == "progressive":
329 # Start with first two sequences
330 result = align_global(sequences[0], sequences[1])
331 # Convert to list[int] if needed
332 aligned_a_list = (
333 list(result.aligned_a) if isinstance(result.aligned_a, bytes) else result.aligned_a
334 )
335 aligned_b_list = (
336 list(result.aligned_b) if isinstance(result.aligned_b, bytes) else result.aligned_b
337 )
338 aligned: list[list[int]] = [aligned_a_list, aligned_b_list]
340 # Add remaining sequences one by one
341 for seq in sequences[2:]:
342 # Align seq to consensus of current alignment
343 consensus_seq = _compute_consensus(aligned)
344 consensus_bytes = bytes([v if v != -1 else 0 for v in consensus_seq])
345 result = align_global(consensus_bytes, seq)
347 # Insert gaps in existing alignments
348 new_aligned: list[list[int]] = []
349 result_a_list = (
350 list(result.aligned_a) if isinstance(result.aligned_a, bytes) else result.aligned_a
351 )
352 for existing in aligned:
353 new_seq = _insert_gaps_from_alignment(existing, result_a_list)
354 new_aligned.append(new_seq)
356 # Add new sequence
357 result_b_list = (
358 list(result.aligned_b) if isinstance(result.aligned_b, bytes) else result.aligned_b
359 )
360 new_aligned.append(result_b_list)
361 aligned = new_aligned
363 return aligned
364 else:
365 # Iterative not implemented, fall back to progressive
366 return align_multiple(sequences, method="progressive")
369def compute_similarity(aligned_a: bytes | list[int], aligned_b: bytes | list[int]) -> float:
370 """Compute similarity between aligned sequences.
372 : Similarity calculation.
374 Args:
375 aligned_a: First aligned sequence
376 aligned_b: Second aligned sequence
378 Returns:
379 Similarity ratio (0-1)
381 Raises:
382 ValueError: If aligned sequences have different lengths.
383 """
384 if len(aligned_a) != len(aligned_b):
385 raise ValueError("Aligned sequences must have same length")
387 if len(aligned_a) == 0:
388 return 0.0
390 matches = 0
391 total = 0
393 for a, b in zip(aligned_a, aligned_b, strict=True):
394 # Skip double gaps
395 if a == -1 and b == -1:
396 continue
398 total += 1
399 if a == b and a != -1:
400 matches += 1
402 if total == 0:
403 return 0.0
405 return matches / total
408def find_conserved_regions(
409 aligned_sequences: list[list[int]], min_conservation: float = 0.9, min_length: int = 4
410) -> list[tuple[int, int]]:
411 """Find highly conserved regions in aligned sequences.
413 : Conserved region detection.
415 Args:
416 aligned_sequences: List of aligned sequences
417 min_conservation: Minimum conservation ratio (0-1)
418 min_length: Minimum region length
420 Returns:
421 List of (start, end) tuples for conserved regions
422 """
423 if not aligned_sequences:
424 return []
426 length = len(aligned_sequences[0])
427 _num_seqs = len(aligned_sequences)
429 # Calculate conservation at each position
430 conservation = []
431 for pos in range(length):
432 values = [seq[pos] for seq in aligned_sequences if pos < len(seq)]
434 # Skip gaps
435 non_gap_values = [v for v in values if v != -1]
437 if len(non_gap_values) == 0:
438 conservation.append(0.0)
439 continue
441 # Count most common value
442 from collections import Counter
444 counts = Counter(non_gap_values)
445 most_common_count = counts.most_common(1)[0][1]
447 cons = most_common_count / len(non_gap_values)
448 conservation.append(cons)
450 # Find regions above threshold
451 regions = []
452 start = None
454 for i, cons in enumerate(conservation):
455 if cons >= min_conservation:
456 if start is None:
457 start = i
458 else:
459 if start is not None:
460 if i - start >= min_length:
461 regions.append((start, i))
462 start = None
464 # Handle region at end
465 if start is not None and length - start >= min_length:
466 regions.append((start, length))
468 return regions
471def find_variable_regions(
472 aligned_sequences: list[list[int]], max_conservation: float = 0.5, min_length: int = 2
473) -> list[tuple[int, int]]:
474 """Find highly variable regions in aligned sequences.
476 : Variable region detection.
478 Args:
479 aligned_sequences: List of aligned sequences
480 max_conservation: Maximum conservation ratio (0-1)
481 min_length: Minimum region length
483 Returns:
484 List of (start, end) tuples for variable regions
485 """
486 if not aligned_sequences:
487 return []
489 length = len(aligned_sequences[0])
491 # Calculate conservation at each position
492 conservation = []
493 for pos in range(length):
494 values = [seq[pos] for seq in aligned_sequences if pos < len(seq)]
496 # Skip gaps
497 non_gap_values = [v for v in values if v != -1]
499 if len(non_gap_values) == 0:
500 conservation.append(1.0) # All gaps = conserved
501 continue
503 # Count most common value
504 from collections import Counter
506 counts = Counter(non_gap_values)
507 most_common_count = counts.most_common(1)[0][1]
509 cons = most_common_count / len(non_gap_values)
510 conservation.append(cons)
512 # Find regions below threshold
513 regions = []
514 start = None
516 for i, cons in enumerate(conservation):
517 if cons <= max_conservation:
518 if start is None:
519 start = i
520 else:
521 if start is not None:
522 if i - start >= min_length:
523 regions.append((start, i))
524 start = None
526 # Handle region at end
527 if start is not None and length - start >= min_length:
528 regions.append((start, length))
530 return regions
533def _find_conserved_simple(aligned_a: list[int], aligned_b: list[int]) -> list[tuple[int, int]]:
534 """Find conserved regions in pairwise alignment.
536 Args:
537 aligned_a: First aligned sequence
538 aligned_b: Second aligned sequence
540 Returns:
541 List of (start, end) tuples
542 """
543 regions = []
544 start = None
546 for i, (a, b) in enumerate(zip(aligned_a, aligned_b, strict=True)):
547 if a == b and a != -1:
548 if start is None:
549 start = i
550 else:
551 if start is not None:
552 if i - start >= 4: # Min length 4
553 regions.append((start, i))
554 start = None
556 # Handle region at end
557 if start is not None and len(aligned_a) - start >= 4:
558 regions.append((start, len(aligned_a)))
560 return regions
563def _find_variable_simple(aligned_a: list[int], aligned_b: list[int]) -> list[tuple[int, int]]:
564 """Find variable regions in pairwise alignment.
566 Args:
567 aligned_a: First aligned sequence
568 aligned_b: Second aligned sequence
570 Returns:
571 List of (start, end) tuples
572 """
573 regions = []
574 start = None
576 for i, (a, b) in enumerate(zip(aligned_a, aligned_b, strict=True)):
577 if a != b:
578 if start is None:
579 start = i
580 else:
581 if start is not None:
582 if i - start >= 2: # Min length 2
583 regions.append((start, i))
584 start = None
586 # Handle region at end
587 if start is not None and len(aligned_a) - start >= 2:
588 regions.append((start, len(aligned_a)))
590 return regions
593def _compute_consensus(aligned_sequences: list[list[int]]) -> list[int]:
594 """Compute consensus sequence from multiple aligned sequences.
596 Args:
597 aligned_sequences: List of aligned sequences
599 Returns:
600 Consensus sequence
601 """
602 if not aligned_sequences:
603 return []
605 length = max(len(seq) for seq in aligned_sequences)
606 consensus = []
608 for pos in range(length):
609 values = [seq[pos] for seq in aligned_sequences if pos < len(seq)]
611 # Skip gaps when computing consensus
612 non_gap_values = [v for v in values if v != -1]
614 if non_gap_values:
615 # Most common value
616 from collections import Counter
618 counts = Counter(non_gap_values)
619 consensus_val = counts.most_common(1)[0][0]
620 consensus.append(consensus_val)
621 else:
622 # All gaps
623 consensus.append(-1)
625 return consensus
628def _insert_gaps_from_alignment(sequence: list[int], alignment_template: list[int]) -> list[int]:
629 """Insert gaps into sequence based on alignment template.
631 Args:
632 sequence: Original sequence
633 alignment_template: Template showing where gaps should be
635 Returns:
636 Sequence with gaps inserted
637 """
638 result = []
639 seq_idx = 0
641 for template_val in alignment_template:
642 if template_val == -1:
643 # Gap in template, insert gap
644 result.append(-1)
645 else:
646 # Non-gap, copy from sequence
647 if seq_idx < len(sequence):
648 result.append(sequence[seq_idx])
649 seq_idx += 1
650 else:
651 result.append(-1)
653 return result