Source code for Synaptipy.core.analysis.single_spike

# src/Synaptipy/core/analysis/single_spike.py
# -*- coding: utf-8 -*-
"""
Core Protocol Module 2: Single Spike Analysis.

Consolidates: Spike Detection, AP Characterisation (threshold, amplitude,
half-width, rise/decay times, AHP) and Phase Plane (dV/dt vs V) analysis.

All registry wrapper functions return::

    {
        "module_used": "single_spike",
        "metrics": { ... flat result keys ... }
    }
"""

import logging
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
from scipy.signal import savgol_filter

from Synaptipy.core.analysis.passive_properties import apply_ljp_correction
from Synaptipy.core.analysis.registry import AnalysisRegistry
from Synaptipy.core.results import SpikeTrainResult

log = logging.getLogger(__name__)


# ---------------------------------------------------------------------------
# Spike Detection
# ---------------------------------------------------------------------------


[docs] def detect_spikes_threshold( # noqa: C901 data: np.ndarray, time: np.ndarray, threshold: float, refractory_samples: int, peak_search_window_samples: int = None, parameters: Dict[str, Any] = None, dvdt_threshold: float = 20.0, ) -> SpikeTrainResult: """ Detect action potentials using a two-stage dV/dt-threshold crossing algorithm. Algorithm --------- 1. **First-derivative computation**: :func:`numpy.gradient` is applied to *data* with sample spacing ``dt = time[1] - time[0]`` (s), yielding dV/dt in mV s⁻¹. 2. **dV/dt crossing detection**: candidate spike onsets are identified as upward crossings of ``dvdt_threshold * 1000`` (mV s⁻¹). Each crossing is the sample where dV/dt transitions from strictly below to at-or-above the threshold. 3. **Refractory period enforcement**: candidate crossings separated by fewer than *refractory_samples* are suppressed, retaining only the first crossing in each refractory interval (greedy forward scan). 4. **Peak localisation**: for each accepted onset, the voltage maximum within the next *peak_search_window_samples* is found. The candidate is accepted as a spike only if ``data[peak_idx] >= threshold`` (mV). Parameters ---------- data : np.ndarray 1-D voltage array (mV). time : np.ndarray 1-D time array aligned with *data* (s). threshold : float Minimum voltage a candidate peak must reach to be accepted as a spike (mV). Guards against sub-threshold dV/dt transients. refractory_samples : int Minimum number of samples between successive accepted spike onsets. Convert from time: ``int(refractory_period_s * sampling_rate_hz)``. peak_search_window_samples : int, optional Number of samples to search forward from each onset crossing for the voltage peak. Defaults to *refractory_samples* when ``None``. parameters : dict, optional Arbitrary parameter dict stored verbatim in the returned :class:`~Synaptipy.core.results.SpikeTrainResult` for provenance. dvdt_threshold : float, optional dV/dt threshold for onset detection (V s⁻¹, default 20.0). Converted internally to mV s⁻¹ by multiplication with 1000. Returns ------- SpikeTrainResult Attributes populated on success: * ``value`` (int) – total spike count. * ``spike_times`` (np.ndarray) – peak times (s). * ``spike_indices`` (np.ndarray) – peak sample indices. * ``mean_frequency`` (float) – mean instantaneous firing rate ``(n_spikes - 1) / (t_last - t_first)`` (Hz); 0.0 for ≤ 1 spike. * ``is_valid`` (bool) – ``False`` when input arrays are malformed. """ if not isinstance(data, np.ndarray) or data.ndim != 1 or data.size < 2: return SpikeTrainResult( value=0, unit="spikes", is_valid=False, error_message="Invalid data array", parameters=parameters or {} ) if not isinstance(time, np.ndarray) or time.shape != data.shape: return SpikeTrainResult( value=0, unit="spikes", is_valid=False, error_message="Time and data mismatch", parameters=parameters or {}, ) if not isinstance(threshold, (int, float)): return SpikeTrainResult( value=0, unit="spikes", is_valid=False, error_message="Threshold must be numeric", parameters=parameters or {}, ) if not isinstance(refractory_samples, int) or refractory_samples < 0: return SpikeTrainResult( value=0, unit="spikes", is_valid=False, error_message="Invalid refractory period", parameters=parameters or {}, ) try: dt = time[1] - time[0] if len(time) > 1 else 1.0 dvdt = np.gradient(data, dt) dvdt_thresh_mvs = dvdt_threshold * 1000.0 crossings = np.where((dvdt[:-1] < dvdt_thresh_mvs) & (dvdt[1:] >= dvdt_thresh_mvs))[0] + 1 if crossings.size == 0: return SpikeTrainResult( value=0, unit="spikes", spike_times=np.array([]), spike_indices=np.array([]), parameters=parameters or {}, ) if refractory_samples <= 0: valid_crossing_indices = crossings else: valid_crossings_list = [crossings[0]] last_crossing_idx = crossings[0] for idx in crossings[1:]: if (idx - last_crossing_idx) >= refractory_samples: valid_crossings_list.append(idx) last_crossing_idx = idx valid_crossing_indices = np.array(valid_crossings_list) if valid_crossing_indices.size == 0: return SpikeTrainResult( value=0, unit="spikes", spike_times=np.array([]), spike_indices=np.array([]), parameters=parameters or {}, ) peak_indices_list = [] if peak_search_window_samples is None: peak_search_window_samples = ( refractory_samples if refractory_samples > 0 else int(0.005 / (time[1] - time[0])) ) for crossing_idx in valid_crossing_indices: search_start = crossing_idx search_end = min(crossing_idx + peak_search_window_samples, len(data)) if search_start >= search_end: peak_idx = crossing_idx else: try: relative_peak_idx = np.argmax(data[search_start:search_end]) peak_idx = search_start + relative_peak_idx except ValueError: peak_idx = crossing_idx if data[peak_idx] >= threshold: peak_indices_list.append(peak_idx) peak_indices_arr = np.array(peak_indices_list).astype(int) peak_times_arr = time[peak_indices_arr] mean_freq = 0.0 if len(peak_times_arr) > 1: spike_span = peak_times_arr[-1] - peak_times_arr[0] if spike_span > 0: mean_freq = (len(peak_times_arr) - 1) / spike_span return SpikeTrainResult( value=len(peak_indices_arr), unit="spikes", spike_times=peak_times_arr, spike_indices=peak_indices_arr, mean_frequency=mean_freq, parameters=parameters or {}, ) except (ValueError, TypeError, KeyError, IndexError) as e: log.error(f"Error during spike detection: {e}", exc_info=True) return SpikeTrainResult( value=0, unit="spikes", is_valid=False, error_message=str(e), parameters=parameters or {} )
# --------------------------------------------------------------------------- # AP Feature Extraction # ---------------------------------------------------------------------------
[docs] def calculate_spike_features( # noqa: C901 data: np.ndarray, time: np.ndarray, spike_indices: np.ndarray, dvdt_threshold: float = 20.0, ahp_window_sec: float = 0.05, onset_lookback: float = 0.01, fahp_window_ms: Tuple[float, float] = (1.0, 5.0), mahp_window_ms: Tuple[float, float] = (10.0, 50.0), ) -> List[Dict[str, Any]]: """ Calculate detailed features for each detected spike (vectorised NumPy). Returns list of dicts per spike: ap_threshold, amplitude, half_width, rise_time_10_90, decay_time_90_10, fahp_depth, mahp_depth, ahp_duration_half, adp_amplitude, max_dvdt, min_dvdt. AP threshold is detected via the peak of d2V/dt2 in the pre-spike lookback window (maximum curvature method). Falls back to the first dV/dt crossing above ``dvdt_threshold`` when d2V/dt2 gives a boundary result. Args: data: 1-D voltage array (mV). time: Corresponding time array (s). spike_indices: Array of sample indices for each spike peak. dvdt_threshold: Fallback dV/dt threshold for AP onset (V/s). ahp_window_sec: Duration of AHP/ADP search window (s). onset_lookback: Lookback window before each spike peak (s). fahp_window_ms: (start, end) of fast-AHP window after peak (ms). mahp_window_ms: (start, end) of medium-AHP window after peak (ms). """ if spike_indices is None or spike_indices.size == 0: return [] spike_indices = np.asarray(spike_indices, dtype=int) n_spikes = len(spike_indices) n_data = len(data) if n_data < 2: return [] dt = time[1] - time[0] if dt <= 0: log.warning("Invalid time vector (dt <= 0). Cannot calculate features.") return [] dvdt = np.gradient(data, dt) d2vdt2 = np.gradient(dvdt, dt) lookback_samples = int(onset_lookback / dt) post_peak_samples = int(0.01 / dt) ahp_max_samples = int(ahp_window_sec / dt) # --- AP Threshold (onset) via d2V/dt2 peak (maximum curvature method) --- # The kink in the voltage trace where the AP upstroke begins corresponds to # the peak of the second derivative. Falls back to the first dV/dt crossing # when the d2V/dt2 peak falls at the window boundary (unreliable estimate). lookback_range = np.arange(-lookback_samples, 0) onset_window_indices = spike_indices[:, None] + lookback_range np.clip(onset_window_indices, 0, n_data - 1, out=onset_window_indices) onset_d2vdt2_windows = d2vdt2[onset_window_indices] d2vdt2_peak_rel = np.argmax(onset_d2vdt2_windows, axis=1) thresh_indices_d2 = onset_window_indices[np.arange(n_spikes), d2vdt2_peak_rel] # Fallback: first dV/dt crossing above a dynamic per-spike threshold. # Using 20% of the per-spike peak rising dV/dt avoids a hardcoded 20 V/s # that is invalid for smaller or inactivated spikes in fast trains. onset_dvdt_windows = dvdt[onset_window_indices] onset_max_dvdt = np.max(onset_dvdt_windows, axis=1) # (n_spikes,) in mV/s # Floor at 2 V/s (2000 mV/s) to prevent triggering at rest. dynamic_thresh_mvs = np.maximum(0.2 * onset_max_dvdt, 2000.0) crossings_mask = onset_dvdt_windows > dynamic_thresh_mvs[:, None] has_crossing = np.any(crossings_mask, axis=1) first_crossing_rel_idx = np.argmax(crossings_mask, axis=1) fallback_indices = np.maximum(0, spike_indices - int(0.001 / dt)) found_thresh_indices = onset_window_indices[np.arange(n_spikes), first_crossing_rel_idx] dvdt_thresh_indices = np.where(has_crossing, found_thresh_indices, fallback_indices) # Use d2V/dt2 peak unless it sits at the edge of the lookback window at_edge = (d2vdt2_peak_rel == 0) | (d2vdt2_peak_rel >= lookback_samples - 1) thresh_indices = np.where(at_edge, dvdt_thresh_indices, thresh_indices_d2) ap_thresholds = data[thresh_indices] # Biological QC on fallback-detected thresholds: flag as NaN when the # per-spike peak rising rate exceeds 300 V/s (artifact ceiling) or the # threshold-to-peak rising phase is shorter than 0.2 ms (false detection). # onset_max_dvdt is in mV/s; 300 V/s = 300_000 mV/s. rising_phase_s = (spike_indices - thresh_indices) * dt artifact_flag = at_edge & ((onset_max_dvdt > 300_000.0) | (rising_phase_s < 0.0002)) ap_thresholds = np.where(artifact_flag, np.nan, ap_thresholds) peak_vals = data[spike_indices] amplitudes = peak_vals - ap_thresholds # --- Full waveform window --- full_window_len = lookback_samples + post_peak_samples full_window_range = np.arange(-lookback_samples, post_peak_samples) full_window_indices = spike_indices[:, None] + full_window_range np.clip(full_window_indices, 0, n_data - 1, out=full_window_indices) waveforms = data[full_window_indices] amp_50 = ap_thresholds + 0.5 * amplitudes amp_10 = ap_thresholds + 0.1 * amplitudes amp_90 = ap_thresholds + 0.9 * amplitudes half_widths = np.full(n_spikes, np.nan) rise_times = np.full(n_spikes, np.nan) decay_times = np.full(n_spikes, np.nan) rel_peak = lookback_samples col_indices = np.arange(full_window_len) is_pre_peak = col_indices < rel_peak is_post_peak = col_indices > rel_peak lev_50 = amp_50[:, None] idxs = np.tile(col_indices, (n_spikes, 1)) temp_mask = is_pre_peak & (waveforms <= lev_50) has_pre_50 = np.any(temp_mask, axis=1) masked_idxs_pre = np.where(temp_mask, idxs, -1) idx_rise_50_rel = np.max(masked_idxs_pre, axis=1) temp_mask_post = is_post_peak & (waveforms <= lev_50) has_post_50 = np.any(temp_mask_post, axis=1) masked_idxs_post = np.where(temp_mask_post, idxs, 999999) idx_fall_50_rel = np.min(masked_idxs_post, axis=1) valid_width = has_pre_50 & has_post_50 & (idx_rise_50_rel != -1) & (idx_fall_50_rel != 999999) lev_50_flat = lev_50.ravel() rise_frac = np.zeros(n_spikes) fall_frac = np.zeros(n_spikes) for k in np.where(valid_width)[0]: ri = idx_rise_50_rel[k] if ri + 1 < waveforms.shape[1]: y_lo, y_hi = waveforms[k, ri], waveforms[k, ri + 1] denom = y_hi - y_lo rise_frac[k] = (lev_50_flat[k] - y_lo) / denom if abs(denom) > 1e-12 else 0.5 fi = idx_fall_50_rel[k] if fi - 1 >= 0: y_hi2, y_lo2 = waveforms[k, fi - 1], waveforms[k, fi] denom2 = y_hi2 - y_lo2 fall_frac[k] = (lev_50_flat[k] - y_lo2) / denom2 if abs(denom2) > 1e-12 else 0.5 half_widths[valid_width] = ( ( (idx_fall_50_rel[valid_width] - fall_frac[valid_width]) - (idx_rise_50_rel[valid_width] + rise_frac[valid_width]) ) * dt * 1000.0 ) lev_10 = amp_10[:, None] lev_90 = amp_90[:, None] mask_10 = is_pre_peak & (waveforms <= lev_10) valid_10 = np.any(mask_10, axis=1) idx_10_rel = np.max(np.where(mask_10, idxs, -1), axis=1) mask_90 = is_pre_peak & (waveforms <= lev_90) valid_90 = np.any(mask_90, axis=1) idx_90_rel = np.max(np.where(mask_90, idxs, -1), axis=1) valid_rise = valid_10 & valid_90 & (idx_90_rel > idx_10_rel) lev_10_flat = amp_10 lev_90_flat = amp_90 rise_frac_10 = np.zeros(n_spikes) rise_frac_90 = np.zeros(n_spikes) for k in np.where(valid_rise)[0]: ri10 = idx_10_rel[k] if ri10 + 1 < waveforms.shape[1]: y_lo, y_hi = waveforms[k, ri10], waveforms[k, ri10 + 1] denom = y_hi - y_lo rise_frac_10[k] = (lev_10_flat[k] - y_lo) / denom if abs(denom) > 1e-12 else 0.5 ri90 = idx_90_rel[k] if ri90 + 1 < waveforms.shape[1]: y_lo, y_hi = waveforms[k, ri90], waveforms[k, ri90 + 1] denom = y_hi - y_lo rise_frac_90[k] = (lev_90_flat[k] - y_lo) / denom if abs(denom) > 1e-12 else 0.5 rise_times[valid_rise] = ( ((idx_90_rel[valid_rise] + rise_frac_90[valid_rise]) - (idx_10_rel[valid_rise] + rise_frac_10[valid_rise])) * dt * 1000.0 ) mask_dec_90 = is_post_peak & (waveforms <= lev_90) valid_dec_90 = np.any(mask_dec_90, axis=1) idx_dec_90_rel = np.min(np.where(mask_dec_90, idxs, 999999), axis=1) mask_dec_10 = is_post_peak & (waveforms <= lev_10) valid_dec_10 = np.any(mask_dec_10, axis=1) idx_dec_10_rel = np.min(np.where(mask_dec_10, idxs, 999999), axis=1) valid_decay = valid_dec_90 & valid_dec_10 & (idx_dec_10_rel > idx_dec_90_rel) decay_frac_90 = np.zeros(n_spikes) decay_frac_10 = np.zeros(n_spikes) for k in np.where(valid_decay)[0]: di90 = idx_dec_90_rel[k] if di90 - 1 >= 0: y_hi, y_lo = waveforms[k, di90 - 1], waveforms[k, di90] denom = y_hi - y_lo decay_frac_90[k] = (lev_90_flat[k] - y_lo) / denom if abs(denom) > 1e-12 else 0.5 di10 = idx_dec_10_rel[k] if di10 - 1 >= 0: y_hi, y_lo = waveforms[k, di10 - 1], waveforms[k, di10] denom = y_hi - y_lo decay_frac_10[k] = (lev_10_flat[k] - y_lo) / denom if abs(denom) > 1e-12 else 0.5 decay_times[valid_decay] = ( ( (idx_dec_10_rel[valid_decay] - decay_frac_10[valid_decay]) - (idx_dec_90_rel[valid_decay] - decay_frac_90[valid_decay]) ) * dt * 1000.0 ) # --- AHP --- ahp_max_samples_per_spike = np.full(n_spikes, ahp_max_samples) if n_spikes > 1: dist_to_next = spike_indices[1:] - spike_indices[:-1] ahp_max_samples_per_spike[:-1] = np.minimum(ahp_max_samples, dist_to_next) ahp_range = np.arange(0, ahp_max_samples) ahp_indices = spike_indices[:, None] + ahp_range np.clip(ahp_indices, 0, n_data - 1, out=ahp_indices) ahp_waveforms = data[ahp_indices] col_idxs_ahp = np.tile(np.arange(ahp_max_samples), (n_spikes, 1)) valid_ahp_mask = col_idxs_ahp < ahp_max_samples_per_spike[:, None] window_length = int(0.005 / dt) if window_length % 2 == 0: window_length += 1 window_length = max(5, window_length) # Cap to trace width; if cap makes it even, step down to next odd so the # Savitzky-Golay constraint (window > polyorder=3) is preserved. n_cols = ahp_waveforms.shape[1] max_win = n_cols if n_cols % 2 == 1 else max(1, n_cols - 1) window_length = min(window_length, max_win) if window_length % 2 == 0: window_length = max(1, window_length - 1) if ahp_waveforms.shape[1] >= window_length and window_length >= 5: smoothed_ahp = savgol_filter(ahp_waveforms, window_length, 3, axis=1) else: smoothed_ahp = ahp_waveforms temp_ahp = smoothed_ahp.copy() temp_ahp[~valid_ahp_mask] = np.inf ahp_min_rel_indices = np.argmin(temp_ahp, axis=1) mean_window = int(0.001 / dt) ahp_min_vals = np.zeros(n_spikes) for i in range(n_spikes): idx = ahp_min_rel_indices[i] start = max(0, idx - mean_window) end = min(ahp_max_samples_per_spike[i], idx + mean_window + 1) ahp_min_vals[i] = np.mean(ahp_waveforms[i, start:end]) rec_targets = ap_thresholds - 0.1 * amplitudes rec_target_bcast = rec_targets[:, None] is_after_min = col_idxs_ahp > ahp_min_rel_indices[:, None] is_recovered = ahp_waveforms >= rec_target_bcast valid_recovery = is_after_min & is_recovered & valid_ahp_mask has_recovery = np.any(valid_recovery, axis=1) rec_rel_indices = np.where(has_recovery, np.argmax(valid_recovery, axis=1), ahp_max_samples) thresh_bcast = ap_thresholds[:, None] is_below_thresh_ahp = ahp_waveforms < thresh_bcast has_ap_end = np.any(is_below_thresh_ahp, axis=1) ap_end_rel_indices = np.where(has_ap_end, np.argmax(is_below_thresh_ahp, axis=1), 0) ahp_durations = np.full(n_spikes, np.nan) valid_ahp_dur = has_recovery & has_ap_end & (rec_rel_indices > ap_end_rel_indices) ahp_durations[valid_ahp_dur] = (rec_rel_indices[valid_ahp_dur] - ap_end_rel_indices[valid_ahp_dur]) * dt * 1000.0 # --- ADP --- adp_amplitudes = np.full(n_spikes, np.nan) if ahp_max_samples > 2: val_mid = ahp_waveforms[:, 1:-1] val_left = ahp_waveforms[:, :-2] val_right = ahp_waveforms[:, 2:] is_local_max_inner = (val_mid > val_left) & (val_mid > val_right) is_local_max = np.pad(is_local_max_inner, ((0, 0), (1, 1)), mode="constant", constant_values=False) col_idxs2 = np.tile(np.arange(ahp_max_samples), (n_spikes, 1)) valid_adp_mask = is_local_max & (col_idxs2 > ap_end_rel_indices[:, None]) has_adp = np.any(valid_adp_mask, axis=1) temp_vals = ahp_waveforms.copy() temp_vals[~valid_adp_mask] = -np.inf adp_peaks = np.max(temp_vals, axis=1) calced_adps = adp_peaks - ahp_min_vals adp_amplitudes = np.where(has_adp, calced_adps, np.nan) # --- fAHP and mAHP (separate physiological windows) --- # fAHP: fast AHP (default 1-5 ms post-peak): Na+ channel-mediated repolarisation overshoot # mAHP: medium AHP (default 10-50 ms post-peak): K+ channel-mediated hyperpolarisation fahp_start = max(1, int(fahp_window_ms[0] / 1000.0 / dt)) fahp_end = max(fahp_start + 1, int(fahp_window_ms[1] / 1000.0 / dt)) mahp_start = max(1, int(mahp_window_ms[0] / 1000.0 / dt)) mahp_end = max(mahp_start + 1, int(mahp_window_ms[1] / 1000.0 / dt)) def _window_min(start_s: int, end_s: int) -> np.ndarray: """Return per-spike min voltage in [peak+start_s, peak+end_s).""" w_len = end_s - start_s if w_len <= 0: return np.full(n_spikes, np.nan) w_range = np.arange(start_s, end_s) w_indices = spike_indices[:, None] + w_range np.clip(w_indices, 0, n_data - 1, out=w_indices) return np.min(data[w_indices], axis=1) fahp_min_vals = _window_min(fahp_start, fahp_end) mahp_min_vals = _window_min(mahp_start, mahp_end) fahp_depths = ap_thresholds - fahp_min_vals mahp_depths = ap_thresholds - mahp_min_vals # --- max/min dV/dt --- full_dvdt = np.gradient(waveforms, axis=1) / dt / 1000.0 pre_peak_dvdt = np.where(is_pre_peak, full_dvdt, -np.inf) post_peak_dvdt = np.where(is_post_peak, full_dvdt, np.inf) max_dvdts = np.max(pre_peak_dvdt, axis=1) min_dvdts = np.min(post_peak_dvdt, axis=1) features_list = [] for i in range(n_spikes): features_list.append( { "ap_threshold": float(ap_thresholds[i]), "amplitude": float(amplitudes[i]), "half_width": float(half_widths[i]), "rise_time_10_90": float(rise_times[i]), "decay_time_90_10": float(decay_times[i]), "fahp_depth": float(fahp_depths[i]), "mahp_depth": float(mahp_depths[i]), "ahp_duration_half": float(ahp_durations[i]), "adp_amplitude": float(adp_amplitudes[i]), "max_dvdt": float(max_dvdts[i]), "min_dvdt": float(min_dvdts[i]), "absolute_peak_mv": float(peak_vals[i]), "overshoot_mv": float(max(0.0, peak_vals[i])), } ) return features_list
[docs] def calculate_isi(spike_times: np.ndarray) -> np.ndarray: """Return inter-spike intervals from spike_times array.""" if len(spike_times) < 2: return np.array([]) return np.diff(spike_times)
[docs] def analyze_multi_sweep_spikes( data_trials: List[np.ndarray], time_vector: np.ndarray, threshold: float, refractory_samples: int, dvdt_threshold: float = 20.0, ) -> List[SpikeTrainResult]: """Detect spikes across multiple sweeps.""" results = [] for i, trial_data in enumerate(data_trials): try: result = detect_spikes_threshold( trial_data, time_vector, threshold, refractory_samples, dvdt_threshold=dvdt_threshold ) result.metadata["sweep_index"] = i results.append(result) except (ValueError, TypeError, KeyError, IndexError) as e: log.error(f"Error analyzing sweep {i}: {e}") error_result = SpikeTrainResult( value=0, unit="spikes", is_valid=False, error_message=f"Sweep {i}: {str(e)}" ) error_result.metadata["sweep_index"] = i results.append(error_result) return results
# --------------------------------------------------------------------------- # Phase Plane (dV/dt vs V) # ---------------------------------------------------------------------------
[docs] def calculate_dvdt(voltage: np.ndarray, sampling_rate: float, sigma_ms: float = 0.1) -> np.ndarray: """ Calculate dV/dt (V/s) with optional Savitzky-Golay smoothing. Computes the raw derivative first, then applies a Savitzky-Golay filter (polynomial order 3) directly to the derivative array. This preserves the true max dV/dt better than pre-smoothing the voltage with a Gaussian, which attenuates the sharp upstroke of action potentials. Args: voltage: 1D voltage array (mV). sampling_rate: Sampling rate (Hz). sigma_ms: Smoothing window (ms). The SG window length is derived as ``max(5, int(sigma_ms / 1000 * sampling_rate))``, rounded up to the next odd integer. Set to 0 for no smoothing. Returns: 1D array of dV/dt in V/s. """ dt = 1.0 / sampling_rate dvdt = np.gradient(voltage, dt) / 1000.0 # mV/s -> V/s if sigma_ms > 0 and len(dvdt) >= 5: # Dynamic window length derived from sigma_ms and sampling rate (must be odd >= 5) window_samples = max(5, int(sigma_ms / 1000.0 * sampling_rate)) if window_samples % 2 == 0: window_samples += 1 # Cap at signal length (savgol_filter requires window <= len) window_samples = min(window_samples, len(dvdt) if len(dvdt) % 2 == 1 else len(dvdt) - 1) if window_samples >= 5: dvdt = savgol_filter(dvdt, window_samples, 3) return dvdt
[docs] def get_phase_plane_trajectory( voltage: np.ndarray, sampling_rate: float, sigma_ms: float = 0.1 ) -> Tuple[np.ndarray, np.ndarray]: """Return (voltage, dvdt) phase-plane trajectory.""" dvdt = calculate_dvdt(voltage, sampling_rate, sigma_ms) return voltage, dvdt
[docs] def detect_threshold_kink( voltage: np.ndarray, sampling_rate: float, dvdt_threshold: float = 20.0, kink_slope: float = 10.0, search_window_ms: float = 5.0, peak_indices: Optional[np.ndarray] = None, ) -> np.ndarray: """ Detect AP threshold using the dV/dt kink method. Returns array of threshold indices. """ if peak_indices is None: res = detect_spikes_threshold( voltage, np.arange(len(voltage)) / sampling_rate, -20.0, int(0.002 * sampling_rate) ) peak_indices = res.spike_indices dvdt = calculate_dvdt(voltage, sampling_rate, sigma_ms=0.1) threshold_indices = [] search_samples = int((search_window_ms / 1000.0) * sampling_rate) for peak_idx in peak_indices: start_search = max(0, peak_idx - search_samples) dvdt_slice = dvdt[start_search:peak_idx] crossings = np.where(dvdt_slice > dvdt_threshold)[0] if crossings.size > 0: thresh_idx = start_search + crossings[0] else: thresh_idx = max(0, peak_idx - int(0.001 * sampling_rate)) threshold_indices.append(thresh_idx) return np.array(threshold_indices)
# --------------------------------------------------------------------------- # Registry Wrappers # ---------------------------------------------------------------------------
[docs] @AnalysisRegistry.register( "spike_detection", label="Spike Detection", ui_params=[ { "name": "threshold", "label": "Threshold (mV):", "type": "float", "default": -20.0, "min": -1e9, "max": 1e9, "decimals": 4, }, { "name": "refractory_period", "label": "Refractory (s):", "type": "float", "default": 0.002, "min": 0.0, "max": 1e9, "decimals": 4, }, { "name": "peak_search_window", "label": "Peak Search (s):", "type": "float", "default": 0.005, "min": 0.0, "max": 1.0, "decimals": 4, }, { "name": "dvdt_threshold", "label": "dV/dt Thresh (V/s):", "type": "float", "default": 20.0, "min": 0.0, "max": 1e6, "decimals": 1, }, { "name": "ahp_window", "label": "AHP Window (s):", "type": "float", "default": 0.05, "min": 0.0, "max": 10.0, "decimals": 3, }, { "name": "onset_lookback", "label": "Onset Lookback (s):", "type": "float", "default": 0.01, "min": 0.0, "max": 0.1, "decimals": 3, }, { "name": "ljp_correction_mv", "label": "LJP Correction (mV):", "type": "float", "default": 0.0, "min": -100.0, "max": 100.0, "decimals": 2, "tooltip": "Liquid Junction Potential in mV. V_true = V_recorded - LJP.", }, ], plots=[ {"type": "hlines", "data": ["threshold"], "color": "r", "styles": ["dash"]}, {"type": "markers", "x": "spike_times", "y": "spike_voltages", "color": "r"}, ], ) def run_spike_detection_wrapper( data: np.ndarray, time: np.ndarray, sampling_rate: float, threshold: float = -20.0, refractory_period: float = 0.002, peak_search_window: float = 0.005, dvdt_threshold: float = 20.0, ahp_window: float = 0.05, onset_lookback: float = 0.01, **kwargs, ) -> Dict[str, Any]: """Wrapper for spike detection. Returns namespaced schema.""" try: ljp_mv = float(kwargs.get("ljp_correction_mv", 0.0)) data = apply_ljp_correction(data, ljp_mv) refractory_samples = int(refractory_period * sampling_rate) peak_window_samples = int(peak_search_window * sampling_rate) params = { "threshold": threshold, "refractory_period": refractory_period, "peak_search_window": peak_search_window, "dvdt_threshold": dvdt_threshold, "ahp_window": ahp_window, "onset_lookback": onset_lookback, } result = detect_spikes_threshold( data, time, threshold, refractory_samples, peak_search_window_samples=peak_window_samples, parameters=params, dvdt_threshold=dvdt_threshold, ) if result.is_valid: features_list = calculate_spike_features( data, time, result.spike_indices, dvdt_threshold=dvdt_threshold, ahp_window_sec=ahp_window, onset_lookback=onset_lookback, ) stats: Dict[str, Any] = {} if features_list: for key in features_list[0].keys(): values = [f[key] for f in features_list if not np.isnan(f[key])] if values: stats[f"{key}_mean"] = float(np.mean(values)) stats[f"{key}_std"] = float(np.std(values)) else: stats[f"{key}_mean"] = np.nan stats[f"{key}_std"] = np.nan v_data = ( data[result.spike_indices] if result.spike_indices is not None and len(result.spike_indices) > 0 else np.array([]) ) metrics: Dict[str, Any] = { "spike_count": len(result.spike_indices) if result.spike_indices is not None else 0, "mean_freq_hz": result.mean_frequency if result.mean_frequency is not None else 0.0, "spike_times": result.spike_times, "spike_indices": result.spike_indices, "spike_voltages": v_data, "threshold": threshold, "parameters": params, } metrics.update(stats) else: metrics = { "spike_count": 0, "mean_freq_hz": 0.0, "threshold": threshold, "spike_error": result.error_message or "Unknown error", "parameters": params, } return {"module_used": "single_spike", "metrics": metrics} except (ValueError, TypeError, KeyError, IndexError) as e: log.error(f"Error in run_spike_detection_wrapper: {e}", exc_info=True) return { "module_used": "single_spike", "metrics": {"spike_count": 0, "mean_freq_hz": 0.0, "spike_error": str(e)}, }
[docs] @AnalysisRegistry.register( "phase_plane_analysis", label="Phase Plane", plots=[ {"name": "Trace", "type": "trace"}, {"type": "popup_phase", "title": "Phase Plane"}, ], ui_params=[ { "name": "sigma_ms", "label": "Smoothing (ms):", "type": "float", "default": 0.1, "min": 0.0, "max": 1e9, "decimals": 4, }, { "name": "dvdt_threshold", "label": "dV/dt Thresh (V/s):", "type": "float", "default": 20.0, "min": 0.0, "max": 1e9, "decimals": 4, }, { "name": "spike_threshold", "label": "Spike Detect Thresh (mV):", "type": "float", "default": -20.0, "min": -1000.0, "max": 1000.0, "decimals": 2, }, {"name": "kink_slope", "label": "Kink Slope:", "type": "float", "default": 10.0, "hidden": True}, { "name": "search_window_ms", "label": "Search Window (ms):", "type": "float", "default": 5.0, "min": 0.1, "max": 100.0, "decimals": 2, }, { "name": "ljp_correction_mv", "label": "LJP Correction (mV):", "type": "float", "default": 0.0, "min": -100.0, "max": 100.0, "decimals": 2, "tooltip": "Liquid Junction Potential in mV. V_true = V_recorded - LJP.", }, ], ) def phase_plane_analysis_wrapper( voltage: np.ndarray, time: np.ndarray, sampling_rate: float, sigma_ms: float = 0.1, dvdt_threshold: float = 20.0, **kwargs, ) -> Dict[str, Any]: """Wrapper for Phase Plane analysis. Returns namespaced schema.""" spike_threshold = kwargs.get("spike_threshold", -20.0) search_window_ms = kwargs.get("search_window_ms", 5.0) kink_slope = kwargs.get("kink_slope", 10.0) ljp_mv = float(kwargs.get("ljp_correction_mv", 0.0)) voltage = apply_ljp_correction(voltage, ljp_mv) v, dvdt = get_phase_plane_trajectory(voltage, sampling_rate, sigma_ms) spike_res = detect_spikes_threshold(voltage, time, spike_threshold, int(0.002 * sampling_rate)) thresh_indices = detect_threshold_kink( voltage, sampling_rate, dvdt_threshold=dvdt_threshold, kink_slope=kink_slope, search_window_ms=search_window_ms, peak_indices=spike_res.spike_indices, ) threshold_vals = voltage[thresh_indices] if thresh_indices.size > 0 else [] metrics = { "voltage": v, "dvdt": dvdt, "threshold_indices": thresh_indices, "threshold_vals": threshold_vals, "threshold_v": float(np.mean(threshold_vals)) if len(threshold_vals) > 0 else np.nan, "threshold_dvdt": float(dvdt_threshold), "max_dvdt": float(np.max(dvdt)) if len(dvdt) > 0 else 0.0, "threshold_mean": float(np.mean(threshold_vals)) if len(threshold_vals) > 0 else np.nan, } return {"module_used": "single_spike", "metrics": metrics}
# Keep the original function name as an alias so existing code and tests still work phase_plane_analysis = phase_plane_analysis_wrapper # --------------------------------------------------------------------------- # Module-level tab aggregator # ---------------------------------------------------------------------------
[docs] @AnalysisRegistry.register( "single_spike", label="Spike Analysis", method_selector={ "Spike Detection": "spike_detection", "Phase Plane": "phase_plane_analysis", }, ui_params=[], plots=[], ) def single_spike_module(**kwargs): """Module-level aggregator tab for single-spike analyses.""" return {}