from dataclasses import dataclass
from typing import Tuple
import numpy as np
import scipy.signal as ss
from math import isclose
from ada.data_containers.scored import ScoredShort
from ada.data_containers._base import _Epoched
from ada.data_containers.generic import GenericData
from ada.short.unified import _ShortScorer
@dataclass(slots=True, eq=False)
class _ColeKripkeGeneralized(_ShortScorer):
_filter: np.ndarray
_rescoring: bool = False
_input_epoch: int = 60
_shift: int = 0
def __post_init__(self):
self._scoring_method_metadata = {'algorithm': type(self).__name__,
'threshold': self._threshold}
self._default_thresholds = _ShortScorer._load_defaults()
def _rescore(self, data):
pass
def to_score(self, epoched: _Epoched | GenericData) -> ScoredShort:
"""Apply scoring algorithm to the input data.
Args:
epoched (_Epoched | GenericData): Epoched data to be scored by the algorithm. Scoring by Cole-Kripke family requires previous epoching.
Returns:
ScoredShort: Object containng scored data.
"""
if epoched.epoching_method_metadata is None:
raise ValueError("Scoring by Cole-Kripke family requires previous epoching.")
if not isclose(1 / self._input_epoch, epoched.fs):
raise ValueError("This scoring algorithm is designed for data epoched in {} s windows. Please use correct input data".format(self._input_epoch))
if self._threshold is None:
threshold = self._interpolate_threshold(epoched)
else:
threshold = self._threshold
out = np.empty((2, epoched.data.shape[1]))
out[1, :] = epoched.timestamp
filter = threshold * self._filter
out[0, :] = ss.fftconvolve(epoched.to_score, filter, mode='full')[len(self._filter) - self._shift - 1:-self._shift] >= 1
if self._rescoring:
out[0, :] = self._rescore(out[0, :])
metadata = self._scoring_method_metadata.copy()
metadata['threshold'] = threshold
return ScoredShort(out, epoched.metadata, epoched.fs, epoched.epoching_method_metadata, metadata)
def transmittance(self, n_points: int = 2048, db: bool = True) -> Tuple[np.ndarray, np.ndarray]:
"""Transmittance of filter used in the algorithm. Evaluated for correct sampling frequency (as defined by given algorithm).
Args:
n_points (int, optional): Number of point at which filter will be evaluated. Defaults to 2048.
db (bool, optional): Whether to return in dB scale. Defaults to True.
Returns:
Tuple[np.ndarray, np.ndarray]: Vector of frequencies and vector of frequency responses.
"""
f, h = ss.freqz(self._filter, 1, n_points, fs=1 / self._input_epoch)
if db:
return f, 20 * np.log10(np.abs(h))
return f, np.abs(h)
[docs]
class ColeKripke(_ColeKripkeGeneralized):
"""Class for scoring data by Cole-Kripke algorithm. See Cole et al. (1992) for details.
Attributes
threshold (float | None): Threshold value discriminating sleep and wake epochs. Autodetected based on input data if None. Defaults to None.
rescoring (bool): Whether to apply rescoring rules. Defaults to True.
"""
def __init__(self, threshold: float | None = None, rescoring: bool = True) -> None:
filter = np.flip(np.array([404, 598, 326, 441, 1408, 508, 350]))
super().__init__(threshold, filter, rescoring, 60, 4)
self._scoring_method_metadata['rescoring'] = self._rescoring
def _rescore(self, data): # TODO can this be more optimal? Looks shit
# Rescoring rules (a)-(c)
wake_count = 0
for i in range(0, len(data) - 4):
if data[i] == 1:
wake_count += 1
else:
if wake_count >= 15:
data[i:i + 4] = 1
elif wake_count >= 10:
data[i:i + 3] = 1
elif wake_count >= 4:
data[i] = 1
wake_count = 0
# Rescoring rule (d)
wake_count = 0
sleep_count = 0
for i in range(0, len(data) - 10):
if data[i] == 1:
wake_count += 1
if sleep_count in range(1, 7) and np.sum(data[i:i + 10]) == 10 and np.sum(data[i - sleep_count - 10:i - sleep_count]) == 10:
data[i - sleep_count:i] = 1
sleep_count = 0
else:
sleep_count += 1
# Rescoring rule (e)
wake_count = 0
sleep_count = 0
for i in range(0, len(data) - 20):
if data[i] == 1:
wake_count += 1
if sleep_count in range(1, 11) and np.sum(data[i:i + 20]) == 20 and np.sum(data[i - sleep_count - 20:i - sleep_count]) == 20:
data[i - sleep_count:i] = 1
sleep_count = 0
else:
sleep_count += 1
return data
[docs]
class Sazonov(_ColeKripkeGeneralized):
"""Class for scoring data by Sazonov algorithm. See Sazonov et al. (2004) for details.
Attributes
threshold (float | None): Threshold value discriminating sleep and wake epochs. Autodetected based on input data if None. Defaults to None.
"""
def __init__(self, threshold: float | None = None) -> None:
filter = np.array([.19450, .09746, .09975, .10194, .08917, .08108, .07494, .07300, .10207])
super().__init__(threshold, filter, False, 30, 8)
[docs]
class Webster(_ColeKripkeGeneralized):
"""Class for scoring data by Sazonov algorithm. See Webster et al. (1989) for details.
Attributes
threshold (float | None): Threshold value discriminating sleep and wake epochs. Autodetected based on input data if None. Defaults to None.
"""
def __init__(self, threshold: float | None = None) -> None:
filter = np.flip(np.array([.07, .08, .1, .11, .12, .14, .09, .09, .09, .1]))
super().__init__(threshold, filter, False, 60, 5)
[docs]
class Ucsd(_ColeKripkeGeneralized):
"""Class for scoring data by Sazonov algorithm. See Jean-Louis et al. (2001) for details.
Attributes
threshold (float | None): Threshold value discriminating sleep and wake epochs. Autodetected based on input data if None. Defaults to None.
"""
def __init__(self, threshold: float | None = None) -> None:
filter = np.flip(np.array([.01, .015, .028, .031, .085, .015, .01]))
super().__init__(threshold, filter, False, 60, 4)
[docs]
class Scripps(_ColeKripkeGeneralized):
"""Class for scoring data by Scripps Clinic algorithm. See Kripke et al. (2012) for details.
Attributes
threshold (float | None): Threshold value discriminating sleep and wake epochs. Autodetected based on input data if None. Defaults to None.
rescoring (bool): Whether to apply rescoring rules. Defaults to True.
"""
def __init__(self, threshold: float | None = None, rescoring: bool = True) -> None:
filter = np.array([.01, .0112, .03, .0664, .028, .0188, .0128, .0118, .0118, .0112, .0112, .0074, .0064])
super().__init__(threshold, filter, rescoring, 30, 10)
self._scoring_method_metadata['rescoring'] = self._rescoring
def _rescore(self, data): # TODO rewrite without loop?
sleep = 0
c = 0
while (sleep <= 10 and c < len(data)):
if data[c] == 0:
sleep += 1
else:
sleep = 0
c += 1
try: # Theoretically there can be no sleep in data...
if sleep == 10:
data[:c - 10] = 1
except IndexError:
pass
return data