Source code for Synaptipy.core.data_model

# src/Synaptipy/core/data_model.py
# -*- coding: utf-8 -*-
"""
Core Domain Data Models for Synaptipy.

Defines the central classes representing electrophysiology concepts like
Recording sessions and individual data Channels.
"""

import logging
import uuid
from datetime import datetime  # Required for Recording timestamp
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple  # Added Any for metadata dict

import numpy as np

from Synaptipy.core.source_interfaces import SourceHandle

# Configure logger for this module
log = logging.getLogger(__name__)


# ---------------------------------------------------------------------------
# Undo / Command support
# ---------------------------------------------------------------------------


[docs] class UndoStack: """Lightweight state-history stack for non-destructive editing (Command pattern). Stores deep-copy snapshots of channel data before destructive operations so that a single :meth:`Channel.undo` call instantly restores the previous state. Memory depth is capped at *max_depth* entries (oldest entries are evicted). Usage:: channel.push_undo("apply lowpass 300 Hz") channel.data_trials = filtered_trials # ... channel.undo() # restores data_trials to state before the filter """ def __init__(self, max_depth: int = 20): """ Initialise the undo stack. Args: max_depth: Maximum number of undo levels retained (default 20). """ self._max_depth = max(1, int(max_depth)) self._states: List[Tuple[str, Dict[str, Any]]] = []
[docs] def push(self, label: str, state: Dict[str, Any]) -> None: """Save a named state snapshot. Args: label: Human-readable description of the pending change (e.g. ``"notch 50 Hz"``). state: Arbitrary serialisable dict representing the channel state to restore. """ self._states.append((label, state)) if len(self._states) > self._max_depth: self._states.pop(0)
[docs] def pop(self) -> Optional[Tuple[str, Dict[str, Any]]]: """Remove and return the most recently saved state. Returns: ``(label, state)`` tuple, or ``None`` if the stack is empty. """ return self._states.pop() if self._states else None
[docs] def can_undo(self) -> bool: """Return ``True`` if at least one undo level is available.""" return bool(self._states)
@property def depth(self) -> int: """Number of undo levels currently stored.""" return len(self._states)
[docs] def clear(self) -> None: """Discard all saved states.""" self._states.clear()
def __repr__(self) -> str: labels = [lbl for lbl, _ in self._states] return f"UndoStack(depth={self.depth}, labels={labels})"
[docs] class Channel: """ Represents a single channel of recorded data, potentially across multiple trials or segments. """ def __init__( self, id: str, name: str, units: str, sampling_rate: float, data_trials: List[np.ndarray], loader: Optional[Any] = None, # Callable[[int], Optional[np.ndarray]] ): """ Initializes a Channel object. Args: id: A unique identifier for the channel (e.g., '0', '1', 'Vm'). name: A descriptive name for the channel (e.g., 'Voltage', 'IN_0'). units: The physical units of the data (e.g., 'mV', 'pA', 'V'). sampling_rate: The sampling frequency in Hz. data_trials: A list where each element is a 1D NumPy array representing the data for one trial/segment. loader: Optional callable/object with load_trial(index) method for lazy loading. """ # --- Core Attributes (Assigned ONCE) --- self.id: str = id self.name: str = name self.units: str = units if units else "unknown" # Ensure units is a string self.sampling_rate: float = sampling_rate self.t_start: float = 0.0 # Absolute start time relative to recording start (set by adapter) # --- Data Trials Validation and Assignment --- # For lazy loading, data_trials may be empty initially if not isinstance(data_trials, list): log.warning(f"Channel '{name}' received non-list data. Attempting conversion.") try: # Ensure data is numpy array and handle potential conversion errors self.data_trials: List[np.ndarray] = ( [np.asarray(data_trials, dtype=float)] if data_trials is not None else [] ) except (TypeError, ValueError) as e: log.error(f"Could not convert data_trials for channel '{name}' to list of arrays: {e}") self.data_trials = [] # Assign empty list on failure else: # For lazy loading, data_trials may be empty or contain actual data if data_trials and all(isinstance(t, np.ndarray) for t in data_trials): # Normal case: data is already loaded self.data_trials: List[np.ndarray] = [np.asarray(t) for t in data_trials] else: # Lazy loading case: empty data_trials # Ensure it's a list so we can append later. # Crucially, we might want to pre-allocate None slots if we know num_trials (from metadata?) # But for now, empty list is safer than [[]] which implies 1 empty trial. self.data_trials: List[Optional[np.ndarray]] = [] # --- ADDED: Attributes for Associated Current Data --- self.current_data_trials: List[np.ndarray] = [] # Populated by adapter if current signal found self.current_units: Optional[str] = None # Populated by adapter # --- END ADDED --- # --- Optional Electrode Metadata (Populated by NeoAdapter) --- self.electrode_description: Optional[str] = None self.electrode_location: Optional[str] = None self.electrode_filtering: Optional[str] = None self.electrode_gain: Optional[float] = None # Gain applied by amplifier/acquisition self.electrode_offset: Optional[float] = None # ADC offset or baseline offset self.electrode_resistance: Optional[str] = ( None # Pipette resistance (e.g., "10 MOhm") - Requires parsing for NWB ) self.electrode_seal: Optional[str] = None # Seal resistance (e.g., "5 GOhm") - Requires parsing for NWB # --- Lazy Loading Support --- self.loader = loader self.metadata: Dict[str, Any] = {} # General metadata dictionary # --- Undo stack (non-destructive editing) --- self._undo_stack: UndoStack = UndoStack() @property def num_trials(self) -> int: """Returns the number of trials/segments available for this channel.""" # For lazy loading, check metadata first, then data_trials if hasattr(self, "metadata") and "num_trials" in self.metadata: return self.metadata["num_trials"] return len(self.data_trials) @property def num_samples(self) -> int: """ Returns the number of samples in the first trial. WARNING: This property only checks the first trial. If trials have variable lengths, this value may be misleading. Use `get_consistent_samples()` for strict validation. Returns 0 if no trials are present. """ if not self.data_trials: return 0 # Ensure the first trial is valid before accessing shape if not isinstance(self.data_trials[0], np.ndarray) or self.data_trials[0].ndim == 0: log.warning(f"Channel '{self.name}': First trial is not a valid NumPy array.") return 0 first_trial_len = self.data_trials[0].shape[0] # Check other trials more carefully lengths = set() valid_trial_found = False for arr in self.data_trials: if isinstance(arr, np.ndarray) and arr.ndim > 0: lengths.add(arr.shape[0]) valid_trial_found = True else: log.warning(f"Channel '{self.name}' contains invalid trial data type: {type(arr)}") if not valid_trial_found: log.warning(f"Channel '{self.name}' contains no valid NumPy array trials.") return 0 if len(lengths) > 1: log.warning( f"Channel '{self.name}' has trials with varying lengths: {lengths}. " f"`num_samples` returning length of first trial ({first_trial_len})." ) # Return length of first valid trial if lengths are consistent or vary return first_trial_len if lengths else 0
[docs] def get_consistent_samples(self) -> int: """ Returns the number of samples per trial, ensuring all trials have the same length. Raises ValueError if trials have different lengths. Returns 0 if no trials. """ if not self.data_trials: return 0 lengths = set() for arr in self.data_trials: if isinstance(arr, np.ndarray) and arr.ndim > 0: lengths.add(arr.shape[0]) if len(lengths) == 0: return 0 if len(lengths) > 1: raise ValueError(f"Channel '{self.name}' has inconsistent trial lengths: {lengths}") return list(lengths)[0]
# --- Data Retrieval Methods ---
[docs] def get_data(self, trial_index: int) -> Optional[np.ndarray]: """ Returns the raw data for a specific trial. For lazy loading, this method will load the data from disk if not already loaded. """ # Check if data is already loaded if self.data_trials and 0 <= trial_index < len(self.data_trials): data = self.data_trials[trial_index] if data is not None: return data # For lazy loading, try to load data using the loader if self.loader: try: # If loader is a callable or has load_trial method if hasattr(self.loader, "load_trial"): data = self.loader.load_trial(trial_index) elif callable(self.loader): data = self.loader(trial_index) else: log.error(f"Channel {self.id}: Invalid loader object.") return None if data is not None: # Store valid data to avoid re-loading # Ensure data_trials list is long enough while len(self.data_trials) <= trial_index: self.data_trials.append(None) self.data_trials[trial_index] = data return data except (TypeError, ValueError, IndexError) as e: log.error(f"Failed to load trial {trial_index} data lazily for channel {self.id}: {e}") return None return None
# _load_trial_data_lazy removed (moved to ChannelLoader strategy)
[docs] def get_time_vector(self, trial_index: int) -> Optional[np.ndarray]: # Returns the absolute time vector for a specific trial. data = self.get_data(trial_index) if data is not None and self.sampling_rate and self.sampling_rate > 0: num_samples = len(data) duration = num_samples / self.sampling_rate trial_t_start = self.t_start + trial_index * duration # Approximate start time return np.linspace(trial_t_start, trial_t_start + duration, num_samples, endpoint=False) return None
[docs] def get_relative_time_vector(self, trial_index: int) -> Optional[np.ndarray]: # Returns the time vector relative to the start of the trial (starts at 0). data = self.get_data(trial_index) if data is not None and self.sampling_rate and self.sampling_rate > 0: num_samples = len(data) duration = num_samples / self.sampling_rate return np.linspace(0, duration, num_samples, endpoint=False) return None
[docs] def get_averaged_data(self, trial_indices: Optional[List[int]] = None) -> Optional[np.ndarray]: # Returns the averaged data across all (or specified) trials. if self.data_trials: try: # Determine which trials to use if trial_indices is not None and len(trial_indices) > 0: # Validate indices valid_indices = [i for i in trial_indices if 0 <= i < len(self.data_trials)] trials_to_avg = [self.data_trials[i] for i in valid_indices if self.data_trials[i] is not None] else: trials_to_avg = [t for t in self.data_trials if t is not None] if not trials_to_avg: return None # Ensure all trials have the same length for simple averaging first_len = len(trials_to_avg[0]) if all(len(trial) == first_len for trial in trials_to_avg): return np.mean(np.array(trials_to_avg), axis=0) else: # Handle differing lengths (e.g., pad or error) log.warning(f"Channel {self.id}: Trials have different lengths, cannot average directly.") return None except (TypeError, ValueError, IndexError) as e: log.error(f"Channel {self.id}: Error averaging trials: {e}") return None return None
[docs] def get_averaged_time_vector(self) -> Optional[np.ndarray]: # Returns the absolute time vector for the averaged data (assumes first trial time base). avg_data = self.get_averaged_data() if avg_data is not None and self.sampling_rate and self.sampling_rate > 0: num_samples = len(avg_data) duration = num_samples / self.sampling_rate return np.linspace(self.t_start, self.t_start + duration, num_samples, endpoint=False) return None
[docs] def get_relative_averaged_time_vector(self) -> Optional[np.ndarray]: # Returns the time vector relative to the start of the averaged data (starts at 0). avg_data = self.get_averaged_data() if avg_data is not None and self.sampling_rate and self.sampling_rate > 0: num_samples = len(avg_data) duration = num_samples / self.sampling_rate return np.linspace(0, duration, num_samples, endpoint=False) return None
[docs] def get_current_data(self, trial_index: int) -> Optional[np.ndarray]: # Returns the current data for a specific trial, if available. if self.current_data_trials and 0 <= trial_index < len(self.current_data_trials): return self.current_data_trials[trial_index] return None
[docs] def get_averaged_current_data(self) -> Optional[np.ndarray]: # Returns the averaged current data across all trials, if available. if self.current_data_trials: try: first_len = len(self.current_data_trials[0]) if all(len(trial) == first_len for trial in self.current_data_trials): return np.mean(np.array(self.current_data_trials), axis=0) else: log.warning(f"Channel {self.id}: Current trials have different lengths, cannot average.") return None except (TypeError, ValueError, IndexError) as e: log.error(f"Channel {self.id}: Error averaging current trials: {e}") return None return None
# --- ADDED HELPER FOR PLOT LABELS ---
[docs] def get_primary_data_label(self) -> str: """Determines a suitable label ('Voltage', 'Current', 'Signal') based on units.""" if self.units: units_lower = self.units.lower() if "v" in units_lower: return "Voltage" elif "a" in units_lower: # Check for 'amp' or 'a' return "Current" return "Signal" # Default if no units or not recognized
# --- END ADDED HELPER ---
[docs] def get_data_bounds(self) -> Optional[Tuple[float, float]]: """Returns the min and max values across all trials for this channel.""" if not self.data_trials or not any(trial.size > 0 for trial in self.data_trials): return None min_val = np.min([np.min(trial) for trial in self.data_trials if trial.size > 0]) max_val = np.max([np.max(trial) for trial in self.data_trials if trial.size > 0]) return float(min_val), float(max_val)
[docs] def get_finite_data_bounds(self) -> Optional[Tuple[float, float]]: """ Returns the min and max values across all trials, ensuring they are finite. Returns None if no finite data is found. """ if not self.data_trials or not any(trial.size > 0 for trial in self.data_trials): return None try: # Concatenate all finite data from all trials all_finite_data = np.concatenate( [trial[np.isfinite(trial)] for trial in self.data_trials if trial.size > 0] ) if all_finite_data.size == 0: return None min_val = np.min(all_finite_data) max_val = np.max(all_finite_data) return float(min_val), float(max_val) except (ValueError, TypeError): # Handles cases where there's no data left after filtering return None
# --- Undo support (non-destructive editing) ---
[docs] def push_undo(self, label: str = "") -> None: """Save the current ``data_trials`` state so that :meth:`undo` can restore it. Call this *before* any destructive operation (filter, event deletion, …). Args: label: Short human-readable description of the upcoming change (e.g. ``"lowpass 300 Hz"``). Stored for UI display only. """ snapshot = { "data_trials": [t.copy() if isinstance(t, np.ndarray) else t for t in self.data_trials], } self._undo_stack.push(label, snapshot) log.debug("Channel '%s': pushed undo state '%s' (stack depth %d).", self.name, label, self._undo_stack.depth)
[docs] def undo(self) -> bool: """Restore ``data_trials`` to the last saved state. Returns: ``True`` if a state was restored, ``False`` if the stack was empty. """ entry = self._undo_stack.pop() if entry is None: log.debug("Channel '%s': undo requested but stack is empty.", self.name) return False label, snapshot = entry self.data_trials = snapshot["data_trials"] log.debug("Channel '%s': undid '%s' (stack depth now %d).", self.name, label, self._undo_stack.depth) return True
@property def can_undo(self) -> bool: """``True`` when at least one undo level is available.""" return self._undo_stack.can_undo() def __repr__(self): return f"Channel(id='{self.id}', name='{self.name}', units='{self.units}', trials={self.num_trials})"
[docs] class Recording: """ Represents data and metadata loaded from a single recording file. Contains multiple Channel objects. """ def __init__(self, source_file: Path): """ Initializes a Recording object. Args: source_file: The Path object pointing to the original data file. """ if not isinstance(source_file, Path): log.warning(f"Invalid source_file type ({type(source_file)}), setting to placeholder.") self.source_file: Path = Path("./unknown_source_file") # Or raise error else: self.source_file: Path = source_file self.channels: Dict[str, Channel] = {} self.sampling_rate: Optional[float] = None self.duration: Optional[float] = None self.t_start: Optional[float] = None self.session_start_time_dt: Optional[datetime] = None self.protocol_name: Optional[str] = None self.injected_current: Optional[float] = None self.metadata: Dict[str, Any] = {} # Use Any for metadata flexibility # --- Nested data hierarchy (n vs N) --- # subject_id identifies the biological subject (e.g. "Mouse_01"). # cell_id identifies the recorded cell within that subject (e.g. "Cell_A"). # Both fields are optional and must be set by the caller after loading. # The BatchAnalysisEngine propagates them into every result row so that # downstream mixed-effects / hierarchical ANOVA tools can distinguish # between-subject variance (N) from within-subject replication (n). self.subject_id: Optional[str] = None self.cell_id: Optional[str] = None # Recording temperature in degrees Celsius. Downstream tools can use # this value to apply Q10 kinetic scaling when comparing across labs or # temperature conditions. Defaults to 22.0 (room temperature) when not # explicitly set by the file reader or the experimenter. self.recording_temperature: float = 22.0 # --- Lazy Loading Support --- # --- Lazy Loading Support --- self.source_handle: Optional[SourceHandle] = None # Decoupled handle for lazy loading
[docs] def close(self) -> None: """Release any underlying file handles held by the source handle. Must be called when a recording is removed from the workspace to prevent Neo IO readers from keeping the source file locked. """ if self.source_handle is not None and hasattr(self.source_handle, "close"): try: self.source_handle.close() except Exception as exc: log.debug("Error closing source handle for %s: %s", self.source_file, exc)
@property def num_channels(self) -> int: """Returns the number of channels in this recording.""" return len(self.channels) @property def channel_names(self) -> List[str]: """Returns a list of the names of all channels.""" return [ch.name for ch in self.channels.values() if hasattr(ch, "name")] @property def max_trials(self) -> int: """ Returns the maximum number of trials found across all channels in this recording. Returns 0 if there are no channels or no trials. """ if not self.channels: return 0 num_trials_list = [ch.num_trials for ch in self.channels.values()] return max(num_trials_list) if num_trials_list else 0
[docs] class Experiment: """ Optional container representing a collection of Recordings, potentially from a single experimental session or related set. (Currently basic). """ def __init__(self): self.recordings: List[Recording] = [] self.metadata: Dict[str, Any] = {} # Use Any for metadata flexibility self.identifier: str = str(uuid.uuid4()) # Example unique ID for the experiment