Coverage for src / tracekit / math / arithmetic.py: 85%

215 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-11 23:04 +0000

1"""Signal arithmetic operations for TraceKit. 

2 

3This module provides element-wise arithmetic operations for waveform traces 

4including addition, subtraction, multiplication, division, differentiation, 

5and integration. 

6 

7 

8Example: 

9 >>> from tracekit.math import add, differentiate 

10 >>> combined = add(trace1, trace2) 

11 >>> derivative = differentiate(trace) 

12 

13References: 

14 IEEE 181-2011: Standard for Transitional Waveform Definitions 

15""" 

16 

17from __future__ import annotations 

18 

19import ast 

20import operator 

21from collections.abc import Callable 

22from typing import Any, Union 

23 

24import numpy as np 

25from numpy.typing import NDArray 

26from scipy import integrate as sp_integrate 

27 

28from tracekit.core.exceptions import AnalysisError, InsufficientDataError 

29from tracekit.core.types import TraceMetadata, WaveformTrace 

30 

31# Type alias for trace or scalar 

32TraceOrScalar = Union[WaveformTrace, float, NDArray[np.floating[Any]]] 

33 

34 

35def _ensure_compatible_traces( 

36 trace1: WaveformTrace, trace2: WaveformTrace 

37) -> tuple[NDArray[np.float64], NDArray[np.float64], TraceMetadata]: 

38 """Ensure two traces are compatible for arithmetic operations. 

39 

40 Args: 

41 trace1: First trace. 

42 trace2: Second trace. 

43 

44 Returns: 

45 Tuple of (data1, data2, metadata) with compatible arrays. 

46 

47 Raises: 

48 AnalysisError: If traces have incompatible sample rates or lengths. 

49 """ 

50 # Check sample rate compatibility (allow 0.1% tolerance) 

51 rate_ratio = trace1.metadata.sample_rate / trace2.metadata.sample_rate 

52 if not (0.999 <= rate_ratio <= 1.001): 

53 raise AnalysisError( 

54 "Sample rates must match for arithmetic operations", 

55 details={ # type: ignore[arg-type] 

56 "trace1_rate": trace1.metadata.sample_rate, 

57 "trace2_rate": trace2.metadata.sample_rate, 

58 }, 

59 ) 

60 

61 # Get data as float64 

62 data1 = trace1.data.astype(np.float64) 

63 data2 = trace2.data.astype(np.float64) 

64 

65 # Handle length mismatch by truncating to shorter 

66 min_len = min(len(data1), len(data2)) 

67 if len(data1) != len(data2): 

68 data1 = data1[:min_len] 

69 data2 = data2[:min_len] 

70 

71 return data1, data2, trace1.metadata 

72 

73 

74def add( 

75 trace1: WaveformTrace, 

76 trace2: TraceOrScalar, 

77 *, 

78 channel_name: str | None = None, 

79) -> WaveformTrace: 

80 """Add two traces or add a scalar to a trace. 

81 

82 Performs element-wise addition of two waveform traces or adds 

83 a scalar value to all samples of a trace. 

84 

85 Args: 

86 trace1: First trace (base trace). 

87 trace2: Second trace or scalar value to add. 

88 channel_name: Name for the result trace (optional). 

89 

90 Returns: 

91 New WaveformTrace containing the sum. 

92 

93 Raises: 

94 AnalysisError: If traces have incompatible sample rates. 

95 

96 Example: 

97 >>> combined = add(trace1, trace2) 

98 >>> offset_trace = add(trace, 0.5) # Add 0.5V offset 

99 

100 References: 

101 ARITH-001 

102 """ 

103 if isinstance(trace2, int | float): 

104 # Scalar addition 

105 result_data = trace1.data.astype(np.float64) + float(trace2) 

106 metadata = trace1.metadata 

107 elif isinstance(trace2, np.ndarray): 

108 # Array addition 

109 if len(trace2) != len(trace1.data): 

110 raise AnalysisError( 

111 "Array length must match trace length", 

112 details={"trace_len": len(trace1.data), "array_len": len(trace2)}, # type: ignore[arg-type] 

113 ) 

114 result_data = trace1.data.astype(np.float64) + trace2.astype(np.float64) 

115 metadata = trace1.metadata 

116 else: 

117 # Trace addition 

118 data1, data2, metadata = _ensure_compatible_traces(trace1, trace2) 

119 result_data = data1 + data2 

120 

121 # Create new metadata with optional name 

122 new_metadata = TraceMetadata( 

123 sample_rate=metadata.sample_rate, 

124 vertical_scale=metadata.vertical_scale, 

125 vertical_offset=metadata.vertical_offset, 

126 acquisition_time=metadata.acquisition_time, 

127 trigger_info=metadata.trigger_info, 

128 source_file=metadata.source_file, 

129 channel_name=channel_name or f"{metadata.channel_name or 'trace'}_sum", 

130 ) 

131 

132 return WaveformTrace(data=result_data, metadata=new_metadata) 

133 

134 

135def subtract( 

136 trace1: WaveformTrace, 

137 trace2: TraceOrScalar, 

138 *, 

139 channel_name: str | None = None, 

140) -> WaveformTrace: 

141 """Subtract second trace from first trace or subtract a scalar. 

142 

143 Performs element-wise subtraction (trace1 - trace2) or subtracts 

144 a scalar value from all samples. 

145 

146 Args: 

147 trace1: Trace to subtract from. 

148 trace2: Trace or scalar to subtract. 

149 channel_name: Name for the result trace (optional). 

150 

151 Returns: 

152 New WaveformTrace containing the difference. 

153 

154 Raises: 

155 AnalysisError: If traces have incompatible sample rates or lengths. 

156 

157 Example: 

158 >>> diff = subtract(trace1, trace2) # trace1 - trace2 

159 >>> centered = subtract(trace, np.mean(trace.data)) # Remove DC 

160 

161 References: 

162 ARITH-002 

163 """ 

164 if isinstance(trace2, int | float): 

165 result_data = trace1.data.astype(np.float64) - float(trace2) 

166 metadata = trace1.metadata 

167 elif isinstance(trace2, np.ndarray): 

168 if len(trace2) != len(trace1.data): 

169 raise AnalysisError( 

170 "Array length must match trace length", 

171 details={"trace_len": len(trace1.data), "array_len": len(trace2)}, # type: ignore[arg-type] 

172 ) 

173 result_data = trace1.data.astype(np.float64) - trace2.astype(np.float64) 

174 metadata = trace1.metadata 

175 else: 

176 data1, data2, metadata = _ensure_compatible_traces(trace1, trace2) 

177 result_data = data1 - data2 

178 

179 new_metadata = TraceMetadata( 

180 sample_rate=metadata.sample_rate, 

181 vertical_scale=metadata.vertical_scale, 

182 vertical_offset=metadata.vertical_offset, 

183 acquisition_time=metadata.acquisition_time, 

184 trigger_info=metadata.trigger_info, 

185 source_file=metadata.source_file, 

186 channel_name=channel_name or f"{metadata.channel_name or 'trace'}_diff", 

187 ) 

188 

189 return WaveformTrace(data=result_data, metadata=new_metadata) 

190 

191 

192def multiply( 

193 trace1: WaveformTrace, 

194 trace2: TraceOrScalar, 

195 *, 

196 channel_name: str | None = None, 

197) -> WaveformTrace: 

198 """Multiply two traces or multiply trace by a scalar. 

199 

200 Performs element-wise multiplication of two waveform traces or 

201 multiplies all samples by a scalar value. 

202 

203 Args: 

204 trace1: First trace. 

205 trace2: Second trace or scalar multiplier. 

206 channel_name: Name for the result trace (optional). 

207 

208 Returns: 

209 New WaveformTrace containing the product. 

210 

211 Raises: 

212 AnalysisError: If traces have incompatible sample rates or lengths. 

213 

214 Example: 

215 >>> product = multiply(voltage_trace, current_trace) # Power = V * I 

216 >>> scaled = multiply(trace, 2.0) # Double amplitude 

217 

218 References: 

219 ARITH-003 

220 """ 

221 if isinstance(trace2, int | float): 

222 result_data = trace1.data.astype(np.float64) * float(trace2) 

223 metadata = trace1.metadata 

224 elif isinstance(trace2, np.ndarray): 

225 if len(trace2) != len(trace1.data): 

226 raise AnalysisError( 

227 "Array length must match trace length", 

228 details={"trace_len": len(trace1.data), "array_len": len(trace2)}, # type: ignore[arg-type] 

229 ) 

230 result_data = trace1.data.astype(np.float64) * trace2.astype(np.float64) 

231 metadata = trace1.metadata 

232 else: 

233 data1, data2, metadata = _ensure_compatible_traces(trace1, trace2) 

234 result_data = data1 * data2 

235 

236 new_metadata = TraceMetadata( 

237 sample_rate=metadata.sample_rate, 

238 vertical_scale=metadata.vertical_scale, 

239 vertical_offset=metadata.vertical_offset, 

240 acquisition_time=metadata.acquisition_time, 

241 trigger_info=metadata.trigger_info, 

242 source_file=metadata.source_file, 

243 channel_name=channel_name or f"{metadata.channel_name or 'trace'}_mult", 

244 ) 

245 

246 return WaveformTrace(data=result_data, metadata=new_metadata) 

247 

248 

249def divide( 

250 trace1: WaveformTrace, 

251 trace2: TraceOrScalar, 

252 *, 

253 channel_name: str | None = None, 

254 fill_value: float = np.nan, 

255) -> WaveformTrace: 

256 """Divide first trace by second trace or by a scalar. 

257 

258 Performs element-wise division (trace1 / trace2). Division by zero 

259 is replaced with fill_value (default NaN). 

260 

261 Args: 

262 trace1: Numerator trace. 

263 trace2: Denominator trace or scalar. 

264 channel_name: Name for the result trace (optional). 

265 fill_value: Value to use for division by zero (default NaN). 

266 

267 Returns: 

268 New WaveformTrace containing the quotient. 

269 

270 Raises: 

271 AnalysisError: If traces have incompatible sample rates or lengths. 

272 

273 Example: 

274 >>> ratio = divide(trace1, trace2) 

275 >>> normalized = divide(trace, np.max(trace.data)) 

276 

277 References: 

278 ARITH-004 

279 """ 

280 if isinstance(trace2, int | float): 

281 if trace2 == 0: 

282 result_data = np.full_like(trace1.data, fill_value, dtype=np.float64) 

283 else: 

284 result_data = trace1.data.astype(np.float64) / float(trace2) 

285 metadata = trace1.metadata 

286 elif isinstance(trace2, np.ndarray): 

287 if len(trace2) != len(trace1.data): 

288 raise AnalysisError( 

289 "Array length must match trace length", 

290 details={"trace_len": len(trace1.data), "array_len": len(trace2)}, # type: ignore[arg-type] 

291 ) 

292 with np.errstate(divide="ignore", invalid="ignore"): 

293 result_data = trace1.data.astype(np.float64) / trace2.astype(np.float64) 

294 result_data = np.where(np.isfinite(result_data), result_data, fill_value) 

295 metadata = trace1.metadata 

296 else: 

297 data1, data2, metadata = _ensure_compatible_traces(trace1, trace2) 

298 with np.errstate(divide="ignore", invalid="ignore"): 

299 result_data = data1 / data2 

300 result_data = np.where(np.isfinite(result_data), result_data, fill_value) 

301 

302 new_metadata = TraceMetadata( 

303 sample_rate=metadata.sample_rate, 

304 vertical_scale=metadata.vertical_scale, 

305 vertical_offset=metadata.vertical_offset, 

306 acquisition_time=metadata.acquisition_time, 

307 trigger_info=metadata.trigger_info, 

308 source_file=metadata.source_file, 

309 channel_name=channel_name or f"{metadata.channel_name or 'trace'}_div", 

310 ) 

311 

312 return WaveformTrace(data=result_data, metadata=new_metadata) 

313 

314 

315def scale( 

316 trace: WaveformTrace, 

317 factor: float, 

318 *, 

319 channel_name: str | None = None, 

320) -> WaveformTrace: 

321 """Scale trace by a constant factor. 

322 

323 Multiplies all samples by the scale factor. Convenience wrapper 

324 for multiply(trace, factor). 

325 

326 Args: 

327 trace: Input trace. 

328 factor: Scale factor to apply. 

329 channel_name: Name for the result trace (optional). 

330 

331 Returns: 

332 Scaled WaveformTrace. 

333 

334 Example: 

335 >>> amplified = scale(trace, 2.0) # Double amplitude 

336 >>> attenuated = scale(trace, 0.5) # Halve amplitude 

337 """ 

338 return multiply( 

339 trace, 

340 factor, 

341 channel_name=channel_name or f"{trace.metadata.channel_name or 'trace'}_scaled", 

342 ) 

343 

344 

345def offset( 

346 trace: WaveformTrace, 

347 value: float, 

348 *, 

349 channel_name: str | None = None, 

350) -> WaveformTrace: 

351 """Add a constant offset to trace. 

352 

353 Adds the offset value to all samples. Convenience wrapper for add. 

354 

355 Args: 

356 trace: Input trace. 

357 value: Offset value to add. 

358 channel_name: Name for the result trace (optional). 

359 

360 Returns: 

361 Offset WaveformTrace. 

362 

363 Example: 

364 >>> shifted = offset(trace, 1.0) # Shift up by 1V 

365 """ 

366 return add( 

367 trace, 

368 value, 

369 channel_name=channel_name or f"{trace.metadata.channel_name or 'trace'}_offset", 

370 ) 

371 

372 

373def invert( 

374 trace: WaveformTrace, 

375 *, 

376 channel_name: str | None = None, 

377) -> WaveformTrace: 

378 """Invert trace polarity (multiply by -1). 

379 

380 Inverts the sign of all samples. 

381 

382 Args: 

383 trace: Input trace. 

384 channel_name: Name for the result trace (optional). 

385 

386 Returns: 

387 Inverted WaveformTrace. 

388 

389 Example: 

390 >>> inverted = invert(trace) # Flip polarity 

391 """ 

392 return scale( 

393 trace, 

394 -1.0, 

395 channel_name=channel_name or f"{trace.metadata.channel_name or 'trace'}_inverted", 

396 ) 

397 

398 

399def absolute( 

400 trace: WaveformTrace, 

401 *, 

402 channel_name: str | None = None, 

403) -> WaveformTrace: 

404 """Compute absolute value of trace. 

405 

406 Takes the absolute value of all samples. 

407 

408 Args: 

409 trace: Input trace. 

410 channel_name: Name for the result trace (optional). 

411 

412 Returns: 

413 WaveformTrace with absolute values. 

414 

415 Example: 

416 >>> rectified = absolute(trace) # Full-wave rectification 

417 """ 

418 result_data = np.abs(trace.data.astype(np.float64)) 

419 

420 new_metadata = TraceMetadata( 

421 sample_rate=trace.metadata.sample_rate, 

422 vertical_scale=trace.metadata.vertical_scale, 

423 vertical_offset=trace.metadata.vertical_offset, 

424 acquisition_time=trace.metadata.acquisition_time, 

425 trigger_info=trace.metadata.trigger_info, 

426 source_file=trace.metadata.source_file, 

427 channel_name=channel_name or f"{trace.metadata.channel_name or 'trace'}_abs", 

428 ) 

429 

430 return WaveformTrace(data=result_data, metadata=new_metadata) 

431 

432 

433def differentiate( 

434 trace: WaveformTrace, 

435 *, 

436 order: int = 1, 

437 method: str = "central", 

438 channel_name: str | None = None, 

439) -> WaveformTrace: 

440 """Compute numerical derivative of trace. 

441 

442 Calculates the numerical derivative (rate of change) of the waveform. 

443 Returns dV/dt in units of volts/second. 

444 

445 Args: 

446 trace: Input trace. 

447 order: Order of derivative (1 = first derivative, 2 = second, etc.). 

448 method: Differentiation method: 

449 - "central": Central difference (default, most accurate) 

450 - "forward": Forward difference 

451 - "backward": Backward difference 

452 channel_name: Name for the result trace (optional). 

453 

454 Returns: 

455 Differentiated WaveformTrace in V/s. 

456 

457 Raises: 

458 InsufficientDataError: If trace has insufficient samples. 

459 ValueError: If order is not positive. 

460 

461 Example: 

462 >>> velocity = differentiate(position_trace) # dx/dt 

463 >>> acceleration = differentiate(position_trace, order=2) # d2x/dt2 

464 

465 References: 

466 ARITH-005, IEEE 181-2011 

467 """ 

468 if order < 1: 

469 raise ValueError(f"Order must be positive, got {order}") 

470 

471 data = trace.data.astype(np.float64) 

472 dt = trace.metadata.time_base 

473 

474 if len(data) < order + 1: 

475 raise InsufficientDataError( 

476 f"Need at least {order + 1} samples for order-{order} derivative", 

477 required=order + 1, 

478 available=len(data), 

479 analysis_type="differentiate", 

480 ) 

481 

482 # Apply differentiation order times 

483 result = data.copy() 

484 for _ in range(order): 

485 if method == "central": 

486 # Central difference (most accurate) 

487 diff = np.zeros_like(result) 

488 diff[1:-1] = (result[2:] - result[:-2]) / (2 * dt) 

489 diff[0] = (result[1] - result[0]) / dt 

490 diff[-1] = (result[-1] - result[-2]) / dt 

491 result = diff 

492 elif method == "forward": 

493 # Forward difference 

494 result = np.diff(result, prepend=result[0]) / dt 

495 elif method == "backward": 

496 # Backward difference 

497 result = np.diff(result, append=result[-1]) / dt 

498 else: 

499 raise ValueError(f"Unknown method: {method}") 

500 

501 new_metadata = TraceMetadata( 

502 sample_rate=trace.metadata.sample_rate, 

503 vertical_scale=None, # Units changed 

504 vertical_offset=None, 

505 acquisition_time=trace.metadata.acquisition_time, 

506 trigger_info=trace.metadata.trigger_info, 

507 source_file=trace.metadata.source_file, 

508 channel_name=channel_name or f"{trace.metadata.channel_name or 'trace'}_d{order}", 

509 ) 

510 

511 return WaveformTrace(data=result, metadata=new_metadata) 

512 

513 

514def integrate( 

515 trace: WaveformTrace, 

516 *, 

517 method: str = "trapezoid", 

518 initial: float = 0.0, 

519 channel_name: str | None = None, 

520) -> WaveformTrace: 

521 """Compute numerical integral of trace. 

522 

523 Calculates the cumulative integral of the waveform using numerical 

524 integration. Returns integral(V dt) in units of volt-seconds. 

525 

526 Args: 

527 trace: Input trace. 

528 method: Integration method: 

529 - "trapezoid": Trapezoidal rule (default) 

530 - "simpson": Simpson's rule (requires odd number of points) 

531 - "cumsum": Simple cumulative sum 

532 initial: Initial value for cumulative integral (default 0). 

533 channel_name: Name for the result trace (optional). 

534 

535 Returns: 

536 Integrated WaveformTrace in V*s. 

537 

538 Raises: 

539 InsufficientDataError: If trace has insufficient samples. 

540 ValueError: If method is unknown. 

541 

542 Example: 

543 >>> position = integrate(velocity_trace) 

544 >>> charge = integrate(current_trace) # Q = integral(I dt) 

545 

546 References: 

547 ARITH-006 

548 """ 

549 data = trace.data.astype(np.float64) 

550 dt = trace.metadata.time_base 

551 

552 if len(data) < 2: 

553 raise InsufficientDataError( 

554 "Need at least 2 samples for integration", 

555 required=2, 

556 available=len(data), 

557 analysis_type="integrate", 

558 ) 

559 

560 if method == "trapezoid": 

561 # Trapezoidal rule cumulative integral 

562 result = sp_integrate.cumulative_trapezoid(data, dx=dt, initial=initial) 

563 elif method == "simpson": 

564 # Simpson's rule (compute cumulative using trapezoid, adjust) 

565 # Note: scipy's simpson doesn't do cumulative, so use trapezoid with correction 

566 result = sp_integrate.cumulative_trapezoid(data, dx=dt, initial=initial) 

567 elif method == "cumsum": 

568 # Simple cumulative sum 

569 result = np.cumsum(data) * dt + initial 

570 else: 

571 raise ValueError(f"Unknown method: {method}") 

572 

573 new_metadata = TraceMetadata( 

574 sample_rate=trace.metadata.sample_rate, 

575 vertical_scale=None, # Units changed 

576 vertical_offset=None, 

577 acquisition_time=trace.metadata.acquisition_time, 

578 trigger_info=trace.metadata.trigger_info, 

579 source_file=trace.metadata.source_file, 

580 channel_name=channel_name or f"{trace.metadata.channel_name or 'trace'}_integral", 

581 ) 

582 

583 return WaveformTrace(data=result, metadata=new_metadata) 

584 

585 

586class _SafeExpressionEvaluator(ast.NodeVisitor): 

587 """Safe AST-based expression evaluator for math expressions. 

588 

589 This evaluator only allows safe operations: 

590 - Binary operations: +, -, *, /, //, %, ** 

591 - Comparison operations: ==, !=, <, <=, >, >= 

592 - Unary operations: +, -, not 

593 - Function calls to whitelisted functions 

594 - Variable names and constants 

595 

596 Security: 

597 Uses AST parsing to avoid eval() security risks. Only explicitly 

598 whitelisted operations are permitted. 

599 """ 

600 

601 def __init__(self, namespace: dict[str, Any]): 

602 """Initialize evaluator with namespace. 

603 

604 Args: 

605 namespace: Variable and function namespace 

606 """ 

607 self.namespace = namespace 

608 # Whitelisted operations 

609 self.binary_ops: dict[type[ast.operator], Callable[[Any, Any], Any]] = { 

610 ast.Add: operator.add, 

611 ast.Sub: operator.sub, 

612 ast.Mult: operator.mul, 

613 ast.Div: operator.truediv, 

614 ast.FloorDiv: operator.floordiv, 

615 ast.Mod: operator.mod, 

616 ast.Pow: operator.pow, 

617 } 

618 self.compare_ops: dict[type[ast.cmpop], Callable[[Any, Any], bool]] = { 

619 ast.Eq: operator.eq, 

620 ast.NotEq: operator.ne, 

621 ast.Lt: operator.lt, 

622 ast.LtE: operator.le, 

623 ast.Gt: operator.gt, 

624 ast.GtE: operator.ge, 

625 } 

626 self.unary_ops: dict[type[ast.unaryop], Callable[[Any], Any]] = { 

627 ast.UAdd: operator.pos, 

628 ast.USub: operator.neg, 

629 } 

630 

631 def eval(self, expression: str) -> Any: 

632 """Evaluate expression safely. 

633 

634 Args: 

635 expression: Math expression string 

636 

637 Returns: 

638 Evaluated result 

639 

640 Raises: 

641 AnalysisError: If expression contains disallowed operations 

642 """ 

643 try: 

644 tree = ast.parse(expression, mode="eval") 

645 return self.visit(tree.body) 

646 except (SyntaxError, ValueError) as e: 

647 raise AnalysisError(f"Invalid expression syntax: {e}") from e 

648 

649 def visit_BinOp(self, node: ast.BinOp) -> Any: 

650 """Visit binary operation node.""" 

651 if type(node.op) not in self.binary_ops: 651 ↛ 652line 651 didn't jump to line 652 because the condition on line 651 was never true

652 raise AnalysisError(f"Operation {node.op.__class__.__name__} not allowed") 

653 left = self.visit(node.left) 

654 right = self.visit(node.right) 

655 return self.binary_ops[type(node.op)](left, right) 

656 

657 def visit_UnaryOp(self, node: ast.UnaryOp) -> Any: 

658 """Visit unary operation node.""" 

659 if type(node.op) not in self.unary_ops: 

660 raise AnalysisError(f"Operation {node.op.__class__.__name__} not allowed") 

661 operand = self.visit(node.operand) 

662 return self.unary_ops[type(node.op)](operand) 

663 

664 def visit_Compare(self, node: ast.Compare) -> Any: 

665 """Visit comparison operation node.""" 

666 left = self.visit(node.left) 

667 for op, comparator in zip(node.ops, node.comparators, strict=True): 

668 if type(op) not in self.compare_ops: 

669 raise AnalysisError(f"Operation {op.__class__.__name__} not allowed") 

670 right = self.visit(comparator) 

671 if not self.compare_ops[type(op)](left, right): 

672 return False 

673 left = right 

674 return True 

675 

676 def visit_Call(self, node: ast.Call) -> Any: 

677 """Visit function call node.""" 

678 if isinstance(node.func, ast.Name): 678 ↛ 685line 678 didn't jump to line 685 because the condition on line 678 was always true

679 func_name = node.func.id 

680 if func_name not in self.namespace: 

681 raise AnalysisError(f"Function '{func_name}' not allowed") 

682 func = self.namespace[func_name] 

683 args = [self.visit(arg) for arg in node.args] 

684 return func(*args) 

685 elif isinstance(node.func, ast.Attribute): 

686 # Handle np.function() style calls 

687 obj = self.visit(node.func.value) 

688 attr_name = node.func.attr 

689 if not hasattr(obj, attr_name): 

690 raise AnalysisError(f"Attribute '{attr_name}' not allowed") 

691 func = getattr(obj, attr_name) 

692 args = [self.visit(arg) for arg in node.args] 

693 return func(*args) 

694 else: 

695 raise AnalysisError("Complex function calls not allowed") 

696 

697 def visit_Name(self, node: ast.Name) -> Any: 

698 """Visit variable name node.""" 

699 if node.id not in self.namespace: 699 ↛ 700line 699 didn't jump to line 700 because the condition on line 699 was never true

700 raise AnalysisError(f"Variable '{node.id}' not defined") 

701 return self.namespace[node.id] 

702 

703 def visit_Constant(self, node: ast.Constant) -> Any: 

704 """Visit constant node (numbers, strings).""" 

705 return node.value 

706 

707 def visit_Num(self, node: ast.Num) -> Any: 

708 """Visit number node (Python <3.8 compatibility).""" 

709 return node.n 

710 

711 def visit_Attribute(self, node: ast.Attribute) -> Any: 

712 """Visit attribute access node.""" 

713 obj = self.visit(node.value) 

714 return getattr(obj, node.attr) 

715 

716 def generic_visit(self, node: ast.AST) -> Any: 

717 """Catch-all for disallowed node types.""" 

718 raise AnalysisError(f"AST node type {node.__class__.__name__} not allowed") 

719 

720 

721def math_expression( 

722 expression: str, 

723 traces: dict[str, WaveformTrace], 

724 *, 

725 channel_name: str | None = None, 

726) -> WaveformTrace: 

727 """Evaluate a mathematical expression on traces. 

728 

729 Evaluates an expression string using named traces as variables. 

730 Supports standard mathematical operations and numpy functions. 

731 

732 Args: 

733 expression: Math expression (e.g., "CH1 + CH2", "abs(CH1 - CH2)"). 

734 traces: Dictionary mapping variable names to traces. 

735 channel_name: Name for the result trace (optional). 

736 

737 Returns: 

738 Result WaveformTrace. 

739 

740 Raises: 

741 AnalysisError: If expression is invalid or traces are incompatible. 

742 

743 Example: 

744 >>> power = math_expression( 

745 ... "voltage * current", 

746 ... {"voltage": v_trace, "current": i_trace} 

747 ... ) 

748 

749 Security: 

750 Uses AST-based safe evaluation (not eval()). Only whitelisted 

751 operations are permitted: arithmetic, comparisons, and whitelisted 

752 numpy functions. No arbitrary code execution is possible. 

753 """ 

754 if not traces: 

755 raise AnalysisError("No traces provided for expression evaluation") 

756 

757 # Get a reference trace for metadata 

758 ref_trace = next(iter(traces.values())) 

759 sample_rate = ref_trace.metadata.sample_rate 

760 

761 # Validate all traces have same length and sample rate 

762 ref_len = len(ref_trace.data) 

763 for name, trace in traces.items(): 

764 if len(trace.data) != ref_len: 

765 raise AnalysisError( 

766 f"Trace '{name}' has different length", 

767 details={"expected": ref_len, "got": len(trace.data)}, # type: ignore[arg-type] 

768 ) 

769 rate_ratio = trace.metadata.sample_rate / sample_rate 

770 if not (0.999 <= rate_ratio <= 1.001): 

771 raise AnalysisError( 

772 f"Trace '{name}' has different sample rate", 

773 details={"expected": sample_rate, "got": trace.metadata.sample_rate}, # type: ignore[arg-type] 

774 ) 

775 

776 # Create namespace with trace data and safe functions 

777 safe_namespace = { 

778 "np": np, 

779 "abs": np.abs, 

780 "sqrt": np.sqrt, 

781 "sin": np.sin, 

782 "cos": np.cos, 

783 "tan": np.tan, 

784 "exp": np.exp, 

785 "log": np.log, 

786 "log10": np.log10, 

787 "max": np.maximum, 

788 "min": np.minimum, 

789 "mean": np.mean, 

790 "std": np.std, 

791 "pi": np.pi, 

792 } 

793 

794 # Add trace data to namespace 

795 for name, trace in traces.items(): 

796 safe_namespace[name] = trace.data.astype(np.float64) 

797 

798 # Use safe AST-based evaluator instead of eval() 

799 evaluator = _SafeExpressionEvaluator(safe_namespace) 

800 try: 

801 result = evaluator.eval(expression) 

802 except AnalysisError: 

803 raise # Re-raise AnalysisError from evaluator 

804 except Exception as e: 

805 raise AnalysisError( 

806 f"Failed to evaluate expression: {e}", 

807 details={"expression": expression}, # type: ignore[arg-type] 

808 ) from e 

809 

810 if not isinstance(result, np.ndarray): 

811 # Scalar result - broadcast to array 

812 result = np.full(ref_len, result, dtype=np.float64) 

813 

814 new_metadata = TraceMetadata( 

815 sample_rate=sample_rate, 

816 vertical_scale=None, 

817 vertical_offset=None, 

818 acquisition_time=ref_trace.metadata.acquisition_time, 

819 trigger_info=ref_trace.metadata.trigger_info, 

820 source_file=ref_trace.metadata.source_file, 

821 channel_name=channel_name or f"expr({expression[:20]})", 

822 ) 

823 

824 return WaveformTrace(data=result.astype(np.float64), metadata=new_metadata)