Coverage for src / tracekit / testing / synthetic.py: 99%

244 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-11 23:04 +0000

1"""Synthetic test data generation with known ground truth. 

2 

3Provides utilities for generating synthetic test data with known properties 

4for validation and testing purposes. 

5""" 

6 

7import struct 

8from dataclasses import dataclass, field 

9from pathlib import Path 

10from typing import Any, Literal, cast 

11 

12import numpy as np 

13from numpy.typing import NDArray 

14 

15 

16@dataclass 

17class SyntheticPacketConfig: 

18 """Configuration for synthetic packet generation.""" 

19 

20 packet_size: int = 1024 

21 header_size: int = 16 

22 sync_pattern: bytes = b"\xaa\x55" 

23 include_sequence: bool = True 

24 include_timestamp: bool = True 

25 include_checksum: bool = True 

26 checksum_algorithm: str = "crc16" 

27 noise_level: float = 0.0 # 0-1, fraction of corrupted packets 

28 

29 

30@dataclass 

31class SyntheticSignalConfig: 

32 """Configuration for synthetic digital signal.""" 

33 

34 pattern_type: Literal["square", "uart", "spi", "i2c", "random"] = "square" 

35 sample_rate: float = 100e6 

36 duration_samples: int = 10000 

37 frequency: float = 1e6 # For clock/square wave 

38 noise_snr_db: float = 40 # Signal-to-noise ratio 

39 

40 

41@dataclass 

42class SyntheticMessageConfig: 

43 """Configuration for synthetic protocol messages.""" 

44 

45 message_size: int = 64 

46 num_fields: int = 5 

47 include_header: bool = True 

48 include_length: bool = True 

49 include_checksum: bool = True 

50 variation: float = 0.1 # Fraction of variable bytes 

51 

52 

53@dataclass 

54class GroundTruth: 

55 """Ground truth data for validation.""" 

56 

57 field_boundaries: list[int] = field(default_factory=list) 

58 field_types: list[str] = field(default_factory=list) 

59 sequence_numbers: list[int] = field(default_factory=list) 

60 pattern_period: int | None = None 

61 cluster_labels: list[int] = field(default_factory=list) 

62 checksum_offsets: list[int] = field(default_factory=list) 

63 decoded_bytes: list[int] = field(default_factory=list) 

64 edge_positions: list[int] = field(default_factory=list) 

65 frequency_hz: float | None = None 

66 

67 

68class SyntheticDataGenerator: 

69 """Generate synthetic test data with known ground truth.""" 

70 

71 def __init__(self, seed: int = 42): 

72 """Initialize with random seed for reproducibility. 

73 

74 Args: 

75 seed: Random seed for reproducible generation. 

76 """ 

77 self.rng = np.random.default_rng(seed) 

78 

79 def generate_packets( 

80 self, config: SyntheticPacketConfig, count: int = 100 

81 ) -> tuple[bytes, GroundTruth]: 

82 """Generate synthetic binary packets. 

83 

84 Args: 

85 config: Packet generation configuration. 

86 count: Number of packets to generate. 

87 

88 Returns: 

89 Tuple of (binary_data, ground_truth). 

90 """ 

91 packets = bytearray() 

92 ground_truth = GroundTruth() 

93 

94 for i in range(count): 

95 packet = bytearray() 

96 

97 # Sync pattern 

98 packet.extend(config.sync_pattern) 

99 

100 # Sequence number (2 bytes) 

101 if config.include_sequence: 

102 packet.extend(struct.pack("<H", i)) 

103 ground_truth.sequence_numbers.append(i) 

104 

105 # Timestamp (4 bytes) 

106 if config.include_timestamp: 

107 timestamp = i * 1000 # Incrementing by 1000 microseconds 

108 packet.extend(struct.pack("<I", timestamp)) 

109 

110 # Padding to header size 

111 while len(packet) < config.header_size: 

112 packet.append(0x00) 

113 

114 # Sample data 

115 sample_data_size = config.packet_size - config.header_size 

116 if config.include_checksum: 

117 sample_data_size -= 2 

118 

119 # Counter pattern for samples (easier to validate) 

120 for j in range(sample_data_size // 2): 

121 packet.extend(struct.pack("<H", j)) 

122 

123 # Checksum (CRC-16) 

124 if config.include_checksum: 

125 checksum = self._calculate_crc16(packet) 

126 checksum_offset = len(packets) + len(packet) 

127 ground_truth.checksum_offsets.append(checksum_offset) 

128 packet.extend(struct.pack("<H", checksum)) 

129 

130 packets.extend(packet) 

131 

132 # Apply noise (corrupt random packets) 

133 if config.noise_level > 0: 

134 packets = bytearray( 

135 self.corrupt_packets(bytes(packets), config.packet_size, config.noise_level) 

136 ) 

137 

138 return bytes(packets), ground_truth 

139 

140 def generate_digital_signal( 

141 self, config: SyntheticSignalConfig 

142 ) -> tuple[NDArray[np.float64], GroundTruth]: 

143 """Generate synthetic digital signal. 

144 

145 Args: 

146 config: Signal generation configuration. 

147 

148 Returns: 

149 Tuple of (signal_array, ground_truth). 

150 """ 

151 ground_truth = GroundTruth() 

152 signal: NDArray[np.float64] 

153 

154 if config.pattern_type == "square": 

155 # Generate square wave 

156 period_samples = int(config.sample_rate / config.frequency) 

157 ground_truth.pattern_period = period_samples 

158 ground_truth.frequency_hz = config.frequency 

159 

160 t = np.arange(config.duration_samples) 

161 signal = (np.sin(2 * np.pi * config.frequency * t / config.sample_rate) > 0).astype( 

162 np.float64 

163 ) 

164 

165 # Track edge positions 

166 edges = np.where(np.diff(signal) != 0)[0] + 1 

167 ground_truth.edge_positions = edges.tolist() 

168 

169 elif config.pattern_type == "uart": 

170 # Generate UART signal (8N1) 

171 signal, uart_truth = self._generate_uart_signal(config) 

172 ground_truth.decoded_bytes = uart_truth["bytes"] 

173 ground_truth.edge_positions = uart_truth["edges"] 

174 

175 elif config.pattern_type == "random": 

176 # Random digital signal 

177 signal = self.rng.choice([0.0, 1.0], size=config.duration_samples) 

178 

179 else: 

180 # Default to simple pattern 

181 pattern = np.array([1, 1, 0, 1, 0, 0, 1, 0], dtype=np.float64) 

182 signal = np.tile(pattern, config.duration_samples // len(pattern) + 1)[ 

183 : config.duration_samples 

184 ] 

185 ground_truth.pattern_period = len(pattern) 

186 

187 # Scale to 3.3V logic levels 

188 signal = signal * 3.3 

189 

190 # Add noise 

191 if config.noise_snr_db < np.inf: 

192 noisy_signal = self.add_noise(signal, config.noise_snr_db) 

193 assert isinstance(noisy_signal, np.ndarray), ( 

194 "add_noise should return ndarray for ndarray input" 

195 ) 

196 signal = noisy_signal 

197 

198 return signal, ground_truth 

199 

200 def generate_protocol_messages( 

201 self, config: SyntheticMessageConfig, count: int = 100 

202 ) -> tuple[list[bytes], GroundTruth]: 

203 """Generate synthetic protocol messages. 

204 

205 Args: 

206 config: Message generation configuration. 

207 count: Number of messages to generate. 

208 

209 Returns: 

210 Tuple of (message_list, ground_truth). 

211 """ 

212 messages = [] 

213 ground_truth = GroundTruth() 

214 

215 # Define field structure 

216 field_boundaries = [0] 

217 field_types = [] 

218 

219 current_offset = 0 

220 

221 if config.include_header: 

222 # 2-byte sync pattern 

223 field_boundaries.append(current_offset + 2) 

224 field_types.append("constant") 

225 current_offset += 2 

226 

227 if config.include_length: 

228 # 2-byte length field 

229 field_boundaries.append(current_offset + 2) 

230 field_types.append("length") 

231 current_offset += 2 

232 

233 # Sequence number (2 bytes) 

234 field_boundaries.append(current_offset + 2) 

235 field_types.append("sequence") 

236 current_offset += 2 

237 

238 # Timestamp (4 bytes) 

239 field_boundaries.append(current_offset + 4) 

240 field_types.append("timestamp") 

241 current_offset += 4 

242 

243 # Payload (variable size) 

244 payload_size = config.message_size - current_offset 

245 if config.include_checksum: 

246 payload_size -= 2 

247 

248 field_boundaries.append(current_offset + payload_size) 

249 field_types.append("data") 

250 current_offset += payload_size 

251 

252 if config.include_checksum: 

253 field_boundaries.append(current_offset + 2) 

254 field_types.append("checksum") 

255 

256 ground_truth.field_boundaries = field_boundaries 

257 ground_truth.field_types = field_types 

258 

259 # Generate messages 

260 for i in range(count): 

261 message = bytearray() 

262 

263 if config.include_header: 

264 message.extend(b"\xaa\x55") 

265 

266 if config.include_length: 

267 message.extend(struct.pack("<H", config.message_size)) 

268 

269 # Sequence number 

270 message.extend(struct.pack("<H", i)) 

271 ground_truth.sequence_numbers.append(i) 

272 

273 # Timestamp 

274 timestamp = i * 100 # Incrementing 

275 message.extend(struct.pack("<I", timestamp)) 

276 

277 # Payload (partially random, partially constant based on variation) 

278 for _ in range(payload_size): 

279 if self.rng.random() < config.variation: 

280 message.append(self.rng.integers(0, 256)) 

281 else: 

282 message.append(0x42) # Constant byte 

283 

284 # Checksum 

285 if config.include_checksum: 

286 checksum = self._calculate_crc16(message) 

287 message.extend(struct.pack("<H", checksum)) 

288 

289 messages.append(bytes(message)) 

290 

291 return messages, ground_truth 

292 

293 def add_noise( 

294 self, data: bytes | NDArray[np.float64], snr_db: float = 20 

295 ) -> bytes | NDArray[np.float64]: 

296 """Add noise to data. 

297 

298 Args: 

299 data: Input data (bytes or numpy array). 

300 snr_db: Signal-to-noise ratio in dB. 

301 

302 Returns: 

303 Noisy data (same type as input). 

304 """ 

305 if isinstance(data, bytes): 

306 # For bytes, add random bit flips 

307 data_array = np.frombuffer(data, dtype=np.uint8) 

308 noise_rate = 10 ** (-snr_db / 10) 

309 mask = self.rng.random(len(data_array)) < noise_rate 

310 noisy = data_array.copy() 

311 noisy[mask] ^= self.rng.integers(1, 256, size=int(np.sum(mask)), dtype=np.uint8) 

312 return noisy.tobytes() 

313 else: 

314 # For arrays, add Gaussian noise 

315 signal_power = np.mean(data**2) 

316 noise_power = signal_power / (10 ** (snr_db / 10)) 

317 noise = self.rng.normal(0, np.sqrt(noise_power), len(data)) 

318 return data + noise 

319 

320 def corrupt_packets( 

321 self, packets: bytes, packet_size: int, corruption_rate: float = 0.01 

322 ) -> bytes: 

323 """Corrupt random packets for testing error handling. 

324 

325 Args: 

326 packets: Binary packet data. 

327 packet_size: Size of each packet in bytes. 

328 corruption_rate: Fraction of packets to corrupt (0-1). 

329 

330 Returns: 

331 Corrupted packet data. 

332 """ 

333 packets_array = bytearray(packets) 

334 num_packets = len(packets) // packet_size 

335 

336 for i in range(num_packets): 

337 if self.rng.random() < corruption_rate: 

338 # Corrupt sync marker 

339 offset = i * packet_size 

340 packets_array[offset] ^= 0xFF 

341 

342 return bytes(packets_array) 

343 

344 def _calculate_crc16(self, data: bytes | bytearray) -> int: 

345 """Calculate CRC-16 checksum. 

346 

347 Args: 

348 data: Input data. 

349 

350 Returns: 

351 CRC-16 checksum value. 

352 """ 

353 crc = 0xFFFF 

354 for byte in data: 

355 crc ^= byte 

356 for _ in range(8): 

357 if crc & 0x0001: 

358 crc = (crc >> 1) ^ 0xA001 

359 else: 

360 crc >>= 1 

361 return crc & 0xFFFF 

362 

363 def _generate_uart_signal( 

364 self, config: SyntheticSignalConfig 

365 ) -> tuple[NDArray[np.float64], dict[str, Any]]: 

366 """Generate UART signal encoding a test message. 

367 

368 Args: 

369 config: Signal configuration. 

370 

371 Returns: 

372 Tuple of (signal, metadata_dict). 

373 """ 

374 # UART parameters 

375 baud_rate = 9600 

376 samples_per_bit = int(config.sample_rate / baud_rate) 

377 

378 # Test message 

379 message = b"Hello, World!" 

380 

381 # Encode with start/stop bits 

382 bits = [] 

383 edges = [] 

384 

385 for byte_val in message: 

386 # Start bit (0) 

387 bits.extend([0] * samples_per_bit) 

388 edges.append(len(bits)) 

389 

390 # Data bits (LSB first) 

391 for i in range(8): 

392 bit = (byte_val >> i) & 1 

393 bits.extend([bit] * samples_per_bit) 

394 if i > 0 and bits[-1] != bits[-samples_per_bit - 1]: 

395 edges.append(len(bits) - samples_per_bit) 

396 

397 # Stop bit (1) 

398 bits.extend([1] * samples_per_bit) 

399 if bits[-1] != bits[-samples_per_bit - 1]: 399 ↛ 385line 399 didn't jump to line 385 because the condition on line 399 was always true

400 edges.append(len(bits) - samples_per_bit) 

401 

402 # Pad to duration 

403 signal = np.array(bits[: config.duration_samples], dtype=np.float64) 

404 if len(signal) < config.duration_samples: 404 ↛ 405line 404 didn't jump to line 405 because the condition on line 404 was never true

405 padding = np.ones(config.duration_samples - len(signal), dtype=np.float64) 

406 signal = np.concatenate([signal, padding]) 

407 

408 metadata = {"bytes": list(message), "edges": edges[: len(signal)]} 

409 

410 return signal, metadata 

411 

412 

413# ============================================================================= 

414# Convenience functions 

415# ============================================================================= 

416 

417 

418def generate_packets(count: int = 100, **kwargs: Any) -> tuple[bytes, GroundTruth]: 

419 """Generate synthetic packets with defaults. 

420 

421 Args: 

422 count: Number of packets to generate. 

423 **kwargs: Additional configuration parameters. 

424 

425 Returns: 

426 Tuple of (binary_data, ground_truth). 

427 """ 

428 config = SyntheticPacketConfig(**kwargs) 

429 generator = SyntheticDataGenerator() 

430 return generator.generate_packets(config, count) 

431 

432 

433def generate_digital_signal( 

434 pattern: str = "square", **kwargs: Any 

435) -> tuple[NDArray[np.float64], GroundTruth]: 

436 """Generate synthetic signal with defaults. 

437 

438 Args: 

439 pattern: Pattern type ('square', 'uart', 'random', etc.). 

440 **kwargs: Additional configuration parameters. 

441 

442 Returns: 

443 Tuple of (signal_array, ground_truth). 

444 """ 

445 # Determine pattern type 

446 valid_patterns = ["square", "uart", "spi", "i2c", "random"] 

447 pattern_type = pattern if pattern in valid_patterns else "square" 

448 

449 # Filter out pattern_type from kwargs to avoid duplicate argument error 

450 filtered_kwargs = {k: v for k, v in kwargs.items() if k != "pattern_type"} 

451 

452 config = SyntheticSignalConfig( 

453 pattern_type=cast("Literal['square', 'uart', 'spi', 'i2c', 'random']", pattern_type), 

454 **filtered_kwargs, 

455 ) 

456 generator = SyntheticDataGenerator() 

457 return generator.generate_digital_signal(config) 

458 

459 

460def generate_protocol_messages(count: int = 100, **kwargs: Any) -> tuple[list[bytes], GroundTruth]: 

461 """Generate synthetic messages with defaults. 

462 

463 Args: 

464 count: Number of messages to generate. 

465 **kwargs: Additional configuration parameters. 

466 

467 Returns: 

468 Tuple of (message_list, ground_truth). 

469 """ 

470 config = SyntheticMessageConfig(**kwargs) 

471 generator = SyntheticDataGenerator() 

472 return generator.generate_protocol_messages(config, count) 

473 

474 

475def generate_test_dataset( 

476 output_dir: str, 

477 num_packets: int = 1000, 

478 num_signals: int = 10, 

479 num_messages: int = 500, 

480) -> dict[str, Any]: 

481 """Generate complete test dataset with ground truth. 

482 

483 Args: 

484 output_dir: Directory to save test data. 

485 num_packets: Number of packets to generate. 

486 num_signals: Number of signals to generate. 

487 num_messages: Number of messages to generate. 

488 

489 Returns: 

490 Dictionary with dataset metadata and file paths. 

491 """ 

492 output_path = Path(output_dir) 

493 output_path.mkdir(parents=True, exist_ok=True) 

494 

495 generator = SyntheticDataGenerator() 

496 metadata: dict[str, Any] = { 

497 "dataset_type": "synthetic_test_data", 

498 "generated_files": [], 

499 } 

500 

501 # Generate packets 

502 packet_config = SyntheticPacketConfig() 

503 packets, packet_truth = generator.generate_packets(packet_config, num_packets) 

504 packet_file = output_path / "test_packets.bin" 

505 packet_file.write_bytes(packets) 

506 metadata["generated_files"].append( 

507 { 

508 "path": str(packet_file), 

509 "type": "packets", 

510 "count": num_packets, 

511 "ground_truth": { 

512 "sequence_numbers": packet_truth.sequence_numbers[:10], # First 10 

513 "checksum_offsets": packet_truth.checksum_offsets[:10], 

514 }, 

515 } 

516 ) 

517 

518 # Generate signals 

519 for i in range(num_signals): 

520 signal_config = SyntheticSignalConfig( 

521 pattern_type="square" if i % 2 == 0 else "uart", 

522 frequency=1e6 * (i + 1), 

523 ) 

524 signal, signal_truth = generator.generate_digital_signal(signal_config) 

525 signal_file = output_path / f"test_signal_{i:03d}.npy" 

526 np.save(signal_file, signal) 

527 metadata["generated_files"].append( 

528 { 

529 "path": str(signal_file), 

530 "type": "signal", 

531 "pattern": signal_config.pattern_type, 

532 "ground_truth": { 

533 "frequency_hz": signal_truth.frequency_hz, 

534 "period_samples": signal_truth.pattern_period, 

535 }, 

536 } 

537 ) 

538 

539 # Generate protocol messages 

540 message_config = SyntheticMessageConfig() 

541 messages, message_truth = generator.generate_protocol_messages(message_config, num_messages) 

542 messages_file = output_path / "test_messages.bin" 

543 with messages_file.open("wb") as f: 

544 for msg in messages: 

545 f.write(msg) 

546 metadata["generated_files"].append( 

547 { 

548 "path": str(messages_file), 

549 "type": "messages", 

550 "count": num_messages, 

551 "ground_truth": { 

552 "field_boundaries": message_truth.field_boundaries, 

553 "field_types": message_truth.field_types, 

554 "message_size": message_config.message_size, 

555 }, 

556 } 

557 ) 

558 

559 # Save metadata 

560 import json 

561 

562 metadata_file = output_path / "dataset_metadata.json" 

563 with metadata_file.open("w") as f: 

564 json.dump(metadata, f, indent=2) 

565 

566 metadata["metadata_file"] = str(metadata_file) 

567 

568 return metadata