Coverage for src / tracekit / analyzers / packet / payload_analysis.py: 0%

476 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-11 23:04 +0000

1"""Payload field inference and comparison analysis. 

2 

3RE-PAY-004: Payload Field Inference 

4RE-PAY-005: Payload Comparison and Differential Analysis 

5 

6This module provides field structure inference, payload comparison, 

7similarity computation, and clustering for binary payloads. 

8""" 

9 

10from __future__ import annotations 

11 

12import logging 

13import struct 

14from collections import Counter 

15from collections.abc import Sequence 

16from dataclasses import dataclass, field 

17from typing import TYPE_CHECKING, Any, Literal 

18 

19import numpy as np 

20 

21if TYPE_CHECKING: 

22 from tracekit.analyzers.packet.payload_extraction import PayloadInfo 

23 

24logger = logging.getLogger(__name__) 

25 

26 

27# ============================================================================= 

28# RE-PAY-004: Field Inference Data Classes 

29# ============================================================================= 

30 

31 

32@dataclass 

33class InferredField: 

34 """Inferred field from binary payload. 

35 

36 Implements RE-PAY-004: Inferred field structure. 

37 

38 Attributes: 

39 name: Field name (auto-generated). 

40 offset: Byte offset within message. 

41 size: Field size in bytes. 

42 inferred_type: Inferred data type. 

43 endianness: Detected endianness. 

44 is_constant: Whether field is constant across messages. 

45 is_sequence: Whether field appears to be a counter/sequence. 

46 is_checksum: Whether field appears to be a checksum. 

47 constant_value: Value if constant. 

48 confidence: Inference confidence (0-1). 

49 sample_values: Sample values from messages. 

50 """ 

51 

52 name: str 

53 offset: int 

54 size: int 

55 inferred_type: Literal[ 

56 "uint8", 

57 "uint16", 

58 "uint32", 

59 "uint64", 

60 "int8", 

61 "int16", 

62 "int32", 

63 "int64", 

64 "float32", 

65 "float64", 

66 "bytes", 

67 "string", 

68 "unknown", 

69 ] 

70 endianness: Literal["big", "little", "n/a"] = "n/a" 

71 is_constant: bool = False 

72 is_sequence: bool = False 

73 is_checksum: bool = False 

74 constant_value: bytes | None = None 

75 confidence: float = 0.5 

76 sample_values: list[Any] = field(default_factory=list) 

77 

78 

79@dataclass 

80class MessageSchema: 

81 """Inferred message schema. 

82 

83 Implements RE-PAY-004: Complete message schema. 

84 

85 Attributes: 

86 fields: List of inferred fields. 

87 message_length: Total message length. 

88 fixed_length: Whether all messages have same length. 

89 length_range: (min, max) length range. 

90 sample_count: Number of samples analyzed. 

91 confidence: Overall schema confidence. 

92 """ 

93 

94 fields: list[InferredField] 

95 message_length: int 

96 fixed_length: bool 

97 length_range: tuple[int, int] 

98 sample_count: int 

99 confidence: float 

100 

101 

102# ============================================================================= 

103# RE-PAY-005: Comparison Data Classes 

104# ============================================================================= 

105 

106 

107@dataclass 

108class PayloadDiff: 

109 """Difference between two payloads. 

110 

111 Implements RE-PAY-005: Payload comparison result. 

112 

113 Attributes: 

114 common_prefix_length: Length of common prefix. 

115 common_suffix_length: Length of common suffix. 

116 differences: List of (offset, byte_a, byte_b) for differences. 

117 similarity: Similarity score (0-1). 

118 edit_distance: Levenshtein edit distance. 

119 """ 

120 

121 common_prefix_length: int 

122 common_suffix_length: int 

123 differences: list[tuple[int, int, int]] 

124 similarity: float 

125 edit_distance: int 

126 

127 

128@dataclass 

129class VariablePositions: 

130 """Analysis of which byte positions vary across payloads. 

131 

132 Implements RE-PAY-005: Variable position analysis. 

133 

134 Attributes: 

135 constant_positions: Positions that are constant. 

136 variable_positions: Positions that vary. 

137 constant_values: Values at constant positions. 

138 variance_by_position: Variance at each position. 

139 """ 

140 

141 constant_positions: list[int] 

142 variable_positions: list[int] 

143 constant_values: dict[int, int] 

144 variance_by_position: np.ndarray[tuple[int], np.dtype[np.float64]] 

145 

146 

147@dataclass 

148class PayloadCluster: 

149 """Cluster of similar payloads. 

150 

151 Implements RE-PAY-005: Payload clustering result. 

152 

153 Attributes: 

154 cluster_id: Cluster identifier. 

155 payloads: List of payload data in cluster. 

156 indices: Original indices of payloads. 

157 representative: Representative payload (centroid). 

158 size: Number of payloads in cluster. 

159 """ 

160 

161 cluster_id: int 

162 payloads: list[bytes] 

163 indices: list[int] 

164 representative: bytes 

165 size: int 

166 

167 

168# ============================================================================= 

169# RE-PAY-004: Field Inference Class 

170# ============================================================================= 

171 

172 

173class FieldInferrer: 

174 """Infer field structure within binary payloads. 

175 

176 Implements RE-PAY-004: Payload Field Inference. 

177 

178 Uses statistical analysis, alignment detection, and type inference 

179 to reconstruct message formats from binary payload samples. 

180 

181 Example: 

182 >>> inferrer = FieldInferrer() 

183 >>> messages = [pkt.data for pkt in udp_packets] 

184 >>> schema = inferrer.infer_fields(messages) 

185 >>> for field in schema.fields: 

186 ... print(f"{field.name}: {field.inferred_type} at offset {field.offset}") 

187 """ 

188 

189 def __init__( 

190 self, 

191 min_samples: int = 10, 

192 entropy_threshold: float = 0.5, 

193 sequence_threshold: int = 3, 

194 ) -> None: 

195 """Initialize field inferrer. 

196 

197 Args: 

198 min_samples: Minimum samples for reliable inference. 

199 entropy_threshold: Entropy change threshold for boundary detection. 

200 sequence_threshold: Minimum consecutive incrementing values for sequence. 

201 """ 

202 self.min_samples = min_samples 

203 self.entropy_threshold = entropy_threshold 

204 self.sequence_threshold = sequence_threshold 

205 

206 def infer_fields( 

207 self, 

208 messages: Sequence[bytes], 

209 min_samples: int | None = None, 

210 ) -> MessageSchema: 

211 """Infer field structure from message samples. 

212 

213 Implements RE-PAY-004: Complete field inference. 

214 

215 Args: 

216 messages: List of binary message samples. 

217 min_samples: Override minimum sample count. 

218 

219 Returns: 

220 MessageSchema with inferred field structure. 

221 

222 Example: 

223 >>> schema = inferrer.infer_fields(messages) 

224 >>> print(f"Detected {len(schema.fields)} fields") 

225 """ 

226 if not messages: 

227 return MessageSchema( 

228 fields=[], 

229 message_length=0, 

230 fixed_length=True, 

231 length_range=(0, 0), 

232 sample_count=0, 

233 confidence=0.0, 

234 ) 

235 

236 min_samples = min_samples or self.min_samples 

237 lengths = [len(m) for m in messages] 

238 min_len = min(lengths) 

239 max_len = max(lengths) 

240 fixed_length = min_len == max_len 

241 

242 # Use shortest message length for analysis 

243 analysis_length = min_len 

244 

245 # Find field boundaries using entropy transitions 

246 boundaries = self._detect_field_boundaries(messages, analysis_length) 

247 

248 # Infer field types for each segment 

249 fields = [] 

250 for i, (start, end) in enumerate(boundaries): 

251 field = self._infer_field(messages, start, end, i) 

252 fields.append(field) 

253 

254 # Calculate overall confidence 

255 if fields: 

256 confidence = sum(f.confidence for f in fields) / len(fields) 

257 else: 

258 confidence = 0.0 

259 

260 return MessageSchema( 

261 fields=fields, 

262 message_length=analysis_length, 

263 fixed_length=fixed_length, 

264 length_range=(min_len, max_len), 

265 sample_count=len(messages), 

266 confidence=confidence, 

267 ) 

268 

269 def detect_field_types( 

270 self, 

271 messages: Sequence[bytes], 

272 boundaries: list[tuple[int, int]], 

273 ) -> list[InferredField]: 

274 """Detect field types for given boundaries. 

275 

276 Implements RE-PAY-004: Field type detection. 

277 

278 Args: 

279 messages: Message samples. 

280 boundaries: List of (start, end) field boundaries. 

281 

282 Returns: 

283 List of InferredField with type information. 

284 """ 

285 fields = [] 

286 for i, (start, end) in enumerate(boundaries): 

287 field = self._infer_field(messages, start, end, i) 

288 fields.append(field) 

289 return fields 

290 

291 def find_sequence_fields( 

292 self, 

293 messages: Sequence[bytes], 

294 ) -> list[tuple[int, int]]: 

295 """Find fields that appear to be sequence/counter values. 

296 

297 Implements RE-PAY-004: Sequence field detection. 

298 

299 Args: 

300 messages: Message samples (should be in order). 

301 

302 Returns: 

303 List of (offset, size) for sequence fields. 

304 

305 Raises: 

306 ValueError: If a message is too short for the expected field size. 

307 """ 

308 if len(messages) < self.sequence_threshold: 

309 return [] 

310 

311 min_len = min(len(m) for m in messages) 

312 sequence_fields = [] 

313 

314 # Check each possible field size at each offset 

315 for size in [1, 2, 4]: 

316 for offset in range(min_len - size + 1): 

317 values = [] 

318 try: 

319 for msg in messages: 

320 # Validate message length before slicing 

321 if len(msg) < offset + size: 

322 raise ValueError( 

323 f"Message too short: expected at least {offset + size} bytes, " 

324 f"got {len(msg)} bytes" 

325 ) 

326 # Try both endianness 

327 val_be = int.from_bytes(msg[offset : offset + size], "big") 

328 values.append(val_be) 

329 

330 if self._is_sequence(values): 

331 sequence_fields.append((offset, size)) 

332 except (ValueError, IndexError) as e: 

333 # Skip this offset/size combination if extraction fails 

334 logger.debug(f"Skipping field at offset={offset}, size={size}: {e}") 

335 continue 

336 

337 return sequence_fields 

338 

339 def find_checksum_fields( 

340 self, 

341 messages: Sequence[bytes], 

342 ) -> list[tuple[int, int, str]]: 

343 """Find fields that appear to be checksums. 

344 

345 Implements RE-PAY-004: Checksum field detection. 

346 

347 Args: 

348 messages: Message samples. 

349 

350 Returns: 

351 List of (offset, size, algorithm_hint) for checksum fields. 

352 

353 Raises: 

354 ValueError: If checksum field offset and size exceed message length. 

355 """ 

356 if len(messages) < 5: 

357 return [] 

358 

359 min_len = min(len(m) for m in messages) 

360 checksum_fields = [] 

361 

362 # Common checksum sizes and positions 

363 for size in [1, 2, 4]: 

364 # Check last position (most common) 

365 for offset in [min_len - size, 0]: 

366 if offset < 0: 

367 continue 

368 

369 try: 

370 # Validate offset and size before processing 

371 if offset + size > min_len: 

372 raise ValueError( 

373 f"Invalid checksum field: offset={offset} + size={size} exceeds " 

374 f"minimum message length={min_len}" 

375 ) 

376 

377 # Extract field values and message content 

378 score = self._check_checksum_correlation(messages, offset, size) 

379 

380 if score > 0.8: 

381 algorithm = self._guess_checksum_algorithm(messages, offset, size) 

382 checksum_fields.append((offset, size, algorithm)) 

383 except (ValueError, IndexError) as e: 

384 # Skip this offset/size combination if validation fails 

385 logger.debug(f"Skipping checksum field at offset={offset}, size={size}: {e}") 

386 continue 

387 

388 return checksum_fields 

389 

390 def _detect_field_boundaries( 

391 self, 

392 messages: Sequence[bytes], 

393 max_length: int, 

394 ) -> list[tuple[int, int]]: 

395 """Detect field boundaries using entropy analysis. 

396 

397 Args: 

398 messages: Message samples. 

399 max_length: Maximum length to analyze. 

400 

401 Returns: 

402 List of (start, end) boundaries. 

403 """ 

404 if max_length == 0: 

405 return [] 

406 

407 # Calculate per-byte entropy 

408 byte_entropies = [] 

409 for pos in range(max_length): 

410 values = [m[pos] for m in messages if len(m) > pos] 

411 if len(values) < 2: 

412 byte_entropies.append(0.0) 

413 continue 

414 

415 counts = Counter(values) 

416 total = len(values) 

417 entropy = 0.0 

418 for count in counts.values(): 

419 if count > 0: 

420 p = count / total 

421 entropy -= p * np.log2(p) 

422 byte_entropies.append(entropy) 

423 

424 # Find boundaries at entropy transitions 

425 boundaries = [] 

426 current_start = 0 

427 

428 for i in range(1, len(byte_entropies)): 

429 delta = abs(byte_entropies[i] - byte_entropies[i - 1]) 

430 

431 # Also check for constant vs variable patterns 

432 if delta > self.entropy_threshold: 

433 if i > current_start: 

434 boundaries.append((current_start, i)) 

435 current_start = i 

436 

437 # Add final segment 

438 if max_length > current_start: 

439 boundaries.append((current_start, max_length)) 

440 

441 # Merge very small segments 

442 merged: list[tuple[int, int]] = [] 

443 for start, end in boundaries: 

444 if merged and start - merged[-1][1] == 0 and end - start < 2: 

445 # Merge with previous 

446 merged[-1] = (merged[-1][0], end) 

447 else: 

448 merged.append((start, end)) 

449 

450 return merged if merged else [(0, max_length)] 

451 

452 def _infer_field( 

453 self, 

454 messages: Sequence[bytes], 

455 start: int, 

456 end: int, 

457 index: int, 

458 ) -> InferredField: 

459 """Infer type for a single field. 

460 

461 Args: 

462 messages: Message samples. 

463 start: Field start offset. 

464 end: Field end offset. 

465 index: Field index for naming. 

466 

467 Returns: 

468 InferredField with inferred type. 

469 """ 

470 size = end - start 

471 name = f"field_{index}" 

472 

473 # Extract field values 

474 values = [] 

475 raw_values = [] 

476 for msg in messages: 

477 if len(msg) >= end: 

478 field_bytes = msg[start:end] 

479 raw_values.append(field_bytes) 

480 values.append(field_bytes) 

481 

482 if not values: 

483 return InferredField( 

484 name=name, 

485 offset=start, 

486 size=size, 

487 inferred_type="unknown", 

488 confidence=0.0, 

489 ) 

490 

491 # Check if constant 

492 unique_values = set(raw_values) 

493 is_constant = len(unique_values) == 1 

494 

495 # Check if sequence 

496 is_sequence = False 

497 if not is_constant and size in [1, 2, 4, 8]: 

498 int_values = [int.from_bytes(v, "big") for v in raw_values] 

499 is_sequence = self._is_sequence(int_values) 

500 

501 # Check for checksum patterns 

502 is_checksum = False 

503 if start >= min(len(m) for m in messages) - 4: 

504 score = self._check_checksum_correlation(messages, start, size) 

505 is_checksum = score > 0.7 

506 

507 # Infer type 

508 inferred_type, endianness, confidence = self._infer_type(raw_values, size) 

509 

510 # Sample values for debugging 

511 sample_values: list[int | str] = [] 

512 for v in raw_values[:5]: 

513 if inferred_type.startswith("uint") or inferred_type.startswith("int"): 

514 try: 

515 # Cast endianness to Literal type for type checker 

516 byte_order: Literal["big", "little"] = ( 

517 "big" if endianness == "n/a" else endianness # type: ignore[assignment] 

518 ) 

519 sample_values.append(int.from_bytes(v, byte_order)) 

520 except Exception: 

521 sample_values.append(v.hex()) 

522 elif inferred_type == "string": 

523 try: 

524 sample_values.append(v.decode("utf-8", errors="replace")) 

525 except Exception: 

526 sample_values.append(v.hex()) 

527 else: 

528 sample_values.append(v.hex()) 

529 

530 # Cast to Literal types for type checker 

531 inferred_type_literal: Literal[ 

532 "uint8", 

533 "uint16", 

534 "uint32", 

535 "uint64", 

536 "int8", 

537 "int16", 

538 "int32", 

539 "int64", 

540 "float32", 

541 "float64", 

542 "bytes", 

543 "string", 

544 "unknown", 

545 ] = inferred_type # type: ignore[assignment] 

546 endianness_literal: Literal["big", "little", "n/a"] = endianness # type: ignore[assignment] 

547 

548 return InferredField( 

549 name=name, 

550 offset=start, 

551 size=size, 

552 inferred_type=inferred_type_literal, 

553 endianness=endianness_literal, 

554 is_constant=is_constant, 

555 is_sequence=is_sequence, 

556 is_checksum=is_checksum, 

557 constant_value=raw_values[0] if is_constant else None, 

558 confidence=confidence, 

559 sample_values=sample_values, 

560 ) 

561 

562 def _infer_type( 

563 self, 

564 values: list[bytes], 

565 size: int, 

566 ) -> tuple[str, str, float]: 

567 """Infer data type from values. 

568 

569 Args: 

570 values: Field values. 

571 size: Field size. 

572 

573 Returns: 

574 Tuple of (type, endianness, confidence). 

575 """ 

576 if not values: 

577 return "unknown", "n/a", 0.0 

578 

579 # Check for string (high printable ratio) 

580 printable_ratio = sum( 

581 1 for v in values for b in v if 32 <= b <= 126 or b in (9, 10, 13) 

582 ) / (len(values) * size) 

583 

584 if printable_ratio > 0.8: 

585 return "string", "n/a", printable_ratio 

586 

587 # Check for standard integer sizes 

588 if size == 1: 

589 return "uint8", "n/a", 0.9 

590 

591 elif size == 2: 

592 # Try to detect endianness 

593 be_variance = np.var([int.from_bytes(v, "big") for v in values]) 

594 le_variance = np.var([int.from_bytes(v, "little") for v in values]) 

595 

596 if be_variance < le_variance: 

597 endian = "big" 

598 else: 

599 endian = "little" 

600 

601 return "uint16", endian, 0.8 

602 

603 elif size == 4: 

604 # Check for float 

605 float_valid = 0 

606 for v in values: 

607 try: 

608 f = struct.unpack(">f", v)[0] 

609 if not (np.isnan(f) or np.isinf(f)) and -1e10 < f < 1e10: 

610 float_valid += 1 

611 except Exception: 

612 pass 

613 

614 if float_valid / len(values) > 0.8: 

615 return "float32", "big", 0.7 

616 

617 # Otherwise integer 

618 be_variance = np.var([int.from_bytes(v, "big") for v in values]) 

619 le_variance = np.var([int.from_bytes(v, "little") for v in values]) 

620 endian = "big" if be_variance < le_variance else "little" 

621 return "uint32", endian, 0.8 

622 

623 elif size == 8: 

624 # Check for float64 or uint64 

625 be_variance = np.var([int.from_bytes(v, "big") for v in values]) 

626 le_variance = np.var([int.from_bytes(v, "little") for v in values]) 

627 endian = "big" if be_variance < le_variance else "little" 

628 return "uint64", endian, 0.7 

629 

630 else: 

631 return "bytes", "n/a", 0.6 

632 

633 def _is_sequence(self, values: list[int]) -> bool: 

634 """Check if values form a sequence. 

635 

636 Args: 

637 values: Integer values. 

638 

639 Returns: 

640 True if values are incrementing/decrementing. 

641 """ 

642 if len(values) < self.sequence_threshold: 

643 return False 

644 

645 # Check for incrementing sequence 

646 diffs = [values[i + 1] - values[i] for i in range(len(values) - 1)] 

647 

648 # Most diffs should be 1 (or consistent) 

649 counter = Counter(diffs) 

650 if not counter: 

651 return False 

652 

653 most_common_diff, count = counter.most_common(1)[0] 

654 ratio = count / len(diffs) 

655 

656 return ratio > 0.8 and most_common_diff in [1, -1, 0] 

657 

658 def _check_checksum_correlation( 

659 self, 

660 messages: Sequence[bytes], 

661 offset: int, 

662 size: int, 

663 ) -> float: 

664 """Check if field correlates with message content like a checksum. 

665 

666 Args: 

667 messages: Message samples. 

668 offset: Field offset. 

669 size: Field size. 

670 

671 Returns: 

672 Correlation score (0-1). 

673 """ 

674 # Simple heuristic: checksum fields have high correlation with 

675 # changes in other parts of the message 

676 

677 if len(messages) < 5: 

678 return 0.0 

679 

680 # Extract checksum values and message content 

681 checksums = [] 

682 contents = [] 

683 

684 for msg in messages: 

685 if len(msg) >= offset + size: 

686 checksums.append(int.from_bytes(msg[offset : offset + size], "big")) 

687 # Content before checksum 

688 content = msg[:offset] + msg[offset + size :] 

689 contents.append(sum(content) % 65536) 

690 

691 if len(checksums) < 5: 

692 return 0.0 

693 

694 # Check if checksum changes correlate with content changes 

695 unique_contents = len(set(contents)) 

696 unique_checksums = len(set(checksums)) 

697 

698 if unique_contents == 1 and unique_checksums == 1: 

699 return 0.3 # Both constant - inconclusive 

700 

701 # Simple correlation check 

702 if unique_contents > 1 and unique_checksums > 1: 

703 return 0.8 

704 

705 return 0.3 

706 

707 def _guess_checksum_algorithm( 

708 self, 

709 messages: Sequence[bytes], 

710 offset: int, 

711 size: int, 

712 ) -> str: 

713 """Guess the checksum algorithm. 

714 

715 Args: 

716 messages: Message samples. 

717 offset: Checksum offset. 

718 size: Checksum size. 

719 

720 Returns: 

721 Algorithm name hint. 

722 """ 

723 if size == 1: 

724 return "xor8_or_sum8" 

725 elif size == 2: 

726 return "crc16_or_sum16" 

727 elif size == 4: 

728 return "crc32" 

729 return "unknown" 

730 

731 

732# ============================================================================= 

733# RE-PAY-004: Convenience Functions 

734# ============================================================================= 

735 

736 

737def infer_fields(messages: Sequence[bytes], min_samples: int = 10) -> MessageSchema: 

738 """Infer field structure from message samples. 

739 

740 Implements RE-PAY-004: Payload Field Inference. 

741 

742 Args: 

743 messages: List of binary message samples. 

744 min_samples: Minimum samples for reliable inference. 

745 

746 Returns: 

747 MessageSchema with inferred field structure. 

748 

749 Example: 

750 >>> messages = [pkt.data for pkt in packets] 

751 >>> schema = infer_fields(messages) 

752 >>> for field in schema.fields: 

753 ... print(f"{field.name}: {field.inferred_type}") 

754 """ 

755 inferrer = FieldInferrer(min_samples=min_samples) 

756 return inferrer.infer_fields(messages) 

757 

758 

759def detect_field_types( 

760 messages: Sequence[bytes], 

761 boundaries: list[tuple[int, int]], 

762) -> list[InferredField]: 

763 """Detect field types for given boundaries. 

764 

765 Implements RE-PAY-004: Field type detection. 

766 

767 Args: 

768 messages: Message samples. 

769 boundaries: List of (start, end) field boundaries. 

770 

771 Returns: 

772 List of InferredField with type information. 

773 """ 

774 inferrer = FieldInferrer() 

775 return inferrer.detect_field_types(messages, boundaries) 

776 

777 

778def find_sequence_fields(messages: Sequence[bytes]) -> list[tuple[int, int]]: 

779 """Find fields that appear to be sequence/counter values. 

780 

781 Implements RE-PAY-004: Sequence field detection. 

782 

783 Args: 

784 messages: Message samples (should be in order). 

785 

786 Returns: 

787 List of (offset, size) for sequence fields. 

788 """ 

789 inferrer = FieldInferrer() 

790 return inferrer.find_sequence_fields(messages) 

791 

792 

793def find_checksum_fields(messages: Sequence[bytes]) -> list[tuple[int, int, str]]: 

794 """Find fields that appear to be checksums. 

795 

796 Implements RE-PAY-004: Checksum field detection. 

797 

798 Args: 

799 messages: Message samples. 

800 

801 Returns: 

802 List of (offset, size, algorithm_hint) for checksum fields. 

803 """ 

804 inferrer = FieldInferrer() 

805 return inferrer.find_checksum_fields(messages) 

806 

807 

808# ============================================================================= 

809# RE-PAY-005: Comparison Functions 

810# ============================================================================= 

811 

812 

813def diff_payloads(payload_a: bytes, payload_b: bytes) -> PayloadDiff: 

814 """Compare two payloads and identify differences. 

815 

816 Implements RE-PAY-005: Payload differential analysis. 

817 

818 Args: 

819 payload_a: First payload. 

820 payload_b: Second payload. 

821 

822 Returns: 

823 PayloadDiff with comparison results. 

824 

825 Example: 

826 >>> diff = diff_payloads(pkt1.data, pkt2.data) 

827 >>> print(f"Common prefix: {diff.common_prefix_length} bytes") 

828 >>> print(f"Different bytes: {len(diff.differences)}") 

829 """ 

830 # Find common prefix 

831 common_prefix = 0 

832 min_len = min(len(payload_a), len(payload_b)) 

833 for i in range(min_len): 

834 if payload_a[i] == payload_b[i]: 

835 common_prefix += 1 

836 else: 

837 break 

838 

839 # Find common suffix 

840 common_suffix = 0 

841 for i in range(1, min_len - common_prefix + 1): 

842 if payload_a[-i] == payload_b[-i]: 

843 common_suffix += 1 

844 else: 

845 break 

846 

847 # Find all differences 

848 differences = [] 

849 for i in range(min_len): 

850 if payload_a[i] != payload_b[i]: 

851 differences.append((i, payload_a[i], payload_b[i])) 

852 

853 # Add length differences 

854 if len(payload_a) > len(payload_b): 

855 for i in range(len(payload_b), len(payload_a)): 

856 differences.append((i, payload_a[i], -1)) 

857 elif len(payload_b) > len(payload_a): 

858 for i in range(len(payload_a), len(payload_b)): 

859 differences.append((i, -1, payload_b[i])) 

860 

861 # Calculate similarity 

862 max_len = max(len(payload_a), len(payload_b)) 

863 if max_len == 0: 

864 similarity = 1.0 

865 else: 

866 matching = min_len - len([d for d in differences if d[0] < min_len]) 

867 similarity = matching / max_len 

868 

869 # Calculate edit distance (simplified Levenshtein) 

870 edit_distance = _levenshtein_distance(payload_a, payload_b) 

871 

872 return PayloadDiff( 

873 common_prefix_length=common_prefix, 

874 common_suffix_length=common_suffix, 

875 differences=differences, 

876 similarity=similarity, 

877 edit_distance=edit_distance, 

878 ) 

879 

880 

881def find_common_bytes(payloads: Sequence[bytes]) -> bytes: 

882 """Find common prefix across all payloads. 

883 

884 Implements RE-PAY-005: Common byte analysis. 

885 

886 Args: 

887 payloads: List of payloads to analyze. 

888 

889 Returns: 

890 Common prefix bytes. 

891 """ 

892 if not payloads: 

893 return b"" 

894 

895 if len(payloads) == 1: 

896 return payloads[0] 

897 

898 # Find minimum length 

899 min_len = min(len(p) for p in payloads) 

900 

901 # Find common prefix 

902 common = bytearray() 

903 for i in range(min_len): 

904 byte = payloads[0][i] 

905 if all(p[i] == byte for p in payloads): 

906 common.append(byte) 

907 else: 

908 break 

909 

910 return bytes(common) 

911 

912 

913def find_variable_positions(payloads: Sequence[bytes]) -> VariablePositions: 

914 """Identify which byte positions vary across payloads. 

915 

916 Implements RE-PAY-005: Variable position detection. 

917 

918 Args: 

919 payloads: List of payloads to analyze. 

920 

921 Returns: 

922 VariablePositions with constant and variable position info. 

923 

924 Example: 

925 >>> result = find_variable_positions(payloads) 

926 >>> print(f"Constant positions: {result.constant_positions}") 

927 >>> print(f"Variable positions: {result.variable_positions}") 

928 """ 

929 if not payloads: 

930 return VariablePositions( 

931 constant_positions=[], 

932 variable_positions=[], 

933 constant_values={}, 

934 variance_by_position=np.array([]), 

935 ) 

936 

937 # Use shortest payload length 

938 min_len = min(len(p) for p in payloads) 

939 

940 constant_positions = [] 

941 variable_positions = [] 

942 constant_values = {} 

943 variances = [] 

944 

945 for i in range(min_len): 

946 values = [p[i] for p in payloads] 

947 unique = set(values) 

948 

949 if len(unique) == 1: 

950 constant_positions.append(i) 

951 constant_values[i] = values[0] 

952 variances.append(0.0) 

953 else: 

954 variable_positions.append(i) 

955 variances.append(float(np.var(values))) 

956 

957 return VariablePositions( 

958 constant_positions=constant_positions, 

959 variable_positions=variable_positions, 

960 constant_values=constant_values, 

961 variance_by_position=np.array(variances), 

962 ) 

963 

964 

965def compute_similarity( 

966 payload_a: bytes, 

967 payload_b: bytes, 

968 metric: Literal["levenshtein", "hamming", "jaccard"] = "levenshtein", 

969) -> float: 

970 """Compute similarity between two payloads. 

971 

972 Implements RE-PAY-005: Similarity computation. 

973 

974 Args: 

975 payload_a: First payload. 

976 payload_b: Second payload. 

977 metric: Similarity metric to use. 

978 

979 Returns: 

980 Similarity score (0-1). 

981 """ 

982 if metric == "levenshtein": 

983 max_len = max(len(payload_a), len(payload_b)) 

984 if max_len == 0: 

985 return 1.0 

986 distance = _levenshtein_distance(payload_a, payload_b) 

987 return 1.0 - (distance / max_len) 

988 

989 elif metric == "hamming": 

990 if len(payload_a) != len(payload_b): 

991 # Pad shorter one 

992 max_len = max(len(payload_a), len(payload_b)) 

993 payload_a = payload_a.ljust(max_len, b"\x00") 

994 payload_b = payload_b.ljust(max_len, b"\x00") 

995 

996 matches = sum(a == b for a, b in zip(payload_a, payload_b, strict=True)) 

997 return matches / len(payload_a) if payload_a else 1.0 

998 

999 # metric == "jaccard" 

1000 # Treat bytes as sets 

1001 set_a = set(payload_a) 

1002 set_b = set(payload_b) 

1003 intersection = len(set_a & set_b) 

1004 union = len(set_a | set_b) 

1005 return intersection / union if union > 0 else 1.0 

1006 

1007 

1008def cluster_payloads( 

1009 payloads: Sequence[bytes], 

1010 threshold: float = 0.8, 

1011 algorithm: Literal["greedy", "dbscan"] = "greedy", 

1012) -> list[PayloadCluster]: 

1013 """Cluster similar payloads together. 

1014 

1015 Implements RE-PAY-005: Payload clustering. 

1016 

1017 Args: 

1018 payloads: List of payloads to cluster. 

1019 threshold: Similarity threshold for clustering. 

1020 algorithm: Clustering algorithm. 

1021 

1022 Returns: 

1023 List of PayloadCluster objects. 

1024 

1025 Example: 

1026 >>> clusters = cluster_payloads(payloads, threshold=0.85) 

1027 >>> for c in clusters: 

1028 ... print(f"Cluster {c.cluster_id}: {c.size} payloads") 

1029 """ 

1030 if not payloads: 

1031 return [] 

1032 

1033 if algorithm == "greedy": 

1034 return _cluster_greedy_optimized(payloads, threshold) 

1035 # algorithm == "dbscan" 

1036 return _cluster_dbscan(payloads, threshold) 

1037 

1038 

1039def correlate_request_response( 

1040 requests: Sequence[PayloadInfo], 

1041 responses: Sequence[PayloadInfo], 

1042 max_delay: float = 1.0, 

1043) -> list[tuple[PayloadInfo, PayloadInfo, float]]: 

1044 """Correlate request payloads with responses. 

1045 

1046 Implements RE-PAY-005: Request-response correlation. 

1047 

1048 Args: 

1049 requests: List of request PayloadInfo. 

1050 responses: List of response PayloadInfo. 

1051 max_delay: Maximum time between request and response. 

1052 

1053 Returns: 

1054 List of (request, response, latency) tuples. 

1055 """ 

1056 pairs = [] 

1057 

1058 for request in requests: 

1059 if request.timestamp is None: 

1060 continue 

1061 

1062 best_response = None 

1063 best_latency = float("inf") 

1064 

1065 for response in responses: 

1066 if response.timestamp is None: 

1067 continue 

1068 

1069 latency = response.timestamp - request.timestamp 

1070 if 0 <= latency <= max_delay and latency < best_latency: 

1071 best_response = response 

1072 best_latency = latency 

1073 

1074 if best_response is not None: 

1075 pairs.append((request, best_response, best_latency)) 

1076 

1077 return pairs 

1078 

1079 

1080# ============================================================================= 

1081# Helper Functions 

1082# ============================================================================= 

1083 

1084 

1085def _levenshtein_distance(a: bytes, b: bytes) -> int: 

1086 """Calculate Levenshtein edit distance between two byte sequences.""" 

1087 if len(a) < len(b): 

1088 return _levenshtein_distance(b, a) 

1089 

1090 if len(b) == 0: 

1091 return len(a) 

1092 

1093 previous_row: list[int] = list(range(len(b) + 1)) 

1094 for i, c1 in enumerate(a): 

1095 current_row = [i + 1] 

1096 for j, c2 in enumerate(b): 

1097 insertions = previous_row[j + 1] + 1 

1098 deletions = current_row[j] + 1 

1099 substitutions = previous_row[j] + (c1 != c2) 

1100 current_row.append(min(insertions, deletions, substitutions)) 

1101 previous_row = current_row 

1102 

1103 return previous_row[-1] 

1104 

1105 

1106def _fast_similarity(payload_a: bytes, payload_b: bytes, threshold: float) -> float | None: 

1107 """Fast similarity check with early termination. 

1108 

1109 Uses length-based filtering and sampling to quickly reject dissimilar payloads. 

1110 Returns None if payloads are likely similar (needs full check), 

1111 or a similarity value if they can be quickly determined. 

1112 

1113 Args: 

1114 payload_a: First payload. 

1115 payload_b: Second payload. 

1116 threshold: Similarity threshold for clustering. 

1117 

1118 Returns: 

1119 Similarity value if quickly determined, None if full check needed. 

1120 """ 

1121 len_a = len(payload_a) 

1122 len_b = len(payload_b) 

1123 

1124 # Empty payloads 

1125 if len_a == 0 and len_b == 0: 

1126 return 1.0 

1127 if len_a == 0 or len_b == 0: 

1128 return 0.0 

1129 

1130 # Length difference filter: if lengths differ by more than (1-threshold)*max_len, 

1131 # similarity can't exceed threshold 

1132 max_len = max(len_a, len_b) 

1133 min_len = min(len_a, len_b) 

1134 _length_diff = max_len - min_len 

1135 

1136 # Maximum possible similarity given length difference 

1137 max_possible_similarity = min_len / max_len 

1138 if max_possible_similarity < threshold: 

1139 return max_possible_similarity 

1140 

1141 # For same-length payloads, use fast hamming similarity 

1142 if len_a == len_b: 

1143 # Sample comparison for large payloads 

1144 if len_a > 50: 

1145 # Sample first 16, last 16, and some middle bytes 

1146 sample_size = min(48, len_a) 

1147 mismatches = 0 

1148 

1149 # First 16 bytes 

1150 for i in range(min(16, len_a)): 

1151 if payload_a[i] != payload_b[i]: 

1152 mismatches += 1 

1153 

1154 # Last 16 bytes 

1155 for i in range(1, min(17, len_a + 1)): 

1156 if payload_a[-i] != payload_b[-i]: 

1157 mismatches += 1 

1158 

1159 # Middle samples (len_a > 32 always true here since len_a > 50) 

1160 step = (len_a - 32) // 16 

1161 if step > 0: 

1162 for i in range(16, len_a - 16, step): 

1163 if payload_a[i] != payload_b[i]: 

1164 mismatches += 1 

1165 

1166 # Estimate similarity from sample 

1167 estimated_similarity = 1.0 - (mismatches / sample_size) 

1168 

1169 # If sample shows very low similarity, reject early 

1170 if estimated_similarity < threshold * 0.8: 

1171 return estimated_similarity 

1172 

1173 # Full hamming comparison for same-length payloads (faster than Levenshtein) 

1174 matches = sum(a == b for a, b in zip(payload_a, payload_b, strict=True)) 

1175 return matches / len_a 

1176 

1177 # For different-length payloads, use common prefix/suffix heuristic 

1178 common_prefix = 0 

1179 for i in range(min_len): 

1180 if payload_a[i] == payload_b[i]: 

1181 common_prefix += 1 

1182 else: 

1183 break 

1184 

1185 common_suffix = 0 

1186 for i in range(1, min_len - common_prefix + 1): 

1187 if payload_a[-i] == payload_b[-i]: 

1188 common_suffix += 1 

1189 else: 

1190 break 

1191 

1192 # Estimate similarity from prefix/suffix 

1193 common_bytes = common_prefix + common_suffix 

1194 estimated_similarity = common_bytes / max_len 

1195 

1196 # If common bytes suggest low similarity, reject 

1197 if estimated_similarity < threshold * 0.7: 

1198 return estimated_similarity 

1199 

1200 # Need full comparison 

1201 return None 

1202 

1203 

1204def _cluster_greedy_optimized( 

1205 payloads: Sequence[bytes], 

1206 threshold: float, 

1207) -> list[PayloadCluster]: 

1208 """Optimized greedy clustering algorithm. 

1209 

1210 Uses fast pre-filtering based on length and sampling to avoid 

1211 expensive Levenshtein distance calculations when possible. 

1212 

1213 Args: 

1214 payloads: Sequence of payload bytes to cluster. 

1215 threshold: Similarity threshold for clustering (0.0 to 1.0). 

1216 

1217 Returns: 

1218 List of PayloadCluster objects containing clustered payloads. 

1219 """ 

1220 clusters: list[PayloadCluster] = [] 

1221 assigned = [False] * len(payloads) 

1222 

1223 # Precompute lengths for fast filtering 

1224 lengths = [len(p) for p in payloads] 

1225 

1226 for i, payload in enumerate(payloads): 

1227 if assigned[i]: 

1228 continue 

1229 

1230 # Start new cluster 

1231 cluster_payloads = [payload] 

1232 cluster_indices = [i] 

1233 assigned[i] = True 

1234 

1235 payload_len = lengths[i] 

1236 

1237 # Find similar payloads 

1238 for j in range(i + 1, len(payloads)): 

1239 if assigned[j]: 

1240 continue 

1241 

1242 other_len = lengths[j] 

1243 

1244 # Quick length-based rejection 

1245 max_len = max(payload_len, other_len) 

1246 min_len = min(payload_len, other_len) 

1247 if min_len / max_len < threshold: 

1248 continue 

1249 

1250 # Try fast similarity check first 

1251 fast_result = _fast_similarity(payload, payloads[j], threshold) 

1252 

1253 if fast_result is not None: 

1254 similarity = fast_result 

1255 else: 

1256 # Fall back to Levenshtein for uncertain cases 

1257 similarity = compute_similarity(payload, payloads[j]) 

1258 

1259 if similarity >= threshold: 

1260 cluster_payloads.append(payloads[j]) 

1261 cluster_indices.append(j) 

1262 assigned[j] = True 

1263 

1264 clusters.append( 

1265 PayloadCluster( 

1266 cluster_id=len(clusters), 

1267 payloads=cluster_payloads, 

1268 indices=cluster_indices, 

1269 representative=payload, 

1270 size=len(cluster_payloads), 

1271 ) 

1272 ) 

1273 

1274 return clusters 

1275 

1276 

1277def _cluster_greedy( 

1278 payloads: Sequence[bytes], 

1279 threshold: float, 

1280) -> list[PayloadCluster]: 

1281 """Greedy clustering algorithm (legacy, uses optimized version).""" 

1282 return _cluster_greedy_optimized(payloads, threshold) 

1283 

1284 

1285def _cluster_dbscan( 

1286 payloads: Sequence[bytes], 

1287 threshold: float, 

1288) -> list[PayloadCluster]: 

1289 """DBSCAN-style clustering (simplified).""" 

1290 # For simplicity, fall back to greedy 

1291 # Full DBSCAN would require scipy or custom implementation 

1292 return _cluster_greedy_optimized(payloads, threshold) 

1293 

1294 

1295__all__ = [ 

1296 "FieldInferrer", 

1297 "InferredField", 

1298 "MessageSchema", 

1299 "PayloadCluster", 

1300 "PayloadDiff", 

1301 "VariablePositions", 

1302 "cluster_payloads", 

1303 "compute_similarity", 

1304 "correlate_request_response", 

1305 "detect_field_types", 

1306 "diff_payloads", 

1307 "find_checksum_fields", 

1308 "find_common_bytes", 

1309 "find_sequence_fields", 

1310 "find_variable_positions", 

1311 "infer_fields", 

1312]