Coverage for src / tracekit / testing / synthetic.py: 99%
244 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-11 23:04 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-11 23:04 +0000
1"""Synthetic test data generation with known ground truth.
3Provides utilities for generating synthetic test data with known properties
4for validation and testing purposes.
5"""
7import struct
8from dataclasses import dataclass, field
9from pathlib import Path
10from typing import Any, Literal, cast
12import numpy as np
13from numpy.typing import NDArray
16@dataclass
17class SyntheticPacketConfig:
18 """Configuration for synthetic packet generation."""
20 packet_size: int = 1024
21 header_size: int = 16
22 sync_pattern: bytes = b"\xaa\x55"
23 include_sequence: bool = True
24 include_timestamp: bool = True
25 include_checksum: bool = True
26 checksum_algorithm: str = "crc16"
27 noise_level: float = 0.0 # 0-1, fraction of corrupted packets
30@dataclass
31class SyntheticSignalConfig:
32 """Configuration for synthetic digital signal."""
34 pattern_type: Literal["square", "uart", "spi", "i2c", "random"] = "square"
35 sample_rate: float = 100e6
36 duration_samples: int = 10000
37 frequency: float = 1e6 # For clock/square wave
38 noise_snr_db: float = 40 # Signal-to-noise ratio
41@dataclass
42class SyntheticMessageConfig:
43 """Configuration for synthetic protocol messages."""
45 message_size: int = 64
46 num_fields: int = 5
47 include_header: bool = True
48 include_length: bool = True
49 include_checksum: bool = True
50 variation: float = 0.1 # Fraction of variable bytes
53@dataclass
54class GroundTruth:
55 """Ground truth data for validation."""
57 field_boundaries: list[int] = field(default_factory=list)
58 field_types: list[str] = field(default_factory=list)
59 sequence_numbers: list[int] = field(default_factory=list)
60 pattern_period: int | None = None
61 cluster_labels: list[int] = field(default_factory=list)
62 checksum_offsets: list[int] = field(default_factory=list)
63 decoded_bytes: list[int] = field(default_factory=list)
64 edge_positions: list[int] = field(default_factory=list)
65 frequency_hz: float | None = None
68class SyntheticDataGenerator:
69 """Generate synthetic test data with known ground truth."""
71 def __init__(self, seed: int = 42):
72 """Initialize with random seed for reproducibility.
74 Args:
75 seed: Random seed for reproducible generation.
76 """
77 self.rng = np.random.default_rng(seed)
79 def generate_packets(
80 self, config: SyntheticPacketConfig, count: int = 100
81 ) -> tuple[bytes, GroundTruth]:
82 """Generate synthetic binary packets.
84 Args:
85 config: Packet generation configuration.
86 count: Number of packets to generate.
88 Returns:
89 Tuple of (binary_data, ground_truth).
90 """
91 packets = bytearray()
92 ground_truth = GroundTruth()
94 for i in range(count):
95 packet = bytearray()
97 # Sync pattern
98 packet.extend(config.sync_pattern)
100 # Sequence number (2 bytes)
101 if config.include_sequence:
102 packet.extend(struct.pack("<H", i))
103 ground_truth.sequence_numbers.append(i)
105 # Timestamp (4 bytes)
106 if config.include_timestamp:
107 timestamp = i * 1000 # Incrementing by 1000 microseconds
108 packet.extend(struct.pack("<I", timestamp))
110 # Padding to header size
111 while len(packet) < config.header_size:
112 packet.append(0x00)
114 # Sample data
115 sample_data_size = config.packet_size - config.header_size
116 if config.include_checksum:
117 sample_data_size -= 2
119 # Counter pattern for samples (easier to validate)
120 for j in range(sample_data_size // 2):
121 packet.extend(struct.pack("<H", j))
123 # Checksum (CRC-16)
124 if config.include_checksum:
125 checksum = self._calculate_crc16(packet)
126 checksum_offset = len(packets) + len(packet)
127 ground_truth.checksum_offsets.append(checksum_offset)
128 packet.extend(struct.pack("<H", checksum))
130 packets.extend(packet)
132 # Apply noise (corrupt random packets)
133 if config.noise_level > 0:
134 packets = bytearray(
135 self.corrupt_packets(bytes(packets), config.packet_size, config.noise_level)
136 )
138 return bytes(packets), ground_truth
140 def generate_digital_signal(
141 self, config: SyntheticSignalConfig
142 ) -> tuple[NDArray[np.float64], GroundTruth]:
143 """Generate synthetic digital signal.
145 Args:
146 config: Signal generation configuration.
148 Returns:
149 Tuple of (signal_array, ground_truth).
150 """
151 ground_truth = GroundTruth()
152 signal: NDArray[np.float64]
154 if config.pattern_type == "square":
155 # Generate square wave
156 period_samples = int(config.sample_rate / config.frequency)
157 ground_truth.pattern_period = period_samples
158 ground_truth.frequency_hz = config.frequency
160 t = np.arange(config.duration_samples)
161 signal = (np.sin(2 * np.pi * config.frequency * t / config.sample_rate) > 0).astype(
162 np.float64
163 )
165 # Track edge positions
166 edges = np.where(np.diff(signal) != 0)[0] + 1
167 ground_truth.edge_positions = edges.tolist()
169 elif config.pattern_type == "uart":
170 # Generate UART signal (8N1)
171 signal, uart_truth = self._generate_uart_signal(config)
172 ground_truth.decoded_bytes = uart_truth["bytes"]
173 ground_truth.edge_positions = uart_truth["edges"]
175 elif config.pattern_type == "random":
176 # Random digital signal
177 signal = self.rng.choice([0.0, 1.0], size=config.duration_samples)
179 else:
180 # Default to simple pattern
181 pattern = np.array([1, 1, 0, 1, 0, 0, 1, 0], dtype=np.float64)
182 signal = np.tile(pattern, config.duration_samples // len(pattern) + 1)[
183 : config.duration_samples
184 ]
185 ground_truth.pattern_period = len(pattern)
187 # Scale to 3.3V logic levels
188 signal = signal * 3.3
190 # Add noise
191 if config.noise_snr_db < np.inf:
192 noisy_signal = self.add_noise(signal, config.noise_snr_db)
193 assert isinstance(noisy_signal, np.ndarray), (
194 "add_noise should return ndarray for ndarray input"
195 )
196 signal = noisy_signal
198 return signal, ground_truth
200 def generate_protocol_messages(
201 self, config: SyntheticMessageConfig, count: int = 100
202 ) -> tuple[list[bytes], GroundTruth]:
203 """Generate synthetic protocol messages.
205 Args:
206 config: Message generation configuration.
207 count: Number of messages to generate.
209 Returns:
210 Tuple of (message_list, ground_truth).
211 """
212 messages = []
213 ground_truth = GroundTruth()
215 # Define field structure
216 field_boundaries = [0]
217 field_types = []
219 current_offset = 0
221 if config.include_header:
222 # 2-byte sync pattern
223 field_boundaries.append(current_offset + 2)
224 field_types.append("constant")
225 current_offset += 2
227 if config.include_length:
228 # 2-byte length field
229 field_boundaries.append(current_offset + 2)
230 field_types.append("length")
231 current_offset += 2
233 # Sequence number (2 bytes)
234 field_boundaries.append(current_offset + 2)
235 field_types.append("sequence")
236 current_offset += 2
238 # Timestamp (4 bytes)
239 field_boundaries.append(current_offset + 4)
240 field_types.append("timestamp")
241 current_offset += 4
243 # Payload (variable size)
244 payload_size = config.message_size - current_offset
245 if config.include_checksum:
246 payload_size -= 2
248 field_boundaries.append(current_offset + payload_size)
249 field_types.append("data")
250 current_offset += payload_size
252 if config.include_checksum:
253 field_boundaries.append(current_offset + 2)
254 field_types.append("checksum")
256 ground_truth.field_boundaries = field_boundaries
257 ground_truth.field_types = field_types
259 # Generate messages
260 for i in range(count):
261 message = bytearray()
263 if config.include_header:
264 message.extend(b"\xaa\x55")
266 if config.include_length:
267 message.extend(struct.pack("<H", config.message_size))
269 # Sequence number
270 message.extend(struct.pack("<H", i))
271 ground_truth.sequence_numbers.append(i)
273 # Timestamp
274 timestamp = i * 100 # Incrementing
275 message.extend(struct.pack("<I", timestamp))
277 # Payload (partially random, partially constant based on variation)
278 for _ in range(payload_size):
279 if self.rng.random() < config.variation:
280 message.append(self.rng.integers(0, 256))
281 else:
282 message.append(0x42) # Constant byte
284 # Checksum
285 if config.include_checksum:
286 checksum = self._calculate_crc16(message)
287 message.extend(struct.pack("<H", checksum))
289 messages.append(bytes(message))
291 return messages, ground_truth
293 def add_noise(
294 self, data: bytes | NDArray[np.float64], snr_db: float = 20
295 ) -> bytes | NDArray[np.float64]:
296 """Add noise to data.
298 Args:
299 data: Input data (bytes or numpy array).
300 snr_db: Signal-to-noise ratio in dB.
302 Returns:
303 Noisy data (same type as input).
304 """
305 if isinstance(data, bytes):
306 # For bytes, add random bit flips
307 data_array = np.frombuffer(data, dtype=np.uint8)
308 noise_rate = 10 ** (-snr_db / 10)
309 mask = self.rng.random(len(data_array)) < noise_rate
310 noisy = data_array.copy()
311 noisy[mask] ^= self.rng.integers(1, 256, size=int(np.sum(mask)), dtype=np.uint8)
312 return noisy.tobytes()
313 else:
314 # For arrays, add Gaussian noise
315 signal_power = np.mean(data**2)
316 noise_power = signal_power / (10 ** (snr_db / 10))
317 noise = self.rng.normal(0, np.sqrt(noise_power), len(data))
318 return data + noise
320 def corrupt_packets(
321 self, packets: bytes, packet_size: int, corruption_rate: float = 0.01
322 ) -> bytes:
323 """Corrupt random packets for testing error handling.
325 Args:
326 packets: Binary packet data.
327 packet_size: Size of each packet in bytes.
328 corruption_rate: Fraction of packets to corrupt (0-1).
330 Returns:
331 Corrupted packet data.
332 """
333 packets_array = bytearray(packets)
334 num_packets = len(packets) // packet_size
336 for i in range(num_packets):
337 if self.rng.random() < corruption_rate:
338 # Corrupt sync marker
339 offset = i * packet_size
340 packets_array[offset] ^= 0xFF
342 return bytes(packets_array)
344 def _calculate_crc16(self, data: bytes | bytearray) -> int:
345 """Calculate CRC-16 checksum.
347 Args:
348 data: Input data.
350 Returns:
351 CRC-16 checksum value.
352 """
353 crc = 0xFFFF
354 for byte in data:
355 crc ^= byte
356 for _ in range(8):
357 if crc & 0x0001:
358 crc = (crc >> 1) ^ 0xA001
359 else:
360 crc >>= 1
361 return crc & 0xFFFF
363 def _generate_uart_signal(
364 self, config: SyntheticSignalConfig
365 ) -> tuple[NDArray[np.float64], dict[str, Any]]:
366 """Generate UART signal encoding a test message.
368 Args:
369 config: Signal configuration.
371 Returns:
372 Tuple of (signal, metadata_dict).
373 """
374 # UART parameters
375 baud_rate = 9600
376 samples_per_bit = int(config.sample_rate / baud_rate)
378 # Test message
379 message = b"Hello, World!"
381 # Encode with start/stop bits
382 bits = []
383 edges = []
385 for byte_val in message:
386 # Start bit (0)
387 bits.extend([0] * samples_per_bit)
388 edges.append(len(bits))
390 # Data bits (LSB first)
391 for i in range(8):
392 bit = (byte_val >> i) & 1
393 bits.extend([bit] * samples_per_bit)
394 if i > 0 and bits[-1] != bits[-samples_per_bit - 1]:
395 edges.append(len(bits) - samples_per_bit)
397 # Stop bit (1)
398 bits.extend([1] * samples_per_bit)
399 if bits[-1] != bits[-samples_per_bit - 1]: 399 ↛ 385line 399 didn't jump to line 385 because the condition on line 399 was always true
400 edges.append(len(bits) - samples_per_bit)
402 # Pad to duration
403 signal = np.array(bits[: config.duration_samples], dtype=np.float64)
404 if len(signal) < config.duration_samples: 404 ↛ 405line 404 didn't jump to line 405 because the condition on line 404 was never true
405 padding = np.ones(config.duration_samples - len(signal), dtype=np.float64)
406 signal = np.concatenate([signal, padding])
408 metadata = {"bytes": list(message), "edges": edges[: len(signal)]}
410 return signal, metadata
413# =============================================================================
414# Convenience functions
415# =============================================================================
418def generate_packets(count: int = 100, **kwargs: Any) -> tuple[bytes, GroundTruth]:
419 """Generate synthetic packets with defaults.
421 Args:
422 count: Number of packets to generate.
423 **kwargs: Additional configuration parameters.
425 Returns:
426 Tuple of (binary_data, ground_truth).
427 """
428 config = SyntheticPacketConfig(**kwargs)
429 generator = SyntheticDataGenerator()
430 return generator.generate_packets(config, count)
433def generate_digital_signal(
434 pattern: str = "square", **kwargs: Any
435) -> tuple[NDArray[np.float64], GroundTruth]:
436 """Generate synthetic signal with defaults.
438 Args:
439 pattern: Pattern type ('square', 'uart', 'random', etc.).
440 **kwargs: Additional configuration parameters.
442 Returns:
443 Tuple of (signal_array, ground_truth).
444 """
445 # Determine pattern type
446 valid_patterns = ["square", "uart", "spi", "i2c", "random"]
447 pattern_type = pattern if pattern in valid_patterns else "square"
449 # Filter out pattern_type from kwargs to avoid duplicate argument error
450 filtered_kwargs = {k: v for k, v in kwargs.items() if k != "pattern_type"}
452 config = SyntheticSignalConfig(
453 pattern_type=cast("Literal['square', 'uart', 'spi', 'i2c', 'random']", pattern_type),
454 **filtered_kwargs,
455 )
456 generator = SyntheticDataGenerator()
457 return generator.generate_digital_signal(config)
460def generate_protocol_messages(count: int = 100, **kwargs: Any) -> tuple[list[bytes], GroundTruth]:
461 """Generate synthetic messages with defaults.
463 Args:
464 count: Number of messages to generate.
465 **kwargs: Additional configuration parameters.
467 Returns:
468 Tuple of (message_list, ground_truth).
469 """
470 config = SyntheticMessageConfig(**kwargs)
471 generator = SyntheticDataGenerator()
472 return generator.generate_protocol_messages(config, count)
475def generate_test_dataset(
476 output_dir: str,
477 num_packets: int = 1000,
478 num_signals: int = 10,
479 num_messages: int = 500,
480) -> dict[str, Any]:
481 """Generate complete test dataset with ground truth.
483 Args:
484 output_dir: Directory to save test data.
485 num_packets: Number of packets to generate.
486 num_signals: Number of signals to generate.
487 num_messages: Number of messages to generate.
489 Returns:
490 Dictionary with dataset metadata and file paths.
491 """
492 output_path = Path(output_dir)
493 output_path.mkdir(parents=True, exist_ok=True)
495 generator = SyntheticDataGenerator()
496 metadata: dict[str, Any] = {
497 "dataset_type": "synthetic_test_data",
498 "generated_files": [],
499 }
501 # Generate packets
502 packet_config = SyntheticPacketConfig()
503 packets, packet_truth = generator.generate_packets(packet_config, num_packets)
504 packet_file = output_path / "test_packets.bin"
505 packet_file.write_bytes(packets)
506 metadata["generated_files"].append(
507 {
508 "path": str(packet_file),
509 "type": "packets",
510 "count": num_packets,
511 "ground_truth": {
512 "sequence_numbers": packet_truth.sequence_numbers[:10], # First 10
513 "checksum_offsets": packet_truth.checksum_offsets[:10],
514 },
515 }
516 )
518 # Generate signals
519 for i in range(num_signals):
520 signal_config = SyntheticSignalConfig(
521 pattern_type="square" if i % 2 == 0 else "uart",
522 frequency=1e6 * (i + 1),
523 )
524 signal, signal_truth = generator.generate_digital_signal(signal_config)
525 signal_file = output_path / f"test_signal_{i:03d}.npy"
526 np.save(signal_file, signal)
527 metadata["generated_files"].append(
528 {
529 "path": str(signal_file),
530 "type": "signal",
531 "pattern": signal_config.pattern_type,
532 "ground_truth": {
533 "frequency_hz": signal_truth.frequency_hz,
534 "period_samples": signal_truth.pattern_period,
535 },
536 }
537 )
539 # Generate protocol messages
540 message_config = SyntheticMessageConfig()
541 messages, message_truth = generator.generate_protocol_messages(message_config, num_messages)
542 messages_file = output_path / "test_messages.bin"
543 with messages_file.open("wb") as f:
544 for msg in messages:
545 f.write(msg)
546 metadata["generated_files"].append(
547 {
548 "path": str(messages_file),
549 "type": "messages",
550 "count": num_messages,
551 "ground_truth": {
552 "field_boundaries": message_truth.field_boundaries,
553 "field_types": message_truth.field_types,
554 "message_size": message_config.message_size,
555 },
556 }
557 )
559 # Save metadata
560 import json
562 metadata_file = output_path / "dataset_metadata.json"
563 with metadata_file.open("w") as f:
564 json.dump(metadata, f, indent=2)
566 metadata["metadata_file"] = str(metadata_file)
568 return metadata