from dataclasses import dataclass, field
from abc import ABC, abstractmethod
from typing import Tuple, Any
import numpy as np
import scipy.signal as ss
from tqdm import tqdm
from multiprocess.pool import Pool
from math import ceil, isclose
import tomllib
import os
from functools import cache
# from scipy.optimize import minimize
from ada.data_containers.scored import ScoredShort
from ada.data_containers._base import _Epoched, _Raw
from ada.data_containers.generic import GenericData
from ada.io.acti_eeg import ActiPSG
[docs]
def matthews_corrcoef(psg, acti):
tp = np.nansum(np.logical_and(psg == 0, acti == 0))
tn = np.nansum(np.logical_and(psg == 1, acti == 1))
fp = np.nansum(np.logical_and(psg == 1, acti == 0))
fn = np.nansum(np.logical_and(psg == 0, acti == 1))
if tp == 0 and fp == 0:
if fn == 0 and tn != 0:
return 1
elif fn != 0 and tn == 0:
return -1
return 0
elif tp == 0 and fn == 0:
if fp == 0 and tn != 0:
return 1
elif fp != 0 and tn == 0:
return -1
return 0
elif tn == 0 and fp == 0:
if fn == 0 and tp != 0:
return 1
elif fn != 0 and tp == 0:
return -1
return 0
elif tn == 0 and fn == 0:
if fp == 0 and tp != 0:
return 1
elif fp != 0 and tp == 0:
return -1
return 0
temp = (tp * tn - fp * fn) / ((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn)) ** .5
if np.isnan(temp):
return 0
return temp
@dataclass(slots=True, eq=False)
class _ShortScorer(ABC):
_threshold: float | None
_filter: Any
_scoring_method_metadata: dict = field(init=False)
_default_thresholds: dict = field(init=False)
@property
def scoring_method_metadata(self) -> dict:
"""Metadata asssociated with the scoring algorithm and its parameters."""
return self._scoring_method_metadata
@abstractmethod
def to_score(self, acti_data: Any) -> ScoredShort:
pass
@abstractmethod
def transmittance(self, n_points: int = 2048, db: bool = True) -> Tuple[np.ndarray, np.ndarray]:
pass
def fit_threshold(self, acti_psg: list[ActiPSG], thresholds: Tuple[float, float] = (.001, .03), n_thresholds: int = 1000, full_data: bool = False) -> float | list[Tuple[float, float]]:
"""Fit a threshold that is maximazing mean correlation between PSG scorings and actigraphic scorings.
Args:
acti_psg (list[ActiPSG]): List of objects containing actigraphic data and PSG sleep/wake scorings.
thresholds (Tuple[float, float], optional): Range of thresholds to fit. Defaults to (.001, .03).
n_thresholds (int, optional): Number of thresholds to fit. Defaults to 1000.
full_data (bool, optional): If True, list with all correlations will be returned. If False, the maximal correlation will be returned. Defaults to False.
Returns:
float | list[Tuple[float, float]]: List of all correlations (with corresponding thresholds) or maximal correlation.
"""
def find_mcc(threshold: float) -> Tuple[float, float]:
self._threshold = threshold
mcc = 0
for e in acti_psg:
score = self.to_score(e.acti_data).score
idx = np.argwhere(~np.isnan(e.eeg_data.psg_stages))
mcc += matthews_corrcoef(e.eeg_data.psg_stages[idx], score[idx])
return (mcc / len(acti_psg), threshold)
mcc = []
ts = np.linspace(thresholds[0], thresholds[1], n_thresholds)
pool = Pool()
chunksize = ceil(n_thresholds / len(pool._pool))
mcc = list(tqdm(pool.imap(find_mcc, ts, chunksize=chunksize), total=len(ts), desc="Fitting threshold to data"))
mccs = [e[0] for e in mcc]
self._threshold = mcc[np.argmax(mccs)][1]
if full_data:
return mcc
return max(mccs)
# def find_mcc(threshold: float) -> float:
# mcc = 0
# self._threshold = threshold
# for e in acti_psg:
# score = self.to_score(e.acti_data).score
# idx = np.argwhere(~np.isnan(e.eeg_data.psg_stages))
# mcc += matthews_corrcoef(score[idx], e.eeg_data.psg_stages[idx])
# return -mcc / len(acti_psg)
# best_mcc = 0
# best_threshold = 0
# for _ in tqdm(range(10), desc="Fitting threshold to data"):
# start = np.random.uniform(1e-15, 1e15)
# optimized = minimize(find_mcc, start, method='Nelder-Mead')
# if -optimized.fun > best_mcc:
# best_mcc = -optimized.fun
# best_threshold = optimized.x
# self._threshold = best_threshold
# return best_mcc
def _interpolate_threshold(self, data: _Raw | _Epoched | GenericData) -> float:
device = GenericData._get_device(data).__name__
if not isinstance(data, _Raw) and data.epoching_method_metadata is not None:
epocher = data.epoching_method_metadata['epoching method']
else:
raise RuntimeError("No default thresholds for Raw data.")
try:
if isclose(data.fs, 1 / 60):
return self._default_thresholds[device][epocher][type(self).__name__]['long']
return 1 / (self._default_thresholds[device][epocher][type(self).__name__]['slope'] / data.fs + self._default_thresholds[device][epocher][type(self).__name__]['const'])
except KeyError:
if device in self._default_thresholds.keys():
raise RuntimeError(f"Threshold interpolation for this particular epoching method ({epocher}) is not supported.")
else:
raise RuntimeError(f"Threshold interpolation for this particular device ({device}) is not supported.")
@staticmethod
@cache
def _load_defaults():
with open(os.path.join(os.path.dirname(__file__), 'thresholds.toml'), 'rb') as f:
temp = tomllib.load(f)
return temp
@dataclass(slots=True, eq=False)
class _UnifiedScorer(_ShortScorer):
_filter: list[dict]
def __post_init__(self):
self._scoring_method_metadata = {'algorithm': type(self).__name__,
'threshold': self._threshold,
'filter data': self._filter}
self._default_thresholds = _ShortScorer._load_defaults()
@property
def filter(self) -> list[dict]:
"""Characteristics of the filters used by scoring algorithm."""
return self._filter
def to_score(self, acti_data: _Epoched | GenericData | _Raw) -> ScoredShort:
"""Apply scoring algorithm to the input data.
Args:
acti_data (_Epoched | GenericData | _Raw): Epoched data to be scored by the algorithm. Scoring by unified filter does not require previous epoching.
Returns:
ScoredShort: Object containng scored data.
"""
if isinstance(acti_data, _Raw):
epoching_method_metadata = None
else:
epoching_method_metadata = acti_data.epoching_method_metadata
if self._threshold is None:
threshold = self._interpolate_threshold(acti_data)
else:
threshold = self._threshold
filters = []
for filter in self._filter:
sos = ss.iirdesign(filter['wp'], filter['ws'], filter['gpass'], filter['gstop'],
ftype=filter['ftype'], output='sos', fs=acti_data.fs)
filters.append(sos)
out = np.empty((2, acti_data.data.shape[1]))
out[1, :] = acti_data.timestamp
temp = acti_data.to_score
for sos in filters:
temp = ss.sosfiltfilt(sos, temp)
out[0, :] = temp >= (1 / threshold)
del temp
metadata = self._scoring_method_metadata.copy()
metadata['threshold'] = threshold
return ScoredShort(out, acti_data.metadata, acti_data.fs, epoching_method_metadata, metadata)
def transmittance(self, n_points: int = 2048, db: bool = True, fs: float = 1 / 30) -> Tuple[np.ndarray, np.ndarray]:
"""Transmittance of filter used in the algorithm.
Args:
n_points (int, optional): Number of point at which filter will be evaluated. Defaults to 2048.
db (bool, optional): Whether to return in dB scale. Defaults to True.
fs (float, optional): Sampling frequency for which transmittance will be created.
Returns:
Tuple[np.ndarray, np.ndarray]: Vector of frequencies and vector of frequency responses.
"""
f = np.linspace(0, fs / 2, n_points, endpoint=False)
h = np.ones(n_points)
for filt in self._filter:
sos = ss.iirdesign(filt['wp'], filt['ws'], filt['gpass'], filt['gstop'], filt['analog'], filt['ftype'], 'sos', fs=fs)
_, h_temp = ss.sosfreqz(sos, f, fs=fs)
h = h * h_temp
if db:
return f, 20 * np.log10(np.abs(h))
return f, np.abs(h)
[docs]
class GenericUnified(_UnifiedScorer):
"""A class for scoring actigraphic data using using a custom filter, designed via scipy.iirdesign.
Attributes:
threshold (float): Threshold value separating sleep and wake.
wp, ws (float | tuple[float, float]): Passband and stopband edge frequencies. Possible values are scalars (for lowpass and highpass filters) or ranges (for bandpass and bandstop filters). For digital filters, these are in the same units as fs. By default, fs is 2 half-cycles/sample, so these are normalized from 0 to 1, where 1 is the Nyquist frequency.
gpass (float): The maximum loss in the passband (dB).
gstop (float): The minimum attenuation in the stopband (dB).
analog (bool): When True, return an analog filter, otherwise a digital filter is returned. Defaults to False.
ftype (str): The type of IIR filter to design.
"""
def __init__(self, threshold: float, wp: float | tuple[float, float], ws: float | tuple[float, float], gpass: float, gstop: float, analog: bool = False, ftype: str = 'ellip'):
filter = {'wp': wp,
'ws': ws,
'gpass': gpass,
'gstop': gstop,
'analog': analog,
'ftype': ftype}
super().__init__(threshold, [filter])
[docs]
def add_filter(self, wp: float | tuple[float, float], ws: float | tuple[float, float], gpass: float, gstop: float, analog: bool = False, ftype: str = 'ellip'):
"""Add filter to the filter algorithm. Data will be filtered using all the filters provided.
Args:
wp (float | tuple[float, float]): Passband edge frequency.
ws (float | tuple[float, float]): Stopband edge frequency.
gpass (float): The maximum loss in the passband (dB).
gstop (float): The minimum attenuation in the stopband (dB)/
analog (bool, optional): When True, return an analog filter, otherwise a digital filter is returned. Defaults to False.
ftype (str, optional): The type of IIR filter to design. Defaults to 'ellip'.
"""
filter = {'wp': wp,
'ws': ws,
'gpass': gpass,
'gstop': gstop,
'analog': analog,
'ftype': ftype}
self._filter.append(filter)
[docs]
class Dummy(_UnifiedScorer):
"""A dummy class for scoring actigraphic data using without any filtration.
Attributes:
threshold (float): Threshold value separating sleep and wake.
"""
def __init__(self, threshold: float):
filter = [{None: None}]
super().__init__(threshold, filter)
[docs]
def to_score(self, acti_data: _Epoched | GenericData | _Raw) -> ScoredShort:
"""Apply scoring algorithm to the input data.
Args:
acti_data (_Epoched | GenericData | _Raw): Epoched data to be scored by the algorithm. Scoring by unified filter does not require previous epoching.
Returns:
ScoredShort: Object containng scored data.
"""
if isinstance(acti_data, _Raw):
epoching_method_metadata = None
else:
epoching_method_metadata = acti_data.epoching_method_metadata
if self._threshold is None:
threshold = self._interpolate_threshold(acti_data)
else:
threshold = self._threshold
out = np.empty((2, acti_data.data.shape[1]))
out[1, :] = acti_data.timestamp
out[0, :] = acti_data.to_score >= (1 / threshold)
metadata = self._scoring_method_metadata.copy()
metadata['threshold'] = threshold
return ScoredShort(out, acti_data.metadata, acti_data.fs, epoching_method_metadata, metadata)
[docs]
class UFA(_UnifiedScorer):
"""A class for scoring actigraphic data using Unified Filter Approach (for details see Biegański et al. 2026).
Attributes:
threshold (float | None = None): Threshold value separating sleep and wake. Autodetected based on input data if None. Defaults to None.
"""
def __init__(self, threshold: float | None = None):
new_filter = {'wp': 0.0006410425509217069,
'ws': 0.006805922349578073,
'gpass': 3.5661765768415044,
'gstop': 14.594650542815454,
'analog': False,
'ftype': 'ellip'}
super().__init__(threshold, [new_filter])