Source code for scitex_seizure_metrics.surrogates
"""Surrogate alarm-time generators for chance-baseline comparisons (IoC).
Each surrogate takes a description of the recording and returns synthetic
alarm times. Used by forecasting metrics to estimate Improvement-over-
Chance.
Built-in surrogates:
- poisson: uniform-random alarm times at matched rate
- periodic: alarms on a fixed period (circadian-style)
- persistence:alarm fires every X seconds since the last seizure
"""
from __future__ import annotations
from typing import Callable
import numpy as np
SurrogateFn = Callable[[int, float, "np.random.Generator"], np.ndarray]
_REGISTRY: dict[str, SurrogateFn] = {}
[docs]
def register(name: str):
"""Decorator to register a surrogate generator."""
def deco(fn):
_REGISTRY[name] = fn
return fn
return deco
[docs]
def get(name: str) -> SurrogateFn:
if name not in _REGISTRY:
raise KeyError(f"unknown surrogate {name!r}; "
f"registered: {sorted(_REGISTRY)}")
return _REGISTRY[name]
[docs]
@register("poisson")
def poisson(n_alarms: int, total_seconds: float,
rng: np.random.Generator) -> np.ndarray:
"""Uniform-random alarm times. The classic null model."""
if n_alarms == 0:
return np.array([])
return np.sort(rng.uniform(0, total_seconds, size=n_alarms))
[docs]
@register("periodic")
def periodic(n_alarms: int, total_seconds: float,
rng: np.random.Generator) -> np.ndarray:
"""Equally-spaced alarms with a random phase offset.
Stand-in for circadian-only forecasters (Karoly 2017): if a method
can't beat regular ticks, the forecaster has no signal beyond the
daily clock.
"""
if n_alarms == 0 or total_seconds <= 0:
return np.array([])
period = total_seconds / n_alarms
phase = rng.uniform(0, period)
return np.clip(phase + np.arange(n_alarms) * period, 0, total_seconds)
[docs]
@register("persistence")
def persistence(n_alarms: int, total_seconds: float,
rng: np.random.Generator) -> np.ndarray:
"""Single alarm at a random time. Models 'always predict at last
known seizure' style baselines collapsed to one decision."""
if n_alarms == 0 or total_seconds <= 0:
return np.array([])
return np.array([rng.uniform(0, total_seconds)])
SECONDS_PER_DAY = 86_400.0
[docs]
@register("circadian")
def circadian(n_alarms: int, total_seconds: float,
rng: np.random.Generator,
period_seconds: float = SECONDS_PER_DAY) -> np.ndarray:
"""Alarms entrained to a 24-h cycle at a random phase.
Stand-in for circadian-only forecasters (Karoly 2017): if a method
cannot beat regular daily ticks, it has no signal beyond the clock.
"""
if n_alarms == 0 or total_seconds <= 0:
return np.array([])
n_periods = max(1, int(np.ceil(total_seconds / period_seconds)))
phase = rng.uniform(0, period_seconds)
base = phase + np.arange(n_periods) * period_seconds
base = base[base < total_seconds]
if base.size >= n_alarms:
idx = np.linspace(0, base.size - 1, n_alarms).astype(int)
return np.sort(base[idx])
# Pad with random alarms uniformly within the recording.
extra = rng.uniform(0, total_seconds, size=n_alarms - base.size)
return np.sort(np.concatenate([base, extra]))
[docs]
@register("multidien")
def multidien(n_alarms: int, total_seconds: float,
rng: np.random.Generator,
period_seconds: float = 7 * SECONDS_PER_DAY) -> np.ndarray:
"""Alarms entrained to a multi-day cycle (default 7 days).
Karoly 2018 / Proix 2021 baseline — a method that exploits no per-
seizure signal but tracks the multidien rhythm.
"""
return circadian(n_alarms, total_seconds, rng,
period_seconds=period_seconds)
[docs]
@register("from_history")
def from_history(n_alarms: int, total_seconds: float,
rng: np.random.Generator,
seizure_history_seconds: np.ndarray | None = None
) -> np.ndarray:
"""Alarm at every (last_seizure + mean_interval) — Karoly 2019 style.
Requires `seizure_history_seconds` (the past seizure timestamps in the
same recording). Falls back to uniform-random if no history given.
"""
if seizure_history_seconds is None or len(seizure_history_seconds) < 2:
return poisson(n_alarms, total_seconds, rng)
history = np.sort(np.asarray(seizure_history_seconds, dtype=float))
interval = float(np.mean(np.diff(history)))
if interval <= 0:
return poisson(n_alarms, total_seconds, rng)
last = history[-1]
alarms = []
t = last + interval
while t < total_seconds and len(alarms) < n_alarms:
alarms.append(t)
t += interval
if len(alarms) < n_alarms:
extra = rng.uniform(0, total_seconds, size=n_alarms - len(alarms))
alarms.extend(extra.tolist())
return np.sort(np.asarray(alarms))