from dataclasses import dataclass
# import scipy.signal as ss
import numpy as np
# import warnings
from zipfile import ZipFile, ZIP_DEFLATED
import os
from tempfile import TemporaryDirectory
from typing import Callable
from ada.data_containers._base import _ActiEEG, _Raw, _Epoched
from ada.data_containers.generic import GenericData
from ada.data_containers.scored import PSGScore
# try:
# from readmanager.signal_processing.read_manager import ReadManager
# from readmanager.signal_processing.signal.data_generic_write_proxy import SamplePacket
# from readmanager.signal_processing.signal.data_raw_write_proxy import DataRawWriteProxy
# from readmanager.signal_processing.signal.info_file_proxy import InfoFileWriteProxy
# from readmanager.signal_processing.signal.read_info_source import MemoryInfoSource
# from mne.io import RawArray as MNERaw
# from mne.channels import make_standard_montage
# from mne import create_info
# except ImportError:
# print("Install full version with ...")
# @dataclass(slots=True, eq=False)
# class ObciActiRaw(_ActiEEG):
# """A class for storing and handling actigraphic data with corresponding PSG data. Intended mainly for visualisation of the data in SVAROG.
# Attributes:
# acti_data (_Raw | GenericData): Actigraphic data. Must be raw.
# eeg_data (RawArray): MNE RawArray with EEG data.
# same_sampling (bool): If True, actigraphic data will be resampled to the EEG sampling frequency after initialization. Defaults to True.
# """
# _acti_data: _Raw | GenericData
# _eeg_data: MNERaw
# _same_sampling: bool = True
# def __post_init__(self):
# # check if data is raw
# if not isinstance(self._acti_data, _Raw):
# if self._acti_data.epoching_method_metadata is not None or self._acti_data.scoring_method_metadata is not None:
# raise ValueError("Provided actigraphic data is not Raw.")
# # synchronize recordings
# device = GenericData._get_device(self._acti_data)
# acti_start = device._convert_timestamp(self._acti_data.timestamp[0], self._acti_data.metadata, True)
# acti_end = device._convert_timestamp(self._acti_data.timestamp[-1], self._acti_data.metadata, True)
# eeg_start = self._eeg_data.info['meas_date'].timestamp()
# eeg_end = len(self._eeg_data) / self._eeg_data.info['sfreq'] + eeg_start
# common_start = max([acti_start, eeg_start])
# common_end = min([acti_end, eeg_end])
# if common_end - eeg_start > self._eeg_data.times[-1]: # measurements might desynchronize by few ms
# common_end -= common_end - eeg_start - self._eeg_data.times[-1]
# self._acti_data = self._acti_data.cut_by_timestamp(common_start, common_end)
# self._eeg_data.crop(common_start - eeg_start, common_end - eeg_start)
# seconds, microseconds = str(common_start).split('.')
# self._eeg_data.set_meas_date((float(seconds), float(microseconds)))
# # upsample acti_data to eeg_data fs
# if self._same_sampling and self._acti_data.fs != self._eeg_data.info['sfreq']:
# self._common_resample(None)
# def _common_resample(self, fs: float | None):
# print("Resampling actigraphic and EEG data to the same sampling frequency...")
# if fs is None:
# fs = float(self._eeg_data.info['sfreq'])
# else:
# self._eeg_data = self._eeg_data.resample(fs)
# ts = np.linspace(self._acti_data.timestamp[0], self._acti_data.timestamp[-1], self._eeg_data.n_times)
# acti_upsampled = ss.resample(self._acti_data.data, len(self._eeg_data), axis=1)
# acti_upsampled[self._acti_data.channel_names.index('timestamp'), :] = ts
# self._acti_data._data = acti_upsampled
# self._acti_data._fs = fs
# del acti_upsampled, ts
# def export(self, path: str, include_vlen: bool = False):
# """Exports both EEG and actigraphic data to one obci.raw file. Actigraphic metadata are lost. If actigraphic data has different sampling frequency than EEG, it will be resampled.
# Args:
# path (str): Path to the .raw file (SVAROG-readable).
# include_vlen (bool, optional): If True, vlen will be saved as a separate channel. Usefull for data visualisation in SVAROG. Defaults to False.
# """
# if self._acti_data.data.shape[1] != len(self._eeg_data):
# warnings.warn("Different sampling frequencies for actigraph and eeg.")
# self._common_resample(None)
# rm = ReadManager.from_mne(self._eeg_data)
# info = MemoryInfoSource(rm.get_params())
# outfile_raw = path
# outfile_info = path[:-3] + 'xml'
# samples = rm.get_microvolt_samples().T
# for i, name in enumerate(rm.get_param('channels_names')):
# if name.startswith('Resp'):
# samples[:, i] *= 1e-6
# info_proxy = InfoFileWriteProxy(outfile_info)
# if include_vlen:
# ch_names = self._acti_data.channel_names + ['vlen']
# acti_data = np.concatenate((self._acti_data.data, self._acti_data.to_score.reshape(1, -1)), axis=0)
# else:
# ch_names = self._acti_data.channel_names
# acti_data = self._acti_data.data
# acti_data[self._acti_data.channel_names.index("timestamp"), :] = np.arange(0, self._acti_data.data.shape[1] / self._acti_data.fs, 1 / self._acti_data._fs)
# acti_data = acti_data.T
# ch_names = [e + ' (ACTI)' for e in ch_names]
# info.set_param('channels_names', rm.get_param('channels_names') + ch_names)
# gains = list(np.ones(len(ch_names)) * 1e3) # This 1e3 is just for visibility in SVAROG. Might remove later.
# offsets = list(np.zeros(len(ch_names), dtype=np.intc))
# gains = [str(i) for i in gains]
# offsets = [str(i) for i in offsets]
# info.set_param('channels_gains', rm.get_param('channels_gains') + gains)
# info.set_param('channels_offsets', rm.get_param('channels_offsets') + offsets)
# info.set_param('number_of_samples', str(acti_data.shape[0]))
# info.set_param('number_of_channels', str(acti_data.shape[1] + len(self._eeg_data.ch_names) - 1)) # This -1 is here or it doesn't work :)))
# info.set_param('first_sample_timestamp', self._eeg_data.info['meas_date'].timestamp())
# params = info.get_params()
# info_proxy.finish_saving(params)
# data_proxy = DataRawWriteProxy(outfile_raw, False, rm.get_param('sample_type'))
# merged_data = np.concatenate((samples, acti_data), axis=1)
# packet = SamplePacket(merged_data, np.linspace(0, 1, int(rm.get_param('number_of_samples'))))
# data_proxy.data_received(packet)
# data_proxy.finish_saving()
# del rm, merged_data, acti_data, samples
# @staticmethod
# def load_file(path: str) -> "ObciActiRaw":
# """Loads obci.raw file with merged EEG and actigraphic data.
# Args:
# path (str): Path to the .raw file.
# Returns:
# ObciActiRaw: Container with both EEG and actigraphic data.
# """
# def chtype_heuristic(chname):
# montage = make_standard_montage('standard_1005')
# norm_names = [i.lower() for i in montage.ch_names]
# if chname.lower() in norm_names:
# return 'eeg'
# if 'emg' in chname.lower():
# return 'emg'
# if 'eog' in chname.lower():
# return 'eog'
# if 'ekg' in chname.lower() or 'ecg' in chname.lower():
# return 'ecg'
# return 'misc'
# rm = ReadManager(path[:-3] + 'xml', path, None)
# data = rm.get_microvolt_samples()
# first_timestamp = rm.get_param('first_sample_timestamp')
# seconds, microseconds = str(first_timestamp).split('.')
# sampling_freq = float(rm.get_param('sampling_frequency'))
# chnames = rm.get_param('channels_names')
# eeg_channels = [(i, e) for i, e in enumerate(chnames) if "(ACTI)" not in e]
# acti_channels = [(i, e) for i, e in enumerate(chnames) if "(ACTI)" in e and 'vlen' not in e]
# channel_types = [chtype_heuristic(e[1]) for e in eeg_channels]
# info = create_info(ch_names=[e[1] for e in eeg_channels], sfreq=sampling_freq, ch_types=channel_types)
# eeg_data = data[[e[0] for e in eeg_channels], :]
# eeg_data = eeg_data * 1e-6
# eeg_data = MNERaw(eeg_data, info)
# montage = make_standard_montage('standard_1005')
# eeg_data.set_montage(montage)
# eeg_data.set_meas_date((float(seconds), float(microseconds)))
# acti_data = data[[e[0] for e in acti_channels], :]
# acti_data = acti_data * 1e-3 # Cause gain is set as 1e3 to data be visible in svarog
# acti_chnames = [e[1].replace("(ACTI)", '').strip() for e in acti_channels]
# acti_metadata = {"Source": "obci.raw merged with EEG",
# "Start timestamp": float(first_timestamp)} # TODO something else should go here?
# acti_data = GenericData(acti_data, acti_metadata, sampling_freq, acti_chnames)
# del data
# return ObciActiRaw(acti_data, eeg_data, True)
# def cut_by_timestamp(self, start_ts: float, end_ts: float | None) -> "ObciActiRaw": # TODO this does not work, and I don't really care rn
# """Create new object with the data cut by given timestamps.
# Args:
# start_ts (float): Unix timestamp of output data beginning.
# end_ts (float | None): Unix timestamp of output data end. If None, last sample of output data will be last sample of input data.
# Returns:
# ObciActiRaw: Object containing the cutted data.
# """
# new_acti = self._acti_data.cut_by_timestamp(start_ts, end_ts)
# new_eeg = self._eeg_data.copy()
# new_eeg.crop(start_ts, end_ts)
# return ObciActiRaw(new_acti, new_eeg, self._same_sampling)
# @property
# def acti_data(self) -> _Raw | GenericData:
# return self._acti_data
# @property
# def eeg_data(self) -> MNERaw:
# return self._eeg_data
[docs]
@dataclass(slots=True, eq=False)
class ActiPSG(_ActiEEG):
"""A class for storing and handling actigraphic data with corresponding PSG staging.
Attributes:
acti_data (_Raw | _Epoched | GenericData): Actigraphic data.
eeg_data (PSGScore): Container with PSG stages converted to sleep/wake scoring.
trim_to_eeg (bool): If True, actigraphic data will be trimmed to start and end in the same time as PSG scoring. If False, time before and after PSG scoring is assumed to be wake. Defaults to False.
"""
_acti_data: _Raw | _Epoched | GenericData
_eeg_data: PSGScore
_trim_to_eeg: bool = False
def __post_init__(self):
try:
self._eeg_data = self._eeg_data.change_epoch(1 / self._acti_data.fs, 0.8)
except AssertionError:
self._eeg_data = self._eeg_data.resample(self.acti_data.fs)
if self._trim_to_eeg:
self._acti_data = self._acti_data.cut_by_timestamp(self._eeg_data._start_timestamp, self._eeg_data._end_timestamp + 1 / self.acti_data.fs)
# if tags are longer than actigraphy, remove redundant tags from the end
# TODO gotta write test for this
len_dif = len(self._acti_data.timestamp) - len(self._eeg_data.psg_stages)
if len_dif < 0:
self._eeg_data._psg_stages = self._eeg_data._psg_stages[:len_dif]
self._eeg_data._end_timestamp = self._eeg_data._end_timestamp + len_dif * self._eeg_data._epoch_length
else:
device = GenericData._get_device(self._acti_data)
acti_start = device._convert_timestamp(self._acti_data.timestamp[0], self._acti_data.metadata, True)
acti_end = acti_start + (len(self._acti_data) - 1) / self._acti_data.fs
if (acti_start > self._eeg_data.start_timestamp) or (acti_end < self._eeg_data.end_timestamp):
self._eeg_data = self._cut(acti_start, acti_end, len(self._acti_data))
samples_before_psg = max(int((self._eeg_data._start_timestamp - acti_start) * self._acti_data.fs), 0)
self._eeg_data._psg_stages = np.pad(self._eeg_data._psg_stages, (samples_before_psg, 0), 'constant', constant_values=1)
samples_after_psg = max(len(self._acti_data) - len(self._eeg_data._psg_stages), 0)
self._eeg_data._psg_stages = np.pad(self._eeg_data._psg_stages, (0, samples_after_psg), 'constant', constant_values=1)
self._eeg_data._start_timestamp = self._eeg_data._start_timestamp - samples_before_psg * self._eeg_data._epoch_length
self._eeg_data._end_timestamp = self._eeg_data._end_timestamp + samples_after_psg * self._eeg_data._epoch_length
[docs]
def cut_by_timestamp(self, start_ts: float, end_ts: float | None) -> "ActiPSG":
"""Create new object with the data cut by given timestamps.
Args:
start_ts (float): Unix timestamp of output data beginning.
end_ts (float): Unix timestamp of output data end.
Returns:
ActiPSG: Object containing the cutted data.
"""
acti_cut = self._acti_data.cut_by_timestamp(start_ts, end_ts)
if end_ts is None:
end_ts = self._acti_data.last_sample_timestamp
eeg_cut = self._cut(start_ts, end_ts, len(acti_cut))
return ActiPSG(acti_cut, eeg_cut, True)
def _cut(self, start_ts: float, end_ts: float, acti_len: int) -> PSGScore:
if start_ts > self._eeg_data.start_timestamp:
start_idx = int((start_ts - self._eeg_data.start_timestamp) / self._eeg_data.epoch_length)
# Below condition is because while cutting normaly actigraphs starts after the desired timestamp
# and while cutting in __post_init__ we want to start at the desired timestamp.
if not acti_len == len(self._acti_data):
start_idx += 1
psg_start_ts = self.eeg_data.start_timestamp + start_idx / self.acti_data.fs
else:
start_idx = 0
psg_start_ts = self._eeg_data.start_timestamp
if self._eeg_data.end_timestamp > end_ts:
end_idx = acti_len + start_idx
psg_end_ts = self.eeg_data.end_timestamp + end_idx / self.acti_data.fs
else:
end_idx = None
psg_end_ts = self._eeg_data.end_timestamp
out = PSGScore(self._eeg_data.psg_stages[start_idx:end_idx], psg_start_ts, psg_end_ts, self._eeg_data.epoch_length, False)
out._collapse_stages = self._eeg_data._collapse_stages
return out
[docs]
@staticmethod
def from_continous_stages(acti_data: _Raw | _Epoched | GenericData, stages: list[dict], stages_first_ts: float, collapse_stages: bool | Callable = True, trim_to_eeg: bool = False) -> "ActiPSG":
"""Create data container from continous PSG tags of constant length (YASA-like).
Args:
acti_data (_Raw | _Epoched | GenericData): Actigraphic data corresponding to the provided staging.
stages (list[dict]): List of tags. Each tag is a dictionary with at least following fields: name, start_timestamp, end_timestamp. Timestamps should be relative to the EEG start.
stages_first_ts (float): Unix timestamp of the first epoch (tag) beginning.
collapse_stages (bool | Callable, optional): If True, PSG stages will be converted to binary sleep/wake scorings using built-in heuristic. If callable, it is a custom function used to convert PSG stages to binary sleep/wake scorings. Defaults to True.
trim_to_eeg (bool, optional): If True, actigraphic data will be trimmed to start and end in the same time as PSG scoring. If False, time before and after PSG scoring is assumed to be wake. Defaults to False.
Returns:
ActiPSG: Object containing actigraphic data and correspondning PSG stages.
"""
epoch_length = int(stages[0]['end_timestamp'] - stages[0]['start_timestamp'])
stages_last_ts = stages_first_ts + epoch_length * (len(stages) - 1)
stages = [e['name'] for e in stages]
eeg_data = PSGScore(stages, stages_first_ts, stages_last_ts, epoch_length, collapse_stages)
return ActiPSG(acti_data, eeg_data, trim_to_eeg)
[docs]
def export(self, path: str):
"""Save data to a file. Output is an ordinary zip archive containing actigraphic data and psg score in the data-specific format.
Args:
path (str): Path to the output file.
"""
tempd = TemporaryDirectory()
temp_dir = tempd.name
trimmed = '_trimmed' if self._trim_to_eeg else ''
acti_temp = os.path.join(temp_dir, f'actigraphy{trimmed}.ada')
eeg_temp = os.path.join(temp_dir, f'psg{trimmed}.ada')
acti_data = GenericData.from_nongeneric(self.acti_data) # ensure it's not in device format, it'll mess up reading
acti_data.export(acti_temp)
self.eeg_data.export(eeg_temp)
zipf = ZipFile(path, mode='w', compression=ZIP_DEFLATED)
zipf.write(acti_temp, arcname=os.path.basename(acti_temp))
zipf.write(eeg_temp, arcname=os.path.basename(eeg_temp))
zipf.close()
tempd.cleanup()
[docs]
@staticmethod
def load_file(path: str) -> "ActiPSG":
"""Load zip file created by the export method.
Args:
path (str): Path to the file.
Returns:
ActiPSG: Data loaded from file.
"""
zipf = ZipFile(path, mode='r')
tempd = TemporaryDirectory()
temp_dir = tempd.name
files = zipf.namelist()
files.sort()
zipf.extract(files[0], temp_dir)
zipf.extract(files[1], temp_dir)
zipf.close()
acti_data = GenericData.load_file(os.path.join(temp_dir, files[0]))
eeg_data = PSGScore.load_file(os.path.join(temp_dir, files[1]))
tempd.cleanup()
out = ActiPSG(acti_data, eeg_data, False)
if 'trimmed' in files[0].split('_')[-1]:
out._trim_to_eeg = True
return out
@property
def acti_data(self) -> _Raw | _Epoched | GenericData:
"""Actigraphic data held in the container."""
return self._acti_data
@property
def eeg_data(self) -> PSGScore:
"""Sleep/wake scoring held in the container."""
return self._eeg_data
@property
def timestamp(self) -> np.ndarray:
"""Vector with relative timestamp for each sample."""
return self._acti_data.timestamp