Source code for ada.preprocessing.epocher

from typing import Callable
import numpy as np
from tqdm import tqdm
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
import warnings
from scipy.signal import resample_poly, sosfiltfilt, iirfilter
from scipy.stats import gamma
from scipy.interpolate import splrep, splev
from scipy.integrate import trapezoid
from math import ceil

from ada.data_containers._base import _Raw, _Epoched
from ada.data_containers.epoched import EpochedActivityIndex, EpochedCK, Resampled, EpochedMIMS, GenericMVM
from ada.io.geneactiv import RawGeneActiv, GeneActivMVM
from ada.data_containers.generic import GenericData
from ada.io.acti_eeg import ActiPSG


@dataclass(slots=True, eq=False)
class _Epocher(ABC):
    """Framework for different epoching methods.
    
    Attributes:
        epoch_length (int): Length of output epochs in seconds.
    """
    
    _epoch_length: int
    _epoching_method_metadata: dict = field(init=False)

    @property
    def epoch_length(self) -> int:
        """Output data epoch length in seconds."""
        return self._epoch_length

    @abstractmethod
    def _to_epoch(self, raw: _Raw | GenericData) -> _Epoched:
        pass

    def to_epoch(self, data: ActiPSG | _Raw | GenericData, wake_percentage: float = 0.8) -> ActiPSG | _Epoched:
        """Epoching actigraphic data. Works for both ActiData and ActiPSG. Actigraphic data will be epoched by the selected Epocher. PSG data (if present) will be epoched using the change_epoch method. In case of ActiPSG intended to work mostly with 20, 30 and 60 s epochs.

        Args:
            data (ActiPSG | _Raw | GenericData): Object containing actigraphic data and psg scoring
            wake_percentage (float, optional): Parameter describing how many points inside a new epoch must be wake, so the new epoch will also be wake. Meaningless when epoching only actigraphic data. Defaults to 0.8.

        Raises:
            ValueError: If actigraphic data is already scored. Don't do this.

        Returns:
            ActiPSG | _Epoched: Epoched actigraphic data and PSG scoring (if input was ActiPsg).
        """
        if not isinstance(data, ActiPSG):
            return self._to_epoch(data)

        if not (isinstance(data.acti_data, _Raw) or data.acti_data.epoching_method_metadata is None) or isinstance(data.acti_data, _Epoched):
            raise ValueError("Do not epoch already epoched data.")
        
        new_acti = self._to_epoch(data.acti_data)
        new_psg = data.eeg_data.change_epoch(self._epoch_length, wake_percentage)
        out = ActiPSG(new_acti, new_psg, True)
        out._trim_to_eeg = data._trim_to_eeg
        return out


[docs] @dataclass(slots=True, eq=False) class ActivityIndex(_Epocher): """A class for epoching actigraphic data using algorithm by Bai et al. (2013). Attributes: epoch_length (int): Length of output epochs in seconds. normalized (bool): Whether to return data normalized (see paper for details). Defaults to True. stationary_variance (float | None): Variance of data recorded from stationary actigraph. Autodetected from device if None. """ _normalized: bool = True _stationary_variance: float | None = None def __post_init__(self): self._epoching_method_metadata = {'epoching method': 'ActivityIndex', 'epoch length': self._epoch_length, 'normalized': self._normalized, 'stationary variance': self._stationary_variance, 'to_score channel': 0} @property def normalized(self) -> bool: """Whether to return data normalized (see paper for details). Defaults to True.""" return self._normalized @property def stationary_variance(self) -> float | None: """Variance of data recorded from stationary actigraph.""" return self._stationary_variance def _to_epoch(self, raw: _Raw | GenericData) -> EpochedActivityIndex: """Epoch given data using Activity Index algorithm. Args: raw (_Raw | GenericData): Raw actigraphic data to be epoched. Returns: EpochedActivityIndex: Object containing epoched data. """ if not isinstance(raw, _Raw): if raw.epoching_method_metadata is not None and raw.scoring_method_metadata is not None: raise ValueError("Provided data is not raw data.") stationary_variance = raw.stationary_variance if self._stationary_variance is None else self._stationary_variance n = raw.data.shape[1] # length of non-epoch data N = int(self._epoch_length * raw.fs) # window size in samples length = n - N def haxis(x, i): return np.var(x[i:i + N]) - stationary_variance out = np.zeros((2, length // N + 1)) pbar = tqdm(total=length // N, position=0, leave=True, desc="Epoching...", dynamic_ncols=True) if self._normalized: for i in range(out.shape[1]): out[0, i] = np.sqrt(max((haxis(raw.data[raw.channel_names.index('x')], i * N) + haxis(raw.data[raw.channel_names.index('y')], i * N) + haxis(raw.data[raw.channel_names.index('z')], i * N)) / (3 * stationary_variance), 0)) out[1, i] = raw.timestamp[0] + i * self._epoch_length pbar.update(1) else: for i in range(out.shape[1]): out[0, i] = np.sqrt(max((haxis(raw.data[raw.channel_names.index('x')], i * N) + haxis(raw.data[raw.channel_names.index('y')], i * N) + haxis(raw.data[raw.channel_names.index('z')], i * N)) / 3, 0)) out[1, i] = raw.timestamp[0] + i * self._epoch_length pbar.update(1) pbar.close() metadata = self._epoching_method_metadata.copy() metadata['stationary variance'] = stationary_variance return EpochedActivityIndex(out, raw.metadata, 1 / self._epoch_length, ['activity index', 'timestamp'], metadata)
[docs] @dataclass(slots=True, eq=False) class CKEpocher(_Epocher): """A class for epoching data in the style of Cole-Kripke, as described in Cole et al. (1992). Attributes: epoch_length (int): Length of output epochs in seconds. subepoch_length (int | None): Lenght of subepochs in each epoch. If None, there are no subepochs. Defaults to None. See paper for details. metric (Callable): Metric to be calculated on each subepoch. Defaults to np.max. See paper for details. """ _subepoch_length: int | None = None _metric: Callable = np.max def __post_init__(self): if self._subepoch_length is not None: if self._epoch_length % self._subepoch_length != 0: raise ValueError("Epoch length must be integer multiple of subepoch length!") self._epoching_method_metadata = {'epoching method': 'CKEpocher', 'epoch length': self._epoch_length, 'subepoch length': self._subepoch_length, 'metric': self._metric.__name__, 'to_score channel': 0} # This is temporary, value depends on raw shape @property def subepcoh_length(self) -> int | None: """Length of subepochs in each epoch. If None, there are no subepochs. Defaults to None. See paper for details.""" return self._subepoch_length @property def metric(self) -> str: """Metric to be calculated on each subepoch. Defaults to np.max. See paper for details.""" return self._metric.__name__ def _to_epoch(self, raw: _Raw | GenericData) -> EpochedCK: """Epoch given data using Cole-Kripke style algorithm. Args: raw (_Raw | GenericData): Raw actigraphic data to be epoched. Returns: EpochedCK: Object containing epoched data. """ if not isinstance(raw, _Raw): if raw.epoching_method_metadata is not None and raw.scoring_method_metadata is not None: raise ValueError("Provided data is not raw data.") fs = raw.fs n = raw.data.shape[1] n_epoch = int(self._epoch_length * fs) N = int(n // n_epoch) # number of epochs mod = n % n_epoch # number of excessive points (last, shorter epoch) ts_index = raw.channel_names.index('timestamp') timestamp = raw.data[ts_index, :int(n - mod)] # timestamp will be epoched differently epoched_data = np.zeros((raw.data.shape[0] + 1, N)) vlen = raw.to_score[:int(n - mod)] to_epoch_data = np.concatenate((raw.data[:, :int(n - mod)], vlen.reshape(1, -1)), axis=0) if self._subepoch_length is None: subN = 1 else: subN = self._epoch_length // self._subepoch_length # number of subepochs in each epoch pbar = tqdm(total=N, position=0, leave=True, desc='Epoching...', dynamic_ncols=True) for i in range(N): epoch = to_epoch_data[:, i * n_epoch:(i + 1) * n_epoch] epoch = np.array(np.split(epoch, subN, axis=1)) # dividing epoch into subepochs epoched_data[:, i] = np.mean(self._metric(epoch, axis=2).T, axis=1) # calculating metric on subepochs and recording epoch's value epoched_data[ts_index, i] = timestamp[i * n_epoch] # timestamp of epoch is timestamp of it's begining pbar.update(1) pbar.close() self._epoching_method_metadata['to_score channel'] = epoched_data.shape[0] - 1 channel_names = raw.channel_names.copy() channel_names.append('epoched vlen') return EpochedCK(epoched_data, raw.metadata, 1 / self._epoch_length, channel_names, self._epoching_method_metadata.copy())
[docs] @dataclass(slots=True, eq=False) class MVM(_Epocher): """A class for epoching data in the style of GeneActiv software. See GeneActiv documentation for details. Attributes: epoch_length (int): Length of output epochs in seconds. """ def __post_init__(self): self._epoching_method_metadata = {'epoching method': 'MVM', 'epoch length': self._epoch_length, 'to_score channel': 7} def _to_epoch(self, raw: _Raw | GenericData) -> GeneActivMVM | GenericMVM: """Epoch given data using MVM algorithm as defined by GeneActiv documentation. Args: raw (_Raw | GenericData): Raw actigraphic data to be epoched. Returns: GeneActivMVM | GenericData: Object containing epoched data. If input was RawGeneActiv, GeneActivMVM is returned; GenericData otherwise. """ if not isinstance(raw, _Raw): if raw.epoching_method_metadata is not None and raw.scoring_method_metadata is not None: raise ValueError("Provided data is not raw data.") n_epochs = int(raw.data.shape[1] / raw.fs / self._epoch_length) epoched_data = np.full((12, n_epochs), None, dtype=float) indexes = {0: raw.channel_names.index('x'), 1: raw.channel_names.index('y'), 2: raw.channel_names.index('z')} # formats different than GeneActiv might not contain these fields try: indexes[3] = raw.channel_names.index('lux') except ValueError: warnings.warn("Provided raw does not contain lux data.") try: button_idx = raw.channel_names.index('button') except ValueError: warnings.warn("Provided raw does not contain button data.") button_idx = None try: indexes[5] = raw.channel_names.index('temperature') except ValueError: warnings.warn("Provided raw does not contain temperature data.") print('Epoching...') n_points = int(n_epochs * raw.fs * self._epoch_length) trimmed_data = raw.data[:, :n_points] split_data = np.array(np.split(trimmed_data, n_epochs, axis=1)) for idx in indexes.keys(): epoched_data[idx] = np.mean(split_data[:, indexes[idx], :], axis=1) epoched_data[6] = np.linspace(trimmed_data[6, int(self._epoch_length * raw.fs) - 1], trimmed_data[6, -1], n_epochs) # timestamp of epoch is timestamp of its end (FU GeneActiv) epoched_data[8] = np.std(split_data[:, indexes[0], :], axis=1) epoched_data[9] = np.std(split_data[:, indexes[1], :], axis=1) epoched_data[10] = np.std(split_data[:, indexes[2], :], axis=1) def mvm(x): return np.abs(np.sqrt(x[:, 0, :] ** 2 + x[:, 1, :] ** 2 + x[:, 2, :] ** 2) - 1) # x is numpy matrix conatining x, y, z axes split into n_epochs chunks epoched_data[7] = np.sum(mvm(split_data[:, [indexes[0], indexes[1], indexes[2]], :]), axis=1) if button_idx is not None: epoched_data[4] = np.sum(split_data[:, button_idx, :], axis=1) if 'lux' in raw.channel_names: epoched_data[11] = np.max(split_data[:, raw.channel_names.index('lux'), :], axis=1) channel_names = ['x', 'y', 'z', 'lux', 'button', 'temperature', 'timestamp', 'mvm', 'x_std', 'y_std', 'z_std', 'peak_lux'] if isinstance(raw, RawGeneActiv): return GeneActivMVM(epoched_data, raw.metadata, 1 / self._epoch_length, channel_names, self._epoching_method_metadata.copy()) else: warnings.warn("Epoching raw data not generated by the GeneActiv. Data will be returned in generic format.") return GenericMVM(epoched_data, raw.metadata, 1 / self._epoch_length, channel_names, self._epoching_method_metadata.copy())
[docs] @dataclass(slots=True, eq=False) class Downsampler(_Epocher): """A class for epoching (downsampling) the data using standard downsampling procedure. The only one which guarantees no aliasing artifacts. Attributes: epoch_length (int): Length of output epochs in seconds. """ def __post_init__(self): self._epoching_method_metadata = {'epoching method': 'Downsampler', 'epoch length': self._epoch_length, 'to_score channel': 0} # This is temporary, value depends on raw shape def _to_epoch(self, raw: _Raw | GenericData) -> Resampled: """Epoch given data using standard downsampling procedure. Args: raw (_Raw | GenericData): Raw actigraphic data to be epoched. Returns: Resampled: Object containing epoched data. """ if not isinstance(raw, _Raw): if raw.epoching_method_metadata is not None and raw.scoring_method_metadata is not None: raise ValueError("Provided data is not raw data.") if int(raw.fs * self._epoch_length) == raw.fs * self._epoch_length: down = int(raw.fs * self._epoch_length) up = 1 else: raise RuntimeError("WTF is that sampling frequency and epoch length combination? Find another.") print("Epoching...") resampled_data = resample_poly(raw.data, up, down, axis=1) resampled_data[raw.channel_names.index('timestamp')] = np.linspace(raw.timestamp[0], raw.timestamp[0] + (resampled_data.shape[1] - 1) * self._epoch_length, resampled_data.shape[1]) # resampled_vlen = resample_poly(raw.to_score, up, down) # this seems to emphasize floating / position changes? resampled_vlen = np.abs(np.sqrt(resampled_data[raw.channel_names.index('x')] ** 2 + resampled_data[raw.channel_names.index('y')] ** 2 + resampled_data[raw.channel_names.index('z')] ** 2) - 1) resampled_data = np.concatenate((resampled_data, resampled_vlen.reshape(1, -1)), axis=0) del resampled_vlen self._epoching_method_metadata['to_score channel'] = resampled_data.shape[0] - 1 channel_names = raw.channel_names.copy() channel_names.append('resampled vlen') return Resampled(resampled_data, raw.metadata, 1 / self._epoch_length, channel_names, self._epoching_method_metadata.copy())
[docs] @dataclass(slots=True, eq=False) class MIMS(_Epocher): """A class for epoching (downsampling) the data using (slightly modified) MIMS algorithm by Dinesh et al. (2019). Attributes: epoch_length (int): Length of output epochs in seconds. stationary_variance (float | None): Variance of data recorded from stationary actigraph. Autodetected from device if None. dynamic_range (float | None): Maximum acceleration recorded by sensor. Autodetected from device if None. extrapolation_samples (int): Number of samples taken into account in the extrapolation step (see article). extrapolation_shape (float): Shape parameter of splines used during extrapolation step. """ _stationary_variance: float | None = None _dynamic_range: float | None = None _extrapolation_samples: int = 5 _extrapolation_shape: float = .6 def __post_init__(self): self._epoching_method_metadata = {'epoching method': 'MIMS', 'epoch length': self._epoch_length, 'stationary variance': self._stationary_variance, 'dynamic range': self._dynamic_range, 'extrapolation samples': self._extrapolation_samples, 'extrapolation shape': self._extrapolation_shape, 'to_score channel': 4} @property def stationary_variance(self) -> float | None: """Variance of data recorded from stationary actigraph.""" return self._stationary_variance def _to_epoch(self, raw: _Raw | GenericData) -> _Epoched: """Epoch given data using MIMS algorithm. Args: raw (_Raw | GenericData): Raw actigraphic data to be epoched. Returns: EpochedMIMS: Object containing epoched data. """ stationary_variance = raw.stationary_variance if self._stationary_variance is None else self._stationary_variance dynamic_range = raw.dynamic_range if self._dynamic_range is None else self._dynamic_range # preparations test_k = np.linspace(0.01, 0.5, 1000) # TODO this range might need extension, depends on available devices and their stationary variance k = test_k[np.argmin(np.abs(gamma.cdf(3 * stationary_variance ** .5, test_k, scale=1) - 0.95))] gamma_cdf = np.vectorize(lambda x: gamma.cdf(x, a=k, scale=1) if x >= 0. else 0.) sos = iirfilter(4, [.2, 5], btype='bandpass', ftype='butter', fs=100, output='sos') # interpolate down, up = (raw.fs / 100).as_integer_ratio() data = np.empty((3, ceil(raw.data.shape[1] * up / down))) time = np.linspace(raw.timestamp[0], raw.timestamp[-1], data.shape[1]) data[0] = resample_poly(raw.x, up, down) data[1] = resample_poly(raw.y, up, down) data[2] = resample_poly(raw.z, up, down) for ch in tqdm(range(3), desc="Calculating MIMS epochs"): # extrapolate temp = np.abs(data[ch]) - dynamic_range + 3 * stationary_variance ** .5 maxed_probability = gamma_cdf(temp) shifted = np.zeros_like(maxed_probability) shifted[1:] = maxed_probability[:-1] diff = maxed_probability - shifted idx = zip(np.nonzero(diff > 0.5)[0], np.nonzero(diff < -0.5)[0]) for e in idx: samples_before = data[ch, e[0] - self._extrapolation_samples:e[0]] samples_after = data[ch, e[1]:e[1] + self._extrapolation_samples] time_before = time[e[0] - self._extrapolation_samples:e[0]] time_after = time[e[1]:e[1] + self._extrapolation_samples] weights_before = 1 - maxed_probability[e[0] - self._extrapolation_samples:e[0]] weights_after = 1 - maxed_probability[e[1]:e[1] + self._extrapolation_samples] time_middle = (time_after[0] + time_before[-1]) / 2 spline_before = splev(time_middle, splrep(time_before, samples_before, weights_before, s=self._extrapolation_shape)) # TODO in newer scipy it is make_splrep spline_after = splev(time_middle, splrep(time_after, samples_after, weights_after, s=self._extrapolation_shape)) sample_middle = (spline_after + spline_before) / 2 spline_extrapolated = splrep(np.concatenate((time_before, [time_middle], time_after)), np.concatenate((samples_before, [sample_middle], samples_after)), s=self._extrapolation_shape) data[ch, e[0]:e[1]] = splev(time[e[0]:e[1]], spline_extrapolated) # filter data[ch] = sosfiltfilt(sos, data[ch]) # aggregate epoch_len = self._epoch_length * 100 n = data.shape[1] // epoch_len data = trapezoid(np.array(np.split(np.abs(data[:, :n * epoch_len]), n, axis=1)), axis=2, dx=.01).T data[data < .001] = 0 output = np.empty((5, data.shape[1])) output[:3] = data output[4] = np.sum(data, axis=0) # output[3] = np.arange(raw.timestamp[0], raw.timestamp[0] + n * self._epoch_length, self._epoch_length) output[3] = np.linspace(raw.timestamp[0], raw.timestamp[0] + (n - 1) * self._epoch_length, output.shape[1]) channel_names = ['x', 'y', 'z', 'timestamp', 'summary MIMS'] metadata = self._epoching_method_metadata.copy() metadata['stationary variance'] = stationary_variance metadata['dynamic range'] = dynamic_range return EpochedMIMS(output, raw.metadata, 1 / self._epoch_length, channel_names, metadata)
[docs] def main(): """A script for epoching a batch of files. Run with -h or --help for details. """ import argparse import os from pathlib import Path from ada.io.file_manager import FileManager methods = {'ai': ActivityIndex, 'mvm': MVM, 'downsample': Downsampler, 'ck': CKEpocher, 'mims': MIMS} parser = argparse.ArgumentParser(description="Epoch given files and export epoched.", add_help=False) parser.add_argument('--files', '-f', nargs='+', type=str, required=True, help="Path to raw files.") parser.add_argument('--outdir', '-o', type=str, default=None, help="Directory in which epoched files will be saved.") parser.add_argument('--epocher', '-e', choices=methods.keys(), default='resample', help="Epoching algorithm.") parser.add_argument('--epoch-length', '-l', type=int, default=30, help="Epoch length in seconds.") namespace = parser.parse_known_args()[0] if namespace.epocher == 'ai': parser.add_argument('--stationary-variance', type=float, default=.0004563, help="Variance of stationary actigraph.") parser.add_argument('--normalized', type=bool, default=True, help="Algorithm with normalization.") elif namespace.epocher == 'ck': parser.add_argument('--subepoch-length', type=int, default=None, help="Subepoch length in seconds. If not given, None.") parser.add_argument('-h', '--help', action='help', default='==SUPPRESS==') namespace = parser.parse_args() if len(namespace.files) == 1: files_to_work = [str(e) for e in Path(os.path.dirname(namespace.files[0])).glob(os.path.basename(namespace.files[0]))] else: files_to_work = namespace.files # Epocher-specific settings if namespace.epocher == 'ai': epocher = methods[namespace.epocher](namespace.epoch_length, namespace.normalized, namespace.stationary_variance) elif namespace.epocher == 'ck': epocher = methods[namespace.epocher](namespace.epoch_length, namespace.subepoch_length) else: epocher = methods[namespace.epocher](namespace.epoch_length) for file in files_to_work: filename = os.path.basename(file).split('.')[0] if namespace.outdir is None: path = os.path.dirname(file) folder = "{}_{}s".format(namespace.epocher, namespace.epoch_length) outdir = os.path.join(path, folder) else: outdir = namespace.outdir os.makedirs(outdir, exist_ok=True) outfile = os.path.join(outdir, filename + '.ada') if os.path.exists(outfile): print('File already exists, skipping', outfile) continue acti = FileManager.load_file(file) epoched = epocher.to_epoch(acti) epoched = GenericData.from_nongeneric(epoched) # just to make sure mvm will not save as .csv epoched.export(outfile) del epoched del acti
if __name__ == '__main__': main()