Source code for scitex_seizure_metrics.forecasting

"""Forecasting regime — alarm-based metrics with explicit AlarmPolicy.

Two entry points:
- evaluate(alarms, seizures, policy, ...) — accepts pre-computed alarm
  times; useful when alarms come from a different pipeline.
- evaluate_stream(proba, times, seizures, policy, ...) — accepts a
  continuous prediction stream (proba per timepoint); applies the
  policy's threshold + refractory + merging to produce alarms, then
  evaluates. This is the gold-standard entry point — every reported
  number is reproducible from (proba, times, policy).

Also provides:
- sweep_thresholds — sensitivity-vs-FP/hr ROC across a threshold sweep.
- sweep_policies   — same metrics across multiple AlarmPolicy values
                     (e.g. cadence ablation).

All metrics use the AlarmPolicy.fp_denominator convention. The Mormann-
tradition default ("interictal") excludes [seizure - sop - sph,
seizure + sop] windows from the FP/hr denominator.
"""
from __future__ import annotations

from typing import Iterable

import numpy as np

from . import surrogates as _surrogates
from . import _alarm
from .policy import AlarmPolicy
from .report import MetricsReport


[docs] def evaluate(alarm_times, seizure_times, policy: AlarmPolicy, *, total_recording_time: float, n_surrogate: int = 1000, surrogate: str = "poisson", rng_seed: int = 0, name: str = "") -> MetricsReport: """Alarm-based forecasting evaluation under an explicit AlarmPolicy. Args: alarm_times: seconds-since-recording-start timestamps of alarms. seizure_times: seconds-since-recording-start timestamps of true seizure onsets. policy: AlarmPolicy describing SPH/SOP/cadence/refractory/etc. total_recording_time: total recording duration (seconds). n_surrogate: number of random surrogates to estimate the chance baseline for IoC. surrogate: name of registered surrogate ("poisson", "periodic", "persistence", or any user-registered). rng_seed: surrogate RNG seed. name: identifier carried into the report. Returns: MetricsReport with the forecasting-regime fields populated. """ alarms = np.sort(np.asarray(alarm_times, dtype=float).ravel()) seizures = np.sort(np.asarray(seizure_times, dtype=float).ravel()) if total_recording_time <= 0: raise ValueError("total_recording_time must be > 0") sc, au = _alarm.alarm_match(alarms, seizures, policy.sph_seconds, policy.sop_seconds) tp = int(sc.sum()) fp = int((~au).sum()) # Time-in-warning: union of [a + SPH, a + SPH + SOP] across alarms. if alarms.size > 0: tiw_seconds = _alarm.union_length(np.column_stack([ alarms + policy.sph_seconds, alarms + policy.sph_seconds + policy.sop_seconds, ])) else: tiw_seconds = 0.0 tiw_frac = float(tiw_seconds / total_recording_time) # FP/hr denominator if policy.fp_denominator == "interictal": denom_seconds = _alarm.interictal_seconds( total_recording_time, seizures, policy.sop_seconds, policy.sph_seconds) else: denom_seconds = total_recording_time fp_per_hour = fp / max(1e-9, denom_seconds / 3600.0) sens = float(tp / seizures.size) if seizures.size > 0 else float("nan") # Surrogate IoC rng = np.random.default_rng(rng_seed) surro_fn = _surrogates.get(surrogate) surrogate_sens = [] for _ in range(n_surrogate): rand_alarms = surro_fn(int(alarms.size), max(total_recording_time - policy.sop_seconds - policy.sph_seconds, 1.0), rng) sc_s, _ = _alarm.alarm_match(rand_alarms, seizures, policy.sph_seconds, policy.sop_seconds) surrogate_sens.append(sc_s.sum() / max(1, seizures.size)) surro_mean = float(np.mean(surrogate_sens)) if surrogate_sens else 0.0 rep = MetricsReport( name=name, regime="forecasting", sensitivity=sens, n_ref_events=int(seizures.size), n_tp=tp, n_fp=fp, fp_per_hour=fp_per_hour, fp_per_day=fp_per_hour * 24.0, sph_seconds=policy.sph_seconds, sop_seconds=policy.sop_seconds, time_in_warning_frac=tiw_frac, ioc=sens - surro_mean, surrogate_sensitivity=surro_mean, extras={"policy": policy.describe(), "interictal_seconds": denom_seconds, "n_alarms": int(alarms.size)}, ) return rep
[docs] def evaluate_stream(proba, times, seizure_times, policy: AlarmPolicy, *, total_recording_time: float | None = None, n_surrogate: int = 1000, surrogate: str = "poisson", rng_seed: int = 0, name: str = "") -> MetricsReport: """Continuous-stream evaluation entry point — gold standard. Threshold + dedupe (per AlarmPolicy) the proba stream, then call evaluate() on the derived alarms. Use this when you have per-window predictions and want fully-reproducible alarm-based metrics. """ proba = np.asarray(proba, dtype=float).ravel() times = np.asarray(times, dtype=float).ravel() if total_recording_time is None: total_recording_time = float(times.max() - times.min()) if times.size else 0.0 alarms = _alarm.proba_stream_to_alarms( proba, times, threshold=policy.alarm_threshold, refractory_seconds=policy.refractory_seconds, merge_consecutive=policy.merge_consecutive, ) return evaluate(alarms, seizure_times, policy, total_recording_time=total_recording_time, n_surrogate=n_surrogate, surrogate=surrogate, rng_seed=rng_seed, name=name)
[docs] def sweep_thresholds(proba, times, seizure_times, policy: AlarmPolicy, *, thresholds: Iterable[float] | None = None, total_recording_time: float | None = None, n_surrogate: int = 200, surrogate: str = "poisson", rng_seed: int = 0, name: str = ""): """Generate the sensitivity-vs-FP/hr operating curve. Returns a pandas DataFrame with one row per threshold; columns include sensitivity, fp_per_hour, time_in_warning_frac, ioc. """ import pandas as pd if thresholds is None: thresholds = np.linspace(0.05, 0.95, 19) rows = [] for t in thresholds: pol = AlarmPolicy( sph_seconds=policy.sph_seconds, sop_seconds=policy.sop_seconds, cadence_seconds=policy.cadence_seconds, refractory_seconds=policy.refractory_seconds, alarm_threshold=float(t), merge_consecutive=policy.merge_consecutive, fp_denominator=policy.fp_denominator, ) rep = evaluate_stream(proba, times, seizure_times, pol, total_recording_time=total_recording_time, n_surrogate=n_surrogate, surrogate=surrogate, rng_seed=rng_seed, name=f"{name}@thresh={t:.2f}") d = rep.to_dict() d.update({"threshold": float(t)}) rows.append(d) return pd.DataFrame(rows)
[docs] def sweep_policies(proba, times, seizure_times, policies: list[AlarmPolicy], *, total_recording_time: float | None = None, n_surrogate: int = 200, surrogate: str = "poisson", rng_seed: int = 0, name: str = ""): """Same evaluation across a list of AlarmPolicy values (e.g. cadence or refractory ablation). Returns DataFrame with one row per policy plus the policy.describe() columns.""" import pandas as pd rows = [] for i, pol in enumerate(policies): rep = evaluate_stream(proba, times, seizure_times, pol, total_recording_time=total_recording_time, n_surrogate=n_surrogate, surrogate=surrogate, rng_seed=rng_seed, name=f"{name}_pol{i}") d = rep.to_dict() d.update(pol.describe()) rows.append(d) return pd.DataFrame(rows)
[docs] def bootstrap_ci(values, n_boot: int = 1000, ci: float = 0.95, rng_seed: int = 0): """Percentile bootstrap CI of the mean for a 1-D array.""" rng = np.random.default_rng(rng_seed) values = np.asarray(values, dtype=float) if values.size == 0: return float("nan"), float("nan"), float("nan") boots = np.empty(n_boot) for i in range(n_boot): boots[i] = rng.choice(values, size=values.size, replace=True).mean() lo = np.quantile(boots, (1 - ci) / 2) hi = np.quantile(boots, 1 - (1 - ci) / 2) return float(values.mean()), float(lo), float(hi)