Coverage for src / tracekit / math / arithmetic.py: 85%
215 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-11 23:04 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-11 23:04 +0000
1"""Signal arithmetic operations for TraceKit.
3This module provides element-wise arithmetic operations for waveform traces
4including addition, subtraction, multiplication, division, differentiation,
5and integration.
8Example:
9 >>> from tracekit.math import add, differentiate
10 >>> combined = add(trace1, trace2)
11 >>> derivative = differentiate(trace)
13References:
14 IEEE 181-2011: Standard for Transitional Waveform Definitions
15"""
17from __future__ import annotations
19import ast
20import operator
21from collections.abc import Callable
22from typing import Any, Union
24import numpy as np
25from numpy.typing import NDArray
26from scipy import integrate as sp_integrate
28from tracekit.core.exceptions import AnalysisError, InsufficientDataError
29from tracekit.core.types import TraceMetadata, WaveformTrace
31# Type alias for trace or scalar
32TraceOrScalar = Union[WaveformTrace, float, NDArray[np.floating[Any]]]
35def _ensure_compatible_traces(
36 trace1: WaveformTrace, trace2: WaveformTrace
37) -> tuple[NDArray[np.float64], NDArray[np.float64], TraceMetadata]:
38 """Ensure two traces are compatible for arithmetic operations.
40 Args:
41 trace1: First trace.
42 trace2: Second trace.
44 Returns:
45 Tuple of (data1, data2, metadata) with compatible arrays.
47 Raises:
48 AnalysisError: If traces have incompatible sample rates or lengths.
49 """
50 # Check sample rate compatibility (allow 0.1% tolerance)
51 rate_ratio = trace1.metadata.sample_rate / trace2.metadata.sample_rate
52 if not (0.999 <= rate_ratio <= 1.001):
53 raise AnalysisError(
54 "Sample rates must match for arithmetic operations",
55 details={ # type: ignore[arg-type]
56 "trace1_rate": trace1.metadata.sample_rate,
57 "trace2_rate": trace2.metadata.sample_rate,
58 },
59 )
61 # Get data as float64
62 data1 = trace1.data.astype(np.float64)
63 data2 = trace2.data.astype(np.float64)
65 # Handle length mismatch by truncating to shorter
66 min_len = min(len(data1), len(data2))
67 if len(data1) != len(data2):
68 data1 = data1[:min_len]
69 data2 = data2[:min_len]
71 return data1, data2, trace1.metadata
74def add(
75 trace1: WaveformTrace,
76 trace2: TraceOrScalar,
77 *,
78 channel_name: str | None = None,
79) -> WaveformTrace:
80 """Add two traces or add a scalar to a trace.
82 Performs element-wise addition of two waveform traces or adds
83 a scalar value to all samples of a trace.
85 Args:
86 trace1: First trace (base trace).
87 trace2: Second trace or scalar value to add.
88 channel_name: Name for the result trace (optional).
90 Returns:
91 New WaveformTrace containing the sum.
93 Raises:
94 AnalysisError: If traces have incompatible sample rates.
96 Example:
97 >>> combined = add(trace1, trace2)
98 >>> offset_trace = add(trace, 0.5) # Add 0.5V offset
100 References:
101 ARITH-001
102 """
103 if isinstance(trace2, int | float):
104 # Scalar addition
105 result_data = trace1.data.astype(np.float64) + float(trace2)
106 metadata = trace1.metadata
107 elif isinstance(trace2, np.ndarray):
108 # Array addition
109 if len(trace2) != len(trace1.data):
110 raise AnalysisError(
111 "Array length must match trace length",
112 details={"trace_len": len(trace1.data), "array_len": len(trace2)}, # type: ignore[arg-type]
113 )
114 result_data = trace1.data.astype(np.float64) + trace2.astype(np.float64)
115 metadata = trace1.metadata
116 else:
117 # Trace addition
118 data1, data2, metadata = _ensure_compatible_traces(trace1, trace2)
119 result_data = data1 + data2
121 # Create new metadata with optional name
122 new_metadata = TraceMetadata(
123 sample_rate=metadata.sample_rate,
124 vertical_scale=metadata.vertical_scale,
125 vertical_offset=metadata.vertical_offset,
126 acquisition_time=metadata.acquisition_time,
127 trigger_info=metadata.trigger_info,
128 source_file=metadata.source_file,
129 channel_name=channel_name or f"{metadata.channel_name or 'trace'}_sum",
130 )
132 return WaveformTrace(data=result_data, metadata=new_metadata)
135def subtract(
136 trace1: WaveformTrace,
137 trace2: TraceOrScalar,
138 *,
139 channel_name: str | None = None,
140) -> WaveformTrace:
141 """Subtract second trace from first trace or subtract a scalar.
143 Performs element-wise subtraction (trace1 - trace2) or subtracts
144 a scalar value from all samples.
146 Args:
147 trace1: Trace to subtract from.
148 trace2: Trace or scalar to subtract.
149 channel_name: Name for the result trace (optional).
151 Returns:
152 New WaveformTrace containing the difference.
154 Raises:
155 AnalysisError: If traces have incompatible sample rates or lengths.
157 Example:
158 >>> diff = subtract(trace1, trace2) # trace1 - trace2
159 >>> centered = subtract(trace, np.mean(trace.data)) # Remove DC
161 References:
162 ARITH-002
163 """
164 if isinstance(trace2, int | float):
165 result_data = trace1.data.astype(np.float64) - float(trace2)
166 metadata = trace1.metadata
167 elif isinstance(trace2, np.ndarray):
168 if len(trace2) != len(trace1.data):
169 raise AnalysisError(
170 "Array length must match trace length",
171 details={"trace_len": len(trace1.data), "array_len": len(trace2)}, # type: ignore[arg-type]
172 )
173 result_data = trace1.data.astype(np.float64) - trace2.astype(np.float64)
174 metadata = trace1.metadata
175 else:
176 data1, data2, metadata = _ensure_compatible_traces(trace1, trace2)
177 result_data = data1 - data2
179 new_metadata = TraceMetadata(
180 sample_rate=metadata.sample_rate,
181 vertical_scale=metadata.vertical_scale,
182 vertical_offset=metadata.vertical_offset,
183 acquisition_time=metadata.acquisition_time,
184 trigger_info=metadata.trigger_info,
185 source_file=metadata.source_file,
186 channel_name=channel_name or f"{metadata.channel_name or 'trace'}_diff",
187 )
189 return WaveformTrace(data=result_data, metadata=new_metadata)
192def multiply(
193 trace1: WaveformTrace,
194 trace2: TraceOrScalar,
195 *,
196 channel_name: str | None = None,
197) -> WaveformTrace:
198 """Multiply two traces or multiply trace by a scalar.
200 Performs element-wise multiplication of two waveform traces or
201 multiplies all samples by a scalar value.
203 Args:
204 trace1: First trace.
205 trace2: Second trace or scalar multiplier.
206 channel_name: Name for the result trace (optional).
208 Returns:
209 New WaveformTrace containing the product.
211 Raises:
212 AnalysisError: If traces have incompatible sample rates or lengths.
214 Example:
215 >>> product = multiply(voltage_trace, current_trace) # Power = V * I
216 >>> scaled = multiply(trace, 2.0) # Double amplitude
218 References:
219 ARITH-003
220 """
221 if isinstance(trace2, int | float):
222 result_data = trace1.data.astype(np.float64) * float(trace2)
223 metadata = trace1.metadata
224 elif isinstance(trace2, np.ndarray):
225 if len(trace2) != len(trace1.data):
226 raise AnalysisError(
227 "Array length must match trace length",
228 details={"trace_len": len(trace1.data), "array_len": len(trace2)}, # type: ignore[arg-type]
229 )
230 result_data = trace1.data.astype(np.float64) * trace2.astype(np.float64)
231 metadata = trace1.metadata
232 else:
233 data1, data2, metadata = _ensure_compatible_traces(trace1, trace2)
234 result_data = data1 * data2
236 new_metadata = TraceMetadata(
237 sample_rate=metadata.sample_rate,
238 vertical_scale=metadata.vertical_scale,
239 vertical_offset=metadata.vertical_offset,
240 acquisition_time=metadata.acquisition_time,
241 trigger_info=metadata.trigger_info,
242 source_file=metadata.source_file,
243 channel_name=channel_name or f"{metadata.channel_name or 'trace'}_mult",
244 )
246 return WaveformTrace(data=result_data, metadata=new_metadata)
249def divide(
250 trace1: WaveformTrace,
251 trace2: TraceOrScalar,
252 *,
253 channel_name: str | None = None,
254 fill_value: float = np.nan,
255) -> WaveformTrace:
256 """Divide first trace by second trace or by a scalar.
258 Performs element-wise division (trace1 / trace2). Division by zero
259 is replaced with fill_value (default NaN).
261 Args:
262 trace1: Numerator trace.
263 trace2: Denominator trace or scalar.
264 channel_name: Name for the result trace (optional).
265 fill_value: Value to use for division by zero (default NaN).
267 Returns:
268 New WaveformTrace containing the quotient.
270 Raises:
271 AnalysisError: If traces have incompatible sample rates or lengths.
273 Example:
274 >>> ratio = divide(trace1, trace2)
275 >>> normalized = divide(trace, np.max(trace.data))
277 References:
278 ARITH-004
279 """
280 if isinstance(trace2, int | float):
281 if trace2 == 0:
282 result_data = np.full_like(trace1.data, fill_value, dtype=np.float64)
283 else:
284 result_data = trace1.data.astype(np.float64) / float(trace2)
285 metadata = trace1.metadata
286 elif isinstance(trace2, np.ndarray):
287 if len(trace2) != len(trace1.data):
288 raise AnalysisError(
289 "Array length must match trace length",
290 details={"trace_len": len(trace1.data), "array_len": len(trace2)}, # type: ignore[arg-type]
291 )
292 with np.errstate(divide="ignore", invalid="ignore"):
293 result_data = trace1.data.astype(np.float64) / trace2.astype(np.float64)
294 result_data = np.where(np.isfinite(result_data), result_data, fill_value)
295 metadata = trace1.metadata
296 else:
297 data1, data2, metadata = _ensure_compatible_traces(trace1, trace2)
298 with np.errstate(divide="ignore", invalid="ignore"):
299 result_data = data1 / data2
300 result_data = np.where(np.isfinite(result_data), result_data, fill_value)
302 new_metadata = TraceMetadata(
303 sample_rate=metadata.sample_rate,
304 vertical_scale=metadata.vertical_scale,
305 vertical_offset=metadata.vertical_offset,
306 acquisition_time=metadata.acquisition_time,
307 trigger_info=metadata.trigger_info,
308 source_file=metadata.source_file,
309 channel_name=channel_name or f"{metadata.channel_name or 'trace'}_div",
310 )
312 return WaveformTrace(data=result_data, metadata=new_metadata)
315def scale(
316 trace: WaveformTrace,
317 factor: float,
318 *,
319 channel_name: str | None = None,
320) -> WaveformTrace:
321 """Scale trace by a constant factor.
323 Multiplies all samples by the scale factor. Convenience wrapper
324 for multiply(trace, factor).
326 Args:
327 trace: Input trace.
328 factor: Scale factor to apply.
329 channel_name: Name for the result trace (optional).
331 Returns:
332 Scaled WaveformTrace.
334 Example:
335 >>> amplified = scale(trace, 2.0) # Double amplitude
336 >>> attenuated = scale(trace, 0.5) # Halve amplitude
337 """
338 return multiply(
339 trace,
340 factor,
341 channel_name=channel_name or f"{trace.metadata.channel_name or 'trace'}_scaled",
342 )
345def offset(
346 trace: WaveformTrace,
347 value: float,
348 *,
349 channel_name: str | None = None,
350) -> WaveformTrace:
351 """Add a constant offset to trace.
353 Adds the offset value to all samples. Convenience wrapper for add.
355 Args:
356 trace: Input trace.
357 value: Offset value to add.
358 channel_name: Name for the result trace (optional).
360 Returns:
361 Offset WaveformTrace.
363 Example:
364 >>> shifted = offset(trace, 1.0) # Shift up by 1V
365 """
366 return add(
367 trace,
368 value,
369 channel_name=channel_name or f"{trace.metadata.channel_name or 'trace'}_offset",
370 )
373def invert(
374 trace: WaveformTrace,
375 *,
376 channel_name: str | None = None,
377) -> WaveformTrace:
378 """Invert trace polarity (multiply by -1).
380 Inverts the sign of all samples.
382 Args:
383 trace: Input trace.
384 channel_name: Name for the result trace (optional).
386 Returns:
387 Inverted WaveformTrace.
389 Example:
390 >>> inverted = invert(trace) # Flip polarity
391 """
392 return scale(
393 trace,
394 -1.0,
395 channel_name=channel_name or f"{trace.metadata.channel_name or 'trace'}_inverted",
396 )
399def absolute(
400 trace: WaveformTrace,
401 *,
402 channel_name: str | None = None,
403) -> WaveformTrace:
404 """Compute absolute value of trace.
406 Takes the absolute value of all samples.
408 Args:
409 trace: Input trace.
410 channel_name: Name for the result trace (optional).
412 Returns:
413 WaveformTrace with absolute values.
415 Example:
416 >>> rectified = absolute(trace) # Full-wave rectification
417 """
418 result_data = np.abs(trace.data.astype(np.float64))
420 new_metadata = TraceMetadata(
421 sample_rate=trace.metadata.sample_rate,
422 vertical_scale=trace.metadata.vertical_scale,
423 vertical_offset=trace.metadata.vertical_offset,
424 acquisition_time=trace.metadata.acquisition_time,
425 trigger_info=trace.metadata.trigger_info,
426 source_file=trace.metadata.source_file,
427 channel_name=channel_name or f"{trace.metadata.channel_name or 'trace'}_abs",
428 )
430 return WaveformTrace(data=result_data, metadata=new_metadata)
433def differentiate(
434 trace: WaveformTrace,
435 *,
436 order: int = 1,
437 method: str = "central",
438 channel_name: str | None = None,
439) -> WaveformTrace:
440 """Compute numerical derivative of trace.
442 Calculates the numerical derivative (rate of change) of the waveform.
443 Returns dV/dt in units of volts/second.
445 Args:
446 trace: Input trace.
447 order: Order of derivative (1 = first derivative, 2 = second, etc.).
448 method: Differentiation method:
449 - "central": Central difference (default, most accurate)
450 - "forward": Forward difference
451 - "backward": Backward difference
452 channel_name: Name for the result trace (optional).
454 Returns:
455 Differentiated WaveformTrace in V/s.
457 Raises:
458 InsufficientDataError: If trace has insufficient samples.
459 ValueError: If order is not positive.
461 Example:
462 >>> velocity = differentiate(position_trace) # dx/dt
463 >>> acceleration = differentiate(position_trace, order=2) # d2x/dt2
465 References:
466 ARITH-005, IEEE 181-2011
467 """
468 if order < 1:
469 raise ValueError(f"Order must be positive, got {order}")
471 data = trace.data.astype(np.float64)
472 dt = trace.metadata.time_base
474 if len(data) < order + 1:
475 raise InsufficientDataError(
476 f"Need at least {order + 1} samples for order-{order} derivative",
477 required=order + 1,
478 available=len(data),
479 analysis_type="differentiate",
480 )
482 # Apply differentiation order times
483 result = data.copy()
484 for _ in range(order):
485 if method == "central":
486 # Central difference (most accurate)
487 diff = np.zeros_like(result)
488 diff[1:-1] = (result[2:] - result[:-2]) / (2 * dt)
489 diff[0] = (result[1] - result[0]) / dt
490 diff[-1] = (result[-1] - result[-2]) / dt
491 result = diff
492 elif method == "forward":
493 # Forward difference
494 result = np.diff(result, prepend=result[0]) / dt
495 elif method == "backward":
496 # Backward difference
497 result = np.diff(result, append=result[-1]) / dt
498 else:
499 raise ValueError(f"Unknown method: {method}")
501 new_metadata = TraceMetadata(
502 sample_rate=trace.metadata.sample_rate,
503 vertical_scale=None, # Units changed
504 vertical_offset=None,
505 acquisition_time=trace.metadata.acquisition_time,
506 trigger_info=trace.metadata.trigger_info,
507 source_file=trace.metadata.source_file,
508 channel_name=channel_name or f"{trace.metadata.channel_name or 'trace'}_d{order}",
509 )
511 return WaveformTrace(data=result, metadata=new_metadata)
514def integrate(
515 trace: WaveformTrace,
516 *,
517 method: str = "trapezoid",
518 initial: float = 0.0,
519 channel_name: str | None = None,
520) -> WaveformTrace:
521 """Compute numerical integral of trace.
523 Calculates the cumulative integral of the waveform using numerical
524 integration. Returns integral(V dt) in units of volt-seconds.
526 Args:
527 trace: Input trace.
528 method: Integration method:
529 - "trapezoid": Trapezoidal rule (default)
530 - "simpson": Simpson's rule (requires odd number of points)
531 - "cumsum": Simple cumulative sum
532 initial: Initial value for cumulative integral (default 0).
533 channel_name: Name for the result trace (optional).
535 Returns:
536 Integrated WaveformTrace in V*s.
538 Raises:
539 InsufficientDataError: If trace has insufficient samples.
540 ValueError: If method is unknown.
542 Example:
543 >>> position = integrate(velocity_trace)
544 >>> charge = integrate(current_trace) # Q = integral(I dt)
546 References:
547 ARITH-006
548 """
549 data = trace.data.astype(np.float64)
550 dt = trace.metadata.time_base
552 if len(data) < 2:
553 raise InsufficientDataError(
554 "Need at least 2 samples for integration",
555 required=2,
556 available=len(data),
557 analysis_type="integrate",
558 )
560 if method == "trapezoid":
561 # Trapezoidal rule cumulative integral
562 result = sp_integrate.cumulative_trapezoid(data, dx=dt, initial=initial)
563 elif method == "simpson":
564 # Simpson's rule (compute cumulative using trapezoid, adjust)
565 # Note: scipy's simpson doesn't do cumulative, so use trapezoid with correction
566 result = sp_integrate.cumulative_trapezoid(data, dx=dt, initial=initial)
567 elif method == "cumsum":
568 # Simple cumulative sum
569 result = np.cumsum(data) * dt + initial
570 else:
571 raise ValueError(f"Unknown method: {method}")
573 new_metadata = TraceMetadata(
574 sample_rate=trace.metadata.sample_rate,
575 vertical_scale=None, # Units changed
576 vertical_offset=None,
577 acquisition_time=trace.metadata.acquisition_time,
578 trigger_info=trace.metadata.trigger_info,
579 source_file=trace.metadata.source_file,
580 channel_name=channel_name or f"{trace.metadata.channel_name or 'trace'}_integral",
581 )
583 return WaveformTrace(data=result, metadata=new_metadata)
586class _SafeExpressionEvaluator(ast.NodeVisitor):
587 """Safe AST-based expression evaluator for math expressions.
589 This evaluator only allows safe operations:
590 - Binary operations: +, -, *, /, //, %, **
591 - Comparison operations: ==, !=, <, <=, >, >=
592 - Unary operations: +, -, not
593 - Function calls to whitelisted functions
594 - Variable names and constants
596 Security:
597 Uses AST parsing to avoid eval() security risks. Only explicitly
598 whitelisted operations are permitted.
599 """
601 def __init__(self, namespace: dict[str, Any]):
602 """Initialize evaluator with namespace.
604 Args:
605 namespace: Variable and function namespace
606 """
607 self.namespace = namespace
608 # Whitelisted operations
609 self.binary_ops: dict[type[ast.operator], Callable[[Any, Any], Any]] = {
610 ast.Add: operator.add,
611 ast.Sub: operator.sub,
612 ast.Mult: operator.mul,
613 ast.Div: operator.truediv,
614 ast.FloorDiv: operator.floordiv,
615 ast.Mod: operator.mod,
616 ast.Pow: operator.pow,
617 }
618 self.compare_ops: dict[type[ast.cmpop], Callable[[Any, Any], bool]] = {
619 ast.Eq: operator.eq,
620 ast.NotEq: operator.ne,
621 ast.Lt: operator.lt,
622 ast.LtE: operator.le,
623 ast.Gt: operator.gt,
624 ast.GtE: operator.ge,
625 }
626 self.unary_ops: dict[type[ast.unaryop], Callable[[Any], Any]] = {
627 ast.UAdd: operator.pos,
628 ast.USub: operator.neg,
629 }
631 def eval(self, expression: str) -> Any:
632 """Evaluate expression safely.
634 Args:
635 expression: Math expression string
637 Returns:
638 Evaluated result
640 Raises:
641 AnalysisError: If expression contains disallowed operations
642 """
643 try:
644 tree = ast.parse(expression, mode="eval")
645 return self.visit(tree.body)
646 except (SyntaxError, ValueError) as e:
647 raise AnalysisError(f"Invalid expression syntax: {e}") from e
649 def visit_BinOp(self, node: ast.BinOp) -> Any:
650 """Visit binary operation node."""
651 if type(node.op) not in self.binary_ops: 651 ↛ 652line 651 didn't jump to line 652 because the condition on line 651 was never true
652 raise AnalysisError(f"Operation {node.op.__class__.__name__} not allowed")
653 left = self.visit(node.left)
654 right = self.visit(node.right)
655 return self.binary_ops[type(node.op)](left, right)
657 def visit_UnaryOp(self, node: ast.UnaryOp) -> Any:
658 """Visit unary operation node."""
659 if type(node.op) not in self.unary_ops:
660 raise AnalysisError(f"Operation {node.op.__class__.__name__} not allowed")
661 operand = self.visit(node.operand)
662 return self.unary_ops[type(node.op)](operand)
664 def visit_Compare(self, node: ast.Compare) -> Any:
665 """Visit comparison operation node."""
666 left = self.visit(node.left)
667 for op, comparator in zip(node.ops, node.comparators, strict=True):
668 if type(op) not in self.compare_ops:
669 raise AnalysisError(f"Operation {op.__class__.__name__} not allowed")
670 right = self.visit(comparator)
671 if not self.compare_ops[type(op)](left, right):
672 return False
673 left = right
674 return True
676 def visit_Call(self, node: ast.Call) -> Any:
677 """Visit function call node."""
678 if isinstance(node.func, ast.Name): 678 ↛ 685line 678 didn't jump to line 685 because the condition on line 678 was always true
679 func_name = node.func.id
680 if func_name not in self.namespace:
681 raise AnalysisError(f"Function '{func_name}' not allowed")
682 func = self.namespace[func_name]
683 args = [self.visit(arg) for arg in node.args]
684 return func(*args)
685 elif isinstance(node.func, ast.Attribute):
686 # Handle np.function() style calls
687 obj = self.visit(node.func.value)
688 attr_name = node.func.attr
689 if not hasattr(obj, attr_name):
690 raise AnalysisError(f"Attribute '{attr_name}' not allowed")
691 func = getattr(obj, attr_name)
692 args = [self.visit(arg) for arg in node.args]
693 return func(*args)
694 else:
695 raise AnalysisError("Complex function calls not allowed")
697 def visit_Name(self, node: ast.Name) -> Any:
698 """Visit variable name node."""
699 if node.id not in self.namespace: 699 ↛ 700line 699 didn't jump to line 700 because the condition on line 699 was never true
700 raise AnalysisError(f"Variable '{node.id}' not defined")
701 return self.namespace[node.id]
703 def visit_Constant(self, node: ast.Constant) -> Any:
704 """Visit constant node (numbers, strings)."""
705 return node.value
707 def visit_Num(self, node: ast.Num) -> Any:
708 """Visit number node (Python <3.8 compatibility)."""
709 return node.n
711 def visit_Attribute(self, node: ast.Attribute) -> Any:
712 """Visit attribute access node."""
713 obj = self.visit(node.value)
714 return getattr(obj, node.attr)
716 def generic_visit(self, node: ast.AST) -> Any:
717 """Catch-all for disallowed node types."""
718 raise AnalysisError(f"AST node type {node.__class__.__name__} not allowed")
721def math_expression(
722 expression: str,
723 traces: dict[str, WaveformTrace],
724 *,
725 channel_name: str | None = None,
726) -> WaveformTrace:
727 """Evaluate a mathematical expression on traces.
729 Evaluates an expression string using named traces as variables.
730 Supports standard mathematical operations and numpy functions.
732 Args:
733 expression: Math expression (e.g., "CH1 + CH2", "abs(CH1 - CH2)").
734 traces: Dictionary mapping variable names to traces.
735 channel_name: Name for the result trace (optional).
737 Returns:
738 Result WaveformTrace.
740 Raises:
741 AnalysisError: If expression is invalid or traces are incompatible.
743 Example:
744 >>> power = math_expression(
745 ... "voltage * current",
746 ... {"voltage": v_trace, "current": i_trace}
747 ... )
749 Security:
750 Uses AST-based safe evaluation (not eval()). Only whitelisted
751 operations are permitted: arithmetic, comparisons, and whitelisted
752 numpy functions. No arbitrary code execution is possible.
753 """
754 if not traces:
755 raise AnalysisError("No traces provided for expression evaluation")
757 # Get a reference trace for metadata
758 ref_trace = next(iter(traces.values()))
759 sample_rate = ref_trace.metadata.sample_rate
761 # Validate all traces have same length and sample rate
762 ref_len = len(ref_trace.data)
763 for name, trace in traces.items():
764 if len(trace.data) != ref_len:
765 raise AnalysisError(
766 f"Trace '{name}' has different length",
767 details={"expected": ref_len, "got": len(trace.data)}, # type: ignore[arg-type]
768 )
769 rate_ratio = trace.metadata.sample_rate / sample_rate
770 if not (0.999 <= rate_ratio <= 1.001):
771 raise AnalysisError(
772 f"Trace '{name}' has different sample rate",
773 details={"expected": sample_rate, "got": trace.metadata.sample_rate}, # type: ignore[arg-type]
774 )
776 # Create namespace with trace data and safe functions
777 safe_namespace = {
778 "np": np,
779 "abs": np.abs,
780 "sqrt": np.sqrt,
781 "sin": np.sin,
782 "cos": np.cos,
783 "tan": np.tan,
784 "exp": np.exp,
785 "log": np.log,
786 "log10": np.log10,
787 "max": np.maximum,
788 "min": np.minimum,
789 "mean": np.mean,
790 "std": np.std,
791 "pi": np.pi,
792 }
794 # Add trace data to namespace
795 for name, trace in traces.items():
796 safe_namespace[name] = trace.data.astype(np.float64)
798 # Use safe AST-based evaluator instead of eval()
799 evaluator = _SafeExpressionEvaluator(safe_namespace)
800 try:
801 result = evaluator.eval(expression)
802 except AnalysisError:
803 raise # Re-raise AnalysisError from evaluator
804 except Exception as e:
805 raise AnalysisError(
806 f"Failed to evaluate expression: {e}",
807 details={"expression": expression}, # type: ignore[arg-type]
808 ) from e
810 if not isinstance(result, np.ndarray):
811 # Scalar result - broadcast to array
812 result = np.full(ref_len, result, dtype=np.float64)
814 new_metadata = TraceMetadata(
815 sample_rate=sample_rate,
816 vertical_scale=None,
817 vertical_offset=None,
818 acquisition_time=ref_trace.metadata.acquisition_time,
819 trigger_info=ref_trace.metadata.trigger_info,
820 source_file=ref_trace.metadata.source_file,
821 channel_name=channel_name or f"expr({expression[:20]})",
822 )
824 return WaveformTrace(data=result.astype(np.float64), metadata=new_metadata)