Coverage for src / tracekit / inference / state_machine.py: 96%

228 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-11 23:04 +0000

1"""State machine inference using RPNI algorithm. 

2 

3Requirements addressed: PSI-002 

4 

5This module infers protocol state machines from observed message sequences using 

6passive learning algorithms (no system interaction required). 

7 

8Key capabilities: 

9- RPNI algorithm for passive DFA learning 

10- State merging to minimize automaton 

11- Export to DOT format for visualization 

12- Export to NetworkX graph for analysis 

13""" 

14 

15from copy import deepcopy 

16from dataclasses import dataclass 

17from typing import Any 

18 

19 

20@dataclass 

21class State: 

22 """A state in the inferred automaton. 

23 

24 : State representation. 

25 

26 Attributes: 

27 id: Unique state identifier 

28 name: Human-readable state name 

29 is_initial: Whether this is the initial state 

30 is_accepting: Whether this is an accepting state 

31 """ 

32 

33 id: int 

34 name: str 

35 is_initial: bool = False 

36 is_accepting: bool = False 

37 

38 

39@dataclass 

40class Transition: 

41 """A transition in the automaton. 

42 

43 : Transition representation. 

44 

45 Attributes: 

46 source: Source state ID 

47 target: Target state ID 

48 symbol: Transition label/symbol 

49 count: Number of times observed 

50 """ 

51 

52 source: int # State ID 

53 target: int # State ID 

54 symbol: str # Transition label 

55 count: int = 1 # Number of observations 

56 

57 

58@dataclass 

59class FiniteAutomaton: 

60 """An inferred finite automaton. 

61 

62 : Complete automaton representation with export capabilities. 

63 

64 Attributes: 

65 states: List of all states 

66 transitions: List of all transitions 

67 alphabet: Set of all symbols 

68 initial_state: Initial state ID 

69 accepting_states: Set of accepting state IDs 

70 """ 

71 

72 states: list[State] 

73 transitions: list[Transition] 

74 alphabet: set[str] 

75 initial_state: int 

76 accepting_states: set[int] 

77 

78 def to_dot(self) -> str: 

79 """Export to DOT format for Graphviz. 

80 

81 : DOT format export for visualization. 

82 

83 Returns: 

84 DOT format string 

85 """ 

86 lines = ["digraph finite_automaton {", " rankdir=LR;", " node [shape=circle];"] 

87 

88 # Mark accepting states 

89 if self.accepting_states: 

90 accepting_names = [s.name for s in self.states if s.id in self.accepting_states] 

91 lines.append(f" node [shape=doublecircle]; {' '.join(accepting_names)};") 

92 lines.append(" node [shape=circle];") 

93 

94 # Add invisible start node for initial state 

95 initial_state = next(s for s in self.states if s.id == self.initial_state) 

96 lines.append(' __start__ [shape=none, label=""];') 

97 lines.append(f" __start__ -> {initial_state.name};") 

98 

99 # Add transitions 

100 for trans in self.transitions: 

101 src_state = next(s for s in self.states if s.id == trans.source) 

102 tgt_state = next(s for s in self.states if s.id == trans.target) 

103 label = trans.symbol 

104 if trans.count > 1: 

105 label = f"{trans.symbol} ({trans.count})" 

106 lines.append(f' {src_state.name} -> {tgt_state.name} [label="{label}"];') 

107 

108 lines.append("}") 

109 return "\n".join(lines) 

110 

111 def to_networkx(self) -> Any: 

112 """Export to NetworkX graph. 

113 

114 : NetworkX export for programmatic analysis. 

115 

116 Returns: 

117 NetworkX DiGraph 

118 

119 Raises: 

120 ImportError: If NetworkX is not installed. 

121 """ 

122 try: 

123 import networkx as nx # type: ignore[import-untyped] 

124 except ImportError as err: 

125 raise ImportError("NetworkX is required for graph export") from err 

126 

127 G = nx.DiGraph() 

128 

129 # Add nodes 

130 for state in self.states: 

131 G.add_node( 

132 state.id, 

133 name=state.name, 

134 is_initial=state.is_initial, 

135 is_accepting=state.is_accepting, 

136 ) 

137 

138 # Add edges 

139 for trans in self.transitions: 

140 G.add_edge(trans.source, trans.target, symbol=trans.symbol, count=trans.count) 

141 

142 return G 

143 

144 def accepts(self, sequence: list[str]) -> bool: 

145 """Check if automaton accepts sequence. 

146 

147 : Sequence acceptance checking. 

148 

149 Args: 

150 sequence: List of symbols 

151 

152 Returns: 

153 True if sequence is accepted 

154 """ 

155 current_state = self.initial_state 

156 

157 for symbol in sequence: 

158 # Find transition with this symbol 

159 trans = None 

160 for t in self.transitions: 

161 if t.source == current_state and t.symbol == symbol: 

162 trans = t 

163 break 

164 

165 if trans is None: 

166 return False # No valid transition 

167 

168 current_state = trans.target 

169 

170 # Check if we ended in accepting state 

171 return current_state in self.accepting_states 

172 

173 def get_successors(self, state_id: int) -> dict[str, int]: 

174 """Get successor states from given state. 

175 

176 : State successor lookup. 

177 

178 Args: 

179 state_id: State ID to query 

180 

181 Returns: 

182 Dictionary mapping symbols to target state IDs 

183 """ 

184 successors = {} 

185 for trans in self.transitions: 

186 if trans.source == state_id: 

187 successors[trans.symbol] = trans.target 

188 return successors 

189 

190 

191class StateMachineInferrer: 

192 """Infer state machines using passive learning. 

193 

194 : RPNI algorithm for DFA inference. 

195 

196 The RPNI (Regular Positive and Negative Inference) algorithm: 

197 1. Build Prefix Tree Acceptor from positive samples 

198 2. Iteratively merge compatible state pairs 

199 3. Validate against negative samples 

200 4. Converge to minimal consistent DFA 

201 """ 

202 

203 def __init__(self) -> None: 

204 """Initialize inferrer.""" 

205 self._next_state_id = 0 

206 

207 def infer_rpni( 

208 self, positive_traces: list[list[str]], negative_traces: list[list[str]] | None = None 

209 ) -> FiniteAutomaton: 

210 """Infer DFA using RPNI (Regular Positive and Negative Inference). 

211 

212 : Complete RPNI algorithm. 

213 

214 Args: 

215 positive_traces: List of accepted sequences (list of symbols) 

216 negative_traces: List of rejected sequences (optional) 

217 

218 Returns: 

219 Inferred FiniteAutomaton 

220 

221 Raises: 

222 ValueError: If no positive traces provided. 

223 """ 

224 if not positive_traces: 

225 raise ValueError("Need at least one positive trace") 

226 

227 # Build alphabet from all traces 

228 alphabet: set[str] = set() 

229 neg_traces = negative_traces if negative_traces is not None else [] 

230 for trace in positive_traces + neg_traces: 

231 alphabet.update(trace) 

232 

233 # Build Prefix Tree Acceptor from positive traces 

234 pta = self._build_pta(positive_traces) 

235 

236 # RPNI merging process 

237 automaton = pta 

238 states = sorted([s.id for s in automaton.states]) 

239 

240 # Try to merge states in order 

241 i = 1 # Start from second state (never merge initial state) 

242 while i < len(states): 

243 merged = False 

244 

245 # Try to merge states[i] with any earlier state 

246 for j in range(i): 

247 if self._is_compatible(automaton, states[j], states[i], neg_traces): 

248 # Merge states[i] into states[j] 

249 automaton = self._merge_states(automaton, states[j], states[i]) 

250 # Update state list 

251 states = sorted([s.id for s in automaton.states]) 

252 merged = True 

253 break 

254 

255 if not merged: 

256 i += 1 

257 

258 return automaton 

259 

260 def _build_pta(self, traces: list[list[str]]) -> FiniteAutomaton: 

261 """Build Prefix Tree Acceptor from traces. 

262 

263 : PTA construction. 

264 

265 Args: 

266 traces: List of sequences 

267 

268 Returns: 

269 Prefix Tree Acceptor as FiniteAutomaton 

270 """ 

271 # Reset state counter 

272 self._next_state_id = 0 

273 

274 # Create initial state 

275 initial_state = State( 

276 id=self._get_next_state_id(), name="q0", is_initial=True, is_accepting=False 

277 ) 

278 

279 states: list[State] = [initial_state] 

280 transitions: list[Transition] = [] 

281 alphabet: set[str] = set() 

282 

283 # Build tree from traces 

284 for trace in traces: 

285 current_state_id = initial_state.id 

286 

287 # Walk/build tree for this trace 

288 for symbol in trace: 

289 alphabet.add(symbol) 

290 

291 # Check if transition exists 

292 next_state_id = None 

293 for trans in transitions: 

294 if trans.source == current_state_id and trans.symbol == symbol: 

295 next_state_id = trans.target 

296 break 

297 

298 if next_state_id is None: 

299 # Create new state and transition 

300 new_state_id = self._get_next_state_id() 

301 new_state = State( 

302 id=new_state_id, 

303 name=f"q{new_state_id}", 

304 is_initial=False, 

305 is_accepting=False, 

306 ) 

307 states.append(new_state) 

308 

309 new_trans = Transition( 

310 source=current_state_id, target=new_state_id, symbol=symbol 

311 ) 

312 transitions.append(new_trans) 

313 

314 next_state_id = new_state_id 

315 

316 current_state_id = next_state_id 

317 

318 # Mark final state as accepting 

319 for state in states: 

320 if state.id == current_state_id: 

321 state.is_accepting = True 

322 

323 accepting_states = {s.id for s in states if s.is_accepting} 

324 

325 return FiniteAutomaton( 

326 states=states, 

327 transitions=transitions, 

328 alphabet=alphabet, 

329 initial_state=initial_state.id, 

330 accepting_states=accepting_states, 

331 ) 

332 

333 def _merge_states( 

334 self, automaton: FiniteAutomaton, state_a: int, state_b: int 

335 ) -> FiniteAutomaton: 

336 """Merge two states in automaton. 

337 

338 : State merging operation. 

339 

340 Merges state_b into state_a. 

341 

342 Args: 

343 automaton: Current automaton 

344 state_a: Target state ID (survives) 

345 state_b: Source state ID (removed) 

346 

347 Returns: 

348 New automaton with merged states 

349 """ 

350 # Deep copy to avoid modifying original 

351 new_automaton = deepcopy(automaton) 

352 

353 # Remove state_b 

354 new_automaton.states = [s for s in new_automaton.states if s.id != state_b] 

355 

356 # Update transitions: redirect all transitions to/from state_b to state_a 

357 for trans in new_automaton.transitions: 

358 if trans.source == state_b: 

359 trans.source = state_a 

360 if trans.target == state_b: 

361 trans.target = state_a 

362 

363 # Merge accepting status 

364 if state_b in new_automaton.accepting_states: 

365 new_automaton.accepting_states.add(state_a) 

366 new_automaton.accepting_states.discard(state_b) 

367 

368 # Merge duplicate transitions (same source, target, symbol) 

369 unique_transitions = [] 

370 seen = set() 

371 

372 for trans in new_automaton.transitions: 

373 key = (trans.source, trans.target, trans.symbol) 

374 if key not in seen: 

375 seen.add(key) 

376 unique_transitions.append(trans) 

377 else: 

378 # Increment count on existing transition 

379 for ut in unique_transitions: 379 ↛ 372line 379 didn't jump to line 372 because the loop on line 379 didn't complete

380 if (ut.source, ut.target, ut.symbol) == key: 

381 ut.count += trans.count 

382 break 

383 

384 new_automaton.transitions = unique_transitions 

385 

386 return new_automaton 

387 

388 def _is_compatible( 

389 self, 

390 automaton: FiniteAutomaton, 

391 state_a: int, 

392 state_b: int, 

393 negative_traces: list[list[str]], 

394 ) -> bool: 

395 """Check if two states can be merged without accepting negatives. 

396 

397 : Compatibility checking for state merging. 

398 

399 Args: 

400 automaton: Current automaton 

401 state_a: First state ID 

402 state_b: Second state ID 

403 negative_traces: Negative example traces 

404 

405 Returns: 

406 True if states are compatible 

407 """ 

408 # Get accepting status 

409 _a_accepting = state_a in automaton.accepting_states 

410 _b_accepting = state_b in automaton.accepting_states 

411 

412 # If one is accepting and other is not, they might still be compatible 

413 # (we'll merge accepting status), but check negative traces 

414 

415 # Try merging and test 

416 test_automaton = self._merge_states(automaton, state_a, state_b) 

417 

418 # Check that no negative traces are accepted 

419 for neg_trace in negative_traces: 

420 if test_automaton.accepts(neg_trace): 

421 return False 

422 

423 # Recursively check successor compatibility 

424 _succ_a = test_automaton.get_successors(state_a) 

425 # state_b has been merged, so its successors are now in state_a 

426 

427 return True 

428 

429 def _get_next_state_id(self) -> int: 

430 """Get next available state ID. 

431 

432 Returns: 

433 Next state ID 

434 """ 

435 state_id = self._next_state_id 

436 self._next_state_id += 1 

437 return state_id 

438 

439 

440def minimize_dfa(automaton: FiniteAutomaton) -> FiniteAutomaton: 

441 """Minimize DFA using partition refinement. 

442 

443 : DFA minimization using Hopcroft's algorithm. 

444 

445 Args: 

446 automaton: DFA to minimize 

447 

448 Returns: 

449 Minimized FiniteAutomaton 

450 """ 

451 # Use partition refinement (simplified version) 

452 # Start with two partitions: accepting and non-accepting 

453 accepting = automaton.accepting_states 

454 non_accepting = {s.id for s in automaton.states if s.id not in accepting} 

455 

456 partitions = [] 

457 if accepting: 

458 partitions.append(accepting) 

459 if non_accepting: 

460 partitions.append(non_accepting) 

461 

462 # Refine partitions 

463 changed = True 

464 while changed: 

465 changed = False 

466 new_partitions = [] 

467 

468 for partition in partitions: 

469 # Try to split this partition 

470 if len(partition) <= 1: 

471 new_partitions.append(partition) 

472 continue 

473 

474 # Group states by transition signatures 

475 groups: dict[tuple[tuple[str, int | None], ...], set[int]] = {} 

476 for state_id in partition: 

477 successors = automaton.get_successors(state_id) 

478 

479 # Create signature based on which partition each successor is in 

480 signature_list: list[tuple[str, int | None]] = [] 

481 for symbol in sorted(automaton.alphabet): 

482 if symbol in successors: 

483 target = successors[symbol] 

484 # Find which partition target is in 

485 target_partition: int | None = None 

486 for i, p in enumerate(partitions): 486 ↛ 490line 486 didn't jump to line 490 because the loop on line 486 didn't complete

487 if target in p: 

488 target_partition = i 

489 break 

490 signature_list.append((symbol, target_partition)) 

491 else: 

492 signature_list.append((symbol, None)) 

493 

494 signature = tuple(signature_list) 

495 if signature not in groups: 

496 groups[signature] = set() 

497 groups[signature].add(state_id) 

498 

499 # If we split, mark as changed 

500 if len(groups) > 1: 

501 changed = True 

502 

503 new_partitions.extend(groups.values()) 

504 

505 partitions = new_partitions 

506 

507 # Build minimized automaton 

508 # Map old state IDs to partition IDs 

509 state_to_partition = {} 

510 for i, partition in enumerate(partitions): 

511 for state_id in partition: 

512 state_to_partition[state_id] = i 

513 

514 # Create new states 

515 new_states = [] 

516 for i, partition in enumerate(partitions): 

517 # Pick representative state 

518 rep_id = min(partition) 

519 _rep_state = next(s for s in automaton.states if s.id == rep_id) 

520 

521 is_accepting = any(sid in automaton.accepting_states for sid in partition) 

522 is_initial = automaton.initial_state in partition 

523 

524 new_state = State(id=i, name=f"q{i}", is_initial=is_initial, is_accepting=is_accepting) 

525 new_states.append(new_state) 

526 

527 # Create new transitions 

528 new_transitions = [] 

529 seen_transitions = set() 

530 

531 for trans in automaton.transitions: 

532 src_partition = state_to_partition[trans.source] 

533 tgt_partition = state_to_partition[trans.target] 

534 

535 key = (src_partition, tgt_partition, trans.symbol) 

536 if key not in seen_transitions: 536 ↛ 531line 536 didn't jump to line 531 because the condition on line 536 was always true

537 seen_transitions.add(key) 

538 new_transitions.append( 

539 Transition( 

540 source=src_partition, 

541 target=tgt_partition, 

542 symbol=trans.symbol, 

543 count=trans.count, 

544 ) 

545 ) 

546 

547 # Find new initial state 

548 new_initial = state_to_partition[automaton.initial_state] 

549 new_accepting = {s.id for s in new_states if s.is_accepting} 

550 

551 return FiniteAutomaton( 

552 states=new_states, 

553 transitions=new_transitions, 

554 alphabet=automaton.alphabet, 

555 initial_state=new_initial, 

556 accepting_states=new_accepting, 

557 ) 

558 

559 

560def to_dot(automaton: FiniteAutomaton) -> str: 

561 """Export automaton to DOT format. 

562 

563 : Convenience function for DOT export. 

564 

565 Args: 

566 automaton: Automaton to export 

567 

568 Returns: 

569 DOT format string 

570 """ 

571 return automaton.to_dot() 

572 

573 

574def to_networkx(automaton: FiniteAutomaton) -> Any: 

575 """Export automaton to NetworkX graph. 

576 

577 : Convenience function for NetworkX export. 

578 

579 Args: 

580 automaton: Automaton to export 

581 

582 Returns: 

583 NetworkX DiGraph 

584 """ 

585 return automaton.to_networkx() 

586 

587 

588def infer_rpni( 

589 positive_traces: list[list[str]], negative_traces: list[list[str]] | None = None 

590) -> FiniteAutomaton: 

591 """Convenience function for RPNI inference. 

592 

593 : Top-level API for state machine inference. 

594 

595 Args: 

596 positive_traces: List of accepted sequences 

597 negative_traces: List of rejected sequences (optional) 

598 

599 Returns: 

600 Inferred FiniteAutomaton 

601 """ 

602 inferrer = StateMachineInferrer() 

603 return inferrer.infer_rpni(positive_traces, negative_traces)