Source code for ada.short.cole_kripke

from dataclasses import dataclass
from typing import Tuple
import numpy as np
import scipy.signal as ss
from math import isclose

from ada.data_containers.scored import ScoredShort
from ada.data_containers._base import _Epoched
from ada.data_containers.generic import GenericData
from ada.short.unified import _ShortScorer


@dataclass(slots=True, eq=False)
class _ColeKripkeGeneralized(_ShortScorer):

    _filter: np.ndarray
    _rescoring: bool = False
    _input_epoch: int = 60
    _shift: int = 0

    def __post_init__(self):
        self._scoring_method_metadata = {'algorithm': type(self).__name__,
                                         'threshold': self._threshold}
        self._default_thresholds = _ShortScorer._load_defaults()

    def _rescore(self, data):
        pass

    def to_score(self, epoched: _Epoched | GenericData) -> ScoredShort:
        """Apply scoring algorithm to the input data.

        Args:
            epoched (_Epoched | GenericData): Epoched data to be scored by the algorithm. Scoring by Cole-Kripke family requires previous epoching.

        Returns:
            ScoredShort: Object containng scored data.
        """
        if epoched.epoching_method_metadata is None:
            raise ValueError("Scoring by Cole-Kripke family requires previous epoching.")
        
        if not isclose(1 / self._input_epoch, epoched.fs):
            raise ValueError("This scoring algorithm is designed for data epoched in {} s windows. Please use correct input data".format(self._input_epoch))

        if self._threshold is None:
            threshold = self._interpolate_threshold(epoched)
        else:
            threshold = self._threshold
        
        out = np.empty((2, epoched.data.shape[1]))
        out[1, :] = epoched.timestamp
        filter = threshold * self._filter
        out[0, :] = ss.fftconvolve(epoched.to_score, filter, mode='full')[len(self._filter) - self._shift - 1:-self._shift] >= 1
        if self._rescoring:
            out[0, :] = self._rescore(out[0, :])
        metadata = self._scoring_method_metadata.copy()
        metadata['threshold'] = threshold
        return ScoredShort(out, epoched.metadata, epoched.fs, epoched.epoching_method_metadata, metadata)
    
    def transmittance(self, n_points: int = 2048, db: bool = True) -> Tuple[np.ndarray, np.ndarray]:
        """Transmittance of filter used in the algorithm. Evaluated for correct sampling frequency (as defined by given algorithm).

        Args:
            n_points (int, optional): Number of point at which filter will be evaluated. Defaults to 2048.
            db (bool, optional): Whether to return in dB scale. Defaults to True.

        Returns:
            Tuple[np.ndarray, np.ndarray]: Vector of frequencies and vector of frequency responses.
        """
        f, h = ss.freqz(self._filter, 1, n_points, fs=1 / self._input_epoch)
        if db:
            return f, 20 * np.log10(np.abs(h))
        return f, np.abs(h)


[docs] class ColeKripke(_ColeKripkeGeneralized): """Class for scoring data by Cole-Kripke algorithm. See Cole et al. (1992) for details. Attributes threshold (float | None): Threshold value discriminating sleep and wake epochs. Autodetected based on input data if None. Defaults to None. rescoring (bool): Whether to apply rescoring rules. Defaults to True. """ def __init__(self, threshold: float | None = None, rescoring: bool = True) -> None: filter = np.flip(np.array([404, 598, 326, 441, 1408, 508, 350])) super().__init__(threshold, filter, rescoring, 60, 4) self._scoring_method_metadata['rescoring'] = self._rescoring def _rescore(self, data): # TODO can this be more optimal? Looks shit # Rescoring rules (a)-(c) wake_count = 0 for i in range(0, len(data) - 4): if data[i] == 1: wake_count += 1 else: if wake_count >= 15: data[i:i + 4] = 1 elif wake_count >= 10: data[i:i + 3] = 1 elif wake_count >= 4: data[i] = 1 wake_count = 0 # Rescoring rule (d) wake_count = 0 sleep_count = 0 for i in range(0, len(data) - 10): if data[i] == 1: wake_count += 1 if sleep_count in range(1, 7) and np.sum(data[i:i + 10]) == 10 and np.sum(data[i - sleep_count - 10:i - sleep_count]) == 10: data[i - sleep_count:i] = 1 sleep_count = 0 else: sleep_count += 1 # Rescoring rule (e) wake_count = 0 sleep_count = 0 for i in range(0, len(data) - 20): if data[i] == 1: wake_count += 1 if sleep_count in range(1, 11) and np.sum(data[i:i + 20]) == 20 and np.sum(data[i - sleep_count - 20:i - sleep_count]) == 20: data[i - sleep_count:i] = 1 sleep_count = 0 else: sleep_count += 1 return data
[docs] class Sazonov(_ColeKripkeGeneralized): """Class for scoring data by Sazonov algorithm. See Sazonov et al. (2004) for details. Attributes threshold (float | None): Threshold value discriminating sleep and wake epochs. Autodetected based on input data if None. Defaults to None. """ def __init__(self, threshold: float | None = None) -> None: filter = np.array([.19450, .09746, .09975, .10194, .08917, .08108, .07494, .07300, .10207]) super().__init__(threshold, filter, False, 30, 8)
[docs] class Webster(_ColeKripkeGeneralized): """Class for scoring data by Sazonov algorithm. See Webster et al. (1989) for details. Attributes threshold (float | None): Threshold value discriminating sleep and wake epochs. Autodetected based on input data if None. Defaults to None. """ def __init__(self, threshold: float | None = None) -> None: filter = np.flip(np.array([.07, .08, .1, .11, .12, .14, .09, .09, .09, .1])) super().__init__(threshold, filter, False, 60, 5)
[docs] class Ucsd(_ColeKripkeGeneralized): """Class for scoring data by Sazonov algorithm. See Jean-Louis et al. (2001) for details. Attributes threshold (float | None): Threshold value discriminating sleep and wake epochs. Autodetected based on input data if None. Defaults to None. """ def __init__(self, threshold: float | None = None) -> None: filter = np.flip(np.array([.01, .015, .028, .031, .085, .015, .01])) super().__init__(threshold, filter, False, 60, 4)
[docs] class Scripps(_ColeKripkeGeneralized): """Class for scoring data by Scripps Clinic algorithm. See Kripke et al. (2012) for details. Attributes threshold (float | None): Threshold value discriminating sleep and wake epochs. Autodetected based on input data if None. Defaults to None. rescoring (bool): Whether to apply rescoring rules. Defaults to True. """ def __init__(self, threshold: float | None = None, rescoring: bool = True) -> None: filter = np.array([.01, .0112, .03, .0664, .028, .0188, .0128, .0118, .0118, .0112, .0112, .0074, .0064]) super().__init__(threshold, filter, rescoring, 30, 10) self._scoring_method_metadata['rescoring'] = self._rescoring def _rescore(self, data): # TODO rewrite without loop? sleep = 0 c = 0 while (sleep <= 10 and c < len(data)): if data[c] == 0: sleep += 1 else: sleep = 0 c += 1 try: # Theoretically there can be no sleep in data... if sleep == 10: data[:c - 10] = 1 except IndexError: pass return data