Source code for scitex_ml.metrics._calc_seizure_prediction_metrics

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Timestamp: "2025-10-03 01:56:15 (ywatanabe)"
# File: /ssh:sp:/home/ywatanabe/proj/scitex_repo/src/scitex/ml/metrics/_calc_seizure_prediction_metrics.py
from __future__ import annotations
"""Calculate clinical seizure prediction metrics.

This module provides both window-based and event-based seizure prediction metrics
following FDA/clinical guidelines.

Two Approaches:
  1. Window-based: Measures % of seizure time windows detected
  2. Event-based: Measures % of seizure events detected (≥1 alarm per event)

Key Metrics:
  - seizure_sensitivity: % detected (interpretation depends on window vs event-based)
  - fp_per_hour: False positives per hour during interictal periods
  - time_in_warning: % of total time in alarm state

Clinical Targets (FDA guidelines):
  - Sensitivity ≥ 90%
  - FP/h ≤ 0.2
  - Time in warning ≤ 20%
"""


from typing import Dict

import numpy as np
import pandas as pd


[docs] def calc_seizure_window_prediction_metrics( y_true: np.ndarray, y_pred: np.ndarray, metadata: pd.DataFrame, window_duration_min: float = 1.0, ) -> Dict[str, float]: """Calculate clinical seizure prediction metrics (window-based). This function calculates window-based sensitivity, meaning it measures the percentage of seizure time windows that were correctly identified. This is NOT event-based sensitivity (which would measure % of seizure events detected regardless of how many windows within each event). Parameters ---------- y_true : np.ndarray True labels (string: 'seizure' or 'interictal_control') y_pred : np.ndarray Predicted labels (string: 'seizure' or 'interictal_control') metadata : pd.DataFrame Metadata with 'seizure_type' column indicating seizure/interictal periods window_duration_min : float, optional Duration of each time window in minutes (default: 1.0) Returns ------- Dict[str, float] Dictionary containing: - seizure_sensitivity: % of seizure *time windows* detected (NOT event-based) - fp_per_hour: False positives per hour during interictal periods - time_in_warning: % of total time in alarm state - n_seizure_windows: Number of seizure windows - n_interictal_windows: Number of interictal windows - n_true_positives: Correctly predicted seizure windows - n_false_positives: Incorrectly predicted as seizure - n_false_negatives: Missed seizure windows - n_true_negatives: Correctly predicted as interictal - meets_sensitivity_target: Whether sensitivity ≥ 90% - meets_fp_target: Whether FP/h ≤ 0.2 - meets_tiw_target: Whether time in warning ≤ 20% Notes ----- - False positives are calculated only during interictal periods - True positives/false negatives are calculated only during seizure periods - Clinical targets based on FDA guidance for seizure prediction devices - For event-based sensitivity, use calc_seizure_event_prediction_metrics instead Example ------- >>> # 1 seizure spanning 20 windows, detect 5 windows >>> # Window-based sensitivity: 5/20 = 25% >>> # This measures temporal coverage of the seizure References ---------- FDA guidance on seizure prediction devices """ # Create masks for seizure and interictal periods seizure_mask = metadata["seizure_type"] == "seizure" interictal_mask = metadata["seizure_type"] == "interictal_control" # Convert string labels to binary for calculations y_true_bin = (y_true == "seizure").astype(int) y_pred_bin = (y_pred == "seizure").astype(int) # True positives (seizure windows correctly identified) tp = np.sum((y_true_bin == 1) & (y_pred_bin == 1) & seizure_mask) # False negatives (seizure windows missed) fn = np.sum((y_true_bin == 1) & (y_pred_bin == 0) & seizure_mask) # False positives (interictal windows incorrectly alarmed) fp = np.sum((y_true_bin == 0) & (y_pred_bin == 1) & interictal_mask) # True negatives (interictal windows correctly identified) tn = np.sum((y_true_bin == 0) & (y_pred_bin == 0) & interictal_mask) # Sensitivity (seizure detection rate) - WINDOW-BASED n_seizures = seizure_mask.sum() seizure_sensitivity = (tp / n_seizures * 100) if n_seizures > 0 else 0.0 # False positives per hour n_interictal = interictal_mask.sum() total_interictal_hours = (n_interictal * window_duration_min) / 60.0 fp_per_hour = fp / total_interictal_hours if total_interictal_hours > 0 else 0.0 # Time in warning (% of total time in alarm state) total_windows = len(y_pred) alarm_windows = np.sum(y_pred_bin == 1) time_in_warning = ( (alarm_windows / total_windows * 100) if total_windows > 0 else 0.0 ) metrics = { # Primary prediction metrics "seizure_sensitivity": round(seizure_sensitivity, 3), "fp_per_hour": round(fp_per_hour, 3), "time_in_warning": round(time_in_warning, 3), # Counts (time windows, not events) "n_seizure_windows": int(n_seizures), "n_interictal_windows": int(n_interictal), "n_true_positives": int(tp), "n_false_positives": int(fp), "n_false_negatives": int(fn), "n_true_negatives": int(tn), # Clinical targets (FDA/clinical guidelines) "meets_sensitivity_target": bool(seizure_sensitivity >= 90.0), "meets_fp_target": bool(fp_per_hour <= 0.2), "meets_tiw_target": bool(time_in_warning <= 20.0), } return metrics
[docs] def calc_seizure_event_prediction_metrics( y_true: np.ndarray, y_pred: np.ndarray, metadata: pd.DataFrame, window_duration_min: float = 1.0, ) -> Dict[str, float]: """Calculate clinical seizure prediction metrics (event-based). This function calculates event-based sensitivity, meaning it measures whether each seizure EVENT was detected (at least one alarm raised), regardless of how many windows within that event were predicted. This is clinically more relevant as one timely alarm per seizure event is sufficient for intervention, matching the clinical requirement: "Did the system raise an alarm for this seizure?" Parameters ---------- y_true : np.ndarray True labels (string: 'seizure' or 'interictal_control') y_pred : np.ndarray Predicted labels (string: 'seizure' or 'interictal_control') metadata : pd.DataFrame Metadata with 'seizure_type' and 'seizure_id' columns. seizure_id: Unique identifier for each seizure event (e.g., 'sz_001', 'sz_002') Should be NaN or empty for interictal periods window_duration_min : float, optional Duration of each time window in minutes (default: 1.0) Returns ------- Dict[str, float] Dictionary containing: - seizure_sensitivity: % of seizure *events* detected (event-based) - fp_per_hour: False positives per hour during interictal periods - time_in_warning: % of total time in alarm state - n_seizure_events: Number of unique seizure events - n_detected_events: Number of events with at least one alarm - n_missed_events: Number of events with zero alarms - n_interictal_windows: Number of interictal windows - n_false_positives: Incorrectly predicted as seizure - n_true_negatives: Correctly predicted as interictal - meets_sensitivity_target: Whether sensitivity ≥ 90% - meets_fp_target: Whether FP/h ≤ 0.2 - meets_tiw_target: Whether time in warning ≤ 20% Notes ----- - Requires 'seizure_id' column in metadata to group windows by event - False positives are calculated only during interictal periods - Event detection requires at least one window predicted as seizure - Clinical targets based on FDA guidance for seizure prediction devices - For window-based sensitivity, use calc_seizure_window_prediction_metrics instead Example ------- >>> # 1 seizure spanning 20 windows, detect just 1 window >>> # Event-based sensitivity: 1/1 = 100% (event was detected!) >>> # This measures "did we catch the seizure at all?" References ---------- FDA guidance on seizure prediction devices """ # Validate required column if "seizure_id" not in metadata.columns: raise ValueError( "metadata must contain 'seizure_id' column for event-based metrics. " "Use calc_seizure_window_prediction_metrics for window-based metrics." ) # Create masks seizure_mask = metadata["seizure_type"] == "seizure" interictal_mask = metadata["seizure_type"] == "interictal_control" # Convert string labels to binary y_true_bin = (y_true == "seizure").astype(int) y_pred_bin = (y_pred == "seizure").astype(int) # Event-based sensitivity calculation # Group by seizure_id and check if any window in that event was predicted seizure_events = metadata[seizure_mask]["seizure_id"].unique() n_seizure_events = len(seizure_events) detected_events = 0 for event_id in seizure_events: event_mask = (metadata["seizure_id"] == event_id).values # Check if at least one window in this event was predicted as seizure event_predictions = y_pred_bin[event_mask] if np.any(event_predictions == 1): detected_events += 1 missed_events = n_seizure_events - detected_events # Event-based sensitivity: % of events detected seizure_sensitivity = ( (detected_events / n_seizure_events * 100) if n_seizure_events > 0 else 0.0 ) # False positives (interictal windows incorrectly alarmed) fp = np.sum((y_true_bin == 0) & (y_pred_bin == 1) & interictal_mask) # True negatives (interictal windows correctly identified) tn = np.sum((y_true_bin == 0) & (y_pred_bin == 0) & interictal_mask) # False positives per hour n_interictal = interictal_mask.sum() total_interictal_hours = (n_interictal * window_duration_min) / 60.0 fp_per_hour = fp / total_interictal_hours if total_interictal_hours > 0 else 0.0 # Time in warning (% of total time in alarm state) total_windows = len(y_pred) alarm_windows = np.sum(y_pred_bin == 1) time_in_warning = ( (alarm_windows / total_windows * 100) if total_windows > 0 else 0.0 ) metrics = { # Primary prediction metrics "seizure_sensitivity": round(seizure_sensitivity, 3), "fp_per_hour": round(fp_per_hour, 3), "time_in_warning": round(time_in_warning, 3), # Counts (events, not windows) "n_seizure_events": int(n_seizure_events), "n_detected_events": int(detected_events), "n_missed_events": int(missed_events), "n_interictal_windows": int(n_interictal), "n_false_positives": int(fp), "n_true_negatives": int(tn), # Clinical targets (FDA/clinical guidelines) "meets_sensitivity_target": bool(seizure_sensitivity >= 90.0), "meets_fp_target": bool(fp_per_hour <= 0.2), "meets_tiw_target": bool(time_in_warning <= 20.0), } return metrics
# Backward compatibility aliases calc_seizure_prediction_metrics = calc_seizure_window_prediction_metrics calculate_seizure_prediction_metrics = calc_seizure_window_prediction_metrics def parse_args(): """Parse command line arguments.""" import argparse parser = argparse.ArgumentParser( description="Demonstrate seizure prediction metrics calculation" ) parser.add_argument( "--n-windows", type=int, default=1000, help="Number of time windows to simulate (default: %(default)s)", ) parser.add_argument( "--window-duration", type=float, default=1.0, help="Duration of each window in minutes (default: %(default)s)", ) parser.add_argument( "--sensitivity", type=float, default=0.9, help="Target sensitivity to simulate (default: %(default)s)", ) args = parser.parse_args() return args def main(args): """Demonstrate seizure prediction metrics with synthetic data.""" import scitex_logging as logging logger = logging.getLogger(__name__) logger.info("Creating synthetic seizure prediction data") logger.info(f" n_windows: {args.n_windows}") logger.info(f" window_duration: {args.window_duration} min") logger.info(f" target_sensitivity: {args.sensitivity * 100}%") # Create synthetic test data n_windows = args.n_windows window_duration_min = args.window_duration # Create labels and metadata with seizure_id for event-based metrics y_true = np.array(["interictal_control"] * n_windows) y_pred = np.array(["interictal_control"] * n_windows) metadata = pd.DataFrame( { "seizure_type": ["interictal_control"] * n_windows, "seizure_id": [None] * n_windows, # seizure_id for event-based metrics } ) # Add TWO seizure events (event 1: 100-119, event 2: 500-529) event1_indices = list(range(100, 120)) # 20 windows event2_indices = list(range(500, 530)) # 30 windows seizure_indices = event1_indices + event2_indices y_true[event1_indices] = "seizure" y_true[event2_indices] = "seizure" metadata.loc[event1_indices, "seizure_type"] = "seizure" metadata.loc[event1_indices, "seizure_id"] = "sz_001" metadata.loc[event2_indices, "seizure_type"] = "seizure" metadata.loc[event2_indices, "seizure_id"] = "sz_002" logger.info( f"Created 2 seizure events spanning {len(seizure_indices)} windows total" ) logger.info(f" Event 1 (sz_001): 20 windows") logger.info(f" Event 2 (sz_002): 30 windows") # Predict some seizures correctly based on target sensitivity # For event-based demo: detect only 1 window from event 1, most of event 2 n_detect = int(len(seizure_indices) * args.sensitivity) # Detect 1 window from event 1, rest from event 2 detected_indices = [event1_indices[0]] + event2_indices[: n_detect - 1] y_pred[detected_indices] = "seizure" logger.info( f"Simulating detection of {n_detect}/{len(seizure_indices)} seizure windows" ) logger.info(f" Event 1: 1/20 windows detected") logger.info(f" Event 2: {n_detect - 1}/30 windows detected") # Add some false positives fp_indices = [200, 300, 400, 600, 700] y_pred[fp_indices] = "seizure" logger.info(f"Added {len(fp_indices)} false positive alarms") # Calculate WINDOW-BASED metrics logger.info("") logger.info("Calculating WINDOW-BASED seizure prediction metrics") metrics_window = calc_seizure_window_prediction_metrics( y_true, y_pred, metadata, window_duration_min ) # Print window-based results logger.info("=" * 70) logger.info("WINDOW-BASED Metrics (How well did we cover seizure duration?)") logger.info("=" * 70) logger.info(f"Seizure Sensitivity: {metrics_window['seizure_sensitivity']:.1f}%") logger.info(f"False Positives/Hour: {metrics_window['fp_per_hour']:.3f}") logger.info(f"Time in Warning: {metrics_window['time_in_warning']:.1f}%") logger.info("") logger.info("Counts:") logger.info(f" Seizure windows: {metrics_window['n_seizure_windows']}") logger.info(f" Interictal windows: {metrics_window['n_interictal_windows']}") logger.info(f" True positives: {metrics_window['n_true_positives']}") logger.info(f" False positives: {metrics_window['n_false_positives']}") logger.info(f" False negatives: {metrics_window['n_false_negatives']}") logger.info(f" True negatives: {metrics_window['n_true_negatives']}") logger.info("") logger.info("Clinical Targets (FDA Guidelines):") logger.info( f" Meets sensitivity target (≥90%): {metrics_window['meets_sensitivity_target']}" ) logger.info(f" Meets FP/h target (≤0.2): {metrics_window['meets_fp_target']}") logger.info( f" Meets time in warning target (≤20%): {metrics_window['meets_tiw_target']}" ) logger.info("=" * 70) # Calculate EVENT-BASED metrics logger.info("") logger.info("Calculating EVENT-BASED seizure prediction metrics") metrics_event = calc_seizure_event_prediction_metrics( y_true, y_pred, metadata, window_duration_min ) # Print event-based results logger.info("=" * 70) logger.info("EVENT-BASED Metrics (Did we detect each seizure event?)") logger.info("=" * 70) logger.info(f"Seizure Sensitivity: {metrics_event['seizure_sensitivity']:.1f}%") logger.info(f"False Positives/Hour: {metrics_event['fp_per_hour']:.3f}") logger.info(f"Time in Warning: {metrics_event['time_in_warning']:.1f}%") logger.info("") logger.info("Counts:") logger.info(f" Seizure events: {metrics_event['n_seizure_events']}") logger.info(f" Detected events: {metrics_event['n_detected_events']}") logger.info(f" Missed events: {metrics_event['n_missed_events']}") logger.info(f" Interictal windows: {metrics_event['n_interictal_windows']}") logger.info(f" False positives: {metrics_event['n_false_positives']}") logger.info(f" True negatives: {metrics_event['n_true_negatives']}") logger.info("") logger.info("Clinical Targets (FDA Guidelines):") logger.info( f" Meets sensitivity target (≥90%): {metrics_event['meets_sensitivity_target']}" ) logger.info(f" Meets FP/h target (≤0.2): {metrics_event['meets_fp_target']}") logger.info( f" Meets time in warning target (≤20%): {metrics_event['meets_tiw_target']}" ) logger.info("=" * 70) # Comparison summary logger.info("") logger.info("=" * 70) logger.info("KEY DIFFERENCE DEMONSTRATION") logger.info("=" * 70) logger.info( f"Window-based sensitivity: {metrics_window['seizure_sensitivity']:.1f}% (detected {metrics_window['n_true_positives']}/{metrics_window['n_seizure_windows']} windows)" ) logger.info( f"Event-based sensitivity: {metrics_event['seizure_sensitivity']:.1f}% (detected {metrics_event['n_detected_events']}/{metrics_event['n_seizure_events']} events)" ) logger.info("") logger.info("Interpretation:") logger.info( " - Window-based: Detected only 1 window from Event 1 → Low sensitivity" ) logger.info( " - Event-based: Detected at least 1 window from BOTH events → 100% sensitivity!" ) logger.info(" - Clinical relevance: One timely alarm per seizure is sufficient") logger.info("=" * 70) return 0 def run_main(): """Initialize scitex framework, run main function, and cleanup.""" global CONFIG, CC, sys, plt, rng import sys import matplotlib.pyplot as plt import scitex as stx args = parse_args() CONFIG, sys.stdout, sys.stderr, plt, CC, rng = stx.session.start( sys, plt, args=args, file=__file__, sdir_suffix=None, verbose=False, agg=True, ) exit_status = main(args) stx.session.close( CONFIG, verbose=False, notify=False, message="", exit_status=exit_status, ) if __name__ == "__main__": run_main() # EOF