# src/Synaptipy/core/processing_pipeline.py
# -*- coding: utf-8 -*-
"""
Signal Processing Pipeline.
Formalizes the order of operations for signal processing (e.g., Baseline -> Filter).
Ensures that both visualization and analysis use the exact same processing sequence.
"""
import logging
from typing import Any, Dict, List, Optional
import numpy as np
from Synaptipy.core import signal_processor
log = logging.getLogger(__name__)
[docs]
class SignalProcessingPipeline:
"""
Manages an ordered list of signal processing steps.
"""
def __init__(self):
self._steps: List[Dict[str, Any]] = []
[docs]
def add_step(self, step_config: Dict[str, Any], index: Optional[int] = None):
"""
Add a processing step to the pipeline.
Args:
step_config: Dictionary defining the step (e.g., {'type': 'baseline', 'method': 'mean'})
index: Optional index to insert at. If None, appends to end.
"""
if index is not None:
self._steps.insert(index, step_config)
else:
self._steps.append(step_config)
log.debug(f"Added pipeline step: {step_config}")
[docs]
def remove_step_by_type(self, step_type: str):
"""Remove all steps of a specific type (e.g. 'baseline')."""
original_count = len(self._steps)
self._steps = [s for s in self._steps if s.get("type") != step_type]
if len(self._steps) < original_count:
log.debug(f"Removed steps of type '{step_type}'")
[docs]
def clear(self):
"""Clear all steps."""
self._steps.clear()
log.debug("Pipeline cleared")
[docs]
def get_steps(self) -> List[Dict[str, Any]]:
"""Return a copy of the current steps."""
return [s.copy() for s in self._steps]
[docs]
def set_steps(self, steps: List[Dict[str, Any]]):
"""Replace all steps."""
self._steps = [s.copy() for s in steps]
log.debug(f"Pipeline steps set to: {self._steps}")
[docs]
def process( # noqa: C901
self, data: np.ndarray, fs: float, time_vector: Optional[np.ndarray] = None
) -> np.ndarray:
"""
Apply all steps in order to the data.
Args:
data: Input signal array
fs: Sampling rate in Hz
time_vector: Optional time vector (required for region-based baseline)
Returns:
Processed data array
"""
if data is None or len(data) == 0:
return data
result = data.copy()
for step in self._steps:
try:
op_type = step.get("type")
if op_type == "baseline":
method = step.get("method", "mode")
start_t = step.get("start_t")
end_t = step.get("end_t")
use_region = (
start_t is not None
and end_t is not None
and time_vector is not None
and method in ("mean", "median", "mode")
)
if method == "mode":
if use_region:
mask = (time_vector >= float(start_t)) & (time_vector <= float(end_t))
if np.any(mask):
decimals = int(step.get("decimals", 1))
rounded = np.round(result[mask], decimals)
if len(rounded):
vals, counts = np.unique(rounded, return_counts=True)
result = result - float(vals[np.argmax(counts)])
else:
decimals = int(step.get("decimals", 1))
result = signal_processor.subtract_baseline_mode(result, decimals=decimals)
else:
decimals = int(step.get("decimals", 1))
result = signal_processor.subtract_baseline_mode(result, decimals=decimals)
elif method == "mean":
if use_region:
result = signal_processor.subtract_baseline_region(
result, time_vector, float(start_t), float(end_t)
)
else:
result = signal_processor.subtract_baseline_mean(result)
elif method == "median":
if use_region:
mask = (time_vector >= float(start_t)) & (time_vector <= float(end_t))
if np.any(mask):
result = result - float(np.median(result[mask]))
else:
result = signal_processor.subtract_baseline_median(result)
else:
result = signal_processor.subtract_baseline_median(result)
elif method == "linear":
result = signal_processor.subtract_baseline_linear(result)
elif method == "region":
if time_vector is not None:
st = float(step.get("start_t", 0.0))
et = float(step.get("end_t", 0.0))
result = signal_processor.subtract_baseline_region(result, time_vector, st, et)
else:
log.warning("Region baseline requested but no time vector provided. Skipping.")
elif op_type == "filter":
method = step.get("method")
order = int(step.get("order", 5))
if method == "lowpass":
result = signal_processor.lowpass_filter(result, float(step.get("cutoff")), fs, order=order)
elif method == "highpass":
result = signal_processor.highpass_filter(result, float(step.get("cutoff")), fs, order=order)
elif method == "bandpass":
result = signal_processor.bandpass_filter(
result, float(step.get("low_cut")), float(step.get("high_cut")), fs, order=order
)
elif method == "notch":
result = signal_processor.notch_filter(
result, float(step.get("freq")), float(step.get("q_factor")), fs
)
elif op_type == "artifact":
if time_vector is not None:
onset = float(step.get("onset_time", 0.0))
duration = float(step.get("duration_ms", 0.5))
method = step.get("method", "hold")
result = signal_processor.blank_artifact(result, time_vector, onset, duration, method=method)
else:
log.warning("Artifact blanking requested but no time " "vector provided. Skipping.")
# Check for bad data after each step
if result is not None:
if np.any(np.isnan(result)) or np.any(np.isinf(result)):
log.error(f"Step {op_type}/{step.get('method')} produced invalid data (NaN/Inf)")
except Exception as e:
log.error(f"Error processing step {step}: {e}")
return result
# ---------------------------------------------------------------------------
# Immutable Trace Correction Pipeline
# ---------------------------------------------------------------------------
def _apply_pn_subtraction(result: np.ndarray, pn_traces, pn_scale: float) -> np.ndarray:
"""Step B: P/N leak subtraction helper."""
pn_arr = np.asarray(pn_traces, dtype=float)
if pn_arr.ndim == 1:
pn_arr = pn_arr[np.newaxis, :]
if pn_arr.shape[1] == result.shape[0]:
pn_mean = pn_arr.mean(axis=0) * float(pn_scale)
result = result - pn_mean
log.debug(
"apply_trace_corrections: Step B — P/N leak subtracted (%d sweeps, scale=%.3f).",
pn_arr.shape[0],
pn_scale,
)
else:
log.warning(
"apply_trace_corrections: Step B skipped — pn_traces length %d != data length %d.",
pn_arr.shape[1],
result.shape[0],
)
return result
def _apply_noise_floor_zeroing(result: np.ndarray, time: np.ndarray, pre_event_window_s: tuple) -> np.ndarray:
"""Step C: scalar pre-event noise-floor zeroing helper."""
t0, t1 = float(pre_event_window_s[0]), float(pre_event_window_s[1])
mask = (time >= t0) & (time < t1)
if np.any(mask):
floor_offset = float(np.median(result[mask]))
result = result - floor_offset
log.debug(
"apply_trace_corrections: Step C — noise floor zeroed (%.4f mV, window %.3f-%.3f s).",
floor_offset,
t0,
t1,
)
else:
log.warning(
"apply_trace_corrections: Step C skipped — no samples in window %.3f-%.3f s.",
t0,
t1,
)
return result
[docs]
def apply_trace_corrections(
data: np.ndarray,
time: np.ndarray,
fs: float,
*,
ljp_mv: float = 0.0,
pn_traces: Optional[np.ndarray] = None,
pn_scale: float = 1.0,
pre_event_window_s: Optional[tuple] = None,
artifact_interp_steps: Optional[List[Dict[str, Any]]] = None,
filter_steps: Optional[List[Dict[str, Any]]] = None,
) -> np.ndarray:
"""Apply the immutable five-step trace correction in a guaranteed order.
Regardless of the order the user toggles settings in the GUI, **this
function must be used as the single entry point for all backend
corrections** so that the execution order is always:
Step A - LJP Voltage Offset
``V_true = V_recorded - ljp_mv``
Step B - P/N Leak Subtraction
If *pn_traces* is supplied, compute the per-sample mean across the
sub-threshold repetitions, scale by *pn_scale*, and subtract from the
corrected trace. This removes capacitive transients and steady-state
leak currents without affecting the signal of interest.
Step C - Scalar Noise-Floor Zeroing
Subtract the median of the user-specified pre-event window
``pre_event_window_s = (t_start, t_end)``. Because the LJP and
P/N corrections have already been applied, this median reflects only
the residual noise floor, not a physiological offset.
Step D - Pre-filter Artifact Interpolation
Linearly interpolate across each stimulus artifact defined in
*artifact_interp_steps*. Running this **after** A-C and **before**
filtering prevents Gibbs ringing: the DSP filter operates on an
already-flat waveform without sharp transient edges.
Step E - DSP Filtering
Apply any filters listed in *filter_steps* (same dict schema as
``SignalProcessingPipeline``: ``{'type': 'filter', 'method': 'lowpass',
'cutoff': 1000, 'order': 5}``). Running filters **after** A-D
prevents edge artefacts from the transient subtraction from being
smeared across the waveform.
Args:
data: Raw (uncorrected) signal array.
time: Time vector aligned with *data* (seconds).
fs: Sampling rate in Hz.
ljp_mv: Liquid Junction Potential in mV. Step A only
runs when ``ljp_mv != 0.0``.
pn_traces: 2-D array of shape ``(n_sweeps, n_samples)``
containing the sub-threshold P/N sweeps. Step B
is skipped when *pn_traces* is ``None``.
pn_scale: Scalar factor applied to the averaged P/N
template before subtraction (default 1.0).
pre_event_window_s: ``(t_start, t_end)`` tuple in seconds. Step C
is skipped when this is ``None``.
artifact_interp_steps: List of artifact dicts with keys
``onset_time`` (s) and ``duration_ms`` (ms).
Each defines a stimulus artifact to linearly
interpolate. Step D is skipped when ``None``.
filter_steps: List of filter dicts consumed by
``SignalProcessingPipeline.process()``. Step E
is skipped when the list is empty or ``None``.
Returns:
Corrected signal array (always a copy — the input is never mutated).
"""
if data is None or data.size == 0:
return data
result: np.ndarray = data.copy()
# ------------------------------------------------------------------
# Step A — Liquid Junction Potential subtraction
# ------------------------------------------------------------------
if ljp_mv != 0.0:
result = result - float(ljp_mv)
log.debug("apply_trace_corrections: Step A — LJP %.4f mV subtracted.", ljp_mv)
# ------------------------------------------------------------------
# Step B — P/N Leak Subtraction
# ------------------------------------------------------------------
if pn_traces is not None:
result = _apply_pn_subtraction(result, pn_traces, pn_scale)
# ------------------------------------------------------------------
# Step C — Pre-event Scalar Zeroing (median of pre-event window)
# ------------------------------------------------------------------
if pre_event_window_s is not None and time is not None and time.size == result.size:
result = _apply_noise_floor_zeroing(result, time, pre_event_window_s)
# ------------------------------------------------------------------
# Step D — Pre-filter Artifact Interpolation
# Linear interpolation across stimulus artifacts must occur AFTER
# baseline zeroing (Step C) and BEFORE filtering (Step E). This
# ordering prevents Gibbs ringing: the filter operates on an already
# flat waveform without the sharp transient edges of the artifact.
# ------------------------------------------------------------------
if artifact_interp_steps:
for art in artifact_interp_steps:
onset = float(art.get("onset_time", 0.0))
duration_ms = float(art.get("duration_ms", 0.5))
from Synaptipy.core import signal_processor as _sp
result = _sp.blank_artifact(result, time, onset, duration_ms, method="linear")
log.debug(
"apply_trace_corrections: Step D — %d artifact interpolation step(s) applied.",
len(artifact_interp_steps),
)
# ------------------------------------------------------------------
# Step E — Signal Filtering
# ------------------------------------------------------------------
if filter_steps:
pipeline = SignalProcessingPipeline()
pipeline.set_steps(filter_steps)
result = pipeline.process(result, fs, time_vector=time)
log.debug("apply_trace_corrections: Step E — %d filter step(s) applied.", len(filter_steps))
return result