from typing import Callable
import numpy as np
from tqdm import tqdm
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
import warnings
from scipy.signal import resample_poly, sosfiltfilt, iirfilter
from scipy.stats import gamma
from scipy.interpolate import splrep, splev
from scipy.integrate import trapezoid
from math import ceil
from ada.data_containers._base import _Raw, _Epoched
from ada.data_containers.epoched import EpochedActivityIndex, EpochedCK, Resampled, EpochedMIMS, GenericMVM
from ada.io.geneactiv import RawGeneActiv, GeneActivMVM
from ada.data_containers.generic import GenericData
from ada.io.acti_eeg import ActiPSG
@dataclass(slots=True, eq=False)
class _Epocher(ABC):
"""Framework for different epoching methods.
Attributes:
epoch_length (int): Length of output epochs in seconds.
"""
_epoch_length: int
_epoching_method_metadata: dict = field(init=False)
@property
def epoch_length(self) -> int:
"""Output data epoch length in seconds."""
return self._epoch_length
@abstractmethod
def _to_epoch(self, raw: _Raw | GenericData) -> _Epoched:
pass
def to_epoch(self, data: ActiPSG | _Raw | GenericData, wake_percentage: float = 0.8) -> ActiPSG | _Epoched:
"""Epoching actigraphic data. Works for both ActiData and ActiPSG. Actigraphic data will be epoched by the selected Epocher. PSG data (if present) will be epoched using the change_epoch method. In case of ActiPSG intended to work mostly with 20, 30 and 60 s epochs.
Args:
data (ActiPSG | _Raw | GenericData): Object containing actigraphic data and psg scoring
wake_percentage (float, optional): Parameter describing how many points inside a new epoch must be wake, so the new epoch will also be wake. Meaningless when epoching only actigraphic data. Defaults to 0.8.
Raises:
ValueError: If actigraphic data is already scored. Don't do this.
Returns:
ActiPSG | _Epoched: Epoched actigraphic data and PSG scoring (if input was ActiPsg).
"""
if not isinstance(data, ActiPSG):
return self._to_epoch(data)
if not (isinstance(data.acti_data, _Raw) or data.acti_data.epoching_method_metadata is None) or isinstance(data.acti_data, _Epoched):
raise ValueError("Do not epoch already epoched data.")
new_acti = self._to_epoch(data.acti_data)
new_psg = data.eeg_data.change_epoch(self._epoch_length, wake_percentage)
out = ActiPSG(new_acti, new_psg, True)
out._trim_to_eeg = data._trim_to_eeg
return out
[docs]
@dataclass(slots=True, eq=False)
class ActivityIndex(_Epocher):
"""A class for epoching actigraphic data using algorithm by Bai et al. (2013).
Attributes:
epoch_length (int): Length of output epochs in seconds.
normalized (bool): Whether to return data normalized (see paper for details). Defaults to True.
stationary_variance (float | None): Variance of data recorded from stationary actigraph. Autodetected from device if None.
"""
_normalized: bool = True
_stationary_variance: float | None = None
def __post_init__(self):
self._epoching_method_metadata = {'epoching method': 'ActivityIndex',
'epoch length': self._epoch_length,
'normalized': self._normalized,
'stationary variance': self._stationary_variance,
'to_score channel': 0}
@property
def normalized(self) -> bool:
"""Whether to return data normalized (see paper for details). Defaults to True."""
return self._normalized
@property
def stationary_variance(self) -> float | None:
"""Variance of data recorded from stationary actigraph."""
return self._stationary_variance
def _to_epoch(self, raw: _Raw | GenericData) -> EpochedActivityIndex:
"""Epoch given data using Activity Index algorithm.
Args:
raw (_Raw | GenericData): Raw actigraphic data to be epoched.
Returns:
EpochedActivityIndex: Object containing epoched data.
"""
if not isinstance(raw, _Raw):
if raw.epoching_method_metadata is not None and raw.scoring_method_metadata is not None:
raise ValueError("Provided data is not raw data.")
stationary_variance = raw.stationary_variance if self._stationary_variance is None else self._stationary_variance
n = raw.data.shape[1] # length of non-epoch data
N = int(self._epoch_length * raw.fs) # window size in samples
length = n - N
def haxis(x, i):
return np.var(x[i:i + N]) - stationary_variance
out = np.zeros((2, length // N + 1))
pbar = tqdm(total=length // N, position=0, leave=True, desc="Epoching...", dynamic_ncols=True)
if self._normalized:
for i in range(out.shape[1]):
out[0, i] = np.sqrt(max((haxis(raw.data[raw.channel_names.index('x')], i * N)
+ haxis(raw.data[raw.channel_names.index('y')], i * N)
+ haxis(raw.data[raw.channel_names.index('z')], i * N)) / (3 * stationary_variance), 0))
out[1, i] = raw.timestamp[0] + i * self._epoch_length
pbar.update(1)
else:
for i in range(out.shape[1]):
out[0, i] = np.sqrt(max((haxis(raw.data[raw.channel_names.index('x')], i * N)
+ haxis(raw.data[raw.channel_names.index('y')], i * N)
+ haxis(raw.data[raw.channel_names.index('z')], i * N)) / 3, 0))
out[1, i] = raw.timestamp[0] + i * self._epoch_length
pbar.update(1)
pbar.close()
metadata = self._epoching_method_metadata.copy()
metadata['stationary variance'] = stationary_variance
return EpochedActivityIndex(out, raw.metadata, 1 / self._epoch_length, ['activity index', 'timestamp'], metadata)
[docs]
@dataclass(slots=True, eq=False)
class CKEpocher(_Epocher):
"""A class for epoching data in the style of Cole-Kripke, as described in Cole et al. (1992).
Attributes:
epoch_length (int): Length of output epochs in seconds.
subepoch_length (int | None): Lenght of subepochs in each epoch. If None, there are no subepochs. Defaults to None. See paper for details.
metric (Callable): Metric to be calculated on each subepoch. Defaults to np.max. See paper for details.
"""
_subepoch_length: int | None = None
_metric: Callable = np.max
def __post_init__(self):
if self._subepoch_length is not None:
if self._epoch_length % self._subepoch_length != 0:
raise ValueError("Epoch length must be integer multiple of subepoch length!")
self._epoching_method_metadata = {'epoching method': 'CKEpocher',
'epoch length': self._epoch_length,
'subepoch length': self._subepoch_length,
'metric': self._metric.__name__,
'to_score channel': 0} # This is temporary, value depends on raw shape
@property
def subepcoh_length(self) -> int | None:
"""Length of subepochs in each epoch. If None, there are no subepochs. Defaults to None. See paper for details."""
return self._subepoch_length
@property
def metric(self) -> str:
"""Metric to be calculated on each subepoch. Defaults to np.max. See paper for details."""
return self._metric.__name__
def _to_epoch(self, raw: _Raw | GenericData) -> EpochedCK:
"""Epoch given data using Cole-Kripke style algorithm.
Args:
raw (_Raw | GenericData): Raw actigraphic data to be epoched.
Returns:
EpochedCK: Object containing epoched data.
"""
if not isinstance(raw, _Raw):
if raw.epoching_method_metadata is not None and raw.scoring_method_metadata is not None:
raise ValueError("Provided data is not raw data.")
fs = raw.fs
n = raw.data.shape[1]
n_epoch = int(self._epoch_length * fs)
N = int(n // n_epoch) # number of epochs
mod = n % n_epoch # number of excessive points (last, shorter epoch)
ts_index = raw.channel_names.index('timestamp')
timestamp = raw.data[ts_index, :int(n - mod)] # timestamp will be epoched differently
epoched_data = np.zeros((raw.data.shape[0] + 1, N))
vlen = raw.to_score[:int(n - mod)]
to_epoch_data = np.concatenate((raw.data[:, :int(n - mod)], vlen.reshape(1, -1)), axis=0)
if self._subepoch_length is None:
subN = 1
else:
subN = self._epoch_length // self._subepoch_length # number of subepochs in each epoch
pbar = tqdm(total=N, position=0, leave=True, desc='Epoching...', dynamic_ncols=True)
for i in range(N):
epoch = to_epoch_data[:, i * n_epoch:(i + 1) * n_epoch]
epoch = np.array(np.split(epoch, subN, axis=1)) # dividing epoch into subepochs
epoched_data[:, i] = np.mean(self._metric(epoch, axis=2).T, axis=1) # calculating metric on subepochs and recording epoch's value
epoched_data[ts_index, i] = timestamp[i * n_epoch] # timestamp of epoch is timestamp of it's begining
pbar.update(1)
pbar.close()
self._epoching_method_metadata['to_score channel'] = epoched_data.shape[0] - 1
channel_names = raw.channel_names.copy()
channel_names.append('epoched vlen')
return EpochedCK(epoched_data, raw.metadata, 1 / self._epoch_length, channel_names, self._epoching_method_metadata.copy())
[docs]
@dataclass(slots=True, eq=False)
class MVM(_Epocher):
"""A class for epoching data in the style of GeneActiv software. See GeneActiv documentation for details.
Attributes:
epoch_length (int): Length of output epochs in seconds.
"""
def __post_init__(self):
self._epoching_method_metadata = {'epoching method': 'MVM',
'epoch length': self._epoch_length,
'to_score channel': 7}
def _to_epoch(self, raw: _Raw | GenericData) -> GeneActivMVM | GenericMVM:
"""Epoch given data using MVM algorithm as defined by GeneActiv documentation.
Args:
raw (_Raw | GenericData): Raw actigraphic data to be epoched.
Returns:
GeneActivMVM | GenericData: Object containing epoched data. If input was RawGeneActiv, GeneActivMVM is returned; GenericData otherwise.
"""
if not isinstance(raw, _Raw):
if raw.epoching_method_metadata is not None and raw.scoring_method_metadata is not None:
raise ValueError("Provided data is not raw data.")
n_epochs = int(raw.data.shape[1] / raw.fs / self._epoch_length)
epoched_data = np.full((12, n_epochs), None, dtype=float)
indexes = {0: raw.channel_names.index('x'),
1: raw.channel_names.index('y'),
2: raw.channel_names.index('z')}
# formats different than GeneActiv might not contain these fields
try:
indexes[3] = raw.channel_names.index('lux')
except ValueError:
warnings.warn("Provided raw does not contain lux data.")
try:
button_idx = raw.channel_names.index('button')
except ValueError:
warnings.warn("Provided raw does not contain button data.")
button_idx = None
try:
indexes[5] = raw.channel_names.index('temperature')
except ValueError:
warnings.warn("Provided raw does not contain temperature data.")
print('Epoching...')
n_points = int(n_epochs * raw.fs * self._epoch_length)
trimmed_data = raw.data[:, :n_points]
split_data = np.array(np.split(trimmed_data, n_epochs, axis=1))
for idx in indexes.keys():
epoched_data[idx] = np.mean(split_data[:, indexes[idx], :], axis=1)
epoched_data[6] = np.linspace(trimmed_data[6, int(self._epoch_length * raw.fs) - 1], trimmed_data[6, -1], n_epochs) # timestamp of epoch is timestamp of its end (FU GeneActiv)
epoched_data[8] = np.std(split_data[:, indexes[0], :], axis=1)
epoched_data[9] = np.std(split_data[:, indexes[1], :], axis=1)
epoched_data[10] = np.std(split_data[:, indexes[2], :], axis=1)
def mvm(x):
return np.abs(np.sqrt(x[:, 0, :] ** 2 + x[:, 1, :] ** 2 + x[:, 2, :] ** 2) - 1) # x is numpy matrix conatining x, y, z axes split into n_epochs chunks
epoched_data[7] = np.sum(mvm(split_data[:, [indexes[0], indexes[1], indexes[2]], :]), axis=1)
if button_idx is not None:
epoched_data[4] = np.sum(split_data[:, button_idx, :], axis=1)
if 'lux' in raw.channel_names:
epoched_data[11] = np.max(split_data[:, raw.channel_names.index('lux'), :], axis=1)
channel_names = ['x', 'y', 'z', 'lux', 'button', 'temperature', 'timestamp', 'mvm', 'x_std', 'y_std', 'z_std', 'peak_lux']
if isinstance(raw, RawGeneActiv):
return GeneActivMVM(epoched_data, raw.metadata, 1 / self._epoch_length, channel_names, self._epoching_method_metadata.copy())
else:
warnings.warn("Epoching raw data not generated by the GeneActiv. Data will be returned in generic format.")
return GenericMVM(epoched_data, raw.metadata, 1 / self._epoch_length, channel_names, self._epoching_method_metadata.copy())
[docs]
@dataclass(slots=True, eq=False)
class Downsampler(_Epocher):
"""A class for epoching (downsampling) the data using standard downsampling procedure. The only one which guarantees no aliasing artifacts.
Attributes:
epoch_length (int): Length of output epochs in seconds.
"""
def __post_init__(self):
self._epoching_method_metadata = {'epoching method': 'Downsampler',
'epoch length': self._epoch_length,
'to_score channel': 0} # This is temporary, value depends on raw shape
def _to_epoch(self, raw: _Raw | GenericData) -> Resampled:
"""Epoch given data using standard downsampling procedure.
Args:
raw (_Raw | GenericData): Raw actigraphic data to be epoched.
Returns:
Resampled: Object containing epoched data.
"""
if not isinstance(raw, _Raw):
if raw.epoching_method_metadata is not None and raw.scoring_method_metadata is not None:
raise ValueError("Provided data is not raw data.")
if int(raw.fs * self._epoch_length) == raw.fs * self._epoch_length:
down = int(raw.fs * self._epoch_length)
up = 1
else:
raise RuntimeError("WTF is that sampling frequency and epoch length combination? Find another.")
print("Epoching...")
resampled_data = resample_poly(raw.data, up, down, axis=1)
resampled_data[raw.channel_names.index('timestamp')] = np.linspace(raw.timestamp[0], raw.timestamp[0] + (resampled_data.shape[1] - 1) * self._epoch_length, resampled_data.shape[1])
# resampled_vlen = resample_poly(raw.to_score, up, down) # this seems to emphasize floating / position changes?
resampled_vlen = np.abs(np.sqrt(resampled_data[raw.channel_names.index('x')] ** 2
+ resampled_data[raw.channel_names.index('y')] ** 2
+ resampled_data[raw.channel_names.index('z')] ** 2) - 1)
resampled_data = np.concatenate((resampled_data, resampled_vlen.reshape(1, -1)), axis=0)
del resampled_vlen
self._epoching_method_metadata['to_score channel'] = resampled_data.shape[0] - 1
channel_names = raw.channel_names.copy()
channel_names.append('resampled vlen')
return Resampled(resampled_data, raw.metadata, 1 / self._epoch_length, channel_names, self._epoching_method_metadata.copy())
[docs]
@dataclass(slots=True, eq=False)
class MIMS(_Epocher):
"""A class for epoching (downsampling) the data using (slightly modified) MIMS algorithm by Dinesh et al. (2019).
Attributes:
epoch_length (int): Length of output epochs in seconds.
stationary_variance (float | None): Variance of data recorded from stationary actigraph. Autodetected from device if None.
dynamic_range (float | None): Maximum acceleration recorded by sensor. Autodetected from device if None.
extrapolation_samples (int): Number of samples taken into account in the extrapolation step (see article).
extrapolation_shape (float): Shape parameter of splines used during extrapolation step.
"""
_stationary_variance: float | None = None
_dynamic_range: float | None = None
_extrapolation_samples: int = 5
_extrapolation_shape: float = .6
def __post_init__(self):
self._epoching_method_metadata = {'epoching method': 'MIMS',
'epoch length': self._epoch_length,
'stationary variance': self._stationary_variance,
'dynamic range': self._dynamic_range,
'extrapolation samples': self._extrapolation_samples,
'extrapolation shape': self._extrapolation_shape,
'to_score channel': 4}
@property
def stationary_variance(self) -> float | None:
"""Variance of data recorded from stationary actigraph."""
return self._stationary_variance
def _to_epoch(self, raw: _Raw | GenericData) -> _Epoched:
"""Epoch given data using MIMS algorithm.
Args:
raw (_Raw | GenericData): Raw actigraphic data to be epoched.
Returns:
EpochedMIMS: Object containing epoched data.
"""
stationary_variance = raw.stationary_variance if self._stationary_variance is None else self._stationary_variance
dynamic_range = raw.dynamic_range if self._dynamic_range is None else self._dynamic_range
# preparations
test_k = np.linspace(0.01, 0.5, 1000) # TODO this range might need extension, depends on available devices and their stationary variance
k = test_k[np.argmin(np.abs(gamma.cdf(3 * stationary_variance ** .5, test_k, scale=1) - 0.95))]
gamma_cdf = np.vectorize(lambda x: gamma.cdf(x, a=k, scale=1) if x >= 0. else 0.)
sos = iirfilter(4, [.2, 5], btype='bandpass', ftype='butter', fs=100, output='sos')
# interpolate
down, up = (raw.fs / 100).as_integer_ratio()
data = np.empty((3, ceil(raw.data.shape[1] * up / down)))
time = np.linspace(raw.timestamp[0], raw.timestamp[-1], data.shape[1])
data[0] = resample_poly(raw.x, up, down)
data[1] = resample_poly(raw.y, up, down)
data[2] = resample_poly(raw.z, up, down)
for ch in tqdm(range(3), desc="Calculating MIMS epochs"):
# extrapolate
temp = np.abs(data[ch]) - dynamic_range + 3 * stationary_variance ** .5
maxed_probability = gamma_cdf(temp)
shifted = np.zeros_like(maxed_probability)
shifted[1:] = maxed_probability[:-1]
diff = maxed_probability - shifted
idx = zip(np.nonzero(diff > 0.5)[0], np.nonzero(diff < -0.5)[0])
for e in idx:
samples_before = data[ch, e[0] - self._extrapolation_samples:e[0]]
samples_after = data[ch, e[1]:e[1] + self._extrapolation_samples]
time_before = time[e[0] - self._extrapolation_samples:e[0]]
time_after = time[e[1]:e[1] + self._extrapolation_samples]
weights_before = 1 - maxed_probability[e[0] - self._extrapolation_samples:e[0]]
weights_after = 1 - maxed_probability[e[1]:e[1] + self._extrapolation_samples]
time_middle = (time_after[0] + time_before[-1]) / 2
spline_before = splev(time_middle, splrep(time_before, samples_before, weights_before, s=self._extrapolation_shape)) # TODO in newer scipy it is make_splrep
spline_after = splev(time_middle, splrep(time_after, samples_after, weights_after, s=self._extrapolation_shape))
sample_middle = (spline_after + spline_before) / 2
spline_extrapolated = splrep(np.concatenate((time_before, [time_middle], time_after)), np.concatenate((samples_before, [sample_middle], samples_after)), s=self._extrapolation_shape)
data[ch, e[0]:e[1]] = splev(time[e[0]:e[1]], spline_extrapolated)
# filter
data[ch] = sosfiltfilt(sos, data[ch])
# aggregate
epoch_len = self._epoch_length * 100
n = data.shape[1] // epoch_len
data = trapezoid(np.array(np.split(np.abs(data[:, :n * epoch_len]), n, axis=1)), axis=2, dx=.01).T
data[data < .001] = 0
output = np.empty((5, data.shape[1]))
output[:3] = data
output[4] = np.sum(data, axis=0)
# output[3] = np.arange(raw.timestamp[0], raw.timestamp[0] + n * self._epoch_length, self._epoch_length)
output[3] = np.linspace(raw.timestamp[0], raw.timestamp[0] + (n - 1) * self._epoch_length, output.shape[1])
channel_names = ['x', 'y', 'z', 'timestamp', 'summary MIMS']
metadata = self._epoching_method_metadata.copy()
metadata['stationary variance'] = stationary_variance
metadata['dynamic range'] = dynamic_range
return EpochedMIMS(output, raw.metadata, 1 / self._epoch_length, channel_names, metadata)
[docs]
def main():
"""A script for epoching a batch of files. Run with -h or --help for details.
"""
import argparse
import os
from pathlib import Path
from ada.io.file_manager import FileManager
methods = {'ai': ActivityIndex, 'mvm': MVM, 'downsample': Downsampler, 'ck': CKEpocher, 'mims': MIMS}
parser = argparse.ArgumentParser(description="Epoch given files and export epoched.", add_help=False)
parser.add_argument('--files', '-f', nargs='+', type=str, required=True, help="Path to raw files.")
parser.add_argument('--outdir', '-o', type=str, default=None, help="Directory in which epoched files will be saved.")
parser.add_argument('--epocher', '-e', choices=methods.keys(), default='resample', help="Epoching algorithm.")
parser.add_argument('--epoch-length', '-l', type=int, default=30, help="Epoch length in seconds.")
namespace = parser.parse_known_args()[0]
if namespace.epocher == 'ai':
parser.add_argument('--stationary-variance', type=float, default=.0004563, help="Variance of stationary actigraph.")
parser.add_argument('--normalized', type=bool, default=True, help="Algorithm with normalization.")
elif namespace.epocher == 'ck':
parser.add_argument('--subepoch-length', type=int, default=None, help="Subepoch length in seconds. If not given, None.")
parser.add_argument('-h', '--help', action='help', default='==SUPPRESS==')
namespace = parser.parse_args()
if len(namespace.files) == 1:
files_to_work = [str(e) for e in Path(os.path.dirname(namespace.files[0])).glob(os.path.basename(namespace.files[0]))]
else:
files_to_work = namespace.files
# Epocher-specific settings
if namespace.epocher == 'ai':
epocher = methods[namespace.epocher](namespace.epoch_length, namespace.normalized, namespace.stationary_variance)
elif namespace.epocher == 'ck':
epocher = methods[namespace.epocher](namespace.epoch_length, namespace.subepoch_length)
else:
epocher = methods[namespace.epocher](namespace.epoch_length)
for file in files_to_work:
filename = os.path.basename(file).split('.')[0]
if namespace.outdir is None:
path = os.path.dirname(file)
folder = "{}_{}s".format(namespace.epocher, namespace.epoch_length)
outdir = os.path.join(path, folder)
else:
outdir = namespace.outdir
os.makedirs(outdir, exist_ok=True)
outfile = os.path.join(outdir, filename + '.ada')
if os.path.exists(outfile):
print('File already exists, skipping', outfile)
continue
acti = FileManager.load_file(file)
epoched = epocher.to_epoch(acti)
epoched = GenericData.from_nongeneric(epoched) # just to make sure mvm will not save as .csv
epoched.export(outfile)
del epoched
del acti
if __name__ == '__main__':
main()