Coverage for src / tracekit / inference / message_format.py: 96%

241 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-11 23:04 +0000

1"""Message format inference using statistical analysis. 

2 

3Requirements addressed: PSI-001 

4 

5This module automatically infers message field structure from collections of 

6similar messages for protocol reverse engineering. 

7 

8Key capabilities: 

9- Detect field boundaries via entropy transitions 

10- Classify fields as constant, variable, or sequential 

11- Infer field types (integer, counter, timestamp, checksum) 

12- Detect field dependencies (length fields, checksums) 

13- Generate message format specifications 

14""" 

15 

16from dataclasses import dataclass 

17from dataclasses import field as dataclass_field 

18from typing import Any, Literal 

19 

20import numpy as np 

21from numpy.typing import NDArray 

22 

23 

24@dataclass 

25class InferredField: 

26 """An inferred message field. 

27 

28 : Field classification and type inference. 

29 

30 Attributes: 

31 name: Auto-generated field name 

32 offset: Byte offset from message start 

33 size: Field size in bytes 

34 field_type: Inferred field type classification 

35 entropy: Shannon entropy of field values 

36 variance: Statistical variance of field values 

37 confidence: Confidence score (0-1) for type inference 

38 values_seen: Sample values for validation 

39 """ 

40 

41 name: str 

42 offset: int 

43 size: int 

44 field_type: Literal["constant", "counter", "timestamp", "length", "checksum", "data", "unknown"] 

45 entropy: float 

46 variance: float 

47 confidence: float 

48 values_seen: list[Any] = dataclass_field(default_factory=list) # Sample values 

49 

50 

51@dataclass 

52class MessageSchema: 

53 """Inferred message format schema. 

54 

55 : Complete message format specification. 

56 

57 Attributes: 

58 total_size: Total message size in bytes 

59 fields: List of inferred fields 

60 field_boundaries: Byte offsets of field starts 

61 header_size: Detected header size 

62 payload_offset: Start of payload region 

63 checksum_field: Detected checksum field if any 

64 length_field: Detected length field if any 

65 """ 

66 

67 total_size: int 

68 fields: list[InferredField] 

69 field_boundaries: list[int] # Byte offsets of field starts 

70 header_size: int # Detected header size 

71 payload_offset: int 

72 checksum_field: InferredField | None 

73 length_field: InferredField | None 

74 

75 

76class MessageFormatInferrer: 

77 """Infer message format from samples. 

78 

79 : Message format inference using entropy and variance analysis. 

80 

81 Algorithm: 

82 1. Detect field boundaries using entropy transitions 

83 2. Classify fields based on statistical patterns 

84 3. Detect dependencies between fields 

85 4. Generate complete schema 

86 """ 

87 

88 def __init__(self, min_samples: int = 10): 

89 """Initialize inferrer. 

90 

91 Args: 

92 min_samples: Minimum number of message samples required 

93 """ 

94 self.min_samples = min_samples 

95 

96 def infer_format(self, messages: list[bytes | NDArray[np.uint8]]) -> MessageSchema: 

97 """Infer message format from collection of similar messages. 

98 

99 : Complete format inference workflow. 

100 

101 Args: 

102 messages: List of message samples (bytes or np.ndarray) 

103 

104 Returns: 

105 MessageSchema with inferred field structure 

106 

107 Raises: 

108 ValueError: If insufficient samples or invalid input 

109 """ 

110 if len(messages) < self.min_samples: 

111 raise ValueError(f"Need at least {self.min_samples} messages, got {len(messages)}") 

112 

113 # Convert to numpy arrays for processing 

114 msg_arrays = [] 

115 for msg in messages: 

116 if isinstance(msg, bytes): 

117 msg_arrays.append(np.frombuffer(msg, dtype=np.uint8)) 

118 elif isinstance(msg, np.ndarray): 

119 msg_arrays.append(msg.astype(np.uint8)) 

120 else: 

121 raise ValueError(f"Invalid message type: {type(msg)}") 

122 

123 # Check all messages are same length 

124 lengths = [len(m) for m in msg_arrays] 

125 if len(set(lengths)) > 1: 

126 raise ValueError(f"Messages have varying lengths: {set(lengths)}") 

127 

128 msg_len = lengths[0] 

129 

130 # Detect field boundaries 

131 boundaries = self.detect_field_boundaries(msg_arrays, method="combined") 

132 

133 # Detect field types 

134 fields = self.detect_field_types(msg_arrays, boundaries) 

135 

136 # Determine header size (first high-entropy transition or first 4 fields) 

137 header_size = self._estimate_header_size(fields) 

138 

139 # Find checksum and length fields 

140 checksum_field = None 

141 length_field = None 

142 

143 for f in fields: 

144 if f.field_type == "checksum": 

145 checksum_field = f 

146 elif f.field_type == "length": 146 ↛ 147line 146 didn't jump to line 147 because the condition on line 146 was never true

147 length_field = f 

148 

149 # Payload starts after header 

150 payload_offset = header_size 

151 

152 schema = MessageSchema( 

153 total_size=msg_len, 

154 fields=fields, 

155 field_boundaries=boundaries, 

156 header_size=header_size, 

157 payload_offset=payload_offset, 

158 checksum_field=checksum_field, 

159 length_field=length_field, 

160 ) 

161 

162 return schema 

163 

164 def detect_field_boundaries( 

165 self, 

166 messages: list[NDArray[np.uint8]], 

167 method: Literal["entropy", "variance", "combined"] = "combined", 

168 ) -> list[int]: 

169 """Detect field boundaries using entropy transitions. 

170 

171 : Boundary detection via statistical transitions. 

172 

173 Args: 

174 messages: List of message arrays 

175 method: Detection method ('entropy', 'variance', or 'combined') 

176 

177 Returns: 

178 List of byte offsets marking field starts (always includes 0) 

179 """ 

180 if not messages: 

181 return [0] 

182 

183 msg_len = len(messages[0]) 

184 boundaries = [0] # Always start at offset 0 

185 

186 if method in ["entropy", "combined"]: 

187 # Calculate entropy at each byte position 

188 entropies = [] 

189 for offset in range(msg_len): 

190 entropy = self._calculate_byte_entropy(messages, offset) 

191 entropies.append(entropy) 

192 

193 # Find transitions (entropy changes > threshold) 

194 entropy_threshold = 1.5 # bits 

195 for i in range(1, len(entropies)): 

196 delta = abs(entropies[i] - entropies[i - 1]) 

197 if delta > entropy_threshold and i not in boundaries: 

198 boundaries.append(i) 

199 

200 if method in ["variance", "combined"]: 

201 # Calculate variance at each byte position 

202 variances = [] 

203 for offset in range(msg_len): 

204 values = [msg[offset] for msg in messages] 

205 variance = np.var(values) 

206 variances.append(variance) 

207 

208 # Find variance transitions 

209 var_threshold = 1000.0 

210 for i in range(1, len(variances)): 

211 delta = abs(variances[i] - variances[i - 1]) 

212 if delta > var_threshold and i not in boundaries: 

213 boundaries.append(i) 

214 

215 # Sort and ensure we don't have too many tiny fields 

216 boundaries = sorted(set(boundaries)) 

217 

218 # Merge boundaries that are too close (< 2 bytes apart) 

219 merged = [boundaries[0]] 

220 for b in boundaries[1:]: 

221 if b - merged[-1] >= 2: 

222 merged.append(b) 

223 

224 return merged 

225 

226 def detect_field_types( 

227 self, messages: list[NDArray[np.uint8]], boundaries: list[int] 

228 ) -> list[InferredField]: 

229 """Classify field types based on value patterns. 

230 

231 : Field type classification. 

232 

233 Args: 

234 messages: List of message arrays 

235 boundaries: Field boundary offsets 

236 

237 Returns: 

238 List of InferredField objects 

239 """ 

240 fields = [] 

241 

242 for i in range(len(boundaries)): 

243 offset = boundaries[i] 

244 

245 # Determine field size 

246 if i < len(boundaries) - 1: 

247 size = boundaries[i + 1] - offset 

248 else: 

249 size = len(messages[0]) - offset 

250 

251 # Extract field values 

252 values: list[int | tuple[int, ...]] 

253 if size <= 4: 

254 # Use integer representation for small fields 

255 int_values: list[int] = [] 

256 for msg in messages: 

257 if size == 1: 

258 val_int = int(msg[offset]) 

259 elif size == 2: 

260 val_int = int(msg[offset]) << 8 | int(msg[offset + 1]) 

261 elif size == 4: 

262 val_int = ( 

263 int(msg[offset]) << 24 

264 | int(msg[offset + 1]) << 16 

265 | int(msg[offset + 2]) << 8 

266 | int(msg[offset + 3]) 

267 ) 

268 else: # size == 3 

269 val_int = ( 

270 int(msg[offset]) << 16 

271 | int(msg[offset + 1]) << 8 

272 | int(msg[offset + 2]) 

273 ) 

274 int_values.append(val_int) 

275 values = list(int_values) # type: ignore[assignment] 

276 else: 

277 # For larger fields, use bytes 

278 tuple_values: list[tuple[int, ...]] = [] 

279 for msg in messages: 

280 val_tuple = tuple(int(b) for b in msg[offset : offset + size]) 

281 tuple_values.append(val_tuple) 

282 values = list(tuple_values) # type: ignore[assignment] 

283 

284 # Calculate statistics 

285 if size > 4: 

286 # Bytes field - calculate entropy across all bytes 

287 all_bytes_list: list[int] = [] 

288 for v in values: 

289 if isinstance(v, tuple): 289 ↛ 288line 289 didn't jump to line 288 because the condition on line 289 was always true

290 all_bytes_list.extend(v) 

291 all_bytes = np.array(all_bytes_list, dtype=np.uint8) 

292 entropy = self._calculate_entropy(all_bytes) 

293 variance = float(np.var(all_bytes)) 

294 else: 

295 entropy = self._calculate_entropy(np.array(values, dtype=np.int64)) 

296 variance = float(np.var(values)) 

297 

298 # Classify field type 

299 field_type, confidence = self._classify_field(values, offset, size, messages) 

300 

301 # Sample values (first 5) 

302 sample_values = values[:5] 

303 

304 field_obj = InferredField( 

305 name=f"field_{i}", 

306 offset=offset, 

307 size=size, 

308 field_type=field_type, # type: ignore[arg-type] 

309 entropy=float(entropy), 

310 variance=float(variance), 

311 confidence=confidence, 

312 values_seen=sample_values, 

313 ) 

314 

315 fields.append(field_obj) 

316 

317 return fields 

318 

319 def find_dependencies( 

320 self, messages: list[NDArray[np.uint8]], schema: MessageSchema 

321 ) -> dict[str, str]: 

322 """Find dependencies between fields (e.g., length->payload). 

323 

324 : Field dependency detection. 

325 

326 Args: 

327 messages: List of message arrays 

328 schema: Inferred message schema 

329 

330 Returns: 

331 Dictionary mapping field names to dependency descriptions 

332 """ 

333 dependencies = {} 

334 

335 # Check for length field dependencies 

336 for field in schema.fields: 

337 if field.field_type == "length": 

338 # Check if any field size correlates with this length value 

339 for msg in messages: 

340 _length_val = self._extract_field_value(msg, field) 

341 # Look for fields that might be variable length 

342 # This is a simplified check 

343 dependencies[field.name] = "Potential length indicator" 

344 

345 return dependencies 

346 

347 def _calculate_byte_entropy(self, messages: list[NDArray[np.uint8]], offset: int) -> float: 

348 """Calculate entropy at byte offset across messages. 

349 

350 : Entropy calculation for boundary detection. 

351 

352 Args: 

353 messages: List of message arrays 

354 offset: Byte offset to analyze 

355 

356 Returns: 

357 Shannon entropy in bits 

358 """ 

359 values = [msg[offset] for msg in messages] 

360 return float(self._calculate_entropy(np.array(values))) 

361 

362 def _calculate_entropy(self, values: NDArray[np.int_ | np.uint8]) -> float: 

363 """Calculate Shannon entropy of values. 

364 

365 Args: 

366 values: Array of values 

367 

368 Returns: 

369 Entropy in bits 

370 """ 

371 if len(values) == 0: 

372 return 0.0 

373 

374 # Count frequencies 

375 _unique, counts = np.unique(values, return_counts=True) 

376 probabilities = counts / len(values) 

377 

378 # Calculate Shannon entropy 

379 entropy = -np.sum(probabilities * np.log2(probabilities + 1e-10)) 

380 return float(entropy) 

381 

382 def _classify_field( 

383 self, 

384 values: list[int | tuple[int, ...]], 

385 offset: int, 

386 size: int, 

387 messages: list[NDArray[np.uint8]], 

388 ) -> tuple[str, float]: 

389 """Classify field type based on patterns. 

390 

391 : Field type classification logic. 

392 

393 Args: 

394 values: Field values across all messages 

395 offset: Field offset 

396 size: Field size 

397 messages: Original messages 

398 

399 Returns: 

400 Tuple of (field_type, confidence) 

401 """ 

402 # Handle byte fields (larger than 4 bytes) 

403 if isinstance(values[0], tuple): 

404 # Check if all tuples are identical (truly constant) 

405 if len(set(values)) == 1: 

406 return ("constant", 1.0) 

407 

408 entropy = self._calculate_entropy(np.concatenate([np.array(v) for v in values])) 

409 if entropy < 1.0: 409 ↛ 410line 409 didn't jump to line 410 because the condition on line 409 was never true

410 return ("constant", 0.9) 

411 elif entropy > 7.0: 

412 return ("data", 0.6) 

413 else: 

414 return ("data", 0.5) 

415 

416 # Check for constant field 

417 if len(set(values)) == 1: 

418 return ("constant", 1.0) 

419 

420 # Check for counter field 

421 if not isinstance(values[0], tuple) and self._detect_counter_field( # type: ignore[misc, unreachable] 

422 [v for v in values if isinstance(v, int)] 

423 ): 

424 return ("counter", 0.9) 

425 

426 # Check for checksum (if near end of message) 

427 msg_len = len(messages[0]) 

428 if offset + size >= msg_len - 4: # Within last 4 bytes 

429 if self._detect_checksum_field(messages, offset, size): 

430 return ("checksum", 0.8) 

431 

432 # Check for length field (small values, near start) 

433 if offset < 8 and size <= 2: 

434 if not isinstance(values[0], tuple): # type: ignore[unreachable] 434 ↛ 440line 434 didn't jump to line 440 because the condition on line 434 was always true

435 max_val = max(v for v in values if isinstance(v, int)) 

436 if max_val < msg_len * 2: # Reasonable length value 436 ↛ 437line 436 didn't jump to line 437 because the condition on line 436 was never true

437 return ("length", 0.6) 

438 

439 # Check variance for classification 

440 variance = np.var(values) 

441 entropy = self._calculate_entropy(np.array(values)) 

442 

443 if variance < 10: 443 ↛ 444line 443 didn't jump to line 444 because the condition on line 443 was never true

444 return ("constant", 0.6) 

445 elif entropy > 6.0: 

446 return ("data", 0.7) 

447 else: 

448 return ("unknown", 0.5) 

449 

450 def _detect_counter_field(self, values: list[int]) -> bool: 

451 """Check if values form a counter sequence. 

452 

453 : Counter field detection. 

454 

455 Args: 

456 values: List of integer values 

457 

458 Returns: 

459 True if values appear to be a counter 

460 """ 

461 if len(values) < 3: 

462 return False 

463 

464 # Check for monotonic increase 

465 diffs = [values[i + 1] - values[i] for i in range(len(values) - 1)] 

466 

467 # Allow wrapping 

468 diffs_filtered = [d for d in diffs if d >= 0] 

469 

470 # Check if most differences are 1 (counter increments) 

471 if len(diffs_filtered) < len(diffs) * 0.7: 

472 return False 

473 

474 ones = sum(1 for d in diffs_filtered if d == 1) 

475 return ones >= len(diffs_filtered) * 0.7 

476 

477 def _detect_checksum_field( 

478 self, messages: list[NDArray[np.uint8]], field_offset: int, field_size: int 

479 ) -> bool: 

480 """Check if field is likely a checksum. 

481 

482 : Checksum field detection. 

483 

484 Args: 

485 messages: List of message arrays 

486 field_offset: Offset of potential checksum field 

487 field_size: Size of potential checksum field 

488 

489 Returns: 

490 True if field appears to be a checksum 

491 """ 

492 if field_size not in [1, 2, 4]: 

493 return False 

494 

495 # Try simple XOR checksum 

496 for msg in messages[: min(5, len(messages))]: 

497 # Calculate XOR of all bytes before checksum 

498 xor_sum = 0 

499 for i in range(field_offset): 

500 xor_sum ^= int(msg[i]) 

501 

502 # Extract checksum value 

503 if field_size == 1: 

504 checksum = int(msg[field_offset]) 

505 elif field_size == 2: 

506 checksum = int(msg[field_offset]) << 8 | int(msg[field_offset + 1]) 

507 else: 

508 checksum = ( 

509 int(msg[field_offset]) << 24 

510 | int(msg[field_offset + 1]) << 16 

511 | int(msg[field_offset + 2]) << 8 

512 | int(msg[field_offset + 3]) 

513 ) 

514 

515 # For single-byte, compare 

516 if field_size == 1 and (xor_sum & 0xFF) == checksum: 

517 continue 

518 else: 

519 return False # Not a match 

520 

521 return True # All matched 

522 

523 def _estimate_header_size(self, fields: list[InferredField]) -> int: 

524 """Estimate header size from field patterns. 

525 

526 Args: 

527 fields: List of inferred fields 

528 

529 Returns: 

530 Estimated header size in bytes 

531 """ 

532 # Look for transition from low-entropy to high-entropy 

533 for i, field in enumerate(fields): 

534 if field.field_type == "data" and field.entropy > 6.0: 

535 if i > 0: 

536 return field.offset 

537 

538 # Default: first 4 fields or 16 bytes 

539 if len(fields) >= 5: 

540 # Header includes first 4 fields, so return offset of 5th field 

541 return fields[4].offset 

542 elif len(fields) >= 4: 

543 # If exactly 4 fields, header is up to end of 4th field 

544 return fields[3].offset + fields[3].size 

545 elif fields: 545 ↛ 549line 545 didn't jump to line 549 because the condition on line 545 was always true

546 # Fewer than 4 fields - use offset of last field 

547 return min(16, fields[-1].offset) 

548 else: 

549 return 16 

550 

551 def _extract_field_value(self, msg: NDArray[np.uint8], field: InferredField) -> int: 

552 """Extract field value from message. 

553 

554 Args: 

555 msg: Message array 

556 field: Field definition 

557 

558 Returns: 

559 Field value as integer 

560 """ 

561 if field.size == 1: 

562 return int(msg[field.offset]) 

563 elif field.size == 2: 

564 return int(msg[field.offset]) << 8 | int(msg[field.offset + 1]) 

565 elif field.size == 4: 

566 return ( 

567 int(msg[field.offset]) << 24 

568 | int(msg[field.offset + 1]) << 16 

569 | int(msg[field.offset + 2]) << 8 

570 | int(msg[field.offset + 3]) 

571 ) 

572 else: 

573 # Return first byte for larger fields 

574 return int(msg[field.offset]) 

575 

576 

577def infer_format(messages: list[bytes | NDArray[np.uint8]], min_samples: int = 10) -> MessageSchema: 

578 """Convenience function for format inference. 

579 

580 : Top-level API for message format inference. 

581 

582 Args: 

583 messages: List of message samples (bytes or np.ndarray) 

584 min_samples: Minimum required samples 

585 

586 Returns: 

587 MessageSchema with inferred structure 

588 """ 

589 inferrer = MessageFormatInferrer(min_samples=min_samples) 

590 return inferrer.infer_format(messages) 

591 

592 

593def detect_field_types( 

594 messages: list[bytes | NDArray[np.uint8]], boundaries: list[int] 

595) -> list[InferredField]: 

596 """Detect field types at boundaries. 

597 

598 : Field type detection. 

599 

600 Args: 

601 messages: List of message samples 

602 boundaries: Field boundary offsets 

603 

604 Returns: 

605 List of InferredField objects 

606 

607 Raises: 

608 ValueError: If message type is invalid. 

609 """ 

610 inferrer = MessageFormatInferrer() 

611 

612 # Convert to arrays 

613 msg_arrays = [] 

614 for msg in messages: 

615 if isinstance(msg, bytes): 

616 msg_arrays.append(np.frombuffer(msg, dtype=np.uint8)) 

617 elif isinstance(msg, np.ndarray): 

618 msg_arrays.append(msg.astype(np.uint8)) 

619 else: 

620 raise ValueError(f"Invalid message type: {type(msg)}") 

621 

622 return inferrer.detect_field_types(msg_arrays, boundaries) 

623 

624 

625def find_dependencies( 

626 messages: list[bytes | NDArray[np.uint8]], schema: MessageSchema 

627) -> dict[str, str]: 

628 """Find field dependencies. 

629 

630 : Field dependency analysis. 

631 

632 Args: 

633 messages: List of message samples 

634 schema: Message schema 

635 

636 Returns: 

637 Dictionary of dependencies 

638 

639 Raises: 

640 ValueError: If message type is invalid. 

641 """ 

642 inferrer = MessageFormatInferrer() 

643 

644 # Convert to arrays 

645 msg_arrays = [] 

646 for msg in messages: 

647 if isinstance(msg, bytes): 

648 msg_arrays.append(np.frombuffer(msg, dtype=np.uint8)) 

649 elif isinstance(msg, np.ndarray): 649 ↛ 652line 649 didn't jump to line 652 because the condition on line 649 was always true

650 msg_arrays.append(msg.astype(np.uint8)) 

651 else: 

652 raise ValueError(f"Invalid message type: {type(msg)}") 

653 

654 return inferrer.find_dependencies(msg_arrays, schema)