Coverage for src / tracekit / analyzers / packet / payload_analysis.py: 0%
476 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-11 23:04 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-11 23:04 +0000
1"""Payload field inference and comparison analysis.
3RE-PAY-004: Payload Field Inference
4RE-PAY-005: Payload Comparison and Differential Analysis
6This module provides field structure inference, payload comparison,
7similarity computation, and clustering for binary payloads.
8"""
10from __future__ import annotations
12import logging
13import struct
14from collections import Counter
15from collections.abc import Sequence
16from dataclasses import dataclass, field
17from typing import TYPE_CHECKING, Any, Literal
19import numpy as np
21if TYPE_CHECKING:
22 from tracekit.analyzers.packet.payload_extraction import PayloadInfo
24logger = logging.getLogger(__name__)
27# =============================================================================
28# RE-PAY-004: Field Inference Data Classes
29# =============================================================================
32@dataclass
33class InferredField:
34 """Inferred field from binary payload.
36 Implements RE-PAY-004: Inferred field structure.
38 Attributes:
39 name: Field name (auto-generated).
40 offset: Byte offset within message.
41 size: Field size in bytes.
42 inferred_type: Inferred data type.
43 endianness: Detected endianness.
44 is_constant: Whether field is constant across messages.
45 is_sequence: Whether field appears to be a counter/sequence.
46 is_checksum: Whether field appears to be a checksum.
47 constant_value: Value if constant.
48 confidence: Inference confidence (0-1).
49 sample_values: Sample values from messages.
50 """
52 name: str
53 offset: int
54 size: int
55 inferred_type: Literal[
56 "uint8",
57 "uint16",
58 "uint32",
59 "uint64",
60 "int8",
61 "int16",
62 "int32",
63 "int64",
64 "float32",
65 "float64",
66 "bytes",
67 "string",
68 "unknown",
69 ]
70 endianness: Literal["big", "little", "n/a"] = "n/a"
71 is_constant: bool = False
72 is_sequence: bool = False
73 is_checksum: bool = False
74 constant_value: bytes | None = None
75 confidence: float = 0.5
76 sample_values: list[Any] = field(default_factory=list)
79@dataclass
80class MessageSchema:
81 """Inferred message schema.
83 Implements RE-PAY-004: Complete message schema.
85 Attributes:
86 fields: List of inferred fields.
87 message_length: Total message length.
88 fixed_length: Whether all messages have same length.
89 length_range: (min, max) length range.
90 sample_count: Number of samples analyzed.
91 confidence: Overall schema confidence.
92 """
94 fields: list[InferredField]
95 message_length: int
96 fixed_length: bool
97 length_range: tuple[int, int]
98 sample_count: int
99 confidence: float
102# =============================================================================
103# RE-PAY-005: Comparison Data Classes
104# =============================================================================
107@dataclass
108class PayloadDiff:
109 """Difference between two payloads.
111 Implements RE-PAY-005: Payload comparison result.
113 Attributes:
114 common_prefix_length: Length of common prefix.
115 common_suffix_length: Length of common suffix.
116 differences: List of (offset, byte_a, byte_b) for differences.
117 similarity: Similarity score (0-1).
118 edit_distance: Levenshtein edit distance.
119 """
121 common_prefix_length: int
122 common_suffix_length: int
123 differences: list[tuple[int, int, int]]
124 similarity: float
125 edit_distance: int
128@dataclass
129class VariablePositions:
130 """Analysis of which byte positions vary across payloads.
132 Implements RE-PAY-005: Variable position analysis.
134 Attributes:
135 constant_positions: Positions that are constant.
136 variable_positions: Positions that vary.
137 constant_values: Values at constant positions.
138 variance_by_position: Variance at each position.
139 """
141 constant_positions: list[int]
142 variable_positions: list[int]
143 constant_values: dict[int, int]
144 variance_by_position: np.ndarray[tuple[int], np.dtype[np.float64]]
147@dataclass
148class PayloadCluster:
149 """Cluster of similar payloads.
151 Implements RE-PAY-005: Payload clustering result.
153 Attributes:
154 cluster_id: Cluster identifier.
155 payloads: List of payload data in cluster.
156 indices: Original indices of payloads.
157 representative: Representative payload (centroid).
158 size: Number of payloads in cluster.
159 """
161 cluster_id: int
162 payloads: list[bytes]
163 indices: list[int]
164 representative: bytes
165 size: int
168# =============================================================================
169# RE-PAY-004: Field Inference Class
170# =============================================================================
173class FieldInferrer:
174 """Infer field structure within binary payloads.
176 Implements RE-PAY-004: Payload Field Inference.
178 Uses statistical analysis, alignment detection, and type inference
179 to reconstruct message formats from binary payload samples.
181 Example:
182 >>> inferrer = FieldInferrer()
183 >>> messages = [pkt.data for pkt in udp_packets]
184 >>> schema = inferrer.infer_fields(messages)
185 >>> for field in schema.fields:
186 ... print(f"{field.name}: {field.inferred_type} at offset {field.offset}")
187 """
189 def __init__(
190 self,
191 min_samples: int = 10,
192 entropy_threshold: float = 0.5,
193 sequence_threshold: int = 3,
194 ) -> None:
195 """Initialize field inferrer.
197 Args:
198 min_samples: Minimum samples for reliable inference.
199 entropy_threshold: Entropy change threshold for boundary detection.
200 sequence_threshold: Minimum consecutive incrementing values for sequence.
201 """
202 self.min_samples = min_samples
203 self.entropy_threshold = entropy_threshold
204 self.sequence_threshold = sequence_threshold
206 def infer_fields(
207 self,
208 messages: Sequence[bytes],
209 min_samples: int | None = None,
210 ) -> MessageSchema:
211 """Infer field structure from message samples.
213 Implements RE-PAY-004: Complete field inference.
215 Args:
216 messages: List of binary message samples.
217 min_samples: Override minimum sample count.
219 Returns:
220 MessageSchema with inferred field structure.
222 Example:
223 >>> schema = inferrer.infer_fields(messages)
224 >>> print(f"Detected {len(schema.fields)} fields")
225 """
226 if not messages:
227 return MessageSchema(
228 fields=[],
229 message_length=0,
230 fixed_length=True,
231 length_range=(0, 0),
232 sample_count=0,
233 confidence=0.0,
234 )
236 min_samples = min_samples or self.min_samples
237 lengths = [len(m) for m in messages]
238 min_len = min(lengths)
239 max_len = max(lengths)
240 fixed_length = min_len == max_len
242 # Use shortest message length for analysis
243 analysis_length = min_len
245 # Find field boundaries using entropy transitions
246 boundaries = self._detect_field_boundaries(messages, analysis_length)
248 # Infer field types for each segment
249 fields = []
250 for i, (start, end) in enumerate(boundaries):
251 field = self._infer_field(messages, start, end, i)
252 fields.append(field)
254 # Calculate overall confidence
255 if fields:
256 confidence = sum(f.confidence for f in fields) / len(fields)
257 else:
258 confidence = 0.0
260 return MessageSchema(
261 fields=fields,
262 message_length=analysis_length,
263 fixed_length=fixed_length,
264 length_range=(min_len, max_len),
265 sample_count=len(messages),
266 confidence=confidence,
267 )
269 def detect_field_types(
270 self,
271 messages: Sequence[bytes],
272 boundaries: list[tuple[int, int]],
273 ) -> list[InferredField]:
274 """Detect field types for given boundaries.
276 Implements RE-PAY-004: Field type detection.
278 Args:
279 messages: Message samples.
280 boundaries: List of (start, end) field boundaries.
282 Returns:
283 List of InferredField with type information.
284 """
285 fields = []
286 for i, (start, end) in enumerate(boundaries):
287 field = self._infer_field(messages, start, end, i)
288 fields.append(field)
289 return fields
291 def find_sequence_fields(
292 self,
293 messages: Sequence[bytes],
294 ) -> list[tuple[int, int]]:
295 """Find fields that appear to be sequence/counter values.
297 Implements RE-PAY-004: Sequence field detection.
299 Args:
300 messages: Message samples (should be in order).
302 Returns:
303 List of (offset, size) for sequence fields.
305 Raises:
306 ValueError: If a message is too short for the expected field size.
307 """
308 if len(messages) < self.sequence_threshold:
309 return []
311 min_len = min(len(m) for m in messages)
312 sequence_fields = []
314 # Check each possible field size at each offset
315 for size in [1, 2, 4]:
316 for offset in range(min_len - size + 1):
317 values = []
318 try:
319 for msg in messages:
320 # Validate message length before slicing
321 if len(msg) < offset + size:
322 raise ValueError(
323 f"Message too short: expected at least {offset + size} bytes, "
324 f"got {len(msg)} bytes"
325 )
326 # Try both endianness
327 val_be = int.from_bytes(msg[offset : offset + size], "big")
328 values.append(val_be)
330 if self._is_sequence(values):
331 sequence_fields.append((offset, size))
332 except (ValueError, IndexError) as e:
333 # Skip this offset/size combination if extraction fails
334 logger.debug(f"Skipping field at offset={offset}, size={size}: {e}")
335 continue
337 return sequence_fields
339 def find_checksum_fields(
340 self,
341 messages: Sequence[bytes],
342 ) -> list[tuple[int, int, str]]:
343 """Find fields that appear to be checksums.
345 Implements RE-PAY-004: Checksum field detection.
347 Args:
348 messages: Message samples.
350 Returns:
351 List of (offset, size, algorithm_hint) for checksum fields.
353 Raises:
354 ValueError: If checksum field offset and size exceed message length.
355 """
356 if len(messages) < 5:
357 return []
359 min_len = min(len(m) for m in messages)
360 checksum_fields = []
362 # Common checksum sizes and positions
363 for size in [1, 2, 4]:
364 # Check last position (most common)
365 for offset in [min_len - size, 0]:
366 if offset < 0:
367 continue
369 try:
370 # Validate offset and size before processing
371 if offset + size > min_len:
372 raise ValueError(
373 f"Invalid checksum field: offset={offset} + size={size} exceeds "
374 f"minimum message length={min_len}"
375 )
377 # Extract field values and message content
378 score = self._check_checksum_correlation(messages, offset, size)
380 if score > 0.8:
381 algorithm = self._guess_checksum_algorithm(messages, offset, size)
382 checksum_fields.append((offset, size, algorithm))
383 except (ValueError, IndexError) as e:
384 # Skip this offset/size combination if validation fails
385 logger.debug(f"Skipping checksum field at offset={offset}, size={size}: {e}")
386 continue
388 return checksum_fields
390 def _detect_field_boundaries(
391 self,
392 messages: Sequence[bytes],
393 max_length: int,
394 ) -> list[tuple[int, int]]:
395 """Detect field boundaries using entropy analysis.
397 Args:
398 messages: Message samples.
399 max_length: Maximum length to analyze.
401 Returns:
402 List of (start, end) boundaries.
403 """
404 if max_length == 0:
405 return []
407 # Calculate per-byte entropy
408 byte_entropies = []
409 for pos in range(max_length):
410 values = [m[pos] for m in messages if len(m) > pos]
411 if len(values) < 2:
412 byte_entropies.append(0.0)
413 continue
415 counts = Counter(values)
416 total = len(values)
417 entropy = 0.0
418 for count in counts.values():
419 if count > 0:
420 p = count / total
421 entropy -= p * np.log2(p)
422 byte_entropies.append(entropy)
424 # Find boundaries at entropy transitions
425 boundaries = []
426 current_start = 0
428 for i in range(1, len(byte_entropies)):
429 delta = abs(byte_entropies[i] - byte_entropies[i - 1])
431 # Also check for constant vs variable patterns
432 if delta > self.entropy_threshold:
433 if i > current_start:
434 boundaries.append((current_start, i))
435 current_start = i
437 # Add final segment
438 if max_length > current_start:
439 boundaries.append((current_start, max_length))
441 # Merge very small segments
442 merged: list[tuple[int, int]] = []
443 for start, end in boundaries:
444 if merged and start - merged[-1][1] == 0 and end - start < 2:
445 # Merge with previous
446 merged[-1] = (merged[-1][0], end)
447 else:
448 merged.append((start, end))
450 return merged if merged else [(0, max_length)]
452 def _infer_field(
453 self,
454 messages: Sequence[bytes],
455 start: int,
456 end: int,
457 index: int,
458 ) -> InferredField:
459 """Infer type for a single field.
461 Args:
462 messages: Message samples.
463 start: Field start offset.
464 end: Field end offset.
465 index: Field index for naming.
467 Returns:
468 InferredField with inferred type.
469 """
470 size = end - start
471 name = f"field_{index}"
473 # Extract field values
474 values = []
475 raw_values = []
476 for msg in messages:
477 if len(msg) >= end:
478 field_bytes = msg[start:end]
479 raw_values.append(field_bytes)
480 values.append(field_bytes)
482 if not values:
483 return InferredField(
484 name=name,
485 offset=start,
486 size=size,
487 inferred_type="unknown",
488 confidence=0.0,
489 )
491 # Check if constant
492 unique_values = set(raw_values)
493 is_constant = len(unique_values) == 1
495 # Check if sequence
496 is_sequence = False
497 if not is_constant and size in [1, 2, 4, 8]:
498 int_values = [int.from_bytes(v, "big") for v in raw_values]
499 is_sequence = self._is_sequence(int_values)
501 # Check for checksum patterns
502 is_checksum = False
503 if start >= min(len(m) for m in messages) - 4:
504 score = self._check_checksum_correlation(messages, start, size)
505 is_checksum = score > 0.7
507 # Infer type
508 inferred_type, endianness, confidence = self._infer_type(raw_values, size)
510 # Sample values for debugging
511 sample_values: list[int | str] = []
512 for v in raw_values[:5]:
513 if inferred_type.startswith("uint") or inferred_type.startswith("int"):
514 try:
515 # Cast endianness to Literal type for type checker
516 byte_order: Literal["big", "little"] = (
517 "big" if endianness == "n/a" else endianness # type: ignore[assignment]
518 )
519 sample_values.append(int.from_bytes(v, byte_order))
520 except Exception:
521 sample_values.append(v.hex())
522 elif inferred_type == "string":
523 try:
524 sample_values.append(v.decode("utf-8", errors="replace"))
525 except Exception:
526 sample_values.append(v.hex())
527 else:
528 sample_values.append(v.hex())
530 # Cast to Literal types for type checker
531 inferred_type_literal: Literal[
532 "uint8",
533 "uint16",
534 "uint32",
535 "uint64",
536 "int8",
537 "int16",
538 "int32",
539 "int64",
540 "float32",
541 "float64",
542 "bytes",
543 "string",
544 "unknown",
545 ] = inferred_type # type: ignore[assignment]
546 endianness_literal: Literal["big", "little", "n/a"] = endianness # type: ignore[assignment]
548 return InferredField(
549 name=name,
550 offset=start,
551 size=size,
552 inferred_type=inferred_type_literal,
553 endianness=endianness_literal,
554 is_constant=is_constant,
555 is_sequence=is_sequence,
556 is_checksum=is_checksum,
557 constant_value=raw_values[0] if is_constant else None,
558 confidence=confidence,
559 sample_values=sample_values,
560 )
562 def _infer_type(
563 self,
564 values: list[bytes],
565 size: int,
566 ) -> tuple[str, str, float]:
567 """Infer data type from values.
569 Args:
570 values: Field values.
571 size: Field size.
573 Returns:
574 Tuple of (type, endianness, confidence).
575 """
576 if not values:
577 return "unknown", "n/a", 0.0
579 # Check for string (high printable ratio)
580 printable_ratio = sum(
581 1 for v in values for b in v if 32 <= b <= 126 or b in (9, 10, 13)
582 ) / (len(values) * size)
584 if printable_ratio > 0.8:
585 return "string", "n/a", printable_ratio
587 # Check for standard integer sizes
588 if size == 1:
589 return "uint8", "n/a", 0.9
591 elif size == 2:
592 # Try to detect endianness
593 be_variance = np.var([int.from_bytes(v, "big") for v in values])
594 le_variance = np.var([int.from_bytes(v, "little") for v in values])
596 if be_variance < le_variance:
597 endian = "big"
598 else:
599 endian = "little"
601 return "uint16", endian, 0.8
603 elif size == 4:
604 # Check for float
605 float_valid = 0
606 for v in values:
607 try:
608 f = struct.unpack(">f", v)[0]
609 if not (np.isnan(f) or np.isinf(f)) and -1e10 < f < 1e10:
610 float_valid += 1
611 except Exception:
612 pass
614 if float_valid / len(values) > 0.8:
615 return "float32", "big", 0.7
617 # Otherwise integer
618 be_variance = np.var([int.from_bytes(v, "big") for v in values])
619 le_variance = np.var([int.from_bytes(v, "little") for v in values])
620 endian = "big" if be_variance < le_variance else "little"
621 return "uint32", endian, 0.8
623 elif size == 8:
624 # Check for float64 or uint64
625 be_variance = np.var([int.from_bytes(v, "big") for v in values])
626 le_variance = np.var([int.from_bytes(v, "little") for v in values])
627 endian = "big" if be_variance < le_variance else "little"
628 return "uint64", endian, 0.7
630 else:
631 return "bytes", "n/a", 0.6
633 def _is_sequence(self, values: list[int]) -> bool:
634 """Check if values form a sequence.
636 Args:
637 values: Integer values.
639 Returns:
640 True if values are incrementing/decrementing.
641 """
642 if len(values) < self.sequence_threshold:
643 return False
645 # Check for incrementing sequence
646 diffs = [values[i + 1] - values[i] for i in range(len(values) - 1)]
648 # Most diffs should be 1 (or consistent)
649 counter = Counter(diffs)
650 if not counter:
651 return False
653 most_common_diff, count = counter.most_common(1)[0]
654 ratio = count / len(diffs)
656 return ratio > 0.8 and most_common_diff in [1, -1, 0]
658 def _check_checksum_correlation(
659 self,
660 messages: Sequence[bytes],
661 offset: int,
662 size: int,
663 ) -> float:
664 """Check if field correlates with message content like a checksum.
666 Args:
667 messages: Message samples.
668 offset: Field offset.
669 size: Field size.
671 Returns:
672 Correlation score (0-1).
673 """
674 # Simple heuristic: checksum fields have high correlation with
675 # changes in other parts of the message
677 if len(messages) < 5:
678 return 0.0
680 # Extract checksum values and message content
681 checksums = []
682 contents = []
684 for msg in messages:
685 if len(msg) >= offset + size:
686 checksums.append(int.from_bytes(msg[offset : offset + size], "big"))
687 # Content before checksum
688 content = msg[:offset] + msg[offset + size :]
689 contents.append(sum(content) % 65536)
691 if len(checksums) < 5:
692 return 0.0
694 # Check if checksum changes correlate with content changes
695 unique_contents = len(set(contents))
696 unique_checksums = len(set(checksums))
698 if unique_contents == 1 and unique_checksums == 1:
699 return 0.3 # Both constant - inconclusive
701 # Simple correlation check
702 if unique_contents > 1 and unique_checksums > 1:
703 return 0.8
705 return 0.3
707 def _guess_checksum_algorithm(
708 self,
709 messages: Sequence[bytes],
710 offset: int,
711 size: int,
712 ) -> str:
713 """Guess the checksum algorithm.
715 Args:
716 messages: Message samples.
717 offset: Checksum offset.
718 size: Checksum size.
720 Returns:
721 Algorithm name hint.
722 """
723 if size == 1:
724 return "xor8_or_sum8"
725 elif size == 2:
726 return "crc16_or_sum16"
727 elif size == 4:
728 return "crc32"
729 return "unknown"
732# =============================================================================
733# RE-PAY-004: Convenience Functions
734# =============================================================================
737def infer_fields(messages: Sequence[bytes], min_samples: int = 10) -> MessageSchema:
738 """Infer field structure from message samples.
740 Implements RE-PAY-004: Payload Field Inference.
742 Args:
743 messages: List of binary message samples.
744 min_samples: Minimum samples for reliable inference.
746 Returns:
747 MessageSchema with inferred field structure.
749 Example:
750 >>> messages = [pkt.data for pkt in packets]
751 >>> schema = infer_fields(messages)
752 >>> for field in schema.fields:
753 ... print(f"{field.name}: {field.inferred_type}")
754 """
755 inferrer = FieldInferrer(min_samples=min_samples)
756 return inferrer.infer_fields(messages)
759def detect_field_types(
760 messages: Sequence[bytes],
761 boundaries: list[tuple[int, int]],
762) -> list[InferredField]:
763 """Detect field types for given boundaries.
765 Implements RE-PAY-004: Field type detection.
767 Args:
768 messages: Message samples.
769 boundaries: List of (start, end) field boundaries.
771 Returns:
772 List of InferredField with type information.
773 """
774 inferrer = FieldInferrer()
775 return inferrer.detect_field_types(messages, boundaries)
778def find_sequence_fields(messages: Sequence[bytes]) -> list[tuple[int, int]]:
779 """Find fields that appear to be sequence/counter values.
781 Implements RE-PAY-004: Sequence field detection.
783 Args:
784 messages: Message samples (should be in order).
786 Returns:
787 List of (offset, size) for sequence fields.
788 """
789 inferrer = FieldInferrer()
790 return inferrer.find_sequence_fields(messages)
793def find_checksum_fields(messages: Sequence[bytes]) -> list[tuple[int, int, str]]:
794 """Find fields that appear to be checksums.
796 Implements RE-PAY-004: Checksum field detection.
798 Args:
799 messages: Message samples.
801 Returns:
802 List of (offset, size, algorithm_hint) for checksum fields.
803 """
804 inferrer = FieldInferrer()
805 return inferrer.find_checksum_fields(messages)
808# =============================================================================
809# RE-PAY-005: Comparison Functions
810# =============================================================================
813def diff_payloads(payload_a: bytes, payload_b: bytes) -> PayloadDiff:
814 """Compare two payloads and identify differences.
816 Implements RE-PAY-005: Payload differential analysis.
818 Args:
819 payload_a: First payload.
820 payload_b: Second payload.
822 Returns:
823 PayloadDiff with comparison results.
825 Example:
826 >>> diff = diff_payloads(pkt1.data, pkt2.data)
827 >>> print(f"Common prefix: {diff.common_prefix_length} bytes")
828 >>> print(f"Different bytes: {len(diff.differences)}")
829 """
830 # Find common prefix
831 common_prefix = 0
832 min_len = min(len(payload_a), len(payload_b))
833 for i in range(min_len):
834 if payload_a[i] == payload_b[i]:
835 common_prefix += 1
836 else:
837 break
839 # Find common suffix
840 common_suffix = 0
841 for i in range(1, min_len - common_prefix + 1):
842 if payload_a[-i] == payload_b[-i]:
843 common_suffix += 1
844 else:
845 break
847 # Find all differences
848 differences = []
849 for i in range(min_len):
850 if payload_a[i] != payload_b[i]:
851 differences.append((i, payload_a[i], payload_b[i]))
853 # Add length differences
854 if len(payload_a) > len(payload_b):
855 for i in range(len(payload_b), len(payload_a)):
856 differences.append((i, payload_a[i], -1))
857 elif len(payload_b) > len(payload_a):
858 for i in range(len(payload_a), len(payload_b)):
859 differences.append((i, -1, payload_b[i]))
861 # Calculate similarity
862 max_len = max(len(payload_a), len(payload_b))
863 if max_len == 0:
864 similarity = 1.0
865 else:
866 matching = min_len - len([d for d in differences if d[0] < min_len])
867 similarity = matching / max_len
869 # Calculate edit distance (simplified Levenshtein)
870 edit_distance = _levenshtein_distance(payload_a, payload_b)
872 return PayloadDiff(
873 common_prefix_length=common_prefix,
874 common_suffix_length=common_suffix,
875 differences=differences,
876 similarity=similarity,
877 edit_distance=edit_distance,
878 )
881def find_common_bytes(payloads: Sequence[bytes]) -> bytes:
882 """Find common prefix across all payloads.
884 Implements RE-PAY-005: Common byte analysis.
886 Args:
887 payloads: List of payloads to analyze.
889 Returns:
890 Common prefix bytes.
891 """
892 if not payloads:
893 return b""
895 if len(payloads) == 1:
896 return payloads[0]
898 # Find minimum length
899 min_len = min(len(p) for p in payloads)
901 # Find common prefix
902 common = bytearray()
903 for i in range(min_len):
904 byte = payloads[0][i]
905 if all(p[i] == byte for p in payloads):
906 common.append(byte)
907 else:
908 break
910 return bytes(common)
913def find_variable_positions(payloads: Sequence[bytes]) -> VariablePositions:
914 """Identify which byte positions vary across payloads.
916 Implements RE-PAY-005: Variable position detection.
918 Args:
919 payloads: List of payloads to analyze.
921 Returns:
922 VariablePositions with constant and variable position info.
924 Example:
925 >>> result = find_variable_positions(payloads)
926 >>> print(f"Constant positions: {result.constant_positions}")
927 >>> print(f"Variable positions: {result.variable_positions}")
928 """
929 if not payloads:
930 return VariablePositions(
931 constant_positions=[],
932 variable_positions=[],
933 constant_values={},
934 variance_by_position=np.array([]),
935 )
937 # Use shortest payload length
938 min_len = min(len(p) for p in payloads)
940 constant_positions = []
941 variable_positions = []
942 constant_values = {}
943 variances = []
945 for i in range(min_len):
946 values = [p[i] for p in payloads]
947 unique = set(values)
949 if len(unique) == 1:
950 constant_positions.append(i)
951 constant_values[i] = values[0]
952 variances.append(0.0)
953 else:
954 variable_positions.append(i)
955 variances.append(float(np.var(values)))
957 return VariablePositions(
958 constant_positions=constant_positions,
959 variable_positions=variable_positions,
960 constant_values=constant_values,
961 variance_by_position=np.array(variances),
962 )
965def compute_similarity(
966 payload_a: bytes,
967 payload_b: bytes,
968 metric: Literal["levenshtein", "hamming", "jaccard"] = "levenshtein",
969) -> float:
970 """Compute similarity between two payloads.
972 Implements RE-PAY-005: Similarity computation.
974 Args:
975 payload_a: First payload.
976 payload_b: Second payload.
977 metric: Similarity metric to use.
979 Returns:
980 Similarity score (0-1).
981 """
982 if metric == "levenshtein":
983 max_len = max(len(payload_a), len(payload_b))
984 if max_len == 0:
985 return 1.0
986 distance = _levenshtein_distance(payload_a, payload_b)
987 return 1.0 - (distance / max_len)
989 elif metric == "hamming":
990 if len(payload_a) != len(payload_b):
991 # Pad shorter one
992 max_len = max(len(payload_a), len(payload_b))
993 payload_a = payload_a.ljust(max_len, b"\x00")
994 payload_b = payload_b.ljust(max_len, b"\x00")
996 matches = sum(a == b for a, b in zip(payload_a, payload_b, strict=True))
997 return matches / len(payload_a) if payload_a else 1.0
999 # metric == "jaccard"
1000 # Treat bytes as sets
1001 set_a = set(payload_a)
1002 set_b = set(payload_b)
1003 intersection = len(set_a & set_b)
1004 union = len(set_a | set_b)
1005 return intersection / union if union > 0 else 1.0
1008def cluster_payloads(
1009 payloads: Sequence[bytes],
1010 threshold: float = 0.8,
1011 algorithm: Literal["greedy", "dbscan"] = "greedy",
1012) -> list[PayloadCluster]:
1013 """Cluster similar payloads together.
1015 Implements RE-PAY-005: Payload clustering.
1017 Args:
1018 payloads: List of payloads to cluster.
1019 threshold: Similarity threshold for clustering.
1020 algorithm: Clustering algorithm.
1022 Returns:
1023 List of PayloadCluster objects.
1025 Example:
1026 >>> clusters = cluster_payloads(payloads, threshold=0.85)
1027 >>> for c in clusters:
1028 ... print(f"Cluster {c.cluster_id}: {c.size} payloads")
1029 """
1030 if not payloads:
1031 return []
1033 if algorithm == "greedy":
1034 return _cluster_greedy_optimized(payloads, threshold)
1035 # algorithm == "dbscan"
1036 return _cluster_dbscan(payloads, threshold)
1039def correlate_request_response(
1040 requests: Sequence[PayloadInfo],
1041 responses: Sequence[PayloadInfo],
1042 max_delay: float = 1.0,
1043) -> list[tuple[PayloadInfo, PayloadInfo, float]]:
1044 """Correlate request payloads with responses.
1046 Implements RE-PAY-005: Request-response correlation.
1048 Args:
1049 requests: List of request PayloadInfo.
1050 responses: List of response PayloadInfo.
1051 max_delay: Maximum time between request and response.
1053 Returns:
1054 List of (request, response, latency) tuples.
1055 """
1056 pairs = []
1058 for request in requests:
1059 if request.timestamp is None:
1060 continue
1062 best_response = None
1063 best_latency = float("inf")
1065 for response in responses:
1066 if response.timestamp is None:
1067 continue
1069 latency = response.timestamp - request.timestamp
1070 if 0 <= latency <= max_delay and latency < best_latency:
1071 best_response = response
1072 best_latency = latency
1074 if best_response is not None:
1075 pairs.append((request, best_response, best_latency))
1077 return pairs
1080# =============================================================================
1081# Helper Functions
1082# =============================================================================
1085def _levenshtein_distance(a: bytes, b: bytes) -> int:
1086 """Calculate Levenshtein edit distance between two byte sequences."""
1087 if len(a) < len(b):
1088 return _levenshtein_distance(b, a)
1090 if len(b) == 0:
1091 return len(a)
1093 previous_row: list[int] = list(range(len(b) + 1))
1094 for i, c1 in enumerate(a):
1095 current_row = [i + 1]
1096 for j, c2 in enumerate(b):
1097 insertions = previous_row[j + 1] + 1
1098 deletions = current_row[j] + 1
1099 substitutions = previous_row[j] + (c1 != c2)
1100 current_row.append(min(insertions, deletions, substitutions))
1101 previous_row = current_row
1103 return previous_row[-1]
1106def _fast_similarity(payload_a: bytes, payload_b: bytes, threshold: float) -> float | None:
1107 """Fast similarity check with early termination.
1109 Uses length-based filtering and sampling to quickly reject dissimilar payloads.
1110 Returns None if payloads are likely similar (needs full check),
1111 or a similarity value if they can be quickly determined.
1113 Args:
1114 payload_a: First payload.
1115 payload_b: Second payload.
1116 threshold: Similarity threshold for clustering.
1118 Returns:
1119 Similarity value if quickly determined, None if full check needed.
1120 """
1121 len_a = len(payload_a)
1122 len_b = len(payload_b)
1124 # Empty payloads
1125 if len_a == 0 and len_b == 0:
1126 return 1.0
1127 if len_a == 0 or len_b == 0:
1128 return 0.0
1130 # Length difference filter: if lengths differ by more than (1-threshold)*max_len,
1131 # similarity can't exceed threshold
1132 max_len = max(len_a, len_b)
1133 min_len = min(len_a, len_b)
1134 _length_diff = max_len - min_len
1136 # Maximum possible similarity given length difference
1137 max_possible_similarity = min_len / max_len
1138 if max_possible_similarity < threshold:
1139 return max_possible_similarity
1141 # For same-length payloads, use fast hamming similarity
1142 if len_a == len_b:
1143 # Sample comparison for large payloads
1144 if len_a > 50:
1145 # Sample first 16, last 16, and some middle bytes
1146 sample_size = min(48, len_a)
1147 mismatches = 0
1149 # First 16 bytes
1150 for i in range(min(16, len_a)):
1151 if payload_a[i] != payload_b[i]:
1152 mismatches += 1
1154 # Last 16 bytes
1155 for i in range(1, min(17, len_a + 1)):
1156 if payload_a[-i] != payload_b[-i]:
1157 mismatches += 1
1159 # Middle samples (len_a > 32 always true here since len_a > 50)
1160 step = (len_a - 32) // 16
1161 if step > 0:
1162 for i in range(16, len_a - 16, step):
1163 if payload_a[i] != payload_b[i]:
1164 mismatches += 1
1166 # Estimate similarity from sample
1167 estimated_similarity = 1.0 - (mismatches / sample_size)
1169 # If sample shows very low similarity, reject early
1170 if estimated_similarity < threshold * 0.8:
1171 return estimated_similarity
1173 # Full hamming comparison for same-length payloads (faster than Levenshtein)
1174 matches = sum(a == b for a, b in zip(payload_a, payload_b, strict=True))
1175 return matches / len_a
1177 # For different-length payloads, use common prefix/suffix heuristic
1178 common_prefix = 0
1179 for i in range(min_len):
1180 if payload_a[i] == payload_b[i]:
1181 common_prefix += 1
1182 else:
1183 break
1185 common_suffix = 0
1186 for i in range(1, min_len - common_prefix + 1):
1187 if payload_a[-i] == payload_b[-i]:
1188 common_suffix += 1
1189 else:
1190 break
1192 # Estimate similarity from prefix/suffix
1193 common_bytes = common_prefix + common_suffix
1194 estimated_similarity = common_bytes / max_len
1196 # If common bytes suggest low similarity, reject
1197 if estimated_similarity < threshold * 0.7:
1198 return estimated_similarity
1200 # Need full comparison
1201 return None
1204def _cluster_greedy_optimized(
1205 payloads: Sequence[bytes],
1206 threshold: float,
1207) -> list[PayloadCluster]:
1208 """Optimized greedy clustering algorithm.
1210 Uses fast pre-filtering based on length and sampling to avoid
1211 expensive Levenshtein distance calculations when possible.
1213 Args:
1214 payloads: Sequence of payload bytes to cluster.
1215 threshold: Similarity threshold for clustering (0.0 to 1.0).
1217 Returns:
1218 List of PayloadCluster objects containing clustered payloads.
1219 """
1220 clusters: list[PayloadCluster] = []
1221 assigned = [False] * len(payloads)
1223 # Precompute lengths for fast filtering
1224 lengths = [len(p) for p in payloads]
1226 for i, payload in enumerate(payloads):
1227 if assigned[i]:
1228 continue
1230 # Start new cluster
1231 cluster_payloads = [payload]
1232 cluster_indices = [i]
1233 assigned[i] = True
1235 payload_len = lengths[i]
1237 # Find similar payloads
1238 for j in range(i + 1, len(payloads)):
1239 if assigned[j]:
1240 continue
1242 other_len = lengths[j]
1244 # Quick length-based rejection
1245 max_len = max(payload_len, other_len)
1246 min_len = min(payload_len, other_len)
1247 if min_len / max_len < threshold:
1248 continue
1250 # Try fast similarity check first
1251 fast_result = _fast_similarity(payload, payloads[j], threshold)
1253 if fast_result is not None:
1254 similarity = fast_result
1255 else:
1256 # Fall back to Levenshtein for uncertain cases
1257 similarity = compute_similarity(payload, payloads[j])
1259 if similarity >= threshold:
1260 cluster_payloads.append(payloads[j])
1261 cluster_indices.append(j)
1262 assigned[j] = True
1264 clusters.append(
1265 PayloadCluster(
1266 cluster_id=len(clusters),
1267 payloads=cluster_payloads,
1268 indices=cluster_indices,
1269 representative=payload,
1270 size=len(cluster_payloads),
1271 )
1272 )
1274 return clusters
1277def _cluster_greedy(
1278 payloads: Sequence[bytes],
1279 threshold: float,
1280) -> list[PayloadCluster]:
1281 """Greedy clustering algorithm (legacy, uses optimized version)."""
1282 return _cluster_greedy_optimized(payloads, threshold)
1285def _cluster_dbscan(
1286 payloads: Sequence[bytes],
1287 threshold: float,
1288) -> list[PayloadCluster]:
1289 """DBSCAN-style clustering (simplified)."""
1290 # For simplicity, fall back to greedy
1291 # Full DBSCAN would require scipy or custom implementation
1292 return _cluster_greedy_optimized(payloads, threshold)
1295__all__ = [
1296 "FieldInferrer",
1297 "InferredField",
1298 "MessageSchema",
1299 "PayloadCluster",
1300 "PayloadDiff",
1301 "VariablePositions",
1302 "cluster_payloads",
1303 "compute_similarity",
1304 "correlate_request_response",
1305 "detect_field_types",
1306 "diff_payloads",
1307 "find_checksum_fields",
1308 "find_common_bytes",
1309 "find_sequence_fields",
1310 "find_variable_positions",
1311 "infer_fields",
1312]