Coverage for src / tracekit / extensibility / measurements.py: 31%

79 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-11 23:04 +0000

1"""Custom measurement framework for user-defined measurements. 

2 

3This module implements a framework for defining and registering custom 

4measurements that integrate seamlessly with batch processing and export. 

5""" 

6 

7from __future__ import annotations 

8 

9import inspect 

10from dataclasses import dataclass, field 

11from typing import TYPE_CHECKING, Any 

12 

13from .registry import AlgorithmRegistry 

14 

15if TYPE_CHECKING: 

16 from collections.abc import Callable 

17 

18 from ..core.types import WaveformTrace 

19 

20 

21@dataclass 

22class MeasurementDefinition: 

23 """Definition of a custom measurement with metadata. 

24 

25 Defines a measurement function along with metadata about units, category, 

26 and documentation. Measurements can be registered globally and used in 

27 batch processing. 

28 

29 Attributes: 

30 name: Unique name for the measurement. 

31 func: Callable that computes the measurement. 

32 units: Units of measurement (e.g., 'V', 'Hz', 's', 'ratio'). 

33 category: Measurement category (e.g., 'amplitude', 'timing', 'frequency'). 

34 description: Human-readable description. 

35 tags: Optional tags for categorization and search. 

36 

37 Example: 

38 >>> import tracekit as tk 

39 >>> def calculate_crest_factor(trace, **kwargs): 

40 ... peak = abs(trace.data).max() 

41 ... rms = (trace.data ** 2).mean() ** 0.5 

42 ... return peak / rms 

43 >>> tk.register_measurement( 

44 ... name='crest_factor', 

45 ... func=calculate_crest_factor, 

46 ... units='ratio', 

47 ... category='amplitude' 

48 ... ) 

49 >>> cf = tk.measure(trace, 'crest_factor') 

50 

51 Advanced Example: 

52 >>> # Define measurement with full metadata 

53 >>> slew_rate_defn = tk.MeasurementDefinition( 

54 ... name='max_slew_rate', 

55 ... func=lambda trace: abs(trace.derivative()).max(), 

56 ... units='V/s', 

57 ... category='edge', 

58 ... description='Maximum slew rate in trace', 

59 ... tags=['edge', 'derivative', 'speed'] 

60 ... ) 

61 >>> tk.register_measurement(slew_rate_defn) 

62 

63 References: 

64 API-008: Custom Measurement Framework 

65 API-006: Algorithm Override Hooks 

66 """ 

67 

68 name: str 

69 func: Callable[[WaveformTrace], float] 

70 units: str 

71 category: str 

72 description: str = "" 

73 tags: list[str] = field(default_factory=list) 

74 

75 def __post_init__(self) -> None: 

76 """Validate measurement definition. 

77 

78 Raises: 

79 ValueError: If measurement name is empty. 

80 TypeError: If func is not callable or has invalid signature. 

81 """ 

82 if not self.name: 

83 raise ValueError("Measurement name cannot be empty") 

84 

85 if not callable(self.func): 

86 raise TypeError(f"Measurement func must be callable, got {type(self.func).__name__}") 

87 

88 # Validate function signature 

89 self._validate_signature() 

90 

91 def _validate_signature(self) -> None: 

92 """Validate that function has correct signature. 

93 

94 Measurement functions should accept (trace, **kwargs) -> float. 

95 

96 Raises: 

97 TypeError: If signature is invalid. 

98 """ 

99 sig = inspect.signature(self.func) 

100 params = list(sig.parameters.values()) 

101 

102 # Should have at least one parameter (trace) 

103 if len(params) == 0: 

104 raise TypeError( 

105 f"Measurement function must accept at least one parameter " 

106 f"(trace). Got {self.func.__name__} with no parameters." 

107 ) 

108 

109 # Check if first parameter could accept WaveformTrace 

110 first_param = params[0] 

111 if first_param.kind in ( 

112 inspect.Parameter.VAR_POSITIONAL, 

113 inspect.Parameter.VAR_KEYWORD, 

114 ): 

115 raise TypeError( 

116 f"First parameter must be a regular parameter (trace), got {first_param.kind}" 

117 ) 

118 

119 def __call__(self, trace: WaveformTrace, **kwargs: Any) -> float: 

120 """Call measurement function. 

121 

122 Args: 

123 trace: WaveformTrace to measure. 

124 **kwargs: Additional parameters for measurement. 

125 

126 Returns: 

127 Measured value. 

128 

129 Example: 

130 >>> defn = MeasurementDefinition( 

131 ... name='peak', 

132 ... func=lambda trace: abs(trace.data).max(), 

133 ... units='V', 

134 ... category='amplitude' 

135 ... ) 

136 >>> value = defn(trace) 

137 """ 

138 return self.func(trace, **kwargs) 

139 

140 def __repr__(self) -> str: 

141 """String representation. 

142 

143 Returns: 

144 String representation of the measurement definition. 

145 """ 

146 return ( 

147 f"MeasurementDefinition(name='{self.name}', " 

148 f"units='{self.units}', category='{self.category}')" 

149 ) 

150 

151 

152class MeasurementRegistry: 

153 """Registry for custom measurements. 

154 

155 Manages registration and lookup of custom measurements. Integrates with 

156 the AlgorithmRegistry for storage. 

157 

158 Example: 

159 >>> registry = MeasurementRegistry() 

160 >>> registry.register( 

161 ... name='crest_factor', 

162 ... func=calculate_crest_factor, 

163 ... units='ratio', 

164 ... category='amplitude' 

165 ... ) 

166 >>> measurement = registry.get('crest_factor') 

167 >>> value = measurement(trace) 

168 

169 References: 

170 API-008: Custom Measurement Framework 

171 """ 

172 

173 MEASUREMENT_CATEGORY = "measurement" 

174 

175 def __init__(self) -> None: 

176 """Initialize measurement registry.""" 

177 self._definitions: dict[str, MeasurementDefinition] = {} 

178 self._algorithm_registry = AlgorithmRegistry() 

179 

180 def register( 

181 self, 

182 name: str | None = None, 

183 func: Callable[[WaveformTrace], float] | None = None, 

184 units: str | None = None, 

185 category: str | None = None, 

186 description: str = "", 

187 tags: list[str] | None = None, 

188 definition: MeasurementDefinition | None = None, 

189 ) -> None: 

190 """Register a custom measurement. 

191 

192 Can be called with individual parameters or with a MeasurementDefinition. 

193 

194 Args: 

195 name: Measurement name (required if definition not provided). 

196 func: Measurement function (required if definition not provided). 

197 units: Units of measurement (required if definition not provided). 

198 category: Measurement category (required if definition not provided). 

199 description: Optional description. 

200 tags: Optional tags. 

201 definition: Pre-built MeasurementDefinition (alternative to individual args). 

202 

203 Raises: 

204 ValueError: If required parameters missing or name already exists. 

205 

206 Example: 

207 >>> registry = MeasurementRegistry() 

208 >>> # Register with individual parameters 

209 >>> registry.register( 

210 ... name='peak', 

211 ... func=lambda trace: abs(trace.data).max(), 

212 ... units='V', 

213 ... category='amplitude' 

214 ... ) 

215 >>> # Register with definition 

216 >>> defn = MeasurementDefinition(...) 

217 >>> registry.register(definition=defn) 

218 """ 

219 # Handle definition argument 

220 if definition is not None: 

221 defn = definition 

222 else: 

223 # Validate required parameters 

224 if name is None or func is None or units is None or category is None: 

225 raise ValueError( 

226 "Must provide either 'definition' or all of (name, func, units, category)" 

227 ) 

228 

229 defn = MeasurementDefinition( 

230 name=name, 

231 func=func, 

232 units=units, 

233 category=category, 

234 description=description, 

235 tags=tags or [], 

236 ) 

237 

238 # Check for duplicates 

239 if defn.name in self._definitions: 

240 raise ValueError(f"Measurement '{defn.name}' already registered") 

241 

242 # Register in both registries 

243 self._definitions[defn.name] = defn 

244 self._algorithm_registry.register( 

245 name=defn.name, 

246 func=defn.func, 

247 category=self.MEASUREMENT_CATEGORY, 

248 validate=False, # Already validated by MeasurementDefinition 

249 ) 

250 

251 def get(self, name: str) -> MeasurementDefinition: 

252 """Get measurement definition by name. 

253 

254 Args: 

255 name: Measurement name. 

256 

257 Returns: 

258 MeasurementDefinition for the measurement. 

259 

260 Raises: 

261 KeyError: If measurement not found. 

262 

263 Example: 

264 >>> measurement = registry.get('crest_factor') 

265 >>> value = measurement(trace) 

266 """ 

267 if name not in self._definitions: 

268 available = list(self._definitions.keys()) 

269 raise KeyError(f"Measurement '{name}' not found. Available: {available}") 

270 

271 return self._definitions[name] 

272 

273 def has_measurement(self, name: str) -> bool: 

274 """Check if measurement is registered. 

275 

276 Args: 

277 name: Measurement name. 

278 

279 Returns: 

280 True if measurement is registered. 

281 

282 Example: 

283 >>> if registry.has_measurement('crest_factor'): 

284 ... cf = registry.get('crest_factor')(trace) 

285 """ 

286 return name in self._definitions 

287 

288 def list_measurements( 

289 self, 

290 category: str | None = None, 

291 tags: list[str] | None = None, 

292 ) -> list[str]: 

293 """List registered measurements. 

294 

295 Args: 

296 category: Filter by category (optional). 

297 tags: Filter by tags (optional). 

298 

299 Returns: 

300 List of measurement names. 

301 

302 Example: 

303 >>> # List all measurements 

304 >>> all_measurements = registry.list_measurements() 

305 >>> # List amplitude measurements 

306 >>> amplitude = registry.list_measurements(category='amplitude') 

307 >>> # List measurements with 'edge' tag 

308 >>> edge_measurements = registry.list_measurements(tags=['edge']) 

309 """ 

310 measurements = [] 

311 

312 for name, defn in self._definitions.items(): 

313 # Filter by category 

314 if category is not None and defn.category != category: 

315 continue 

316 

317 # Filter by tags 

318 if tags is not None and not any(tag in defn.tags for tag in tags): 

319 continue 

320 

321 measurements.append(name) 

322 

323 return measurements 

324 

325 def get_metadata(self, name: str) -> dict[str, Any]: 

326 """Get metadata for a measurement. 

327 

328 Args: 

329 name: Measurement name. 

330 

331 Returns: 

332 Dictionary with measurement metadata. 

333 

334 Example: 

335 >>> metadata = registry.get_metadata('crest_factor') 

336 >>> print(f"Units: {metadata['units']}") 

337 >>> print(f"Category: {metadata['category']}") 

338 """ 

339 defn = self.get(name) 

340 return { 

341 "name": defn.name, 

342 "units": defn.units, 

343 "category": defn.category, 

344 "description": defn.description, 

345 "tags": defn.tags, 

346 } 

347 

348 def unregister(self, name: str) -> None: 

349 """Unregister a measurement. 

350 

351 Args: 

352 name: Measurement name. 

353 

354 Example: 

355 >>> registry.unregister('crest_factor') 

356 """ 

357 if name in self._definitions: 

358 del self._definitions[name] 

359 

360 if self._algorithm_registry.has_algorithm(self.MEASUREMENT_CATEGORY, name): 

361 self._algorithm_registry.unregister(self.MEASUREMENT_CATEGORY, name) 

362 

363 

364# Global measurement registry 

365_registry = MeasurementRegistry() 

366 

367 

368def register_measurement( 

369 name: str | None = None, 

370 func: Callable[[WaveformTrace], float] | None = None, 

371 units: str | None = None, 

372 category: str | None = None, 

373 description: str = "", 

374 tags: list[str] | None = None, 

375 definition: MeasurementDefinition | None = None, 

376) -> None: 

377 """Register a custom measurement in the global registry. 

378 

379 Convenience function for registering measurements without accessing 

380 the registry directly. 

381 

382 Args: 

383 name: Measurement name. 

384 func: Measurement function. 

385 units: Units of measurement. 

386 category: Measurement category. 

387 description: Optional description. 

388 tags: Optional tags. 

389 definition: Pre-built MeasurementDefinition. 

390 

391 Example: 

392 >>> import tracekit as tk 

393 >>> def calculate_crest_factor(trace, **kwargs): 

394 ... peak = abs(trace.data).max() 

395 ... rms = (trace.data ** 2).mean() ** 0.5 

396 ... return peak / rms 

397 >>> tk.register_measurement( 

398 ... name='crest_factor', 

399 ... func=calculate_crest_factor, 

400 ... units='ratio', 

401 ... category='amplitude' 

402 ... ) 

403 

404 References: 

405 API-008: Custom Measurement Framework 

406 """ 

407 _registry.register( 

408 name=name, 

409 func=func, 

410 units=units, 

411 category=category, 

412 description=description, 

413 tags=tags, 

414 definition=definition, 

415 ) 

416 

417 

418def measure(trace: WaveformTrace, name: str, **kwargs: Any) -> float: 

419 """Execute a registered measurement. 

420 

421 Args: 

422 trace: WaveformTrace to measure. 

423 name: Measurement name. 

424 **kwargs: Additional parameters for the measurement. 

425 

426 Returns: 

427 Measured value. 

428 

429 Example: 

430 >>> import tracekit as tk 

431 >>> cf = tk.measure(trace, 'crest_factor') 

432 >>> print(f"Crest factor: {cf:.2f}") 

433 

434 References: 

435 API-008: Custom Measurement Framework 

436 """ 

437 defn = _registry.get(name) 

438 return defn(trace, **kwargs) 

439 

440 

441def list_measurements( 

442 category: str | None = None, 

443 tags: list[str] | None = None, 

444) -> list[str]: 

445 """List registered measurements. 

446 

447 Args: 

448 category: Filter by category (optional). 

449 tags: Filter by tags (optional). 

450 

451 Returns: 

452 List of measurement names. 

453 

454 Example: 

455 >>> import tracekit as tk 

456 >>> measurements = tk.list_measurements(category='amplitude') 

457 >>> print(f"Amplitude measurements: {measurements}") 

458 

459 References: 

460 API-008: Custom Measurement Framework 

461 """ 

462 return _registry.list_measurements(category=category, tags=tags) 

463 

464 

465def get_measurement_registry() -> MeasurementRegistry: 

466 """Get the global measurement registry. 

467 

468 Returns: 

469 Global MeasurementRegistry instance. 

470 

471 Example: 

472 >>> registry = tk.get_measurement_registry() 

473 >>> metadata = registry.get_metadata('crest_factor') 

474 """ 

475 return _registry 

476 

477 

478__all__ = [ 

479 "MeasurementDefinition", 

480 "MeasurementRegistry", 

481 "get_measurement_registry", 

482 "list_measurements", 

483 "measure", 

484 "register_measurement", 

485]