Coverage for src / tracekit / inference / message_format.py: 96%
241 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-11 23:04 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-11 23:04 +0000
1"""Message format inference using statistical analysis.
3Requirements addressed: PSI-001
5This module automatically infers message field structure from collections of
6similar messages for protocol reverse engineering.
8Key capabilities:
9- Detect field boundaries via entropy transitions
10- Classify fields as constant, variable, or sequential
11- Infer field types (integer, counter, timestamp, checksum)
12- Detect field dependencies (length fields, checksums)
13- Generate message format specifications
14"""
16from dataclasses import dataclass
17from dataclasses import field as dataclass_field
18from typing import Any, Literal
20import numpy as np
21from numpy.typing import NDArray
24@dataclass
25class InferredField:
26 """An inferred message field.
28 : Field classification and type inference.
30 Attributes:
31 name: Auto-generated field name
32 offset: Byte offset from message start
33 size: Field size in bytes
34 field_type: Inferred field type classification
35 entropy: Shannon entropy of field values
36 variance: Statistical variance of field values
37 confidence: Confidence score (0-1) for type inference
38 values_seen: Sample values for validation
39 """
41 name: str
42 offset: int
43 size: int
44 field_type: Literal["constant", "counter", "timestamp", "length", "checksum", "data", "unknown"]
45 entropy: float
46 variance: float
47 confidence: float
48 values_seen: list[Any] = dataclass_field(default_factory=list) # Sample values
51@dataclass
52class MessageSchema:
53 """Inferred message format schema.
55 : Complete message format specification.
57 Attributes:
58 total_size: Total message size in bytes
59 fields: List of inferred fields
60 field_boundaries: Byte offsets of field starts
61 header_size: Detected header size
62 payload_offset: Start of payload region
63 checksum_field: Detected checksum field if any
64 length_field: Detected length field if any
65 """
67 total_size: int
68 fields: list[InferredField]
69 field_boundaries: list[int] # Byte offsets of field starts
70 header_size: int # Detected header size
71 payload_offset: int
72 checksum_field: InferredField | None
73 length_field: InferredField | None
76class MessageFormatInferrer:
77 """Infer message format from samples.
79 : Message format inference using entropy and variance analysis.
81 Algorithm:
82 1. Detect field boundaries using entropy transitions
83 2. Classify fields based on statistical patterns
84 3. Detect dependencies between fields
85 4. Generate complete schema
86 """
88 def __init__(self, min_samples: int = 10):
89 """Initialize inferrer.
91 Args:
92 min_samples: Minimum number of message samples required
93 """
94 self.min_samples = min_samples
96 def infer_format(self, messages: list[bytes | NDArray[np.uint8]]) -> MessageSchema:
97 """Infer message format from collection of similar messages.
99 : Complete format inference workflow.
101 Args:
102 messages: List of message samples (bytes or np.ndarray)
104 Returns:
105 MessageSchema with inferred field structure
107 Raises:
108 ValueError: If insufficient samples or invalid input
109 """
110 if len(messages) < self.min_samples:
111 raise ValueError(f"Need at least {self.min_samples} messages, got {len(messages)}")
113 # Convert to numpy arrays for processing
114 msg_arrays = []
115 for msg in messages:
116 if isinstance(msg, bytes):
117 msg_arrays.append(np.frombuffer(msg, dtype=np.uint8))
118 elif isinstance(msg, np.ndarray):
119 msg_arrays.append(msg.astype(np.uint8))
120 else:
121 raise ValueError(f"Invalid message type: {type(msg)}")
123 # Check all messages are same length
124 lengths = [len(m) for m in msg_arrays]
125 if len(set(lengths)) > 1:
126 raise ValueError(f"Messages have varying lengths: {set(lengths)}")
128 msg_len = lengths[0]
130 # Detect field boundaries
131 boundaries = self.detect_field_boundaries(msg_arrays, method="combined")
133 # Detect field types
134 fields = self.detect_field_types(msg_arrays, boundaries)
136 # Determine header size (first high-entropy transition or first 4 fields)
137 header_size = self._estimate_header_size(fields)
139 # Find checksum and length fields
140 checksum_field = None
141 length_field = None
143 for f in fields:
144 if f.field_type == "checksum":
145 checksum_field = f
146 elif f.field_type == "length": 146 ↛ 147line 146 didn't jump to line 147 because the condition on line 146 was never true
147 length_field = f
149 # Payload starts after header
150 payload_offset = header_size
152 schema = MessageSchema(
153 total_size=msg_len,
154 fields=fields,
155 field_boundaries=boundaries,
156 header_size=header_size,
157 payload_offset=payload_offset,
158 checksum_field=checksum_field,
159 length_field=length_field,
160 )
162 return schema
164 def detect_field_boundaries(
165 self,
166 messages: list[NDArray[np.uint8]],
167 method: Literal["entropy", "variance", "combined"] = "combined",
168 ) -> list[int]:
169 """Detect field boundaries using entropy transitions.
171 : Boundary detection via statistical transitions.
173 Args:
174 messages: List of message arrays
175 method: Detection method ('entropy', 'variance', or 'combined')
177 Returns:
178 List of byte offsets marking field starts (always includes 0)
179 """
180 if not messages:
181 return [0]
183 msg_len = len(messages[0])
184 boundaries = [0] # Always start at offset 0
186 if method in ["entropy", "combined"]:
187 # Calculate entropy at each byte position
188 entropies = []
189 for offset in range(msg_len):
190 entropy = self._calculate_byte_entropy(messages, offset)
191 entropies.append(entropy)
193 # Find transitions (entropy changes > threshold)
194 entropy_threshold = 1.5 # bits
195 for i in range(1, len(entropies)):
196 delta = abs(entropies[i] - entropies[i - 1])
197 if delta > entropy_threshold and i not in boundaries:
198 boundaries.append(i)
200 if method in ["variance", "combined"]:
201 # Calculate variance at each byte position
202 variances = []
203 for offset in range(msg_len):
204 values = [msg[offset] for msg in messages]
205 variance = np.var(values)
206 variances.append(variance)
208 # Find variance transitions
209 var_threshold = 1000.0
210 for i in range(1, len(variances)):
211 delta = abs(variances[i] - variances[i - 1])
212 if delta > var_threshold and i not in boundaries:
213 boundaries.append(i)
215 # Sort and ensure we don't have too many tiny fields
216 boundaries = sorted(set(boundaries))
218 # Merge boundaries that are too close (< 2 bytes apart)
219 merged = [boundaries[0]]
220 for b in boundaries[1:]:
221 if b - merged[-1] >= 2:
222 merged.append(b)
224 return merged
226 def detect_field_types(
227 self, messages: list[NDArray[np.uint8]], boundaries: list[int]
228 ) -> list[InferredField]:
229 """Classify field types based on value patterns.
231 : Field type classification.
233 Args:
234 messages: List of message arrays
235 boundaries: Field boundary offsets
237 Returns:
238 List of InferredField objects
239 """
240 fields = []
242 for i in range(len(boundaries)):
243 offset = boundaries[i]
245 # Determine field size
246 if i < len(boundaries) - 1:
247 size = boundaries[i + 1] - offset
248 else:
249 size = len(messages[0]) - offset
251 # Extract field values
252 values: list[int | tuple[int, ...]]
253 if size <= 4:
254 # Use integer representation for small fields
255 int_values: list[int] = []
256 for msg in messages:
257 if size == 1:
258 val_int = int(msg[offset])
259 elif size == 2:
260 val_int = int(msg[offset]) << 8 | int(msg[offset + 1])
261 elif size == 4:
262 val_int = (
263 int(msg[offset]) << 24
264 | int(msg[offset + 1]) << 16
265 | int(msg[offset + 2]) << 8
266 | int(msg[offset + 3])
267 )
268 else: # size == 3
269 val_int = (
270 int(msg[offset]) << 16
271 | int(msg[offset + 1]) << 8
272 | int(msg[offset + 2])
273 )
274 int_values.append(val_int)
275 values = list(int_values) # type: ignore[assignment]
276 else:
277 # For larger fields, use bytes
278 tuple_values: list[tuple[int, ...]] = []
279 for msg in messages:
280 val_tuple = tuple(int(b) for b in msg[offset : offset + size])
281 tuple_values.append(val_tuple)
282 values = list(tuple_values) # type: ignore[assignment]
284 # Calculate statistics
285 if size > 4:
286 # Bytes field - calculate entropy across all bytes
287 all_bytes_list: list[int] = []
288 for v in values:
289 if isinstance(v, tuple): 289 ↛ 288line 289 didn't jump to line 288 because the condition on line 289 was always true
290 all_bytes_list.extend(v)
291 all_bytes = np.array(all_bytes_list, dtype=np.uint8)
292 entropy = self._calculate_entropy(all_bytes)
293 variance = float(np.var(all_bytes))
294 else:
295 entropy = self._calculate_entropy(np.array(values, dtype=np.int64))
296 variance = float(np.var(values))
298 # Classify field type
299 field_type, confidence = self._classify_field(values, offset, size, messages)
301 # Sample values (first 5)
302 sample_values = values[:5]
304 field_obj = InferredField(
305 name=f"field_{i}",
306 offset=offset,
307 size=size,
308 field_type=field_type, # type: ignore[arg-type]
309 entropy=float(entropy),
310 variance=float(variance),
311 confidence=confidence,
312 values_seen=sample_values,
313 )
315 fields.append(field_obj)
317 return fields
319 def find_dependencies(
320 self, messages: list[NDArray[np.uint8]], schema: MessageSchema
321 ) -> dict[str, str]:
322 """Find dependencies between fields (e.g., length->payload).
324 : Field dependency detection.
326 Args:
327 messages: List of message arrays
328 schema: Inferred message schema
330 Returns:
331 Dictionary mapping field names to dependency descriptions
332 """
333 dependencies = {}
335 # Check for length field dependencies
336 for field in schema.fields:
337 if field.field_type == "length":
338 # Check if any field size correlates with this length value
339 for msg in messages:
340 _length_val = self._extract_field_value(msg, field)
341 # Look for fields that might be variable length
342 # This is a simplified check
343 dependencies[field.name] = "Potential length indicator"
345 return dependencies
347 def _calculate_byte_entropy(self, messages: list[NDArray[np.uint8]], offset: int) -> float:
348 """Calculate entropy at byte offset across messages.
350 : Entropy calculation for boundary detection.
352 Args:
353 messages: List of message arrays
354 offset: Byte offset to analyze
356 Returns:
357 Shannon entropy in bits
358 """
359 values = [msg[offset] for msg in messages]
360 return float(self._calculate_entropy(np.array(values)))
362 def _calculate_entropy(self, values: NDArray[np.int_ | np.uint8]) -> float:
363 """Calculate Shannon entropy of values.
365 Args:
366 values: Array of values
368 Returns:
369 Entropy in bits
370 """
371 if len(values) == 0:
372 return 0.0
374 # Count frequencies
375 _unique, counts = np.unique(values, return_counts=True)
376 probabilities = counts / len(values)
378 # Calculate Shannon entropy
379 entropy = -np.sum(probabilities * np.log2(probabilities + 1e-10))
380 return float(entropy)
382 def _classify_field(
383 self,
384 values: list[int | tuple[int, ...]],
385 offset: int,
386 size: int,
387 messages: list[NDArray[np.uint8]],
388 ) -> tuple[str, float]:
389 """Classify field type based on patterns.
391 : Field type classification logic.
393 Args:
394 values: Field values across all messages
395 offset: Field offset
396 size: Field size
397 messages: Original messages
399 Returns:
400 Tuple of (field_type, confidence)
401 """
402 # Handle byte fields (larger than 4 bytes)
403 if isinstance(values[0], tuple):
404 # Check if all tuples are identical (truly constant)
405 if len(set(values)) == 1:
406 return ("constant", 1.0)
408 entropy = self._calculate_entropy(np.concatenate([np.array(v) for v in values]))
409 if entropy < 1.0: 409 ↛ 410line 409 didn't jump to line 410 because the condition on line 409 was never true
410 return ("constant", 0.9)
411 elif entropy > 7.0:
412 return ("data", 0.6)
413 else:
414 return ("data", 0.5)
416 # Check for constant field
417 if len(set(values)) == 1:
418 return ("constant", 1.0)
420 # Check for counter field
421 if not isinstance(values[0], tuple) and self._detect_counter_field( # type: ignore[misc, unreachable]
422 [v for v in values if isinstance(v, int)]
423 ):
424 return ("counter", 0.9)
426 # Check for checksum (if near end of message)
427 msg_len = len(messages[0])
428 if offset + size >= msg_len - 4: # Within last 4 bytes
429 if self._detect_checksum_field(messages, offset, size):
430 return ("checksum", 0.8)
432 # Check for length field (small values, near start)
433 if offset < 8 and size <= 2:
434 if not isinstance(values[0], tuple): # type: ignore[unreachable] 434 ↛ 440line 434 didn't jump to line 440 because the condition on line 434 was always true
435 max_val = max(v for v in values if isinstance(v, int))
436 if max_val < msg_len * 2: # Reasonable length value 436 ↛ 437line 436 didn't jump to line 437 because the condition on line 436 was never true
437 return ("length", 0.6)
439 # Check variance for classification
440 variance = np.var(values)
441 entropy = self._calculate_entropy(np.array(values))
443 if variance < 10: 443 ↛ 444line 443 didn't jump to line 444 because the condition on line 443 was never true
444 return ("constant", 0.6)
445 elif entropy > 6.0:
446 return ("data", 0.7)
447 else:
448 return ("unknown", 0.5)
450 def _detect_counter_field(self, values: list[int]) -> bool:
451 """Check if values form a counter sequence.
453 : Counter field detection.
455 Args:
456 values: List of integer values
458 Returns:
459 True if values appear to be a counter
460 """
461 if len(values) < 3:
462 return False
464 # Check for monotonic increase
465 diffs = [values[i + 1] - values[i] for i in range(len(values) - 1)]
467 # Allow wrapping
468 diffs_filtered = [d for d in diffs if d >= 0]
470 # Check if most differences are 1 (counter increments)
471 if len(diffs_filtered) < len(diffs) * 0.7:
472 return False
474 ones = sum(1 for d in diffs_filtered if d == 1)
475 return ones >= len(diffs_filtered) * 0.7
477 def _detect_checksum_field(
478 self, messages: list[NDArray[np.uint8]], field_offset: int, field_size: int
479 ) -> bool:
480 """Check if field is likely a checksum.
482 : Checksum field detection.
484 Args:
485 messages: List of message arrays
486 field_offset: Offset of potential checksum field
487 field_size: Size of potential checksum field
489 Returns:
490 True if field appears to be a checksum
491 """
492 if field_size not in [1, 2, 4]:
493 return False
495 # Try simple XOR checksum
496 for msg in messages[: min(5, len(messages))]:
497 # Calculate XOR of all bytes before checksum
498 xor_sum = 0
499 for i in range(field_offset):
500 xor_sum ^= int(msg[i])
502 # Extract checksum value
503 if field_size == 1:
504 checksum = int(msg[field_offset])
505 elif field_size == 2:
506 checksum = int(msg[field_offset]) << 8 | int(msg[field_offset + 1])
507 else:
508 checksum = (
509 int(msg[field_offset]) << 24
510 | int(msg[field_offset + 1]) << 16
511 | int(msg[field_offset + 2]) << 8
512 | int(msg[field_offset + 3])
513 )
515 # For single-byte, compare
516 if field_size == 1 and (xor_sum & 0xFF) == checksum:
517 continue
518 else:
519 return False # Not a match
521 return True # All matched
523 def _estimate_header_size(self, fields: list[InferredField]) -> int:
524 """Estimate header size from field patterns.
526 Args:
527 fields: List of inferred fields
529 Returns:
530 Estimated header size in bytes
531 """
532 # Look for transition from low-entropy to high-entropy
533 for i, field in enumerate(fields):
534 if field.field_type == "data" and field.entropy > 6.0:
535 if i > 0:
536 return field.offset
538 # Default: first 4 fields or 16 bytes
539 if len(fields) >= 5:
540 # Header includes first 4 fields, so return offset of 5th field
541 return fields[4].offset
542 elif len(fields) >= 4:
543 # If exactly 4 fields, header is up to end of 4th field
544 return fields[3].offset + fields[3].size
545 elif fields: 545 ↛ 549line 545 didn't jump to line 549 because the condition on line 545 was always true
546 # Fewer than 4 fields - use offset of last field
547 return min(16, fields[-1].offset)
548 else:
549 return 16
551 def _extract_field_value(self, msg: NDArray[np.uint8], field: InferredField) -> int:
552 """Extract field value from message.
554 Args:
555 msg: Message array
556 field: Field definition
558 Returns:
559 Field value as integer
560 """
561 if field.size == 1:
562 return int(msg[field.offset])
563 elif field.size == 2:
564 return int(msg[field.offset]) << 8 | int(msg[field.offset + 1])
565 elif field.size == 4:
566 return (
567 int(msg[field.offset]) << 24
568 | int(msg[field.offset + 1]) << 16
569 | int(msg[field.offset + 2]) << 8
570 | int(msg[field.offset + 3])
571 )
572 else:
573 # Return first byte for larger fields
574 return int(msg[field.offset])
577def infer_format(messages: list[bytes | NDArray[np.uint8]], min_samples: int = 10) -> MessageSchema:
578 """Convenience function for format inference.
580 : Top-level API for message format inference.
582 Args:
583 messages: List of message samples (bytes or np.ndarray)
584 min_samples: Minimum required samples
586 Returns:
587 MessageSchema with inferred structure
588 """
589 inferrer = MessageFormatInferrer(min_samples=min_samples)
590 return inferrer.infer_format(messages)
593def detect_field_types(
594 messages: list[bytes | NDArray[np.uint8]], boundaries: list[int]
595) -> list[InferredField]:
596 """Detect field types at boundaries.
598 : Field type detection.
600 Args:
601 messages: List of message samples
602 boundaries: Field boundary offsets
604 Returns:
605 List of InferredField objects
607 Raises:
608 ValueError: If message type is invalid.
609 """
610 inferrer = MessageFormatInferrer()
612 # Convert to arrays
613 msg_arrays = []
614 for msg in messages:
615 if isinstance(msg, bytes):
616 msg_arrays.append(np.frombuffer(msg, dtype=np.uint8))
617 elif isinstance(msg, np.ndarray):
618 msg_arrays.append(msg.astype(np.uint8))
619 else:
620 raise ValueError(f"Invalid message type: {type(msg)}")
622 return inferrer.detect_field_types(msg_arrays, boundaries)
625def find_dependencies(
626 messages: list[bytes | NDArray[np.uint8]], schema: MessageSchema
627) -> dict[str, str]:
628 """Find field dependencies.
630 : Field dependency analysis.
632 Args:
633 messages: List of message samples
634 schema: Message schema
636 Returns:
637 Dictionary of dependencies
639 Raises:
640 ValueError: If message type is invalid.
641 """
642 inferrer = MessageFormatInferrer()
644 # Convert to arrays
645 msg_arrays = []
646 for msg in messages:
647 if isinstance(msg, bytes):
648 msg_arrays.append(np.frombuffer(msg, dtype=np.uint8))
649 elif isinstance(msg, np.ndarray): 649 ↛ 652line 649 didn't jump to line 652 because the condition on line 649 was always true
650 msg_arrays.append(msg.astype(np.uint8))
651 else:
652 raise ValueError(f"Invalid message type: {type(msg)}")
654 return inferrer.find_dependencies(msg_arrays, schema)