"""
Batch Analysis Engine for Synaptipy.
Handles processing multiple files and aggregating results using a flexible registry-based pipeline.
The engine uses a registry-based architecture where analysis functions register
themselves via decorators, and the pipeline configuration defines what analyses
to run on which data scopes.
Output Design Principles
------------------------
1. Every row is fully traceable to its source (file, channel, trial, analysis).
2. Metadata columns appear first; analysis results in the middle; internal/debug last.
3. Scalar results live in their own columns; array values are summarised for tabular
compatibility (Excel, Origin, R, MATLAB) and the raw arrays are kept under
private ``_``-prefixed keys that are stripped during CSV export.
4. Channel physical units are always recorded so downstream scripts can auto-label axes.
5. Recording-level metadata (protocol, duration, session time) is propagated when available.
Author: Anzal K Shahul <anzal.ks@gmail.com>
"""
import gc
import logging
import multiprocessing
import traceback # Added for stack trace logging
from concurrent.futures import ProcessPoolExecutor, as_completed
from datetime import datetime
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import pandas as pd
# Import analysis package to trigger all registrations
import Synaptipy.core.analysis # noqa: F401 - Import triggers all registrations
from Synaptipy.core.analysis.registry import AnalysisRegistry
from Synaptipy.core.data_model import Recording
from Synaptipy.infrastructure.file_readers import NeoAdapter
log = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Column ordering constants — metadata first, results middle, debug last
# ---------------------------------------------------------------------------
_METADATA_COLUMNS_ORDER = [
"subject_id",
"cell_id",
"file_name",
"file_path",
"protocol",
"recording_duration_s",
"channel",
"channel_units",
"analysis",
"scope",
"trial_index",
"trial_count",
"sampling_rate",
]
_TRAILING_COLUMNS = [
"batch_timestamp",
"error",
"debug_trace",
]
# Human-readable aliases for result keys that lack biological context.
# Only applied as *additional* columns; originals are preserved for scripting.
_HUMAN_READABLE_ALIASES: Dict[str, str] = {
"cv": "coeff_of_variation",
"cv2": "local_cv2_holt",
"lv": "local_variation_shinomoto",
"fi_slope": "fi_gain_hz_per_pa",
"fi_r_squared": "fi_fit_r_squared",
"iv_r_squared": "iv_fit_r_squared",
}
[docs]
class BatchAnalysisEngine:
"""
Engine for running analysis across multiple files/recordings using a flexible pipeline.
The engine uses a registry-based architecture where analysis functions register
themselves via decorators, and the pipeline configuration defines what analyses
to run on which data scopes.
Example::
engine = BatchAnalysisEngine()
files = [Path("file1.abf"), Path("file2.abf")]
pipeline = [
{
'analysis': 'spike_detection',
'scope': 'all_trials',
'params': {'threshold': -15.0, 'refractory_ms': 2.0}
},
{
'analysis': 'rmp_analysis',
'scope': 'average',
'params': {'baseline_start': 0.0, 'baseline_end': 0.1}
}
]
results_df = engine.run_batch(files, pipeline)
"""
def __init__(self, neo_adapter: Optional[NeoAdapter] = None, max_workers: int = 1):
"""
Initialize the batch analysis engine.
Args:
neo_adapter: Optional NeoAdapter instance. If None, creates a new one.
max_workers: Number of parallel worker processes for :meth:`run_batch`.
1 (default) means fully sequential execution.
Values > 1 enable :class:`~concurrent.futures.ProcessPoolExecutor`
parallelism. Pass ``-1`` to use all available CPU cores.
"""
self.neo_adapter = neo_adapter if neo_adapter else NeoAdapter()
self._cancelled = False
cpu_count = multiprocessing.cpu_count()
if max_workers < 0:
self.max_workers: int = cpu_count
else:
self.max_workers = max(1, int(max_workers))
[docs]
def cancel(self):
"""Request cancellation of the current batch run."""
self._cancelled = True
log.debug("Batch analysis cancellation requested.")
[docs]
@staticmethod
def list_available_analyses() -> List[str]:
"""
Get a list of all registered analysis function names.
Returns:
List of available analysis names.
"""
return AnalysisRegistry.list_registered()
[docs]
@staticmethod
def get_analysis_info(name: str) -> Optional[Dict[str, Any]]:
"""
Get information about a registered analysis function.
Args:
name: The registered name of the analysis function.
Returns:
Dictionary with function info (docstring, etc.) or None if not found.
"""
func = AnalysisRegistry.get_function(name)
if func is None:
return None
return {
"name": name,
"docstring": func.__doc__ or "No documentation available.",
"module": func.__module__,
}
# ------------------------------------------------------------------
# Output post-processing helpers
# ------------------------------------------------------------------
@staticmethod
def _sanitise_value(key: str, value: Any) -> Tuple[Any, Optional[Tuple[str, Any]]]:
"""Sanitise a single result value for export.
Returns:
Tuple of (replacement_value, optional (stash_key, stash_value)).
"""
if isinstance(value, np.ndarray):
return BatchAnalysisEngine._sanitise_ndarray(key, value)
if isinstance(value, list) and len(value) > 5:
return BatchAnalysisEngine._sanitise_long_list(key, value)
if not isinstance(value, (int, float, str, bool, type(None))):
return f"{type(value).__name__}", (f"_{key}_obj", value)
return value, None
@staticmethod
def _sanitise_ndarray(key: str, value: np.ndarray) -> Tuple[Any, Optional[Tuple[str, Any]]]:
"""Summarise numpy arrays for CSV-friendly output."""
if value.size <= 5:
return value.tolist(), None
summary = f"n={value.size}"
if np.issubdtype(value.dtype, np.floating):
summary = (
f"n={value.size}, "
f"mean={np.nanmean(value):.4g}, "
f"min={np.nanmin(value):.4g}, "
f"max={np.nanmax(value):.4g}"
)
return summary, (f"_{key}_raw", value)
@staticmethod
def _sanitise_long_list(key: str, value: list) -> Tuple[Any, Optional[Tuple[str, Any]]]:
"""Summarise long lists for CSV-friendly output."""
try:
arr = np.asarray(value, dtype=float)
summary = (
f"n={arr.size}, "
f"mean={np.nanmean(arr):.4g}, "
f"min={np.nanmin(arr):.4g}, "
f"max={np.nanmax(arr):.4g}"
)
return summary, (f"_{key}_raw", arr)
except (ValueError, TypeError):
return f"[{len(value)} items]", None
# Keys whose arrays represent discrete events (spike times, PSP amplitudes,
# etc.) and should be preserved verbatim inside ``_raw_arrays`` for downstream
# plotting and long-format CSV export. Full continuous trace arrays (fitting
# curves, 10 kHz voltage traces) must NOT appear here to avoid memory spikes.
_EVENT_ARRAY_KEYS: frozenset = frozenset(
{
"event_times",
"event_amplitudes",
"event_iei",
"spike_times",
"spike_amplitudes",
"spike_counts",
"frequencies",
"current_steps",
"adaptation_ratios",
"broadening_indices",
"isi_array",
"event_rise_times",
"event_decay_times",
"event_charges",
}
)
@staticmethod
def _sanitise_result_for_export(result: Dict[str, Any]) -> Dict[str, Any]:
"""
Make a single result row export-friendly.
Dual-representation architecture
---------------------------------
1. Public array keys whose names appear in ``_EVENT_ARRAY_KEYS`` produce
both a human-readable summary string (kept in the main dict so UI
tables render correctly) *and* a copy of the raw NumPy array stored
under ``result["_raw_arrays"][key]``. Only discrete-event arrays
(spike times, PSP amplitudes, …) are stored this way; continuous
10 kHz trace buffers are excluded to avoid memory spikes.
2. All other non-scalar values are summarised as strings and the raw
object moved to a ``_``-prefixed key (existing behaviour).
3. Human-readable aliases are added for cryptic algorithm names.
4. Private ``_``-prefixed keys are left untouched.
Returns:
Cleaned result dict (modified in-place for efficiency).
"""
raw_arrays: Dict[str, Any] = result.get("_raw_arrays", {}) # may already exist
keys_to_add: Dict[str, Any] = {}
for key, value in list(result.items()):
if key.startswith("_"):
continue
new_value, stash = BatchAnalysisEngine._sanitise_value(key, value)
result[key] = new_value
# For discrete-event arrays: also store raw data under _raw_arrays.
if key in BatchAnalysisEngine._EVENT_ARRAY_KEYS and isinstance(value, (np.ndarray, list)):
arr = np.asarray(value) if isinstance(value, list) else value
# Only store compact discrete-event arrays (<= 50 000 elements)
# to guard against accidentally persisting continuous traces.
if arr.ndim <= 2 and arr.size <= 50_000:
raw_arrays[key] = arr
if stash is not None:
keys_to_add[stash[0]] = stash[1]
if raw_arrays:
result["_raw_arrays"] = raw_arrays
result.update(keys_to_add)
# Add human-readable aliases for cryptic keys
for orig_key, alias in _HUMAN_READABLE_ALIASES.items():
if orig_key in result and alias not in result:
result[alias] = result[orig_key]
return result
@staticmethod
def _recording_metadata(recording: "Recording") -> Dict[str, Any]:
"""Extract recording-level metadata for result rows.
Includes ``subject_id`` and ``cell_id`` when set on the Recording so
that downstream hierarchical mixed-effects analyses can distinguish
between-subject (N) from within-subject (n) observations.
"""
meta: Dict[str, Any] = {}
if recording is None:
return meta
if hasattr(recording, "protocol_name") and recording.protocol_name:
meta["protocol"] = recording.protocol_name
if hasattr(recording, "duration") and recording.duration is not None:
meta["recording_duration_s"] = round(float(recording.duration), 4)
if hasattr(recording, "subject_id"):
meta["subject_id"] = recording.subject_id
if hasattr(recording, "cell_id"):
meta["cell_id"] = recording.cell_id
return meta
@staticmethod
def _order_columns(df: pd.DataFrame) -> pd.DataFrame:
"""Reorder DataFrame columns: metadata → results → trailing/debug."""
if df.empty:
return df
all_cols = list(df.columns)
# 1. Leading metadata columns (in defined order)
leading = [c for c in _METADATA_COLUMNS_ORDER if c in all_cols]
# 2. Trailing debug/internal columns
trailing = [c for c in _TRAILING_COLUMNS if c in all_cols]
# 3. Private columns (underscore-prefixed)
private = sorted(c for c in all_cols if c.startswith("_"))
# 4. Everything else = result columns, alphabetically
used = set(leading) | set(trailing) | set(private)
results = sorted(c for c in all_cols if c not in used)
ordered = leading + results + trailing + private
return df[[c for c in ordered if c in all_cols]]
@staticmethod
def _append_batch_error_log(file_name: str, file_path_str: str, exc: Exception) -> None:
"""Append a one-line error entry to ``~/.synaptipy/logs/batch_errors.log``.
Writing errors to a dedicated log file ensures a 100-file batch is never
aborted by a single corrupted recording. Each line is ISO-8601 timestamped
so the analyst can correlate entries with their batch run.
Args:
file_name: Base name of the failed file.
file_path_str: Full path string for the failed file.
exc: The exception that caused the failure.
"""
try:
log_dir = Path.home() / ".synaptipy" / "logs"
log_dir.mkdir(parents=True, exist_ok=True)
error_log_path = log_dir / "batch_errors.log"
timestamp = datetime.now().isoformat(timespec="seconds")
entry = f"{timestamp} | {file_name} | {file_path_str} | {type(exc).__name__}: {exc}\n"
with open(error_log_path, "a", encoding="utf-8") as fh:
fh.write(entry)
except Exception as write_exc: # noqa: BLE001
log.warning("Could not write to batch_errors.log: %s", write_exc)
# ------------------------------------------------------------------
# Parallel execution helpers
# ------------------------------------------------------------------
def _run_batch_parallel( # noqa: C901
self,
files: List[Union[Path, "Recording"]],
pipeline_config: List[Dict[str, Any]],
progress_callback: Optional[Callable[[int, int, str], None]],
channel_filter: Optional[List[str]],
) -> pd.DataFrame:
"""Distribute file-level processing across :attr:`max_workers` worker processes.
Each worker process receives a single file path (in-memory Recording objects
are serialised via pickle), imports the full analysis package to populate the
registry, and returns a list of result-row dicts. Progress signals are emitted
through the optional *progress_callback* as each future completes.
OOM safety: every worker calls ``gc.collect()`` after processing its file.
"""
total_files = len(files)
batch_start_time = datetime.now()
# Separate paths from pre-loaded Recording objects.
# Pre-loaded Recording objects are processed sequentially (pickle cost not worth it).
path_tasks: List[Tuple[int, Path]] = []
inline_recordings: List[Tuple[int, Any]] = []
for idx, item in enumerate(files):
if isinstance(item, (str, Path)):
path_tasks.append((idx, Path(item)))
else:
inline_recordings.append((idx, item))
all_rows: List[List[Dict[str, Any]]] = [[] for _ in range(total_files)]
completed_count = 0
# Submit path-based tasks to the pool
future_to_idx: Dict[Any, int] = {}
pool_kwargs: Dict[str, Any] = {"max_workers": self.max_workers}
# Use spawn context on all platforms for process-safety with Qt/numpy
ctx = multiprocessing.get_context("spawn")
pool_kwargs["mp_context"] = ctx
with ProcessPoolExecutor(**pool_kwargs) as executor:
for orig_idx, file_path in path_tasks:
future = executor.submit(
_worker_process_file,
str(file_path),
pipeline_config,
channel_filter,
)
future_to_idx[future] = orig_idx
for future in as_completed(future_to_idx):
orig_idx = future_to_idx[future]
file_path = files[orig_idx]
file_name = Path(str(file_path)).name
completed_count += 1
try:
rows = future.result()
all_rows[orig_idx] = rows
except Exception as exc: # noqa: BLE001
log.error("Worker failed for %s: %s", file_path, exc, exc_info=True)
self._append_batch_error_log(file_name, str(file_path), exc)
all_rows[orig_idx] = [
{
"file_name": file_name,
"file_path": str(file_path),
"error": str(exc),
"debug_trace": traceback.format_exc(),
}
]
finally:
if progress_callback:
progress_callback(completed_count, total_files, f"Processed {file_name}")
if self._cancelled:
executor.shutdown(wait=False, cancel_futures=True)
break
# Process in-memory recordings sequentially (they can't be pickled reliably)
for orig_idx, recording in inline_recordings:
if self._cancelled:
break
completed_count += 1
file_name = getattr(getattr(recording, "source_file", None), "name", f"InMemory_{orig_idx}")
if progress_callback:
progress_callback(completed_count, total_files, f"Processing {file_name}...")
try:
df_inline = self._run_batch_sequential([recording], pipeline_config, None, channel_filter)
all_rows[orig_idx] = df_inline.to_dict("records") if not df_inline.empty else []
except Exception as exc: # noqa: BLE001
log.error("Inline recording failed: %s", exc, exc_info=True)
all_rows[orig_idx] = [
{"file_name": file_name, "error": str(exc), "debug_trace": traceback.format_exc()}
]
if progress_callback:
msg = "Batch cancelled." if self._cancelled else "Batch analysis complete."
progress_callback(total_files, total_files, msg)
flat_rows = [row for rows in all_rows for row in rows]
df = pd.DataFrame(flat_rows)
if not df.empty:
df["batch_timestamp"] = batch_start_time.isoformat()
df = self._order_columns(df)
return df
[docs]
def run_batch( # noqa: C901
self,
files: List[Union[Path, "Recording"]],
pipeline_config: List[Dict[str, Any]],
progress_callback: Optional[Callable[[int, int, str], None]] = None,
channel_filter: Optional[List[str]] = None,
rs_tolerance: float = 0.20,
) -> pd.DataFrame:
"""
Run analysis on a list of files/recordings using a flexible pipeline configuration.
When :attr:`max_workers` > 1 **and** *files* contains at least two items, the
file-level loop is distributed across worker processes via
:class:`~concurrent.futures.ProcessPoolExecutor`. The GUI thread is never
blocked in either mode — callers should wrap this in a
:class:`~Synaptipy.application.gui.analysis_worker.BatchWorker` QThread.
Args:
files: List of file paths OR Recording objects to process.
pipeline_config: List of task dictionaries.
progress_callback: Optional callback (current, total, status_msg).
channel_filter: Optional list of channel names/IDs to process.
rs_tolerance: Maximum fractional increase in series resistance compared
to the first valid Rs measurement before a sweep is flagged with
``rs_qc_warning``. Default 0.20 (20 %). Set to ``float('inf')``
to disable the check.
Returns:
pandas DataFrame containing aggregated results with metadata.
"""
self._cancelled = False
total_files = len(files)
# Validate pipeline config
if not pipeline_config:
log.warning("Empty pipeline_config provided. No analyses will be run.")
return pd.DataFrame()
# Route to parallel executor when max_workers > 1 and we have multiple files
if self.max_workers > 1 and total_files > 1:
log.info(
"BatchAnalysisEngine: starting parallel batch (%d workers, %d files).", self.max_workers, total_files
)
return self._run_batch_parallel(files, pipeline_config, progress_callback, channel_filter)
return self._run_batch_sequential(files, pipeline_config, progress_callback, channel_filter, rs_tolerance)
def _run_batch_sequential( # noqa: C901
self,
files: List[Union[Path, "Recording"]],
pipeline_config: List[Dict[str, Any]],
progress_callback: Optional[Callable[[int, int, str], None]],
channel_filter: Optional[List[str]],
rs_tolerance: float = 0.20,
) -> pd.DataFrame:
"""Sequential (single-process) batch processing — the original implementation."""
results_list = []
total_files = len(files)
# Add batch metadata
batch_start_time = datetime.now()
for i, item in enumerate(files):
# Check for cancellation
if self._cancelled:
log.debug("Batch analysis cancelled by user.")
if progress_callback:
progress_callback(i, total_files, "Cancelled")
break
file_name = "Unknown"
file_path_str = "InMemory"
file_path = None # Initialize file_path
try:
# Determine if item is Path or Recording
recording = None
if isinstance(item, (str, Path)):
file_path = Path(item)
file_name = file_path.name
file_path_str = str(file_path)
if progress_callback:
progress_callback(i, total_files, f"Processing {file_name}...")
# Load recording from disk with whitelist (Memory Optimization)
recording = self.neo_adapter.read_recording(file_path, channel_whitelist=channel_filter)
if not recording:
log.warning(f"Failed to load {file_path}")
results_list.append(
{"file_name": file_name, "file_path": file_path_str, "error": "Failed to load recording"}
)
continue
else:
# Assume it is a Recording object
recording = item
if hasattr(recording, "source_file") and recording.source_file:
file_path = recording.source_file
file_name = recording.source_file.name
file_path_str = str(recording.source_file)
else:
# Fallback for purely in-memory recordings
file_path = Path(f"InMemory_Recording_{i}")
file_name = file_path.name
file_path_str = str(file_path)
if progress_callback:
progress_callback(i, total_files, f"Processing {file_name}...")
# Filter channels if specified
channels_to_process = recording.channels.items()
if channel_filter:
log.debug(f"Applying channel filter: {channel_filter}")
channels_to_process = [
(name, ch)
for name, ch in recording.channels.items()
if name in channel_filter or str(name) in channel_filter
]
if not channels_to_process:
log.warning(f"Channel filter {channel_filter} matched no channels in {file_name}.")
log.debug(f"Processing {len(channels_to_process)} channels: {[n for n, c in channels_to_process]}")
# Extract recording-level metadata once per file
rec_meta = self._recording_metadata(recording)
# Iterate through channels
for channel_key, channel in channels_to_process:
# Check for cancellation
if self._cancelled:
break
# Prefer the native channel name from the acquisition file header.
# Fall back to the channel key (ID) only when no name is available.
native_channel_name = getattr(channel, "name", None)
channel_name = native_channel_name if native_channel_name else channel_key
# Per-channel metadata available to every result row
ch_meta = {
"channel_units": getattr(channel, "units", "unknown"),
"trial_count": getattr(channel, "num_trials", 0),
}
ch_meta.update(rec_meta)
# Data Buffer for the pipeline (stores (data, time) tuples or lists)
pipeline_context = {
"scope": None, # Current scope of data in context
"data": None, # The data (array or list)
"time": None, # The time (array or list)
}
# Series-resistance stability tracker: reset for each new channel.
# rs_reference_mohm stores the Rs from the first valid sweep so
# all subsequent sweeps can be checked for drift.
rs_reference_mohm: Optional[float] = None
# Process each task in the pipeline
for task in pipeline_config:
if self._cancelled:
break
try:
# Pass the context to allow tasks to use/modify it
task_results, updated_context = self._process_task(
task=task,
channel=channel,
channel_name=channel_name,
file_path=file_path,
context=pipeline_context,
)
# Update context if the task modified it (e.g. preprocessing)
if updated_context:
pipeline_context = updated_context
# Enrich each result row with channel/recording metadata
for res in task_results:
for mk, mv in ch_meta.items():
res.setdefault(mk, mv)
# Sanitise for export (arrays → summaries, aliases)
self._sanitise_result_for_export(res)
# --- Series-Resistance Stability QC ---
# Track rs_mohm across trials. If Rs increases by more
# than rs_tolerance relative to Sweep 1, flag the row with
# a warning so analysts can exclude unstable patches.
for res in task_results:
rs_val = res.get("rs_mohm")
if rs_val is None:
continue
try:
rs_float = float(rs_val)
except (TypeError, ValueError):
continue
if rs_float != rs_float: # NaN guard
continue
if rs_reference_mohm is None:
rs_reference_mohm = rs_float
log.debug(
"Rs reference %.2f MOhm set for %s / %s.",
rs_float,
file_name,
channel_name,
)
elif rs_float > rs_reference_mohm * (1.0 + rs_tolerance):
delta_pct = (rs_float - rs_reference_mohm) / rs_reference_mohm * 100.0
log.warning(
"Series resistance destabilized: Rs=%.2f MOhm "
"(ref=%.2f MOhm, +%.1f%%) in %s / %s trial %s "
"(tolerance=%.0f%%).",
rs_float,
rs_reference_mohm,
delta_pct,
file_name,
channel_name,
res.get("trial_index", "?"),
rs_tolerance * 100.0,
)
res["rs_qc_warning"] = (
f"Series resistance destabilized: "
f"Rs={rs_float:.1f} MOhm (ref={rs_reference_mohm:.1f} "
f"MOhm, +{delta_pct:.1f}%)"
)
# Extend results list with all results from this task
results_list.extend(task_results)
except Exception as e: # noqa: BLE001 - broad catch intentional for fault-tolerance
log.error(
f"Error processing task {task.get('analysis', 'unknown')} on "
f"{file_path.name}/{channel_name}: {e}",
exc_info=True,
)
# Add error row — include full metadata for filtering
error_row = {
"file_name": file_path.name,
"file_path": str(file_path),
"channel": channel_name,
"analysis": task.get("analysis", "unknown"),
"scope": task.get("scope", "unknown"),
"sampling_rate": getattr(channel, "sampling_rate", None),
"error": str(e),
"debug_trace": traceback.format_exc(),
}
error_row.update(ch_meta)
results_list.append(error_row)
continue
# Explicitly release the per-channel data buffer so NumPy arrays
# held in pipeline_context (which can be 10–200 MB for a long ABF)
# are freed before processing the next channel in this file.
pipeline_context = {"scope": None, "data": None, "time": None}
except Exception as e: # noqa: BLE001 - broad catch intentional; Domino Defense
# A single corrupted or unreadable file must never abort the entire batch run.
# Log the full traceback to batch_errors.log and continue to the next file.
log.error(f"Error processing batch file {file_path}: {e}", exc_info=True)
self._append_batch_error_log(file_name, file_path_str, e)
results_list.append(
{
"file_name": file_name,
"file_path": file_path_str,
"error": str(e),
"debug_trace": traceback.format_exc(),
}
)
continue
finally:
# Release the Recording object and collected data immediately after
# each file to prevent cumulative PySide6 / NumPy OOM in headless batch
# runs. gc.collect() ensures cyclic references are broken even when
# GC is otherwise disabled for test-mode offscreen stability.
recording = None # noqa: F841 # drop reference
gc.collect()
log.debug("gc.collect() called after processing item %d.", i)
if progress_callback:
if self._cancelled:
progress_callback(i, total_files, "Batch analysis cancelled.")
else:
progress_callback(total_files, total_files, "Batch analysis complete.")
# Create DataFrame and add batch metadata
df = pd.DataFrame(results_list)
if not df.empty:
df["batch_timestamp"] = batch_start_time.isoformat()
df = self._order_columns(df)
return df
def _process_task( # noqa: C901
self, task: Dict[str, Any], channel, channel_name: str, file_path: Path, context: Dict[str, Any]
) -> Tuple[List[Dict[str, Any]], Optional[Dict[str, Any]]]:
"""
Process a single analysis task on a channel, supporting preprocessing.
Args:
task: Task configuration dict
channel: Channel object
channel_name: Name/ID
file_path: Path
context: Current data context from previous steps
Returns:
Tuple: (List of results, Updated context or None)
"""
analysis_name = task.get("analysis")
scope = task.get("scope", "first_trial")
params = task.get("params", {})
# Check metadata for type
meta = AnalysisRegistry.get_metadata(analysis_name)
is_preprocessing = meta.get("type") == "preprocessing"
# Get the registered analysis function
analysis_func = AnalysisRegistry.get_function(analysis_name)
if analysis_func is None:
log.error(f"Analysis function '{analysis_name}' not found in registry")
return [
{
"file_name": file_path.name,
"file_path": str(file_path),
"channel": channel_name,
"analysis": analysis_name,
"scope": scope,
"error": f"Analysis function '{analysis_name}' not registered",
}
], None
results = []
sampling_rate = channel.sampling_rate
# --- Data Retrieval Strategy ---
# 1. If context matches requested scope, use it.
# 2. If context exists but scope differs, try to adapt (e.g. average existing trials).
# 3. If no context, load from channel.
data = None
time = None
# Check if we can use context
if context["data"] is not None:
# If scope matches, use directly
if context["scope"] == scope:
data = context["data"]
time = context["time"]
# Adaptation: If we have 'all_trials' data but need 'average'
elif context["scope"] == "all_trials" and scope == "average":
# Compute average from cached trials. Guard against mixed-protocol
# files where trials can have different sample counts.
try:
if len(context["data"]) > 0:
lengths = {len(a) for a in context["data"]}
if len(lengths) > 1:
raise ValueError(
f"Cannot average trials with mismatched lengths "
f"{sorted(lengths)} in {file_path.name}/{channel_name}. "
"Ensure all sweeps in a batch file use the same protocol "
"duration."
)
data = np.mean(np.array(context["data"]), axis=0)
time = context["time"][0]
else:
log.warning("Context data empty, cannot average.")
except Exception as e:
log.warning(
"Could not average trials from context (%s/%s): %s. Reloading from source.",
file_path.name,
channel_name,
e,
)
# Adaptation: If we have 'all_trials' data but need 'selected_trials_average'
elif context["scope"] == "all_trials" and scope == "selected_trials_average":
try:
# Extract list of indices from task params, or default to all
trial_indices_str = params.get("trial_indices", "")
if trial_indices_str:
from Synaptipy.shared.utils import parse_trial_selection_string
parsed_indices = parse_trial_selection_string(trial_indices_str, len(context["data"]))
selected_indices = sorted(list(parsed_indices))
else:
selected_indices = list(range(len(context["data"])))
if selected_indices:
selected_data = [context["data"][i] for i in selected_indices if i < len(context["data"])]
lengths = {len(a) for a in selected_data}
if len(lengths) > 1:
raise ValueError(
f"Cannot average selected trials with mismatched lengths "
f"{sorted(lengths)} in {file_path.name}/{channel_name}. "
"Ensure all selected sweeps use the same protocol duration."
)
data = np.mean(np.array(selected_data), axis=0)
time = context["time"][0]
else:
log.warning("No valid trials selected for averaging from context.")
except Exception as e:
log.warning(
"Could not average selected trials from context (%s/%s): %s. Reloading from source.",
file_path.name,
channel_name,
e,
)
# If data is still None, load from channel
if data is None:
if scope == "average":
data = channel.get_averaged_data()
time = channel.get_relative_averaged_time_vector()
elif scope == "all_trials":
data = []
time = []
for i in range(channel.num_trials):
d = channel.get_data(i)
t = channel.get_relative_time_vector(i)
if d is not None:
data.append(d)
time.append(t)
# If loading raw, we might want to update context if this was a heavy load?
# For now, only update context if preprocessing occurs.
elif scope == "selected_trials":
data = []
time = []
trial_indices_str = params.get("trial_indices", "")
if trial_indices_str:
from Synaptipy.shared.utils import parse_trial_selection_string
parsed_indices = parse_trial_selection_string(trial_indices_str, channel.num_trials)
selected_indices = sorted(list(parsed_indices))
else:
selected_indices = list(range(channel.num_trials))
for i in selected_indices:
d = channel.get_data(i)
t = channel.get_relative_time_vector(i)
if d is not None:
data.append(d)
time.append(t)
elif scope == "selected_trials_average":
trial_indices_str = params.get("trial_indices", "")
if trial_indices_str:
from Synaptipy.shared.utils import parse_trial_selection_string
parsed_indices = parse_trial_selection_string(trial_indices_str, channel.num_trials)
selected_indices = sorted(list(parsed_indices))
else:
selected_indices = None
data = channel.get_averaged_data(trial_indices=selected_indices)
time = channel.get_relative_averaged_time_vector()
elif scope == "first_trial":
data = channel.get_data(0)
time = channel.get_relative_time_vector(0)
elif scope == "specific_trial":
idx = int(params.get("trial_index", 0))
data = channel.get_data(idx)
time = channel.get_relative_time_vector(idx)
elif scope == "channel_set":
# channel_set usually implies list of all trials
data = []
time = []
for i in range(channel.num_trials):
d = channel.get_data(i)
t = channel.get_relative_time_vector(i)
if d is not None:
data.append(d)
time.append(t)
# Validation
if data is None or (isinstance(data, list) and len(data) == 0):
return [
{
"file_name": file_path.name,
"file_path": str(file_path),
"channel": channel_name,
"analysis": analysis_name,
"error": "No data available",
}
], None
# --- Execution ---
if is_preprocessing:
# Preprocessing: Modify data and return new context
try:
# Preprocessing functions typically take (data, time, fs, **kwargs)
# and return modified data.
# If scope is 'all_trials', we might need to iterate if the func expects single trace.
# Heuristic: Check if data is list (multiple trials)
if isinstance(data, list):
# Apply to each item
new_data = []
new_time = []
for d, t in zip(data, time):
# Filter might modify data or time? Usually just data.
# Some filters might return (data, time) tuple?
# Let's assume standard signature returns just data for now,
# or we check return type.
res = analysis_func(d, t, sampling_rate, **params)
new_data.append(res)
new_time.append(t) # Assume time unchanged
modified_data = new_data
modified_time = new_time
else:
# Single trace
modified_data = analysis_func(data, time, sampling_rate, **params)
modified_time = time
# Return empty results, but updated context
new_context = {"scope": scope, "data": modified_data, "time": modified_time}
return [], new_context
except Exception as e:
log.error(f"Preprocessing failed: {e}", exc_info=True)
return [
{
"file_name": file_path.name,
"file_path": str(file_path),
"channel": channel_name,
"analysis": analysis_name,
"scope": scope,
"sampling_rate": sampling_rate,
"error": f"Preprocessing failed: {e}",
"debug_trace": traceback.format_exc(),
}
], None
else:
# Standard Analysis
try:
# Helper to run analysis and format result
total_trials = getattr(channel, "num_trials", 0)
def run_single(d, t, trial_idx=None):
# Remove trial_index from params if present
p = params.copy()
p.pop("trial_index", None)
res = analysis_func(d, t, sampling_rate, **p)
# Flatten consolidated-module schema: {"module_used": ..., "metrics": {...}}
if "metrics" in res and isinstance(res.get("metrics"), dict):
metrics = res.pop("metrics")
res.update(metrics)
# Add metadata
res.update(
{
"file_name": file_path.name,
"file_path": str(file_path),
"channel": channel_name,
"analysis": analysis_name,
"scope": scope,
"sampling_rate": sampling_rate,
"trial_count": total_trials,
}
)
if trial_idx is not None:
res["trial_index"] = trial_idx
return res
if scope == "all_trials" or scope == "channel_set":
# For channel_set, some functions expect the list (e.g. F-I curve)
# others expect iteration.
# Check if function handles list?
# NOTE: Original code treated 'channel_set' as passing the list to func.
# 'all_trials' iterated.
if scope == "channel_set":
# Pass full list
res = analysis_func(data, time, sampling_rate, **params)
# Flatten consolidated-module schema: {"module_used": ..., "metrics": {...}}
if "metrics" in res and isinstance(res.get("metrics"), dict):
metrics = res.pop("metrics")
res.update(metrics)
res.update(
{
"file_name": file_path.name,
"file_path": str(file_path),
"channel": channel_name,
"analysis": analysis_name,
"scope": scope,
"sampling_rate": sampling_rate,
"trial_count": len(data) if isinstance(data, list) else 1,
}
)
results.append(res)
else:
# Iterate 'all_trials' or 'selected_trials'
# For 'selected_trials', extract the specific indices used
if scope == "selected_trials":
trial_indices_str = task.get("params", {}).get("trial_indices", "")
if trial_indices_str:
from Synaptipy.shared.utils import parse_trial_selection_string
parsed_indices = parse_trial_selection_string(trial_indices_str, total_trials)
indices_list = sorted(list(parsed_indices))
else:
indices_list = list(range(total_trials))
else:
indices_list = list(range(total_trials))
for i, (d, t) in enumerate(zip(data, time)):
# Ensure we output correct trial index
real_idx = indices_list[i] if i < len(indices_list) else i
results.append(run_single(d, t, real_idx))
elif scope == "specific_trial":
idx = int(params.get("trial_index", 0))
results.append(run_single(data, time, idx))
else:
# Single trace (average, first_trial)
results.append(run_single(data, time))
return results, None # No context update
except Exception as e:
log.error(f"Analysis failed: {e}", exc_info=True)
return [
{
"file_name": file_path.name,
"file_path": str(file_path),
"channel": channel_name,
"analysis": analysis_name,
"scope": scope,
"sampling_rate": sampling_rate,
"error": f"Analysis failed: {e}",
"debug_trace": traceback.format_exc(),
}
], None
# ---------------------------------------------------------------------------
# Module-level worker function for ProcessPoolExecutor
# ---------------------------------------------------------------------------
def _worker_process_file(
file_path_str: str,
pipeline_config: List[Dict[str, Any]],
channel_filter: Optional[List[str]],
) -> List[Dict[str, Any]]:
"""Process a single file in an isolated worker process.
This function is called by :class:`~concurrent.futures.ProcessPoolExecutor`
in a freshly spawned process. It re-imports the full analysis package so
that all ``@AnalysisRegistry.register`` decorators execute, then delegates
to :class:`BatchAnalysisEngine` with ``max_workers=1`` (sequential) to avoid
recursive parallelism.
OOM safety: ``gc.collect()`` is called explicitly after processing.
Args:
file_path_str: Absolute path to the recording file (str, pickle-safe).
pipeline_config: Serialised pipeline task list.
channel_filter: Optional channel whitelist.
Returns:
List of result-row dicts ready for ``pd.DataFrame()``.
"""
import gc as _gc # avoid shadowing module-level gc import
from pathlib import Path as _Path
# Trigger all @AnalysisRegistry.register decorators in this new process
import Synaptipy.core.analysis # noqa: F401,F811
engine = BatchAnalysisEngine(max_workers=1)
try:
df = engine._run_batch_sequential(
[_Path(file_path_str)],
pipeline_config,
None, # progress_callback not serialisable across processes
channel_filter,
)
return df.to_dict("records") if not df.empty else []
finally:
_gc.collect()