Source code for scitex_seizure_metrics.detection
"""Detection regime — per-window classification metrics.
This is what most of the literature calls "sample-based" evaluation
(Andrade 2024). Built on sklearn for the threshold-free metrics
(AUROC, AUPRC, Brier) and on timescoring.SampleScoring for the
threshold-dependent confusion-matrix metrics.
Use this when:
- Your model emits one probability per fixed-length window (clip-level
classification, e.g. 10-min preictal-vs-interictal).
- You want the imbalance-aware metrics (AUPRC + balanced accuracy +
MCC) that no NeuroVista paper has reported (gap noted in the
literature review).
"""
from __future__ import annotations
import numpy as np
from sklearn.metrics import (
average_precision_score,
balanced_accuracy_score,
brier_score_loss,
matthews_corrcoef,
roc_auc_score,
)
from timescoring.annotations import Annotation
from timescoring.scoring import SampleScoring
from .report import MetricsReport
[docs]
def evaluate(
y_true,
y_proba,
*,
threshold: float = 0.5,
fs: int = 1,
name: str = "",
) -> MetricsReport:
"""Evaluate a continuous-output classifier on per-window labels.
Args:
y_true: 1-D binary array of ground-truth labels (1 = pre-ictal).
y_proba: 1-D continuous score / probability matched to y_true.
threshold: cutoff for the binary-mask metrics (sensitivity etc.).
AUROC/AUPRC/Brier are threshold-free.
fs: sampling frequency of the label vector (Hz). Defaults to 1
(one label per second, or per window if windows are 1-s
spaced — adjust accordingly).
name: identifier carried into the report.
Returns:
MetricsReport with the detection-regime fields populated.
"""
y_true = np.asarray(y_true).astype(int)
y_proba = np.asarray(y_proba).astype(float)
if y_true.shape != y_proba.shape or y_true.ndim != 1:
raise ValueError(
f"y_true and y_proba must be matching 1-D arrays; got "
f"{y_true.shape} vs {y_proba.shape}"
)
rep = MetricsReport(name=name, regime="detection")
# Threshold-free
has_both = y_true.min() == 0 and y_true.max() == 1
if has_both:
rep.roc_auc = float(roc_auc_score(y_true, y_proba))
rep.pr_auc = float(average_precision_score(y_true, y_proba))
rep.brier = float(brier_score_loss(y_true, y_proba))
# Threshold-dependent
y_pred = (y_proba >= threshold).astype(int)
if has_both:
rep.balanced_accuracy = float(balanced_accuracy_score(y_true, y_pred))
rep.mcc = float(matthews_corrcoef(y_true, y_pred))
# Sample-based confusion-matrix metrics via timescoring
ref = Annotation(y_true.astype(bool), fs=fs)
hyp = Annotation(y_pred.astype(bool), fs=fs)
ss = SampleScoring(ref, hyp, fs=fs)
rep.sensitivity = float(ss.sensitivity) if not np.isnan(ss.sensitivity) else None
rep.precision = float(ss.precision) if not np.isnan(ss.precision) else None
rep.f1 = float(ss.f1) if not np.isnan(ss.f1) else None
rep.fp_per_day = float(ss.fpRate)
rep.fp_per_hour = float(ss.fpRate / 24.0)
rep.n_ref_events = int(ss.refTrue)
rep.n_tp = int(ss.tp)
rep.n_fp = int(ss.fp)
return rep