Coverage for src / tracekit / comparison / mask.py: 94%
137 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-11 23:04 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-11 23:04 +0000
1"""Mask testing for TraceKit.
3This module provides mask-based pass/fail testing for waveforms,
4including eye diagram masks and custom polygon masks.
7Example:
8 >>> from tracekit.comparison import mask_test, eye_mask
9 >>> mask = eye_mask(0.5, 0.4, 0.3)
10 >>> result = mask_test(trace, mask)
12References:
13 IEEE 802.3: Ethernet eye diagram mask specifications
14"""
16from __future__ import annotations
18from dataclasses import dataclass, field
19from typing import TYPE_CHECKING, Any, Literal
21import numpy as np
23from tracekit.core.exceptions import AnalysisError
25if TYPE_CHECKING:
26 from numpy.typing import NDArray
28 from tracekit.core.types import WaveformTrace
31@dataclass
32class MaskRegion:
33 """A region in a mask definition.
35 Represents a polygon region that waveform data must avoid
36 (violation region) or must stay within (boundary region).
38 Attributes:
39 vertices: List of (x, y) vertices defining the polygon.
40 region_type: "violation" (must avoid) or "boundary" (must stay within).
41 name: Optional name for the region.
42 """
44 vertices: list[tuple[float, float]]
45 region_type: Literal["violation", "boundary"] = "violation"
46 name: str = ""
48 def contains_point(self, x: float, y: float) -> bool:
49 """Check if a point is inside the polygon.
51 Uses ray casting algorithm for point-in-polygon test.
53 Args:
54 x: X coordinate.
55 y: Y coordinate.
57 Returns:
58 True if point is inside the polygon.
59 """
60 n = len(self.vertices)
61 inside = False
63 j = n - 1
64 for i in range(n):
65 xi, yi = self.vertices[i]
66 xj, yj = self.vertices[j]
68 if ((yi > y) != (yj > y)) and (x < (xj - xi) * (y - yi) / (yj - yi) + xi):
69 inside = not inside
70 j = i
72 return inside
75@dataclass
76class Mask:
77 """Mask definition for waveform testing.
79 A mask consists of one or more regions that define pass/fail criteria
80 for waveform data.
82 Attributes:
83 regions: List of MaskRegion polygons.
84 name: Name of the mask.
85 x_unit: Unit for X axis (e.g., "UI", "ns", "samples").
86 y_unit: Unit for Y axis (e.g., "V", "mV", "normalized").
87 description: Optional description.
88 """
90 regions: list[MaskRegion] = field(default_factory=list)
91 name: str = "mask"
92 x_unit: str = "UI"
93 y_unit: str = "V"
94 description: str = ""
96 def add_region(
97 self,
98 vertices: list[tuple[float, float]],
99 region_type: Literal["violation", "boundary"] = "violation",
100 name: str = "",
101 ) -> None:
102 """Add a region to the mask.
104 Args:
105 vertices: List of (x, y) vertices.
106 region_type: "violation" or "boundary".
107 name: Optional region name.
108 """
109 self.regions.append(MaskRegion(vertices, region_type, name))
112@dataclass
113class MaskTestResult:
114 """Result of a mask test.
116 Attributes:
117 passed: True if all samples pass the mask test.
118 num_violations: Number of samples violating the mask.
119 violation_rate: Fraction of samples violating the mask.
120 violation_points: List of (x, y) coordinates that violated.
121 violations_by_region: Count of violations per region.
122 margin: Estimated margin to mask boundary.
123 """
125 passed: bool
126 num_violations: int
127 violation_rate: float
128 violation_points: list[tuple[float, float]] = field(default_factory=list)
129 violations_by_region: dict[str, int] = field(default_factory=dict)
130 margin: float | None = None
133def create_mask(
134 regions: list[dict], # type: ignore[type-arg]
135 *,
136 name: str = "custom_mask",
137 x_unit: str = "samples",
138 y_unit: str = "V",
139) -> Mask:
140 """Create a mask from region definitions.
142 Args:
143 regions: List of region dicts with 'vertices' and optional
144 'type' and 'name' keys.
145 name: Mask name.
146 x_unit: X axis unit.
147 y_unit: Y axis unit.
149 Returns:
150 Mask instance.
152 Example:
153 >>> mask = create_mask([
154 ... {"vertices": [(0, 0.5), (0.5, 0.5), (0.5, -0.5), (0, -0.5)],
155 ... "type": "violation", "name": "center"}
156 ... ])
157 """
158 mask = Mask(name=name, x_unit=x_unit, y_unit=y_unit)
160 for region in regions:
161 vertices = region["vertices"]
162 region_type = region.get("type", "violation")
163 region_name = region.get("name", "")
164 mask.add_region(vertices, region_type, region_name)
166 return mask
169def eye_mask(
170 eye_width: float = 0.5,
171 eye_height: float = 0.4,
172 center_height: float = 0.3,
173 *,
174 x_margin: float = 0.0,
175 y_margin: float = 0.1,
176 unit_interval: float = 1.0,
177 amplitude: float = 1.0,
178) -> Mask:
179 """Create a standard eye diagram mask.
181 Creates a hexagonal eye mask with center violation region and
182 optional boundary regions based on eye opening parameters.
184 Args:
185 eye_width: Width of eye opening (fraction of UI).
186 eye_height: Height of eye opening (fraction of amplitude).
187 center_height: Height of center violation region.
188 x_margin: X margin for boundary (fraction of UI). Reserved for future use.
189 y_margin: Y margin for boundary (fraction of amplitude).
190 unit_interval: Duration of unit interval.
191 amplitude: Signal amplitude.
193 Returns:
194 Mask for eye diagram testing.
196 Example:
197 >>> mask = eye_mask(0.5, 0.4) # Standard 50% width, 40% height
198 >>> # Creates violation region in center of eye
199 """
200 mask = Mask(
201 name="eye_mask",
202 x_unit="UI",
203 y_unit="normalized",
204 description=f"Eye mask: {eye_width * 100:.0f}% width, {eye_height * 100:.0f}% height",
205 )
207 # Scale parameters
208 ui = unit_interval
209 amp = amplitude
211 # Center violation region (hexagonal)
212 # Points arranged clockwise from left
213 center_width = eye_width * ui
214 center_top = eye_height * amp / 2
215 center_bottom = -eye_height * amp / 2
216 mid_width = center_width * 0.7 # Narrower at top/bottom
218 center_vertices = [
219 (-center_width / 2, 0), # Left
220 (-mid_width / 2, center_top), # Upper left
221 (mid_width / 2, center_top), # Upper right
222 (center_width / 2, 0), # Right
223 (mid_width / 2, center_bottom), # Lower right
224 (-mid_width / 2, center_bottom), # Lower left
225 ]
226 mask.add_region(center_vertices, "violation", "eye_center")
228 # Top violation region (above eye)
229 top_y = amp / 2 + y_margin * amp
230 top_vertices = [
231 (-ui / 2, center_top + center_height * amp),
232 (ui / 2, center_top + center_height * amp),
233 (ui / 2, top_y),
234 (-ui / 2, top_y),
235 ]
236 mask.add_region(top_vertices, "violation", "top")
238 # Bottom violation region (below eye)
239 bottom_y = -amp / 2 - y_margin * amp
240 bottom_vertices = [
241 (-ui / 2, bottom_y),
242 (ui / 2, bottom_y),
243 (ui / 2, center_bottom - center_height * amp),
244 (-ui / 2, center_bottom - center_height * amp),
245 ]
246 mask.add_region(bottom_vertices, "violation", "bottom")
248 return mask
251def mask_test(
252 trace: WaveformTrace,
253 mask: Mask,
254 *,
255 x_data: NDArray[np.floating[Any]] | None = None,
256 normalize: bool = True,
257 sample_rate: float | None = None,
258) -> MaskTestResult:
259 """Test waveform against a mask.
261 Checks if any samples of the waveform violate the mask regions.
263 Args:
264 trace: Input waveform trace.
265 mask: Mask to test against.
266 x_data: X coordinates for each sample (if different from time).
267 normalize: Normalize Y data to [-1, 1] range.
268 sample_rate: Sample rate override.
270 Returns:
271 MaskTestResult with pass/fail status and violation details.
273 Example:
274 >>> result = mask_test(eye_trace, mask)
275 >>> print(f"Violations: {result.num_violations}")
276 """
277 # Get Y data
278 y_data = trace.data.astype(np.float64)
280 # Get or create X data
281 if x_data is None:
282 x_data = np.arange(len(y_data), dtype=np.float64)
284 # Normalize if requested
285 if normalize:
286 y_min, y_max = np.min(y_data), np.max(y_data)
287 if y_max - y_min > 0:
288 y_data = 2 * (y_data - y_min) / (y_max - y_min) - 1
290 # Test each point against mask regions
291 violations: list[tuple[float, float]] = []
292 violations_by_region: dict[str, int] = {}
294 for region in mask.regions:
295 region_name = region.name or "unnamed"
296 violations_by_region[region_name] = 0
298 if region.region_type == "violation":
299 # Check if points are inside violation region
300 for i, (x, y) in enumerate(zip(x_data, y_data, strict=False)): # noqa: B007
301 if region.contains_point(float(x), float(y)):
302 violations.append((float(x), float(y)))
303 violations_by_region[region_name] += 1
305 elif region.region_type == "boundary": 305 ↛ 294line 305 didn't jump to line 294 because the condition on line 305 was always true
306 # Check if points are outside boundary region
307 for i, (x, y) in enumerate(zip(x_data, y_data, strict=False)): # noqa: B007
308 if not region.contains_point(float(x), float(y)): 308 ↛ 309line 308 didn't jump to line 309 because the condition on line 308 was never true
309 violations.append((float(x), float(y)))
310 violations_by_region[region_name] += 1
312 # Remove duplicates
313 unique_violations = list(set(violations))
314 num_violations = len(unique_violations)
315 violation_rate = num_violations / len(y_data) if len(y_data) > 0 else 0.0
317 # Estimate margin (simplified - distance to nearest mask edge)
318 margin = None
319 if num_violations == 0 and mask.regions:
320 # Find minimum distance to any violation region
321 min_dist = float("inf")
322 for region in mask.regions:
323 if region.region_type == "violation":
324 for x, y in zip(x_data, y_data, strict=False):
325 for i in range(len(region.vertices)):
326 x1, y1 = region.vertices[i]
327 x2, y2 = region.vertices[(i + 1) % len(region.vertices)]
328 # Distance to line segment
329 dist = _point_to_segment_distance(x, y, x1, y1, x2, y2)
330 min_dist = min(min_dist, dist)
331 margin = min_dist if min_dist != float("inf") else None
333 return MaskTestResult(
334 passed=num_violations == 0,
335 num_violations=num_violations,
336 violation_rate=violation_rate,
337 violation_points=unique_violations,
338 violations_by_region=violations_by_region,
339 margin=margin,
340 )
343def _point_to_segment_distance(
344 px: float, py: float, x1: float, y1: float, x2: float, y2: float
345) -> float:
346 """Calculate distance from point to line segment."""
347 dx = x2 - x1
348 dy = y2 - y1
349 length_sq = dx * dx + dy * dy
351 if length_sq == 0: 351 ↛ 353line 351 didn't jump to line 353 because the condition on line 351 was never true
352 # Segment is a point
353 return np.sqrt((px - x1) ** 2 + (py - y1) ** 2) # type: ignore[no-any-return]
355 # Project point onto line
356 t = max(0, min(1, ((px - x1) * dx + (py - y1) * dy) / length_sq))
357 proj_x = x1 + t * dx
358 proj_y = y1 + t * dy
360 return float(np.sqrt((px - proj_x) ** 2 + (py - proj_y) ** 2))
363def eye_diagram_mask_test(
364 eye_data: NDArray[np.floating[Any]],
365 *,
366 eye_width: float = 0.5,
367 eye_height: float = 0.4,
368 unit_interval: float = 1.0,
369) -> MaskTestResult:
370 """Specialized eye diagram mask test.
372 Tests 2D eye diagram data against a standard eye mask.
374 Args:
375 eye_data: 2D array of shape (num_traces, samples_per_ui).
376 eye_width: Eye opening width (fraction of UI).
377 eye_height: Eye opening height (fraction of amplitude).
378 unit_interval: Duration of unit interval in samples.
380 Returns:
381 MaskTestResult for the eye diagram.
383 Raises:
384 AnalysisError: If eye data is not a 2D array.
385 """
386 if eye_data.ndim != 2:
387 raise AnalysisError("Eye data must be 2D array (num_traces x samples_per_ui)")
389 num_traces, samples_per_ui = eye_data.shape
391 # Create mask
392 mask = eye_mask(
393 eye_width=eye_width,
394 eye_height=eye_height,
395 unit_interval=unit_interval,
396 amplitude=1.0,
397 )
399 # Normalize data
400 flat_data = eye_data.flatten()
401 y_min, y_max = np.min(flat_data), np.max(flat_data)
402 normalized = 2 * (eye_data - y_min) / (y_max - y_min) - 1 if y_max - y_min > 0 else eye_data
404 # Create X coordinates (relative to UI center)
405 x_coords = np.linspace(-0.5, 0.5, samples_per_ui) * unit_interval
407 # Test all traces
408 violations: list[tuple[float, float]] = []
409 violations_by_region: dict[str, int] = {r.name or "unnamed": 0 for r in mask.regions}
411 for trace_idx in range(num_traces):
412 for sample_idx in range(samples_per_ui):
413 x = float(x_coords[sample_idx])
414 y = float(normalized[trace_idx, sample_idx])
416 for region in mask.regions:
417 if region.region_type == "violation": 417 ↛ 416line 417 didn't jump to line 416 because the condition on line 417 was always true
418 if region.contains_point(x, y): 418 ↛ 419line 418 didn't jump to line 419 because the condition on line 418 was never true
419 violations.append((x, y))
420 region_name = region.name or "unnamed"
421 violations_by_region[region_name] += 1
423 unique_violations = list(set(violations))
424 num_violations = len(unique_violations)
425 total_points = num_traces * samples_per_ui
427 return MaskTestResult(
428 passed=num_violations == 0,
429 num_violations=num_violations,
430 violation_rate=num_violations / total_points if total_points > 0 else 0,
431 violation_points=unique_violations,
432 violations_by_region=violations_by_region,
433 margin=None,
434 )