Source code for Synaptipy.core.analysis.firing_dynamics

# src/Synaptipy/core/analysis/firing_dynamics.py
# -*- coding: utf-8 -*-
"""
Core Protocol Module 3: Firing Dynamics.

Consolidates: Excitability (F-I curve), Burst Analysis, and Spike Train
    Dynamics into one self-contained module.

All registry wrapper functions return::

    {
        "module_used": "firing_dynamics",
        "metrics": { ... flat result keys ... }
    }
"""

import logging
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional

import numpy as np
from scipy.stats import linregress

from Synaptipy.core.analysis.registry import AnalysisRegistry
from Synaptipy.core.analysis.single_spike import calculate_spike_features, detect_spikes_threshold
from Synaptipy.core.results import AnalysisResult, BurstResult

log = logging.getLogger(__name__)


# ---------------------------------------------------------------------------
# Excitability (F-I Curve)
# ---------------------------------------------------------------------------


[docs] def calculate_fi_curve( # noqa: C901 sweeps: List[np.ndarray], time_vectors: List[np.ndarray], current_steps: Optional[List[float]] = None, threshold: float = -20.0, refractory_ms: float = 2.0, ) -> Dict[str, Any]: """ Calculate F-I Curve properties from a set of sweeps. Args: sweeps: List of voltage traces (1D arrays). time_vectors: List of corresponding time vectors. current_steps: List of current amplitudes for each sweep. If None, inferred. threshold: Spike detection threshold (mV). refractory_ms: Refractory period (ms). Returns: Dictionary with rheobase_pa, fi_slope, max_freq, spike_counts, frequencies, adaptation_ratios, current_steps. """ num_sweeps = len(sweeps) if num_sweeps == 0: return {"error": "No sweeps provided"} if current_steps is None: log.warning("Current steps not provided. Using sweep indices as proxy for current steps.") current_steps = list(range(num_sweeps)) if len(current_steps) != num_sweeps: log.warning(f"Mismatch between sweeps ({num_sweeps}) and current_steps ({len(current_steps)}). Truncating.") min_len = min(num_sweeps, len(current_steps)) sweeps = sweeps[:min_len] time_vectors = time_vectors[:min_len] current_steps = current_steps[:min_len] spike_counts = [] frequencies = [] adaptation_ratios = [] broadening_indices = [] # Width_last / Width_first within each sweep for i, (data, time) in enumerate(zip(sweeps, time_vectors)): dt = time[1] - time[0] if len(time) > 1 else 1e-4 sampling_rate = 1.0 / dt refractory_samples = int((refractory_ms / 1000.0) * sampling_rate) result = detect_spikes_threshold(data, time, threshold, refractory_samples) count = len(result.spike_indices) if result.spike_indices is not None else 0 freq = result.mean_frequency if result.mean_frequency is not None else 0.0 spike_counts.append(count) frequencies.append(freq) if count >= 3 and result.spike_times is not None: isis = np.diff(result.spike_times) if isis[0] > 0: adaptation_ratios.append(float(isis[-1] / isis[0])) else: adaptation_ratios.append(np.nan) else: adaptation_ratios.append(np.nan) # Spike Broadening Index: half-width of last spike / half-width of first spike broadening_idx = np.nan if count >= 2 and result.spike_indices is not None and len(result.spike_indices) >= 2: try: spike_idx_arr = result.spike_indices features = calculate_spike_features(data, time, spike_idx_arr) widths = [f.get("half_width") for f in features if f.get("half_width") is not None] valid_widths = [w for w in widths if w is not None and not np.isnan(w) and w > 0] if len(valid_widths) >= 2: broadening_idx = float(valid_widths[-1] / valid_widths[0]) except (ValueError, TypeError, IndexError): pass broadening_indices.append(broadening_idx) sorted_indices = np.argsort(current_steps) sorted_currents = np.array(current_steps)[sorted_indices] sorted_counts = np.array(spike_counts)[sorted_indices] sorted_freqs = np.array(frequencies)[sorted_indices] rheobase_pa = None rheobase_idx = -1 for i, count in enumerate(sorted_counts): if count > 0: rheobase_pa = float(sorted_currents[i]) rheobase_idx = i break fi_slope = None r_squared = None if rheobase_idx != -1 and rheobase_idx < len(sorted_counts) - 1: valid_slice = slice(rheobase_idx, None) x = sorted_currents[valid_slice] y = sorted_freqs[valid_slice] # Truncate at the maximum firing frequency to exclude the depolarisation # block region where frequency drops back toward 0 Hz. Including those # points would produce a spuriously flat or negative slope. peak_idx = int(np.argmax(y)) + 1 # +1 so the peak point itself is included x = x[:peak_idx] y = y[:peak_idx] if len(x) >= 2: try: slope, _intercept, r_value, _p, _se = linregress(x, y) fi_slope = float(slope) r_squared = float(r_value**2) except (ValueError, TypeError) as e: log.warning(f"Linear regression failed: {e}") return { "rheobase_pa": rheobase_pa, "fi_slope": fi_slope, "fi_r_squared": r_squared, "max_freq": float(np.max(frequencies)) if frequencies else 0.0, "spike_counts": spike_counts, "frequencies": frequencies, "adaptation_ratios": adaptation_ratios, "broadening_indices": broadening_indices, "current_steps": current_steps, }
[docs] @AnalysisRegistry.register( "excitability_analysis", label="Excitability", requires_multi_trial=True, plots=[ { "type": "popup_xy", "title": "F-I Curve", "x": "current_steps", "y": "frequencies", "x_label": "Current (pA)", "y_label": "Frequency (Hz)", }, ], ui_params=[ { "name": "threshold", "label": "Threshold (mV):", "type": "float", "default": -20.0, "min": -1e9, "max": 1e9, "decimals": 4, }, { "name": "start_current", "label": "Start Current (pA):", "type": "float", "default": 0.0, "min": -1e9, "max": 1e9, "decimals": 4, }, { "name": "step_current", "label": "Step Current (pA):", "type": "float", "default": 10.0, "min": -1e9, "max": 1e9, "decimals": 4, }, { "name": "refractory_ms", "label": "Refractory (ms):", "type": "float", "default": 2.0, "min": 0.0, "max": 1000.0, "decimals": 2, }, ], ) def run_excitability_analysis_wrapper( data_list: List[np.ndarray], time_list: List[np.ndarray], sampling_rate: float, **kwargs ) -> Dict[str, Any]: """Wrapper for Excitability Analysis (F-I Curve).""" try: threshold = kwargs.get("threshold", -20.0) start_current = kwargs.get("start_current", 0.0) step_current = kwargs.get("step_current", 10.0) refractory_ms = kwargs.get("refractory_ms", 2.0) if isinstance(data_list, np.ndarray): if data_list.ndim == 1: data_list = [data_list] time_list = [time_list] if isinstance(time_list, np.ndarray) else time_list elif data_list.ndim == 2: data_list = [data_list[i] for i in range(data_list.shape[0])] if isinstance(time_list, np.ndarray) and time_list.ndim == 1: time_list = [time_list for _ in range(len(data_list))] elif isinstance(time_list, np.ndarray) and time_list.ndim == 2: time_list = [time_list[i] for i in range(time_list.shape[0])] if isinstance(time_list, np.ndarray): time_list = [time_list] num_sweeps = len(data_list) current_steps = [start_current + i * step_current for i in range(num_sweeps)] results = calculate_fi_curve( sweeps=data_list, time_vectors=time_list, current_steps=current_steps, threshold=threshold, refractory_ms=refractory_ms, ) if "error" in results: return {"module_used": "firing_dynamics", "metrics": {"excitability_error": results["error"]}} return { "module_used": "firing_dynamics", "metrics": { "rheobase_pa": results["rheobase_pa"], "fi_slope": results["fi_slope"], "fi_r_squared": results["fi_r_squared"], "max_freq_hz": results["max_freq"], "frequencies": results["frequencies"], "adaptation_ratios": results["adaptation_ratios"], "broadening_indices": results.get("broadening_indices", []), "current_steps": results["current_steps"], }, } except (ValueError, TypeError, KeyError, IndexError) as e: log.error(f"Error in run_excitability_analysis_wrapper: {e}", exc_info=True) return {"module_used": "firing_dynamics", "metrics": {"excitability_error": str(e)}}
# --------------------------------------------------------------------------- # Burst Analysis # ---------------------------------------------------------------------------
[docs] def calculate_bursts_logic( spike_times: np.ndarray, max_isi_start: float = 0.01, max_isi_end: float = 0.2, min_spikes: int = 2, dynamic_burst: bool = False, burst_isi_fraction: float = 0.3, parameters: Optional[Dict[str, Any]] = None, ) -> BurstResult: """ Detect bursts in a spike train. Args: spike_times: 1D array of spike times (seconds). max_isi_start: Max ISI to start a burst (s). Ignored when dynamic_burst=True. max_isi_end: Max ISI to continue a burst (s). Ignored when dynamic_burst=True. min_spikes: Minimum spikes per burst. dynamic_burst: When True, compute the mean ISI of the whole train and define the burst boundary as ``burst_isi_fraction * mean_isi``. This abandons hardcoded thresholds in favour of the train's own temporal structure. burst_isi_fraction: Fraction of mean ISI used as burst boundary when ``dynamic_burst=True`` (default 0.3, i.e. 30%). Returns: BurstResult object. """ if spike_times is None or len(spike_times) < min_spikes: return BurstResult( value=0, unit="bursts", is_valid=True, burst_count=0, spikes_per_burst_avg=0.0, burst_duration_avg=0.0, burst_freq_hz=0.0, bursts=[], parameters=parameters or {}, ) isis = np.diff(spike_times) # Dynamic threshold: fraction of the global mean ISI if dynamic_burst and len(isis) >= 1: mean_isi = float(np.mean(isis)) dyn_threshold = burst_isi_fraction * mean_isi max_isi_start = dyn_threshold max_isi_end = dyn_threshold bursts = [] current_burst: List[float] = [] in_burst = False for i, isi in enumerate(isis): if not in_burst: if isi <= max_isi_start: in_burst = True current_burst = [spike_times[i], spike_times[i + 1]] else: if isi <= max_isi_end: current_burst.append(spike_times[i + 1]) else: in_burst = False if len(current_burst) >= min_spikes: bursts.append(current_burst) current_burst = [] if in_burst and len(current_burst) >= min_spikes: bursts.append(current_burst) num_bursts = len(bursts) if num_bursts == 0: return BurstResult( value=0, unit="bursts", is_valid=True, burst_count=0, bursts=[], parameters=parameters or {}, ) spikes_per_burst = [len(b) for b in bursts] burst_durations = [b[-1] - b[0] for b in bursts] duration = spike_times[-1] - spike_times[0] if len(spike_times) > 0 else 0 burst_freq = num_bursts / duration if duration > 0 else 0.0 return BurstResult( value=num_bursts, unit="bursts", is_valid=True, burst_count=num_bursts, spikes_per_burst_avg=float(np.mean(spikes_per_burst)), burst_duration_avg=float(np.mean(burst_durations)), burst_freq_hz=burst_freq, bursts=bursts, parameters=parameters or {}, )
[docs] def analyze_spikes_and_bursts( data: np.ndarray, time: np.ndarray, sampling_rate: float, threshold: float, max_isi_start: float, max_isi_end: float, refractory_ms: float = 2.0, dynamic_burst: bool = False, burst_isi_fraction: float = 0.3, parameters: Optional[Dict[str, Any]] = None, ) -> BurstResult: """Detect spikes then detect bursts.""" refractory_samples = int((refractory_ms / 1000.0) * sampling_rate) spike_result = detect_spikes_threshold(data, time, threshold, refractory_samples, parameters=parameters) if not spike_result.is_valid: return BurstResult(value=0, unit="bursts", is_valid=False, error_message=spike_result.error_message) if spike_result.spike_times is None: return BurstResult(value=0, unit="bursts", is_valid=True, burst_count=0, bursts=[]) return calculate_bursts_logic( spike_result.spike_times, max_isi_start=max_isi_start, max_isi_end=max_isi_end, dynamic_burst=dynamic_burst, burst_isi_fraction=burst_isi_fraction, parameters=parameters, )
[docs] @AnalysisRegistry.register( "burst_analysis", label="Burst", ui_params=[ { "name": "threshold", "label": "Threshold (mV):", "type": "float", "default": -20.0, "min": -1e9, "max": 1e9, "decimals": 4, }, { "name": "max_isi_start", "label": "Max ISI Start (s):", "type": "float", "default": 0.01, "min": 0.0, "max": 1e9, "decimals": 4, }, { "name": "max_isi_end", "label": "Max ISI End (s):", "type": "float", "default": 0.1, "min": 0.0, "max": 1e9, "decimals": 4, }, {"name": "min_spikes", "label": "Min Spikes:", "type": "int", "default": 2, "min": 2, "max": 1000}, {"name": "dynamic_burst", "label": "Dynamic ISI Threshold", "type": "bool", "default": False}, { "name": "burst_isi_fraction", "label": "Burst ISI Fraction:", "type": "float", "default": 0.3, "min": 0.01, "max": 1.0, "decimals": 2, "tooltip": "Spikes are in a burst if ISI < this fraction of the train mean ISI.", "visible_when": {"param": "dynamic_burst", "value": True}, }, ], plots=[{"type": "brackets", "data": "bursts", "color": "r"}], ) def run_burst_analysis_wrapper(data: np.ndarray, time: np.ndarray, sampling_rate: float, **kwargs) -> Dict[str, Any]: """Wrapper for Burst Analysis.""" threshold = kwargs.get("threshold", -20.0) max_isi_start = kwargs.get("max_isi_start", 0.01) max_isi_end = kwargs.get("max_isi_end", 0.1) dynamic_burst = kwargs.get("dynamic_burst", False) burst_isi_fraction = float(kwargs.get("burst_isi_fraction", 0.3)) result = analyze_spikes_and_bursts( data=data, time=time, sampling_rate=sampling_rate, threshold=threshold, max_isi_start=max_isi_start, max_isi_end=max_isi_end, dynamic_burst=dynamic_burst, burst_isi_fraction=burst_isi_fraction, parameters=kwargs, ) if not result.is_valid: return {"module_used": "firing_dynamics", "metrics": {"burst_error": result.error_message}} return { "module_used": "firing_dynamics", "metrics": { "burst_count": result.burst_count, "spikes_per_burst_avg": result.spikes_per_burst_avg, "burst_duration_avg": result.burst_duration_avg, "burst_freq_hz": result.burst_freq_hz, "bursts": result.bursts, "_result_obj": result, }, }
# --------------------------------------------------------------------------- # Spike Train Dynamics # ---------------------------------------------------------------------------
[docs] @dataclass class TrainDynamicsResult(AnalysisResult): """Result object for spike train dynamics analysis.""" spike_count: int = 0 mean_isi_s: Optional[float] = None cv: Optional[float] = None cv2: Optional[float] = None lv: Optional[float] = None adaptation_index: Optional[float] = None isis: Optional[np.ndarray] = None parameters: Dict[str, Any] = field(default_factory=dict) def __repr__(self): if self.is_valid: cv_str = f"{self.cv:.3f}" if self.cv is not None else "N/A" lv_str = f"{self.lv:.3f}" if self.lv is not None else "N/A" return f"TrainDynamicsResult(Spikes={self.spike_count}, CV={cv_str}, LV={lv_str})" return f"TrainDynamicsResult(Error: {self.error_message})"
[docs] def calculate_train_dynamics(spike_times: np.ndarray) -> TrainDynamicsResult: """ Compute native spike train statistical metrics. Args: spike_times: 1D NumPy array of spike times in seconds. Returns: TrainDynamicsResult. """ spike_count = len(spike_times) if spike_count < 2: return TrainDynamicsResult( value=None, unit="", is_valid=False, error_message="Requires at least 2 spikes for ISI calculations.", spike_count=spike_count, ) isis = np.diff(spike_times) mean_isi = float(np.mean(isis)) cv = float(np.std(isis) / mean_isi) if mean_isi > 0 else np.nan if spike_count < 3: return TrainDynamicsResult( value=mean_isi, unit="s", is_valid=True, spike_count=spike_count, mean_isi_s=mean_isi, cv=cv, cv2=np.nan, lv=np.nan, isis=isis, ) isis = isis[isis > 0] if len(isis) < 2: return TrainDynamicsResult( value=mean_isi, unit="s", is_valid=True, spike_count=spike_count, mean_isi_s=mean_isi, cv=cv, cv2=np.nan, lv=np.nan, isis=isis, ) isi_i = isis[:-1] isi_next = isis[1:] cv2_array = 2.0 * np.abs(isi_next - isi_i) / (isi_next + isi_i) cv2_val = float(np.mean(cv2_array)) lv_array = 3.0 * ((isi_i - isi_next) ** 2) / ((isi_i + isi_next) ** 2) lv_val = float(np.mean(lv_array)) # Adaptation index: ISI_last / ISI_first (>1 = adapting, <1 = bursting) adaptation_index = float(isis[-1] / isis[0]) if isis[0] > 0 else float(np.nan) return TrainDynamicsResult( value=cv, unit="", is_valid=True, spike_count=spike_count, mean_isi_s=mean_isi, cv=cv, cv2=cv2_val, lv=lv_val, adaptation_index=adaptation_index, isis=isis, )
[docs] @AnalysisRegistry.register( name="train_dynamics", label="Spike Train Dynamics", ui_params=[ { "name": "spike_threshold", "type": "float", "label": "AP Threshold (mV)", "default": -20.0, "min": -100.0, "max": 50.0, "step": 1.0, "tooltip": "Threshold to detect action potentials. Lower this for blunted or dendritic spikes.", } ], plots=[ {"name": "Trace", "type": "trace", "show_spikes": True}, { "type": "popup_xy", "title": "ISI Plot", "x": "isi_numbers", "y": "isi_ms", "x_label": "ISI Number", "y_label": "ISI (ms)", }, ], ) def run_train_dynamics_wrapper(data: np.ndarray, time: np.ndarray, sampling_rate: float, **kwargs) -> Dict[str, Any]: """Wrapper for Spike Train Dynamics.""" from Synaptipy.core.analysis.single_spike import calculate_spike_features ap_threshold = kwargs.get("spike_threshold", 0.0) ap_times = kwargs.get("action_potential_times", None) spike_indices = None if ap_times is None: refractory_samples = max(1, int(0.002 * sampling_rate)) spike_result = detect_spikes_threshold( data, time, threshold=ap_threshold, refractory_samples=refractory_samples ) if spike_result.spike_indices is not None and len(spike_result.spike_indices) > 0: spike_indices = spike_result.spike_indices ap_times = time[spike_indices] else: spike_indices = np.array([], dtype=int) ap_times = np.array([]) result = calculate_train_dynamics(ap_times) if not result.is_valid: return {"module_used": "firing_dynamics", "metrics": {"train_dynamics_error": result.error_message}} isi_ms = (result.isis * 1000.0).tolist() if result.isis is not None and len(result.isis) > 0 else [] isi_numbers = list(range(1, len(isi_ms) + 1)) # Spike broadening index: half_width_last / half_width_first for trains >= 3 spikes spike_broadening_index = float(np.nan) if spike_indices is not None and len(spike_indices) >= 3: try: features_list = calculate_spike_features(data, time, spike_indices) widths = [f.get("half_width") for f in features_list if f.get("half_width") is not None] valid_widths = [w for w in widths if w is not None and not np.isnan(w)] if len(valid_widths) >= 3: spike_broadening_index = ( float(valid_widths[-1] / valid_widths[0]) if valid_widths[0] > 0 else float(np.nan) ) except Exception as e: log.warning(f"Could not compute spike broadening index: {e}") metrics: Dict[str, Any] = { "spike_count": result.spike_count, "mean_isi_s": result.mean_isi_s, "cv": result.cv, "cv2": result.cv2, "lv": result.lv, "adaptation_index": result.adaptation_index, "spike_broadening_index": spike_broadening_index, "isi_numbers": isi_numbers, "isi_ms": isi_ms, } if spike_indices is not None: metrics["spike_indices"] = spike_indices return {"module_used": "firing_dynamics", "metrics": metrics}
# --------------------------------------------------------------------------- # Module-level tab aggregator # ---------------------------------------------------------------------------
[docs] @AnalysisRegistry.register( "firing_dynamics", label="Excitability", method_selector={ "Excitability": "excitability_analysis", "Burst Analysis": "burst_analysis", "Spike Train Dynamics": "train_dynamics", }, ui_params=[], plots=[], ) def firing_dynamics_module(**kwargs): """Module-level aggregator tab for firing-dynamics analyses.""" return {}