Coverage for src / tracekit / pipeline / reverse_engineering.py: 84%

369 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-11 23:04 +0000

1"""Reverse Engineering Pipeline for integrated protocol analysis. 

2 

3 - RE-INT-001: RE Pipeline Integration 

4 

5This module provides an integrated pipeline for complete reverse engineering 

6workflows from raw packet capture to decoded messages with automatic tool 

7selection and chaining. 

8""" 

9 

10from __future__ import annotations 

11 

12import json 

13import os 

14import time 

15from collections.abc import Callable, Sequence 

16from dataclasses import dataclass, field 

17from datetime import datetime 

18from pathlib import Path 

19from typing import Any, ClassVar, Literal 

20 

21 

22@dataclass 

23class FlowInfo: 

24 """Information about a network flow. 

25 

26 Attributes: 

27 flow_id: Unique flow identifier. 

28 src_ip: Source IP address. 

29 dst_ip: Destination IP address. 

30 src_port: Source port. 

31 dst_port: Destination port. 

32 protocol: Transport protocol. 

33 packet_count: Number of packets in flow. 

34 byte_count: Total bytes in flow. 

35 start_time: Flow start timestamp. 

36 end_time: Flow end timestamp. 

37 """ 

38 

39 flow_id: str 

40 src_ip: str 

41 dst_ip: str 

42 src_port: int 

43 dst_port: int 

44 protocol: str 

45 packet_count: int 

46 byte_count: int 

47 start_time: float 

48 end_time: float 

49 

50 

51@dataclass 

52class MessageTypeInfo: 

53 """Information about detected message types. 

54 

55 Attributes: 

56 type_id: Unique type identifier. 

57 name: Type name (auto-generated or from inference). 

58 sample_count: Number of messages of this type. 

59 avg_length: Average message length. 

60 field_count: Number of detected fields. 

61 signature: Representative byte signature. 

62 cluster_id: Associated cluster ID. 

63 """ 

64 

65 type_id: str 

66 name: str 

67 sample_count: int 

68 avg_length: float 

69 field_count: int 

70 signature: bytes 

71 cluster_id: int 

72 

73 

74@dataclass 

75class ProtocolCandidate: 

76 """Candidate protocol identification. 

77 

78 Attributes: 

79 name: Protocol name. 

80 confidence: Detection confidence (0-1). 

81 matched_patterns: Patterns that matched. 

82 port_hint: Whether port suggested this protocol. 

83 header_match: Whether header matched signature. 

84 """ 

85 

86 name: str 

87 confidence: float 

88 matched_patterns: list[str] = field(default_factory=list) 

89 port_hint: bool = False 

90 header_match: bool = False 

91 

92 

93@dataclass 

94class REAnalysisResult: 

95 """Complete reverse engineering analysis result. 

96 

97 Implements RE-INT-001: Analysis result structure. 

98 

99 Attributes: 

100 flow_count: Number of flows analyzed. 

101 message_count: Total messages extracted. 

102 message_types: Detected message types. 

103 protocol_candidates: Candidate protocol identifications. 

104 field_schemas: Inferred field schemas per message type. 

105 state_machine: Inferred state machine (if available). 

106 statistics: Analysis statistics. 

107 warnings: Warnings encountered during analysis. 

108 duration_seconds: Analysis duration. 

109 timestamp: Analysis timestamp. 

110 """ 

111 

112 flow_count: int 

113 message_count: int 

114 message_types: list[MessageTypeInfo] 

115 protocol_candidates: list[ProtocolCandidate] 

116 field_schemas: dict[str, Any] 

117 state_machine: Any | None 

118 statistics: dict[str, Any] 

119 warnings: list[str] 

120 duration_seconds: float 

121 timestamp: str 

122 

123 

124@dataclass 

125class StageResult: 

126 """Result from a single pipeline stage. 

127 

128 Attributes: 

129 stage_name: Name of the stage. 

130 success: Whether stage completed successfully. 

131 duration: Stage duration in seconds. 

132 output: Stage output data. 

133 error: Error message if failed. 

134 """ 

135 

136 stage_name: str 

137 success: bool 

138 duration: float 

139 output: Any 

140 error: str | None = None 

141 

142 

143class REPipeline: 

144 """Integrated reverse engineering pipeline. 

145 

146 Implements RE-INT-001: RE Pipeline Integration. 

147 

148 Chains all RE tools in coherent pipeline with automatic tool 

149 selection based on data characteristics. 

150 

151 Example: 

152 >>> pipeline = REPipeline() 

153 >>> results = pipeline.analyze(packet_data) 

154 >>> print(f"Detected {len(results.message_types)} message types") 

155 >>> pipeline.generate_report(results, "report.html") 

156 """ 

157 

158 # Default pipeline stages 

159 DEFAULT_STAGES: ClassVar[list[str]] = [ 

160 "flow_extraction", 

161 "payload_analysis", 

162 "pattern_discovery", 

163 "field_inference", 

164 "protocol_detection", 

165 "state_machine", 

166 ] 

167 

168 def __init__( 

169 self, 

170 stages: list[str] | None = None, 

171 config: dict[str, Any] | None = None, 

172 ) -> None: 

173 """Initialize RE pipeline. 

174 

175 Args: 

176 stages: List of stage names to execute. 

177 config: Configuration options. 

178 """ 

179 self.stages = stages or self.DEFAULT_STAGES 

180 self.config = config or {} 

181 

182 # Default configuration 

183 self.config.setdefault("min_samples", 10) 

184 self.config.setdefault("entropy_threshold", 6.0) 

185 self.config.setdefault("cluster_threshold", 0.8) 

186 self.config.setdefault("state_machine_algorithm", "rpni") 

187 self.config.setdefault("max_message_types", 50) 

188 

189 # Stage handlers 

190 self._stage_handlers: dict[str, Callable[[dict[str, Any]], dict[str, Any]]] = { 

191 "flow_extraction": self._stage_flow_extraction, 

192 "payload_analysis": self._stage_payload_analysis, 

193 "pattern_discovery": self._stage_pattern_discovery, 

194 "field_inference": self._stage_field_inference, 

195 "protocol_detection": self._stage_protocol_detection, 

196 "state_machine": self._stage_state_machine, 

197 } 

198 

199 # Progress callback 

200 self._progress_callback: Callable[[str, float], None] | None = None 

201 

202 # Checkpoint support 

203 self._checkpoint_path: str | None = None 

204 self._checkpoint_data: dict[str, Any] = {} 

205 

206 def analyze( 

207 self, 

208 data: bytes | Sequence[dict[str, Any]] | Sequence[bytes], 

209 checkpoint: str | None = None, 

210 progress_callback: Callable[[str, float], None] | None = None, 

211 ) -> REAnalysisResult: 

212 """Run full reverse engineering analysis. 

213 

214 Implements RE-INT-001: Complete analysis workflow. 

215 

216 Args: 

217 data: Raw binary data, packet list, or PCAP path. 

218 checkpoint: Path for checkpointing progress. 

219 progress_callback: Callback for progress reporting. 

220 

221 Returns: 

222 REAnalysisResult with complete analysis. 

223 

224 Example: 

225 >>> results = pipeline.analyze(packets) 

226 >>> for msg_type in results.message_types: 

227 ... print(f"{msg_type.name}: {msg_type.sample_count} samples") 

228 """ 

229 start_time = time.time() 

230 self._progress_callback = progress_callback 

231 self._checkpoint_path = checkpoint 

232 self._checkpoint_data = {} 

233 

234 # Load checkpoint if available 

235 if checkpoint and os.path.exists(checkpoint): 

236 self._load_checkpoint(checkpoint) 

237 

238 # Initialize context 

239 context: dict[str, Any] = { 

240 "raw_data": data, 

241 "flows": [], 

242 "payloads": [], 

243 "messages": [], 

244 "patterns": [], 

245 "clusters": [], 

246 "schemas": {}, 

247 "protocol_candidates": [], 

248 "state_machine": None, 

249 "warnings": [], 

250 "statistics": {}, 

251 } 

252 

253 # Execute stages 

254 stage_results = [] 

255 total_stages = len(self.stages) 

256 

257 for i, stage_name in enumerate(self.stages): 

258 if stage_name in self._checkpoint_data: 

259 # Skip completed stages 

260 context.update(self._checkpoint_data[stage_name]) 

261 continue 

262 

263 self._report_progress(stage_name, (i / total_stages) * 100) 

264 

265 handler = self._stage_handlers.get(stage_name) 

266 if handler: 

267 try: 

268 stage_start = time.time() 

269 output = handler(context) 

270 stage_duration = time.time() - stage_start 

271 

272 stage_results.append( 

273 StageResult( 

274 stage_name=stage_name, 

275 success=True, 

276 duration=stage_duration, 

277 output=output, 

278 ) 

279 ) 

280 

281 # Update context with stage output 

282 if output: 282 ↛ 286line 282 didn't jump to line 286 because the condition on line 282 was always true

283 context.update(output) 

284 

285 # Checkpoint after each stage 

286 if checkpoint: 

287 self._save_checkpoint(checkpoint, stage_name, context) 

288 

289 except Exception as e: 

290 stage_results.append( 

291 StageResult( 

292 stage_name=stage_name, 

293 success=False, 

294 duration=0, 

295 output=None, 

296 error=str(e), 

297 ) 

298 ) 

299 warnings_list: list[str] = context.get("warnings", []) 

300 warnings_list.append(f"Stage {stage_name} failed: {e}") 

301 context["warnings"] = warnings_list 

302 

303 self._report_progress("complete", 100) 

304 

305 # Build result 

306 duration = time.time() - start_time 

307 

308 flows_list: list[Any] = context.get("flows", []) 

309 messages_list: list[Any] = context.get("messages", []) 

310 protocol_candidates_list: list[ProtocolCandidate] = context.get("protocol_candidates", []) 

311 schemas_dict: dict[str, Any] = context.get("schemas", {}) 

312 warnings_list_result: list[str] = context.get("warnings", []) 

313 

314 return REAnalysisResult( 

315 flow_count=len(flows_list), 

316 message_count=len(messages_list), 

317 message_types=self._build_message_types(context), 

318 protocol_candidates=protocol_candidates_list, 

319 field_schemas=schemas_dict, 

320 state_machine=context.get("state_machine"), 

321 statistics=self._build_statistics(context, stage_results), 

322 warnings=warnings_list_result, 

323 duration_seconds=duration, 

324 timestamp=datetime.now().isoformat(), 

325 ) 

326 

327 def analyze_pcap( 

328 self, 

329 path: str | Path, 

330 checkpoint: str | None = None, 

331 ) -> REAnalysisResult: 

332 """Analyze packets from a PCAP file. 

333 

334 Implements RE-INT-001: PCAP file analysis. 

335 

336 Args: 

337 path: Path to PCAP file. 

338 checkpoint: Optional checkpoint path. 

339 

340 Returns: 

341 REAnalysisResult with analysis results. 

342 

343 Raises: 

344 FileNotFoundError: If PCAP file not found. 

345 """ 

346 # Load PCAP (simplified - would use scapy or pyshark in real impl) 

347 path = Path(path) 

348 if not path.exists(): 

349 raise FileNotFoundError(f"PCAP file not found: {path}") 

350 

351 with open(path, "rb") as f: 

352 data = f.read() 

353 

354 return self.analyze(data, checkpoint=checkpoint) 

355 

356 def generate_report( 

357 self, 

358 results: REAnalysisResult, 

359 output_path: str | Path, 

360 format: Literal["html", "json", "markdown"] = "html", 

361 ) -> None: 

362 """Generate analysis report. 

363 

364 Implements RE-INT-001: Report generation. 

365 

366 Args: 

367 results: Analysis results. 

368 output_path: Output file path. 

369 format: Report format. 

370 

371 Example: 

372 >>> pipeline.generate_report(results, "report.html") 

373 """ 

374 output_path = Path(output_path) 

375 

376 if format == "json": 

377 self._generate_json_report(results, output_path) 

378 elif format == "markdown": 

379 self._generate_markdown_report(results, output_path) 

380 else: 

381 self._generate_html_report(results, output_path) 

382 

383 # ========================================================================= 

384 # Pipeline Stages 

385 # ========================================================================= 

386 

387 def _stage_flow_extraction(self, context: dict[str, Any]) -> dict[str, Any]: 

388 """Extract network flows from raw data. 

389 

390 Args: 

391 context: Pipeline context. 

392 

393 Returns: 

394 Updated context with flows. 

395 """ 

396 data = context["raw_data"] 

397 flows = [] 

398 payloads = [] 

399 

400 if isinstance(data, bytes): 

401 # Raw binary - treat as single payload 

402 payloads.append(data) 

403 flows.append( 

404 FlowInfo( 

405 flow_id="flow_0", 

406 src_ip="unknown", 

407 dst_ip="unknown", 

408 src_port=0, 

409 dst_port=0, 

410 protocol="unknown", 

411 packet_count=1, 

412 byte_count=len(data), 

413 start_time=0, 

414 end_time=0, 

415 ) 

416 ) 

417 

418 elif isinstance(data, list | tuple): 418 ↛ 503line 418 didn't jump to line 503 because the condition on line 418 was always true

419 # List of packets 

420 flow_map: dict[str, dict[str, Any]] = {} 

421 raw_bytes_payloads: list[bytes] = [] 

422 

423 for _i, pkt in enumerate(data): 

424 if isinstance(pkt, dict): 424 ↛ 462line 424 didn't jump to line 462 because the condition on line 424 was always true

425 # Packet with metadata 

426 payload_raw = pkt.get("data", pkt.get("payload", b"")) 

427 if isinstance(payload_raw, list | tuple): 427 ↛ 428line 427 didn't jump to line 428 because the condition on line 427 was never true

428 payload = bytes(payload_raw) 

429 else: 

430 payload = payload_raw if isinstance(payload_raw, bytes) else b"" 

431 

432 # Create flow key 

433 src_ip = pkt.get("src_ip", "0.0.0.0") 

434 dst_ip = pkt.get("dst_ip", "0.0.0.0") 

435 src_port = pkt.get("src_port", 0) 

436 dst_port = pkt.get("dst_port", 0) 

437 protocol = pkt.get("protocol", "unknown") 

438 

439 flow_key = f"{src_ip}:{src_port}-{dst_ip}:{dst_port}-{protocol}" 

440 

441 if flow_key not in flow_map: 

442 flow_map[flow_key] = { 

443 "src_ip": src_ip, 

444 "dst_ip": dst_ip, 

445 "src_port": src_port, 

446 "dst_port": dst_port, 

447 "protocol": protocol, 

448 "packets": [], 

449 "payloads": [], 

450 "timestamps": [], 

451 } 

452 

453 flow_map[flow_key]["packets"].append(pkt) 

454 flow_map[flow_key]["payloads"].append(payload) 

455 if "timestamp" in pkt: 

456 flow_map[flow_key]["timestamps"].append(pkt["timestamp"]) 

457 

458 payloads.append(payload) 

459 

460 else: 

461 # Raw bytes - collect for default flow 

462 raw_payload = bytes(pkt) if not isinstance(pkt, bytes) else pkt 

463 payloads.append(raw_payload) 

464 raw_bytes_payloads.append(raw_payload) 

465 

466 # Build flow objects from flow_map 

467 for flow_id, flow_data in flow_map.items(): 

468 timestamps = flow_data.get("timestamps", [0]) 

469 flows.append( 

470 FlowInfo( 

471 flow_id=flow_id, 

472 src_ip=flow_data["src_ip"], 

473 dst_ip=flow_data["dst_ip"], 

474 src_port=flow_data["src_port"], 

475 dst_port=flow_data["dst_port"], 

476 protocol=flow_data["protocol"], 

477 packet_count=len(flow_data["packets"]), 

478 byte_count=sum(len(p) for p in flow_data["payloads"]), 

479 start_time=min(timestamps) if timestamps else 0, 

480 end_time=max(timestamps) if timestamps else 0, 

481 ) 

482 ) 

483 

484 # Create a default flow for raw bytes if we have any 

485 # This ensures flow_count >= 1 when analyzing raw byte sequences 

486 if raw_bytes_payloads and not flows: 486 ↛ 487line 486 didn't jump to line 487 because the condition on line 486 was never true

487 flows.append( 

488 FlowInfo( 

489 flow_id="flow_default", 

490 src_ip="unknown", 

491 dst_ip="unknown", 

492 src_port=0, 

493 dst_port=0, 

494 protocol="unknown", 

495 packet_count=len(raw_bytes_payloads), 

496 byte_count=sum(len(p) for p in raw_bytes_payloads), 

497 start_time=0, 

498 end_time=0, 

499 ) 

500 ) 

501 

502 # Initialize statistics dict if it doesn't exist 

503 if "statistics" not in context: 

504 context["statistics"] = {} 

505 

506 context["statistics"]["flow_extraction"] = { 

507 "flow_count": len(flows), 

508 "payload_count": len(payloads), 

509 "total_bytes": sum(len(p) for p in payloads), 

510 } 

511 

512 return {"flows": flows, "payloads": payloads} 

513 

514 def _stage_payload_analysis(self, context: dict[str, Any]) -> dict[str, Any]: 

515 """Analyze payloads for structure. 

516 

517 Args: 

518 context: Pipeline context. 

519 

520 Returns: 

521 Updated context with payload analysis. 

522 """ 

523 payloads = context.get("payloads", []) 

524 

525 # Filter non-empty payloads 

526 valid_payloads = [p for p in payloads if p and len(p) > 0] 

527 

528 # Basic statistics 

529 if valid_payloads: 

530 lengths = [len(p) for p in valid_payloads] 

531 avg_len = sum(lengths) / len(lengths) 

532 min_len = min(lengths) 

533 max_len = max(lengths) 

534 else: 

535 avg_len = min_len = max_len = 0 

536 

537 # Detect delimiter patterns 

538 delimiter_info = None 

539 if valid_payloads: 

540 try: 

541 from tracekit.analyzers.packet.payload import detect_delimiter 

542 

543 concat = b"".join(valid_payloads) 

544 delimiter_result = detect_delimiter(concat) 

545 if delimiter_result.confidence > 0.5: 

546 delimiter_info = { 

547 "delimiter": delimiter_result.delimiter.hex(), 

548 "confidence": delimiter_result.confidence, 

549 } 

550 except Exception: 

551 pass 

552 

553 # Initialize statistics dict if it doesn't exist 

554 if "statistics" not in context: 

555 context["statistics"] = {} 

556 

557 context["statistics"]["payload_analysis"] = { 

558 "payload_count": len(valid_payloads), 

559 "avg_length": avg_len, 

560 "min_length": min_len, 

561 "max_length": max_len, 

562 "delimiter": delimiter_info, 

563 } 

564 

565 return {"messages": valid_payloads} 

566 

567 def _stage_pattern_discovery(self, context: dict[str, Any]) -> dict[str, Any]: 

568 """Discover patterns in messages. 

569 

570 Args: 

571 context: Pipeline context. 

572 

573 Returns: 

574 Updated context with patterns. 

575 """ 

576 messages = context.get("messages", []) 

577 patterns = [] 

578 clusters = [] 

579 

580 if len(messages) >= 2: 

581 try: 

582 from tracekit.analyzers.packet.payload import cluster_payloads 

583 

584 # Cluster similar messages 

585 threshold = self.config.get("cluster_threshold", 0.8) 

586 clusters = cluster_payloads(messages, threshold=threshold) 

587 

588 # Extract common patterns from clusters 

589 for cluster in clusters: 

590 if len(cluster.payloads) >= 2: 590 ↛ 592line 590 didn't jump to line 592 because the condition on line 590 was never true

591 # Find common prefix 

592 common_prefix = cluster.payloads[0] 

593 for payload in cluster.payloads[1:]: 

594 new_prefix = bytearray() 

595 for i in range(min(len(common_prefix), len(payload))): 

596 if common_prefix[i] == payload[i]: 

597 new_prefix.append(common_prefix[i]) 

598 else: 

599 break 

600 common_prefix = bytes(new_prefix) 

601 

602 if len(common_prefix) >= 2: 

603 patterns.append( 

604 { 

605 "pattern": common_prefix, 

606 "cluster_id": cluster.cluster_id, 

607 "frequency": len(cluster.payloads), 

608 } 

609 ) 

610 

611 except Exception as e: 

612 context["warnings"].append(f"Pattern discovery failed: {e}") 

613 

614 # Initialize statistics dict if it doesn't exist 

615 if "statistics" not in context: 

616 context["statistics"] = {} 

617 

618 context["statistics"]["pattern_discovery"] = { 

619 "cluster_count": len(clusters), 

620 "pattern_count": len(patterns), 

621 } 

622 

623 return {"patterns": patterns, "clusters": clusters} 

624 

625 def _stage_field_inference(self, context: dict[str, Any]) -> dict[str, Any]: 

626 """Infer field structure in messages. 

627 

628 Args: 

629 context: Pipeline context. 

630 

631 Returns: 

632 Updated context with field schemas. 

633 """ 

634 clusters = context.get("clusters", []) 

635 schemas = {} 

636 

637 for cluster in clusters: 

638 if not hasattr(cluster, "payloads") or len(cluster.payloads) < 5: 638 ↛ 641line 638 didn't jump to line 641 because the condition on line 638 was always true

639 continue 

640 

641 try: 

642 from tracekit.analyzers.packet.payload import FieldInferrer 

643 

644 inferrer = FieldInferrer(min_samples=self.config.get("min_samples", 10)) 

645 schema = inferrer.infer_fields(cluster.payloads) 

646 

647 if schema.fields: 

648 cluster_id = getattr(cluster, "cluster_id", 0) 

649 schemas[f"type_{cluster_id}"] = { 

650 "field_count": len(schema.fields), 

651 "message_length": schema.message_length, 

652 "fixed_length": schema.fixed_length, 

653 "confidence": schema.confidence, 

654 "fields": [ 

655 { 

656 "name": f.name, 

657 "offset": f.offset, 

658 "size": f.size, 

659 "type": f.inferred_type, 

660 "is_constant": f.is_constant, 

661 "is_sequence": f.is_sequence, 

662 } 

663 for f in schema.fields 

664 ], 

665 } 

666 

667 except Exception as e: 

668 context["warnings"].append(f"Field inference failed for cluster: {e}") 

669 

670 # Initialize statistics dict if it doesn't exist 

671 if "statistics" not in context: 

672 context["statistics"] = {} 

673 

674 context["statistics"]["field_inference"] = { 

675 "schema_count": len(schemas), 

676 } 

677 

678 return {"schemas": schemas} 

679 

680 def _stage_protocol_detection(self, context: dict[str, Any]) -> dict[str, Any]: 

681 """Detect protocol candidates. 

682 

683 Args: 

684 context: Pipeline context. 

685 

686 Returns: 

687 Updated context with protocol candidates. 

688 """ 

689 messages = context.get("messages", []) 

690 flows = context.get("flows", []) 

691 candidates = [] 

692 

693 # Check well-known port mappings 

694 port_protocols = { 

695 53: "dns", 

696 80: "http", 

697 443: "https", 

698 502: "modbus_tcp", 

699 1883: "mqtt", 

700 5683: "coap", 

701 47808: "bacnet", 

702 } 

703 

704 for flow in flows: 

705 port = flow.dst_port or flow.src_port 

706 if port in port_protocols: 706 ↛ 707line 706 didn't jump to line 707 because the condition on line 706 was never true

707 candidates.append( 

708 ProtocolCandidate( 

709 name=port_protocols[port], 

710 confidence=0.6, 

711 port_hint=True, 

712 ) 

713 ) 

714 

715 # Check magic byte signatures 

716 if messages: 

717 try: 

718 from tracekit.inference.binary import MagicByteDetector 

719 

720 detector = MagicByteDetector() 

721 sample = messages[0] if messages else b"" 

722 

723 if len(sample) >= 2: 

724 result = detector.detect(sample) 

725 if result and result.known_format: 

726 candidates.append( 

727 ProtocolCandidate( 

728 name=result.known_format, 

729 confidence=result.confidence, 

730 header_match=True, 

731 ) 

732 ) 

733 

734 except Exception: 

735 pass 

736 

737 # Check protocol library 

738 try: 

739 from tracekit.inference.protocol_library import get_library 

740 

741 library = get_library() 

742 

743 for protocol in library.list_protocols(): 

744 if protocol.definition: 744 ↛ 743line 744 didn't jump to line 743 because the condition on line 744 was always true

745 # Check if first bytes match protocol header 

746 for msg in messages[:10]: 

747 if len(msg) >= 4: 

748 # Simple header matching 

749 first_field = ( 

750 protocol.definition.fields[0] 

751 if protocol.definition.fields 

752 else None 

753 ) 

754 if first_field and hasattr(first_field, "value"): 754 ↛ 746line 754 didn't jump to line 746 because the condition on line 754 was always true

755 # Has expected value 

756 candidates.append( 

757 ProtocolCandidate( 

758 name=protocol.name, 

759 confidence=0.4, 

760 matched_patterns=["header_value"], 

761 ) 

762 ) 

763 break 

764 

765 except Exception: 

766 pass 

767 

768 # Deduplicate by name, keeping highest confidence 

769 unique_candidates: dict[str, ProtocolCandidate] = {} 

770 for c in candidates: 

771 if ( 771 ↛ 770line 771 didn't jump to line 770 because the condition on line 771 was always true

772 c.name not in unique_candidates 

773 or c.confidence > unique_candidates[c.name].confidence 

774 ): 

775 unique_candidates[c.name] = c 

776 

777 return {"protocol_candidates": list(unique_candidates.values())} 

778 

779 def _stage_state_machine(self, context: dict[str, Any]) -> dict[str, Any]: 

780 """Infer protocol state machine. 

781 

782 Args: 

783 context: Pipeline context. 

784 

785 Returns: 

786 Updated context with state machine. 

787 """ 

788 clusters = context.get("clusters", []) 

789 

790 if len(clusters) < 2: 

791 return {"state_machine": None} 

792 

793 try: 

794 # Build sequences from cluster transitions 

795 messages = context.get("messages", []) 

796 message_to_cluster = {} 

797 

798 for cluster in clusters: 

799 for idx in getattr(cluster, "indices", []): 

800 message_to_cluster[idx] = getattr(cluster, "cluster_id", 0) 

801 

802 # Build observation sequence 

803 sequence = [ 

804 f"type_{message_to_cluster.get(i, 0)}" 

805 for i in range(len(messages)) 

806 if i in message_to_cluster 

807 ] 

808 

809 if len(sequence) >= 3: 809 ↛ 810line 809 didn't jump to line 810 because the condition on line 809 was never true

810 from tracekit.inference.state_machine import StateMachineInferrer 

811 

812 inferrer = StateMachineInferrer() 

813 automaton = inferrer.infer_rpni([sequence]) 

814 

815 return { 

816 "state_machine": { 

817 "states": len(automaton.states) if automaton is not None else 0, 

818 "transitions": len(automaton.transitions) if automaton is not None else 0, 

819 "automaton": automaton, 

820 } 

821 } 

822 

823 except Exception as e: 

824 context["warnings"].append(f"State machine inference failed: {e}") 

825 

826 return {"state_machine": None} 

827 

828 # ========================================================================= 

829 # Helper Methods 

830 # ========================================================================= 

831 

832 def _report_progress(self, stage: str, percent: float) -> None: 

833 """Report progress to callback.""" 

834 if self._progress_callback: 

835 self._progress_callback(stage, percent) 

836 

837 def _load_checkpoint(self, path: str) -> None: 

838 """Load checkpoint data.""" 

839 try: 

840 with open(path) as f: 

841 self._checkpoint_data = json.load(f) 

842 except Exception: 

843 self._checkpoint_data = {} 

844 

845 def _save_checkpoint(self, path: str, stage: str, context: dict[str, Any]) -> None: 

846 """Save checkpoint data.""" 

847 try: 

848 # Extract serializable parts of context 

849 checkpoint = { 

850 stage: { 

851 "flow_count": len(context.get("flows", [])), 

852 "message_count": len(context.get("messages", [])), 

853 "cluster_count": len(context.get("clusters", [])), 

854 } 

855 } 

856 

857 if os.path.exists(path): 

858 with open(path) as f: 

859 existing = json.load(f) 

860 checkpoint.update(existing) 

861 

862 with open(path, "w") as f: 

863 json.dump(checkpoint, f, indent=2) 

864 

865 except Exception: 

866 pass 

867 

868 def _build_message_types(self, context: dict[str, Any]) -> list[MessageTypeInfo]: 

869 """Build message type information from context.""" 

870 clusters = context.get("clusters", []) 

871 message_types = [] 

872 

873 for cluster in clusters: 

874 payloads = getattr(cluster, "payloads", []) 

875 if not payloads: 

876 continue 

877 

878 cluster_id = getattr(cluster, "cluster_id", 0) 

879 avg_len = sum(len(p) for p in payloads) / len(payloads) if payloads else 0 

880 

881 # Get schema if available 

882 schema = context.get("schemas", {}).get(f"type_{cluster_id}", {}) 

883 field_count = schema.get("field_count", 0) 

884 

885 # Use representative as signature 

886 signature = payloads[0][:16] if payloads else b"" 

887 

888 message_types.append( 

889 MessageTypeInfo( 

890 type_id=f"type_{cluster_id}", 

891 name=f"Message Type {cluster_id}", 

892 sample_count=len(payloads), 

893 avg_length=avg_len, 

894 field_count=field_count, 

895 signature=signature, 

896 cluster_id=cluster_id, 

897 ) 

898 ) 

899 

900 return message_types 

901 

902 def _build_statistics( 

903 self, context: dict[str, Any], stage_results: list[StageResult] 

904 ) -> dict[str, Any]: 

905 """Build analysis statistics.""" 

906 stats: dict[str, Any] = context.get("statistics", {}) 

907 

908 # Add stage timing 

909 stats["stage_timing"] = {r.stage_name: r.duration for r in stage_results} 

910 

911 # Add success info 

912 stats["stages_completed"] = sum(1 for r in stage_results if r.success) 

913 stats["stages_failed"] = sum(1 for r in stage_results if not r.success) 

914 

915 return stats 

916 

917 def _generate_json_report(self, results: REAnalysisResult, path: Path) -> None: 

918 """Generate JSON report.""" 

919 report = { 

920 "flow_count": results.flow_count, 

921 "message_count": results.message_count, 

922 "message_types": [ 

923 { 

924 "type_id": mt.type_id, 

925 "name": mt.name, 

926 "sample_count": mt.sample_count, 

927 "avg_length": mt.avg_length, 

928 "field_count": mt.field_count, 

929 "signature": mt.signature.hex(), 

930 } 

931 for mt in results.message_types 

932 ], 

933 "protocol_candidates": [ 

934 { 

935 "name": pc.name, 

936 "confidence": pc.confidence, 

937 "port_hint": pc.port_hint, 

938 "header_match": pc.header_match, 

939 } 

940 for pc in results.protocol_candidates 

941 ], 

942 "field_schemas": results.field_schemas, 

943 "statistics": results.statistics, 

944 "warnings": results.warnings, 

945 "duration_seconds": results.duration_seconds, 

946 "timestamp": results.timestamp, 

947 } 

948 

949 with open(path, "w") as f: 

950 json.dump(report, f, indent=2) 

951 

952 def _generate_markdown_report(self, results: REAnalysisResult, path: Path) -> None: 

953 """Generate Markdown report.""" 

954 lines = [ 

955 "# Reverse Engineering Analysis Report", 

956 "", 

957 f"**Generated:** {results.timestamp}", 

958 f"**Duration:** {results.duration_seconds:.2f} seconds", 

959 "", 

960 "## Summary", 

961 "", 

962 f"- Flows analyzed: {results.flow_count}", 

963 f"- Messages extracted: {results.message_count}", 

964 f"- Message types detected: {len(results.message_types)}", 

965 f"- Protocol candidates: {len(results.protocol_candidates)}", 

966 "", 

967 "## Message Types", 

968 "", 

969 ] 

970 

971 for mt in results.message_types: 

972 lines.extend( 

973 [ 

974 f"### {mt.name}", 

975 f"- Samples: {mt.sample_count}", 

976 f"- Average length: {mt.avg_length:.1f} bytes", 

977 f"- Fields detected: {mt.field_count}", 

978 f"- Signature: `{mt.signature.hex()}`", 

979 "", 

980 ] 

981 ) 

982 

983 if results.protocol_candidates: 983 ↛ 994line 983 didn't jump to line 994 because the condition on line 983 was always true

984 lines.extend( 

985 [ 

986 "## Protocol Candidates", 

987 "", 

988 ] 

989 ) 

990 for pc in results.protocol_candidates: 

991 lines.append(f"- **{pc.name}** (confidence: {pc.confidence:.2%})") 

992 lines.append("") 

993 

994 if results.warnings: 994 ↛ 995line 994 didn't jump to line 995 because the condition on line 994 was never true

995 lines.extend( 

996 [ 

997 "## Warnings", 

998 "", 

999 ] 

1000 ) 

1001 for warning in results.warnings: 

1002 lines.append(f"- {warning}") 

1003 

1004 with open(path, "w") as f: 

1005 f.write("\n".join(lines)) 

1006 

1007 def _generate_html_report(self, results: REAnalysisResult, path: Path) -> None: 

1008 """Generate HTML report.""" 

1009 html = f"""<!DOCTYPE html> 

1010<html> 

1011<head> 

1012 <title>RE Analysis Report</title> 

1013 <style> 

1014 body {{ font-family: Arial, sans-serif; margin: 40px; }} 

1015 h1 {{ color: #333; }} 

1016 h2 {{ color: #666; border-bottom: 1px solid #ddd; }} 

1017 .summary {{ background: #f5f5f5; padding: 20px; border-radius: 5px; }} 

1018 .type-card {{ border: 1px solid #ddd; padding: 15px; margin: 10px 0; border-radius: 5px; }} 

1019 .signature {{ font-family: monospace; background: #eee; padding: 5px; }} 

1020 .warning {{ color: #856404; background: #fff3cd; padding: 10px; border-radius: 5px; }} 

1021 </style> 

1022</head> 

1023<body> 

1024 <h1>Reverse Engineering Analysis Report</h1> 

1025 

1026 <div class="summary"> 

1027 <p><strong>Generated:</strong> {results.timestamp}</p> 

1028 <p><strong>Duration:</strong> {results.duration_seconds:.2f} seconds</p> 

1029 <p><strong>Flows:</strong> {results.flow_count}</p> 

1030 <p><strong>Messages:</strong> {results.message_count}</p> 

1031 <p><strong>Types:</strong> {len(results.message_types)}</p> 

1032 </div> 

1033 

1034 <h2>Message Types</h2> 

1035""" 

1036 for mt in results.message_types: 

1037 html += f""" 

1038 <div class="type-card"> 

1039 <h3>{mt.name}</h3> 

1040 <p><strong>Samples:</strong> {mt.sample_count}</p> 

1041 <p><strong>Avg Length:</strong> {mt.avg_length:.1f} bytes</p> 

1042 <p><strong>Fields:</strong> {mt.field_count}</p> 

1043 <p><strong>Signature:</strong> <span class="signature">{mt.signature.hex()}</span></p> 

1044 </div> 

1045""" 

1046 if results.protocol_candidates: 1046 ↛ 1052line 1046 didn't jump to line 1052 because the condition on line 1046 was always true

1047 html += "<h2>Protocol Candidates</h2><ul>" 

1048 for pc in results.protocol_candidates: 

1049 html += f"<li><strong>{pc.name}</strong> ({pc.confidence:.0%})</li>" 

1050 html += "</ul>" 

1051 

1052 if results.warnings: 1052 ↛ 1053line 1052 didn't jump to line 1053 because the condition on line 1052 was never true

1053 html += "<h2>Warnings</h2>" 

1054 for warning in results.warnings: 

1055 html += f'<div class="warning">{warning}</div>' 

1056 

1057 html += """ 

1058</body> 

1059</html> 

1060""" 

1061 with open(path, "w") as f: 

1062 f.write(html) 

1063 

1064 

1065def analyze( 

1066 data: bytes | Sequence[dict[str, Any]] | Sequence[bytes], 

1067 stages: list[str] | None = None, 

1068 config: dict[str, Any] | None = None, 

1069) -> REAnalysisResult: 

1070 """Run reverse engineering analysis on data. 

1071 

1072 Implements RE-INT-001: Quick analysis function. 

1073 

1074 Args: 

1075 data: Data to analyze. 

1076 stages: Pipeline stages to run. 

1077 config: Configuration options. 

1078 

1079 Returns: 

1080 REAnalysisResult with analysis. 

1081 """ 

1082 pipeline = REPipeline(stages=stages, config=config) 

1083 return pipeline.analyze(data) 

1084 

1085 

1086__all__ = [ 

1087 "FlowInfo", 

1088 "MessageTypeInfo", 

1089 "ProtocolCandidate", 

1090 "REAnalysisResult", 

1091 "REPipeline", 

1092 "StageResult", 

1093 "analyze", 

1094]