# src/Synaptipy/core/data_model.py
# -*- coding: utf-8 -*-
"""
Core Domain Data Models for Synaptipy.
Defines the central classes representing electrophysiology concepts like
Recording sessions and individual data Channels.
"""
import logging
import uuid
from datetime import datetime # Required for Recording timestamp
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple # Added Any for metadata dict
import numpy as np
from Synaptipy.core.source_interfaces import SourceHandle
# Configure logger for this module
log = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Undo / Command support
# ---------------------------------------------------------------------------
[docs]
class UndoStack:
"""Lightweight state-history stack for non-destructive editing (Command pattern).
Stores deep-copy snapshots of channel data before destructive operations so
that a single :meth:`Channel.undo` call instantly restores the previous state.
Memory depth is capped at *max_depth* entries (oldest entries are evicted).
Usage::
channel.push_undo("apply lowpass 300 Hz")
channel.data_trials = filtered_trials
# ...
channel.undo() # restores data_trials to state before the filter
"""
def __init__(self, max_depth: int = 20):
"""
Initialise the undo stack.
Args:
max_depth: Maximum number of undo levels retained (default 20).
"""
self._max_depth = max(1, int(max_depth))
self._states: List[Tuple[str, Dict[str, Any]]] = []
[docs]
def push(self, label: str, state: Dict[str, Any]) -> None:
"""Save a named state snapshot.
Args:
label: Human-readable description of the pending change (e.g. ``"notch 50 Hz"``).
state: Arbitrary serialisable dict representing the channel state to restore.
"""
self._states.append((label, state))
if len(self._states) > self._max_depth:
self._states.pop(0)
[docs]
def pop(self) -> Optional[Tuple[str, Dict[str, Any]]]:
"""Remove and return the most recently saved state.
Returns:
``(label, state)`` tuple, or ``None`` if the stack is empty.
"""
return self._states.pop() if self._states else None
[docs]
def can_undo(self) -> bool:
"""Return ``True`` if at least one undo level is available."""
return bool(self._states)
@property
def depth(self) -> int:
"""Number of undo levels currently stored."""
return len(self._states)
[docs]
def clear(self) -> None:
"""Discard all saved states."""
self._states.clear()
def __repr__(self) -> str:
labels = [lbl for lbl, _ in self._states]
return f"UndoStack(depth={self.depth}, labels={labels})"
[docs]
class Channel:
"""
Represents a single channel of recorded data, potentially across multiple
trials or segments.
"""
def __init__(
self,
id: str,
name: str,
units: str,
sampling_rate: float,
data_trials: List[np.ndarray],
loader: Optional[Any] = None, # Callable[[int], Optional[np.ndarray]]
):
"""
Initializes a Channel object.
Args:
id: A unique identifier for the channel (e.g., '0', '1', 'Vm').
name: A descriptive name for the channel (e.g., 'Voltage', 'IN_0').
units: The physical units of the data (e.g., 'mV', 'pA', 'V').
sampling_rate: The sampling frequency in Hz.
data_trials: A list where each element is a 1D NumPy array
representing the data for one trial/segment.
loader: Optional callable/object with load_trial(index) method for lazy loading.
"""
# --- Core Attributes (Assigned ONCE) ---
self.id: str = id
self.name: str = name
self.units: str = units if units else "unknown" # Ensure units is a string
self.sampling_rate: float = sampling_rate
self.t_start: float = 0.0 # Absolute start time relative to recording start (set by adapter)
# --- Data Trials Validation and Assignment ---
# For lazy loading, data_trials may be empty initially
if not isinstance(data_trials, list):
log.warning(f"Channel '{name}' received non-list data. Attempting conversion.")
try:
# Ensure data is numpy array and handle potential conversion errors
self.data_trials: List[np.ndarray] = (
[np.asarray(data_trials, dtype=float)] if data_trials is not None else []
)
except (TypeError, ValueError) as e:
log.error(f"Could not convert data_trials for channel '{name}' to list of arrays: {e}")
self.data_trials = [] # Assign empty list on failure
else:
# For lazy loading, data_trials may be empty or contain actual data
if data_trials and all(isinstance(t, np.ndarray) for t in data_trials):
# Normal case: data is already loaded
self.data_trials: List[np.ndarray] = [np.asarray(t) for t in data_trials]
else:
# Lazy loading case: empty data_trials
# Ensure it's a list so we can append later.
# Crucially, we might want to pre-allocate None slots if we know num_trials (from metadata?)
# But for now, empty list is safer than [[]] which implies 1 empty trial.
self.data_trials: List[Optional[np.ndarray]] = []
# --- ADDED: Attributes for Associated Current Data ---
self.current_data_trials: List[np.ndarray] = [] # Populated by adapter if current signal found
self.current_units: Optional[str] = None # Populated by adapter
# --- END ADDED ---
# --- Optional Electrode Metadata (Populated by NeoAdapter) ---
self.electrode_description: Optional[str] = None
self.electrode_location: Optional[str] = None
self.electrode_filtering: Optional[str] = None
self.electrode_gain: Optional[float] = None # Gain applied by amplifier/acquisition
self.electrode_offset: Optional[float] = None # ADC offset or baseline offset
self.electrode_resistance: Optional[str] = (
None # Pipette resistance (e.g., "10 MOhm") - Requires parsing for NWB
)
self.electrode_seal: Optional[str] = None # Seal resistance (e.g., "5 GOhm") - Requires parsing for NWB
# --- Lazy Loading Support ---
self.loader = loader
self.metadata: Dict[str, Any] = {} # General metadata dictionary
# --- Undo stack (non-destructive editing) ---
self._undo_stack: UndoStack = UndoStack()
@property
def num_trials(self) -> int:
"""Returns the number of trials/segments available for this channel."""
# For lazy loading, check metadata first, then data_trials
if hasattr(self, "metadata") and "num_trials" in self.metadata:
return self.metadata["num_trials"]
return len(self.data_trials)
@property
def num_samples(self) -> int:
"""
Returns the number of samples in the first trial.
WARNING: This property only checks the first trial. If trials have variable lengths,
this value may be misleading. Use `get_consistent_samples()` for strict validation.
Returns 0 if no trials are present.
"""
if not self.data_trials:
return 0
# Ensure the first trial is valid before accessing shape
if not isinstance(self.data_trials[0], np.ndarray) or self.data_trials[0].ndim == 0:
log.warning(f"Channel '{self.name}': First trial is not a valid NumPy array.")
return 0
first_trial_len = self.data_trials[0].shape[0]
# Check other trials more carefully
lengths = set()
valid_trial_found = False
for arr in self.data_trials:
if isinstance(arr, np.ndarray) and arr.ndim > 0:
lengths.add(arr.shape[0])
valid_trial_found = True
else:
log.warning(f"Channel '{self.name}' contains invalid trial data type: {type(arr)}")
if not valid_trial_found:
log.warning(f"Channel '{self.name}' contains no valid NumPy array trials.")
return 0
if len(lengths) > 1:
log.warning(
f"Channel '{self.name}' has trials with varying lengths: {lengths}. "
f"`num_samples` returning length of first trial ({first_trial_len})."
)
# Return length of first valid trial if lengths are consistent or vary
return first_trial_len if lengths else 0
[docs]
def get_consistent_samples(self) -> int:
"""
Returns the number of samples per trial, ensuring all trials have the same length.
Raises ValueError if trials have different lengths.
Returns 0 if no trials.
"""
if not self.data_trials:
return 0
lengths = set()
for arr in self.data_trials:
if isinstance(arr, np.ndarray) and arr.ndim > 0:
lengths.add(arr.shape[0])
if len(lengths) == 0:
return 0
if len(lengths) > 1:
raise ValueError(f"Channel '{self.name}' has inconsistent trial lengths: {lengths}")
return list(lengths)[0]
# --- Data Retrieval Methods ---
[docs]
def get_data(self, trial_index: int) -> Optional[np.ndarray]:
"""
Returns the raw data for a specific trial.
For lazy loading, this method will load the data from disk if not already loaded.
"""
# Check if data is already loaded
if self.data_trials and 0 <= trial_index < len(self.data_trials):
data = self.data_trials[trial_index]
if data is not None:
return data
# For lazy loading, try to load data using the loader
if self.loader:
try:
# If loader is a callable or has load_trial method
if hasattr(self.loader, "load_trial"):
data = self.loader.load_trial(trial_index)
elif callable(self.loader):
data = self.loader(trial_index)
else:
log.error(f"Channel {self.id}: Invalid loader object.")
return None
if data is not None:
# Store valid data to avoid re-loading
# Ensure data_trials list is long enough
while len(self.data_trials) <= trial_index:
self.data_trials.append(None)
self.data_trials[trial_index] = data
return data
except (TypeError, ValueError, IndexError) as e:
log.error(f"Failed to load trial {trial_index} data lazily for channel {self.id}: {e}")
return None
return None
# _load_trial_data_lazy removed (moved to ChannelLoader strategy)
[docs]
def get_time_vector(self, trial_index: int) -> Optional[np.ndarray]:
# Returns the absolute time vector for a specific trial.
data = self.get_data(trial_index)
if data is not None and self.sampling_rate and self.sampling_rate > 0:
num_samples = len(data)
duration = num_samples / self.sampling_rate
trial_t_start = self.t_start + trial_index * duration # Approximate start time
return np.linspace(trial_t_start, trial_t_start + duration, num_samples, endpoint=False)
return None
[docs]
def get_relative_time_vector(self, trial_index: int) -> Optional[np.ndarray]:
# Returns the time vector relative to the start of the trial (starts at 0).
data = self.get_data(trial_index)
if data is not None and self.sampling_rate and self.sampling_rate > 0:
num_samples = len(data)
duration = num_samples / self.sampling_rate
return np.linspace(0, duration, num_samples, endpoint=False)
return None
[docs]
def get_averaged_data(self, trial_indices: Optional[List[int]] = None) -> Optional[np.ndarray]:
# Returns the averaged data across all (or specified) trials.
if self.data_trials:
try:
# Determine which trials to use
if trial_indices is not None and len(trial_indices) > 0:
# Validate indices
valid_indices = [i for i in trial_indices if 0 <= i < len(self.data_trials)]
trials_to_avg = [self.data_trials[i] for i in valid_indices if self.data_trials[i] is not None]
else:
trials_to_avg = [t for t in self.data_trials if t is not None]
if not trials_to_avg:
return None
# Ensure all trials have the same length for simple averaging
first_len = len(trials_to_avg[0])
if all(len(trial) == first_len for trial in trials_to_avg):
return np.mean(np.array(trials_to_avg), axis=0)
else:
# Handle differing lengths (e.g., pad or error)
log.warning(f"Channel {self.id}: Trials have different lengths, cannot average directly.")
return None
except (TypeError, ValueError, IndexError) as e:
log.error(f"Channel {self.id}: Error averaging trials: {e}")
return None
return None
[docs]
def get_averaged_time_vector(self) -> Optional[np.ndarray]:
# Returns the absolute time vector for the averaged data (assumes first trial time base).
avg_data = self.get_averaged_data()
if avg_data is not None and self.sampling_rate and self.sampling_rate > 0:
num_samples = len(avg_data)
duration = num_samples / self.sampling_rate
return np.linspace(self.t_start, self.t_start + duration, num_samples, endpoint=False)
return None
[docs]
def get_relative_averaged_time_vector(self) -> Optional[np.ndarray]:
# Returns the time vector relative to the start of the averaged data (starts at 0).
avg_data = self.get_averaged_data()
if avg_data is not None and self.sampling_rate and self.sampling_rate > 0:
num_samples = len(avg_data)
duration = num_samples / self.sampling_rate
return np.linspace(0, duration, num_samples, endpoint=False)
return None
[docs]
def get_current_data(self, trial_index: int) -> Optional[np.ndarray]:
# Returns the current data for a specific trial, if available.
if self.current_data_trials and 0 <= trial_index < len(self.current_data_trials):
return self.current_data_trials[trial_index]
return None
[docs]
def get_averaged_current_data(self) -> Optional[np.ndarray]:
# Returns the averaged current data across all trials, if available.
if self.current_data_trials:
try:
first_len = len(self.current_data_trials[0])
if all(len(trial) == first_len for trial in self.current_data_trials):
return np.mean(np.array(self.current_data_trials), axis=0)
else:
log.warning(f"Channel {self.id}: Current trials have different lengths, cannot average.")
return None
except (TypeError, ValueError, IndexError) as e:
log.error(f"Channel {self.id}: Error averaging current trials: {e}")
return None
return None
# --- ADDED HELPER FOR PLOT LABELS ---
[docs]
def get_primary_data_label(self) -> str:
"""Determines a suitable label ('Voltage', 'Current', 'Signal') based on units."""
if self.units:
units_lower = self.units.lower()
if "v" in units_lower:
return "Voltage"
elif "a" in units_lower: # Check for 'amp' or 'a'
return "Current"
return "Signal" # Default if no units or not recognized
# --- END ADDED HELPER ---
[docs]
def get_data_bounds(self) -> Optional[Tuple[float, float]]:
"""Returns the min and max values across all trials for this channel."""
if not self.data_trials or not any(trial.size > 0 for trial in self.data_trials):
return None
min_val = np.min([np.min(trial) for trial in self.data_trials if trial.size > 0])
max_val = np.max([np.max(trial) for trial in self.data_trials if trial.size > 0])
return float(min_val), float(max_val)
[docs]
def get_finite_data_bounds(self) -> Optional[Tuple[float, float]]:
"""
Returns the min and max values across all trials, ensuring they are finite.
Returns None if no finite data is found.
"""
if not self.data_trials or not any(trial.size > 0 for trial in self.data_trials):
return None
try:
# Concatenate all finite data from all trials
all_finite_data = np.concatenate(
[trial[np.isfinite(trial)] for trial in self.data_trials if trial.size > 0]
)
if all_finite_data.size == 0:
return None
min_val = np.min(all_finite_data)
max_val = np.max(all_finite_data)
return float(min_val), float(max_val)
except (ValueError, TypeError):
# Handles cases where there's no data left after filtering
return None
# --- Undo support (non-destructive editing) ---
[docs]
def push_undo(self, label: str = "") -> None:
"""Save the current ``data_trials`` state so that :meth:`undo` can restore it.
Call this *before* any destructive operation (filter, event deletion, …).
Args:
label: Short human-readable description of the upcoming change
(e.g. ``"lowpass 300 Hz"``). Stored for UI display only.
"""
snapshot = {
"data_trials": [t.copy() if isinstance(t, np.ndarray) else t for t in self.data_trials],
}
self._undo_stack.push(label, snapshot)
log.debug("Channel '%s': pushed undo state '%s' (stack depth %d).", self.name, label, self._undo_stack.depth)
[docs]
def undo(self) -> bool:
"""Restore ``data_trials`` to the last saved state.
Returns:
``True`` if a state was restored, ``False`` if the stack was empty.
"""
entry = self._undo_stack.pop()
if entry is None:
log.debug("Channel '%s': undo requested but stack is empty.", self.name)
return False
label, snapshot = entry
self.data_trials = snapshot["data_trials"]
log.debug("Channel '%s': undid '%s' (stack depth now %d).", self.name, label, self._undo_stack.depth)
return True
@property
def can_undo(self) -> bool:
"""``True`` when at least one undo level is available."""
return self._undo_stack.can_undo()
def __repr__(self):
return f"Channel(id='{self.id}', name='{self.name}', units='{self.units}', trials={self.num_trials})"
[docs]
class Recording:
"""
Represents data and metadata loaded from a single recording file.
Contains multiple Channel objects.
"""
def __init__(self, source_file: Path):
"""
Initializes a Recording object.
Args:
source_file: The Path object pointing to the original data file.
"""
if not isinstance(source_file, Path):
log.warning(f"Invalid source_file type ({type(source_file)}), setting to placeholder.")
self.source_file: Path = Path("./unknown_source_file") # Or raise error
else:
self.source_file: Path = source_file
self.channels: Dict[str, Channel] = {}
self.sampling_rate: Optional[float] = None
self.duration: Optional[float] = None
self.t_start: Optional[float] = None
self.session_start_time_dt: Optional[datetime] = None
self.protocol_name: Optional[str] = None
self.injected_current: Optional[float] = None
self.metadata: Dict[str, Any] = {} # Use Any for metadata flexibility
# --- Nested data hierarchy (n vs N) ---
# subject_id identifies the biological subject (e.g. "Mouse_01").
# cell_id identifies the recorded cell within that subject (e.g. "Cell_A").
# Both fields are optional and must be set by the caller after loading.
# The BatchAnalysisEngine propagates them into every result row so that
# downstream mixed-effects / hierarchical ANOVA tools can distinguish
# between-subject variance (N) from within-subject replication (n).
self.subject_id: Optional[str] = None
self.cell_id: Optional[str] = None
# Recording temperature in degrees Celsius. Downstream tools can use
# this value to apply Q10 kinetic scaling when comparing across labs or
# temperature conditions. Defaults to 22.0 (room temperature) when not
# explicitly set by the file reader or the experimenter.
self.recording_temperature: float = 22.0
# --- Lazy Loading Support ---
# --- Lazy Loading Support ---
self.source_handle: Optional[SourceHandle] = None # Decoupled handle for lazy loading
[docs]
def close(self) -> None:
"""Release any underlying file handles held by the source handle.
Must be called when a recording is removed from the workspace to
prevent Neo IO readers from keeping the source file locked.
"""
if self.source_handle is not None and hasattr(self.source_handle, "close"):
try:
self.source_handle.close()
except Exception as exc:
log.debug("Error closing source handle for %s: %s", self.source_file, exc)
@property
def num_channels(self) -> int:
"""Returns the number of channels in this recording."""
return len(self.channels)
@property
def channel_names(self) -> List[str]:
"""Returns a list of the names of all channels."""
return [ch.name for ch in self.channels.values() if hasattr(ch, "name")]
@property
def max_trials(self) -> int:
"""
Returns the maximum number of trials found across all channels in this recording.
Returns 0 if there are no channels or no trials.
"""
if not self.channels:
return 0
num_trials_list = [ch.num_trials for ch in self.channels.values()]
return max(num_trials_list) if num_trials_list else 0
[docs]
class Experiment:
"""
Optional container representing a collection of Recordings, potentially
from a single experimental session or related set. (Currently basic).
"""
def __init__(self):
self.recordings: List[Recording] = []
self.metadata: Dict[str, Any] = {} # Use Any for metadata flexibility
self.identifier: str = str(uuid.uuid4()) # Example unique ID for the experiment