Coverage for src / tracekit / comparison / mask.py: 94%

137 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-11 23:04 +0000

1"""Mask testing for TraceKit. 

2 

3This module provides mask-based pass/fail testing for waveforms, 

4including eye diagram masks and custom polygon masks. 

5 

6 

7Example: 

8 >>> from tracekit.comparison import mask_test, eye_mask 

9 >>> mask = eye_mask(0.5, 0.4, 0.3) 

10 >>> result = mask_test(trace, mask) 

11 

12References: 

13 IEEE 802.3: Ethernet eye diagram mask specifications 

14""" 

15 

16from __future__ import annotations 

17 

18from dataclasses import dataclass, field 

19from typing import TYPE_CHECKING, Any, Literal 

20 

21import numpy as np 

22 

23from tracekit.core.exceptions import AnalysisError 

24 

25if TYPE_CHECKING: 

26 from numpy.typing import NDArray 

27 

28 from tracekit.core.types import WaveformTrace 

29 

30 

31@dataclass 

32class MaskRegion: 

33 """A region in a mask definition. 

34 

35 Represents a polygon region that waveform data must avoid 

36 (violation region) or must stay within (boundary region). 

37 

38 Attributes: 

39 vertices: List of (x, y) vertices defining the polygon. 

40 region_type: "violation" (must avoid) or "boundary" (must stay within). 

41 name: Optional name for the region. 

42 """ 

43 

44 vertices: list[tuple[float, float]] 

45 region_type: Literal["violation", "boundary"] = "violation" 

46 name: str = "" 

47 

48 def contains_point(self, x: float, y: float) -> bool: 

49 """Check if a point is inside the polygon. 

50 

51 Uses ray casting algorithm for point-in-polygon test. 

52 

53 Args: 

54 x: X coordinate. 

55 y: Y coordinate. 

56 

57 Returns: 

58 True if point is inside the polygon. 

59 """ 

60 n = len(self.vertices) 

61 inside = False 

62 

63 j = n - 1 

64 for i in range(n): 

65 xi, yi = self.vertices[i] 

66 xj, yj = self.vertices[j] 

67 

68 if ((yi > y) != (yj > y)) and (x < (xj - xi) * (y - yi) / (yj - yi) + xi): 

69 inside = not inside 

70 j = i 

71 

72 return inside 

73 

74 

75@dataclass 

76class Mask: 

77 """Mask definition for waveform testing. 

78 

79 A mask consists of one or more regions that define pass/fail criteria 

80 for waveform data. 

81 

82 Attributes: 

83 regions: List of MaskRegion polygons. 

84 name: Name of the mask. 

85 x_unit: Unit for X axis (e.g., "UI", "ns", "samples"). 

86 y_unit: Unit for Y axis (e.g., "V", "mV", "normalized"). 

87 description: Optional description. 

88 """ 

89 

90 regions: list[MaskRegion] = field(default_factory=list) 

91 name: str = "mask" 

92 x_unit: str = "UI" 

93 y_unit: str = "V" 

94 description: str = "" 

95 

96 def add_region( 

97 self, 

98 vertices: list[tuple[float, float]], 

99 region_type: Literal["violation", "boundary"] = "violation", 

100 name: str = "", 

101 ) -> None: 

102 """Add a region to the mask. 

103 

104 Args: 

105 vertices: List of (x, y) vertices. 

106 region_type: "violation" or "boundary". 

107 name: Optional region name. 

108 """ 

109 self.regions.append(MaskRegion(vertices, region_type, name)) 

110 

111 

112@dataclass 

113class MaskTestResult: 

114 """Result of a mask test. 

115 

116 Attributes: 

117 passed: True if all samples pass the mask test. 

118 num_violations: Number of samples violating the mask. 

119 violation_rate: Fraction of samples violating the mask. 

120 violation_points: List of (x, y) coordinates that violated. 

121 violations_by_region: Count of violations per region. 

122 margin: Estimated margin to mask boundary. 

123 """ 

124 

125 passed: bool 

126 num_violations: int 

127 violation_rate: float 

128 violation_points: list[tuple[float, float]] = field(default_factory=list) 

129 violations_by_region: dict[str, int] = field(default_factory=dict) 

130 margin: float | None = None 

131 

132 

133def create_mask( 

134 regions: list[dict], # type: ignore[type-arg] 

135 *, 

136 name: str = "custom_mask", 

137 x_unit: str = "samples", 

138 y_unit: str = "V", 

139) -> Mask: 

140 """Create a mask from region definitions. 

141 

142 Args: 

143 regions: List of region dicts with 'vertices' and optional 

144 'type' and 'name' keys. 

145 name: Mask name. 

146 x_unit: X axis unit. 

147 y_unit: Y axis unit. 

148 

149 Returns: 

150 Mask instance. 

151 

152 Example: 

153 >>> mask = create_mask([ 

154 ... {"vertices": [(0, 0.5), (0.5, 0.5), (0.5, -0.5), (0, -0.5)], 

155 ... "type": "violation", "name": "center"} 

156 ... ]) 

157 """ 

158 mask = Mask(name=name, x_unit=x_unit, y_unit=y_unit) 

159 

160 for region in regions: 

161 vertices = region["vertices"] 

162 region_type = region.get("type", "violation") 

163 region_name = region.get("name", "") 

164 mask.add_region(vertices, region_type, region_name) 

165 

166 return mask 

167 

168 

169def eye_mask( 

170 eye_width: float = 0.5, 

171 eye_height: float = 0.4, 

172 center_height: float = 0.3, 

173 *, 

174 x_margin: float = 0.0, 

175 y_margin: float = 0.1, 

176 unit_interval: float = 1.0, 

177 amplitude: float = 1.0, 

178) -> Mask: 

179 """Create a standard eye diagram mask. 

180 

181 Creates a hexagonal eye mask with center violation region and 

182 optional boundary regions based on eye opening parameters. 

183 

184 Args: 

185 eye_width: Width of eye opening (fraction of UI). 

186 eye_height: Height of eye opening (fraction of amplitude). 

187 center_height: Height of center violation region. 

188 x_margin: X margin for boundary (fraction of UI). Reserved for future use. 

189 y_margin: Y margin for boundary (fraction of amplitude). 

190 unit_interval: Duration of unit interval. 

191 amplitude: Signal amplitude. 

192 

193 Returns: 

194 Mask for eye diagram testing. 

195 

196 Example: 

197 >>> mask = eye_mask(0.5, 0.4) # Standard 50% width, 40% height 

198 >>> # Creates violation region in center of eye 

199 """ 

200 mask = Mask( 

201 name="eye_mask", 

202 x_unit="UI", 

203 y_unit="normalized", 

204 description=f"Eye mask: {eye_width * 100:.0f}% width, {eye_height * 100:.0f}% height", 

205 ) 

206 

207 # Scale parameters 

208 ui = unit_interval 

209 amp = amplitude 

210 

211 # Center violation region (hexagonal) 

212 # Points arranged clockwise from left 

213 center_width = eye_width * ui 

214 center_top = eye_height * amp / 2 

215 center_bottom = -eye_height * amp / 2 

216 mid_width = center_width * 0.7 # Narrower at top/bottom 

217 

218 center_vertices = [ 

219 (-center_width / 2, 0), # Left 

220 (-mid_width / 2, center_top), # Upper left 

221 (mid_width / 2, center_top), # Upper right 

222 (center_width / 2, 0), # Right 

223 (mid_width / 2, center_bottom), # Lower right 

224 (-mid_width / 2, center_bottom), # Lower left 

225 ] 

226 mask.add_region(center_vertices, "violation", "eye_center") 

227 

228 # Top violation region (above eye) 

229 top_y = amp / 2 + y_margin * amp 

230 top_vertices = [ 

231 (-ui / 2, center_top + center_height * amp), 

232 (ui / 2, center_top + center_height * amp), 

233 (ui / 2, top_y), 

234 (-ui / 2, top_y), 

235 ] 

236 mask.add_region(top_vertices, "violation", "top") 

237 

238 # Bottom violation region (below eye) 

239 bottom_y = -amp / 2 - y_margin * amp 

240 bottom_vertices = [ 

241 (-ui / 2, bottom_y), 

242 (ui / 2, bottom_y), 

243 (ui / 2, center_bottom - center_height * amp), 

244 (-ui / 2, center_bottom - center_height * amp), 

245 ] 

246 mask.add_region(bottom_vertices, "violation", "bottom") 

247 

248 return mask 

249 

250 

251def mask_test( 

252 trace: WaveformTrace, 

253 mask: Mask, 

254 *, 

255 x_data: NDArray[np.floating[Any]] | None = None, 

256 normalize: bool = True, 

257 sample_rate: float | None = None, 

258) -> MaskTestResult: 

259 """Test waveform against a mask. 

260 

261 Checks if any samples of the waveform violate the mask regions. 

262 

263 Args: 

264 trace: Input waveform trace. 

265 mask: Mask to test against. 

266 x_data: X coordinates for each sample (if different from time). 

267 normalize: Normalize Y data to [-1, 1] range. 

268 sample_rate: Sample rate override. 

269 

270 Returns: 

271 MaskTestResult with pass/fail status and violation details. 

272 

273 Example: 

274 >>> result = mask_test(eye_trace, mask) 

275 >>> print(f"Violations: {result.num_violations}") 

276 """ 

277 # Get Y data 

278 y_data = trace.data.astype(np.float64) 

279 

280 # Get or create X data 

281 if x_data is None: 

282 x_data = np.arange(len(y_data), dtype=np.float64) 

283 

284 # Normalize if requested 

285 if normalize: 

286 y_min, y_max = np.min(y_data), np.max(y_data) 

287 if y_max - y_min > 0: 

288 y_data = 2 * (y_data - y_min) / (y_max - y_min) - 1 

289 

290 # Test each point against mask regions 

291 violations: list[tuple[float, float]] = [] 

292 violations_by_region: dict[str, int] = {} 

293 

294 for region in mask.regions: 

295 region_name = region.name or "unnamed" 

296 violations_by_region[region_name] = 0 

297 

298 if region.region_type == "violation": 

299 # Check if points are inside violation region 

300 for i, (x, y) in enumerate(zip(x_data, y_data, strict=False)): # noqa: B007 

301 if region.contains_point(float(x), float(y)): 

302 violations.append((float(x), float(y))) 

303 violations_by_region[region_name] += 1 

304 

305 elif region.region_type == "boundary": 305 ↛ 294line 305 didn't jump to line 294 because the condition on line 305 was always true

306 # Check if points are outside boundary region 

307 for i, (x, y) in enumerate(zip(x_data, y_data, strict=False)): # noqa: B007 

308 if not region.contains_point(float(x), float(y)): 308 ↛ 309line 308 didn't jump to line 309 because the condition on line 308 was never true

309 violations.append((float(x), float(y))) 

310 violations_by_region[region_name] += 1 

311 

312 # Remove duplicates 

313 unique_violations = list(set(violations)) 

314 num_violations = len(unique_violations) 

315 violation_rate = num_violations / len(y_data) if len(y_data) > 0 else 0.0 

316 

317 # Estimate margin (simplified - distance to nearest mask edge) 

318 margin = None 

319 if num_violations == 0 and mask.regions: 

320 # Find minimum distance to any violation region 

321 min_dist = float("inf") 

322 for region in mask.regions: 

323 if region.region_type == "violation": 

324 for x, y in zip(x_data, y_data, strict=False): 

325 for i in range(len(region.vertices)): 

326 x1, y1 = region.vertices[i] 

327 x2, y2 = region.vertices[(i + 1) % len(region.vertices)] 

328 # Distance to line segment 

329 dist = _point_to_segment_distance(x, y, x1, y1, x2, y2) 

330 min_dist = min(min_dist, dist) 

331 margin = min_dist if min_dist != float("inf") else None 

332 

333 return MaskTestResult( 

334 passed=num_violations == 0, 

335 num_violations=num_violations, 

336 violation_rate=violation_rate, 

337 violation_points=unique_violations, 

338 violations_by_region=violations_by_region, 

339 margin=margin, 

340 ) 

341 

342 

343def _point_to_segment_distance( 

344 px: float, py: float, x1: float, y1: float, x2: float, y2: float 

345) -> float: 

346 """Calculate distance from point to line segment.""" 

347 dx = x2 - x1 

348 dy = y2 - y1 

349 length_sq = dx * dx + dy * dy 

350 

351 if length_sq == 0: 351 ↛ 353line 351 didn't jump to line 353 because the condition on line 351 was never true

352 # Segment is a point 

353 return np.sqrt((px - x1) ** 2 + (py - y1) ** 2) # type: ignore[no-any-return] 

354 

355 # Project point onto line 

356 t = max(0, min(1, ((px - x1) * dx + (py - y1) * dy) / length_sq)) 

357 proj_x = x1 + t * dx 

358 proj_y = y1 + t * dy 

359 

360 return float(np.sqrt((px - proj_x) ** 2 + (py - proj_y) ** 2)) 

361 

362 

363def eye_diagram_mask_test( 

364 eye_data: NDArray[np.floating[Any]], 

365 *, 

366 eye_width: float = 0.5, 

367 eye_height: float = 0.4, 

368 unit_interval: float = 1.0, 

369) -> MaskTestResult: 

370 """Specialized eye diagram mask test. 

371 

372 Tests 2D eye diagram data against a standard eye mask. 

373 

374 Args: 

375 eye_data: 2D array of shape (num_traces, samples_per_ui). 

376 eye_width: Eye opening width (fraction of UI). 

377 eye_height: Eye opening height (fraction of amplitude). 

378 unit_interval: Duration of unit interval in samples. 

379 

380 Returns: 

381 MaskTestResult for the eye diagram. 

382 

383 Raises: 

384 AnalysisError: If eye data is not a 2D array. 

385 """ 

386 if eye_data.ndim != 2: 

387 raise AnalysisError("Eye data must be 2D array (num_traces x samples_per_ui)") 

388 

389 num_traces, samples_per_ui = eye_data.shape 

390 

391 # Create mask 

392 mask = eye_mask( 

393 eye_width=eye_width, 

394 eye_height=eye_height, 

395 unit_interval=unit_interval, 

396 amplitude=1.0, 

397 ) 

398 

399 # Normalize data 

400 flat_data = eye_data.flatten() 

401 y_min, y_max = np.min(flat_data), np.max(flat_data) 

402 normalized = 2 * (eye_data - y_min) / (y_max - y_min) - 1 if y_max - y_min > 0 else eye_data 

403 

404 # Create X coordinates (relative to UI center) 

405 x_coords = np.linspace(-0.5, 0.5, samples_per_ui) * unit_interval 

406 

407 # Test all traces 

408 violations: list[tuple[float, float]] = [] 

409 violations_by_region: dict[str, int] = {r.name or "unnamed": 0 for r in mask.regions} 

410 

411 for trace_idx in range(num_traces): 

412 for sample_idx in range(samples_per_ui): 

413 x = float(x_coords[sample_idx]) 

414 y = float(normalized[trace_idx, sample_idx]) 

415 

416 for region in mask.regions: 

417 if region.region_type == "violation": 417 ↛ 416line 417 didn't jump to line 416 because the condition on line 417 was always true

418 if region.contains_point(x, y): 418 ↛ 419line 418 didn't jump to line 419 because the condition on line 418 was never true

419 violations.append((x, y)) 

420 region_name = region.name or "unnamed" 

421 violations_by_region[region_name] += 1 

422 

423 unique_violations = list(set(violations)) 

424 num_violations = len(unique_violations) 

425 total_points = num_traces * samples_per_ui 

426 

427 return MaskTestResult( 

428 passed=num_violations == 0, 

429 num_violations=num_violations, 

430 violation_rate=num_violations / total_points if total_points > 0 else 0, 

431 violation_points=unique_violations, 

432 violations_by_region=violations_by_region, 

433 margin=None, 

434 )