Source code for bci_framework.extensions.timelock_analysis.timelock_analysis

import os
import sys
import logging
import copy
import math
from abc import ABCMeta, abstractmethod

import mne
import numpy as np
# from scipy.fftpack import rfft, rfftfreq
from scipy.signal import welch, decimate
from scipy.signal import decimate, welch

from cycler import cycler
import matplotlib
from matplotlib import pyplot
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg
from matplotlib.figure import Figure

from PySide6.QtCore import Qt
from PySide6 import QtWidgets
from PySide6.QtUiTools import QUiLoader
from PySide6.QtWidgets import QSpacerItem, QSizePolicy

from gcpds.filters import frequency as flt
from gcpds.filters import frequency as flt
from bci_framework.framework.dialogs import Dialogs

# from bci_framework.extensions.data_analysis.utils import thread_this, subprocess_this

from PySide6.QtGui import QCursor
from PySide6.QtWidgets import QApplication
from PySide6.QtCore import Qt


# Set logger
logger = logging.getLogger("mne")
logger.setLevel(logging.CRITICAL)
logging.getLogger('matplotlib.font_manager').disabled = True
logging.getLogger().setLevel(logging.WARNING)
logging.root.name = "TimelockAnalysis"

if ('light' in sys.argv) or ('light' in os.environ.get('QTMATERIAL_THEME', '')):
    pass
else:
    pyplot.style.use('dark_background')

try:
    q = matplotlib.cm.get_cmap('cool')
    matplotlib.rcParams['axes.prop_cycle'] = cycler(
        color=[q(m) for m in np.linspace(0, 1, 16)])
    matplotlib.rcParams['figure.dpi'] = 70
    matplotlib.rcParams['font.family'] = 'monospace'
    matplotlib.rcParams['font.size'] = 15
    matplotlib.rcParams['axes.titlecolor'] = '#000000'
    matplotlib.rcParams['xtick.color'] = '#000000'
    matplotlib.rcParams['ytick.color'] = '#000000'
    # matplotlib.rcParams['legend.facecolor'] = 'red'
except:
    # 'rcParams' object does not support item assignment
    pass

LEGEND_KWARGS = {'labelcolor': '#000000',
                 'fontsize': 12,
                 }


# ----------------------------------------------------------------------
def wait_for_it(fn):
    """"""
    # ----------------------------------------------------------------------
    def wrap(*args, **kwargs):
        QApplication.setOverrideCursor(QCursor(Qt.WaitCursor))
        try:
            fn(*args, **kwargs)
        except Exception as e:
            logging.warning(e)
        QApplication.restoreOverrideCursor()

    return wrap


########################################################################
[docs]class Canvas(FigureCanvasQTAgg): # ---------------------------------------------------------------------- def __init__(self, *args, **kwargs): """""" self.figure = Figure(*args, **kwargs) self.configure() super().__init__(self.figure) # self.figure.subplots_adjust(left=0.1, right=0.9, top=0.9, bottom=0.1) # ---------------------------------------------------------------------- def configure(self): """""" # if ('light' in sys.argv) or ('light' in os.environ.get('QTMATERIAL_THEME', '')): # pass # else: # pyplot.style.use('dark_background') for ax in self.figure.axes: ax.tick_params(axis='x', labelsize=12) ax.tick_params(axis='y', labelsize=12) ax.xaxis.label.set_size(14) ax.yaxis.label.set_size(14)
######################################################################## class TimelockWidget(metaclass=ABCMeta): """""" # ---------------------------------------------------------------------- def __init__(self, height, *args, **kwargs): """Constructor""" self.title = '' self.bottom_stretch = [] self.bottom2_stretch = [] self.top_stretch = [] self.top2_stretch = [] self.right_stretch = [] self.left_stretch = [] self._pipeline_output = None ui = os.path.realpath(os.path.join( os.environ['BCISTREAM_ROOT'], 'framework', 'qtgui', 'locktime_widget.ui')) self.widget = QUiLoader().load(ui) if height: self.widget.setMinimumHeight(height) self.canvas = Canvas(*args, **kwargs) self.figure = self.canvas.figure self.widget.gridLayout.addWidget(self.canvas) # ---------------------------------------------------------------------- def draw(self): """""" self.canvas.configure() self.canvas.draw() # ---------------------------------------------------------------------- def _add_spacers(self): """""" for i, s in enumerate(self.bottom_stretch): self.widget.bottomLayout.setStretch(i, s) for i, s in enumerate(self.top_stretch): self.widget.topLayout.setStretch(i, s) for i, s in enumerate(self.bottom2_stretch): self.widget.bottom2Layout.setStretch(i, s) for i, s in enumerate(self.top2_stretch): self.widget.top2Layout.setStretch(i, s) for i, s in enumerate(self.right_stretch): self.widget.rightLayout.setStretch(i, s) for i, s in enumerate(self.left_stretch): self.widget.leftLayout.setStretch(i, s) # ---------------------------------------------------------------------- def add_spacer(self, area='top', fixed=None, stretch=0): """""" if fixed: if area in ['left', 'right']: getattr(self.widget, f'{area}Layout').addItem(QSpacerItem( 20, fixed, QSizePolicy.Minimum, QSizePolicy.Minimum)) elif area in ['top', 'bottom', 'top2', 'bottom2']: getattr(self.widget, f'{area}Layout').addItem(QSpacerItem( fixed, 20, QSizePolicy.Minimum, QSizePolicy.Minimum)) else: if area in ['left', 'right']: getattr(self.widget, f'{area}Layout').addItem(QSpacerItem( 20, 20000, QSizePolicy.Minimum, QSizePolicy.Expanding)) elif area in ['top', 'bottom', 'top2', 'bottom2']: getattr(self.widget, f'{area}Layout').addItem(QSpacerItem( 20000, 20, QSizePolicy.Expanding, QSizePolicy.Minimum)) if stretch: getattr(self, f'{area}_stretch').append(stretch) # ---------------------------------------------------------------------- def clear_layout(self, layout): """""" i = -1 for _ in range(layout.count()): i = i + 1 b = layout.itemAt(i) if b is None: continue if w := b.widget(): # widget w.deleteLater() if b.spacerItem(): # spacer layout.removeItem(b) i = i - 1 if l := b.layout(): self.clear_layout(l) # layout.removeItem(layout.itemAt(i)) # b = layout.takeAt(2) # buttons.pop(2) # b.widget().deleteLater() # ---------------------------------------------------------------------- def clear_widgets(self, areas=['left', 'right', 'top', 'bottom', 'top2', 'bottom2']): """""" for area in areas: layout = getattr(self.widget, f'{area}Layout') self.clear_layout(layout) # ---------------------------------------------------------------------- def add_textarea(self, content='', area='top', stretch=0): """""" textarea = QtWidgets.QTextEdit(content) textarea.setProperty('class', 'clear') textarea.setMinimumWidth(500) textarea.setReadOnly(True) # if callback: # button.clicked.connect(callback) getattr(self.widget, f'{area}Layout').addWidget(textarea) getattr(self, f'{area}_stretch').append(stretch) return textarea # ---------------------------------------------------------------------- def add_button(self, label, callback=None, area='top', stretch=0): """""" button = QtWidgets.QPushButton(label) if callback: button.clicked.connect(callback) getattr(self.widget, f'{area}Layout').addWidget(button) getattr(self, f'{area}_stretch').append(stretch) return button # ---------------------------------------------------------------------- def add_radios(self, group_name, radios, cols=None, rows=None, callback=None, area='top', stretch=1): """""" group = QtWidgets.QGroupBox(group_name) group.setProperty('class', 'fill_background') vbox = QtWidgets.QVBoxLayout() group.setLayout(vbox) if cols is None: cols = len(radios) if rows: cols = math.ceil(len(radios) / rows) for i, radio in enumerate(radios): if (i % cols) == 0: hbox = QtWidgets.QHBoxLayout() vbox.addLayout(hbox) # group.setLayout(hbox) r = QtWidgets.QRadioButton() r.setText(radio) r.setChecked(i == 0) def dec(*args): def wrap(fn): return callback(*args) return wrap if callback: r.clicked.connect(dec(group_name, radio)) hbox.addWidget(r) getattr(self.widget, f'{area}Layout').addWidget(group) getattr(self, f'{area}_stretch').append(stretch) # ---------------------------------------------------------------------- def add_checkbox(self, group_name, checkboxes, cols=None, rows=None, callback=None, area='top', stretch=1): """""" group = QtWidgets.QGroupBox(group_name) group.setProperty('class', 'fill_background') vbox = QtWidgets.QVBoxLayout() group.setLayout(vbox) if cols is None: cols = len(checkboxes) if rows: cols = math.ceil(len(checkboxes) / rows) list_radios = [] for i, checkbox in enumerate(checkboxes): if (i % cols) == 0: hbox = QtWidgets.QHBoxLayout() vbox.addLayout(hbox) # group.setLayout(hbox) r = QtWidgets.QCheckBox() r.setText(checkbox) r.setChecked(i == 0) list_radios.append(r) def dec(*args): def wrap(fn): return callback(*args) return wrap if callback: r.clicked.connect(dec(group_name, checkbox)) hbox.addWidget(r) getattr(self.widget, f'{area}Layout').addWidget(group) getattr(self, f'{area}_stretch').append(stretch) return list_radios # ---------------------------------------------------------------------- def add_channels(self, group_name, channels, callback=None, area='top', stretch=1): """""" group = QtWidgets.QGroupBox(group_name) group.setProperty('class', 'fill_background') vbox = QtWidgets.QHBoxLayout() group.setLayout(vbox) # ncol = len(radios) vbox_odd = QtWidgets.QVBoxLayout() vbox_z = QtWidgets.QVBoxLayout() vbox_even = QtWidgets.QVBoxLayout() vbox.addLayout(vbox_even) vbox.addLayout(vbox_z) vbox.addLayout(vbox_odd) list_radios = [] for channel in channels: r = QtWidgets.QCheckBox() r.setText(channel) r.setChecked(True) list_radios.append(r) if channel[-1].isnumeric() and int(channel[-1]) % 2 != 0: # odd vbox_even.addWidget(r) elif channel[-1].isnumeric() and int(channel[-1]) % 2 == 0: # even vbox_odd.addWidget(r) else: vbox_z.addWidget(r) def dec(*args): def wrap(fn): return callback(*args) return wrap if callback: r.clicked.connect(dec(group_name, channel)) getattr(self.widget, f'{area}Layout').addWidget(group) getattr(self, f'{area}_stretch').append(stretch) return list_radios # ---------------------------------------------------------------------- def add_scroll(self, callback=None, area='bottom', stretch=0): """""" scroll = QtWidgets.QScrollBar() scroll.setOrientation(Qt.Horizontal) # scroll.setMaximum(255) scroll.sliderMoved.connect(callback) scroll.setProperty('class', 'big') # scroll.setPageStep(1000) getattr(self.widget, f'{area}Layout').addWidget(scroll) getattr(self, f'{area}_stretch').append(stretch) return scroll # ---------------------------------------------------------------------- def add_slider(self, callback=None, area='bottom', stretch=0): """""" slider = QtWidgets.QSlider() slider.setOrientation(Qt.Horizontal) slider.setMaximum(0) slider.setMaximum(500) slider.setValue(500) slider.valueChanged.connect(callback) getattr(self.widget, f'{area}Layout').addWidget(slider) getattr(self, f'{area}_stretch').append(stretch) return slider # ---------------------------------------------------------------------- def add_spin(self, label, value, decimals=1, step=0.1, prefix='', suffix='', min_=0, max_=999, callback=None, area='top', stretch=0): """""" spin = QtWidgets.QDoubleSpinBox() spin.setDecimals(decimals) spin.setSingleStep(step) spin.setMinimum(min_) spin.setMaximum(max_) spin.setValue(value) if callback: spin.valueChanged.connect(callback) if prefix: spin.setPrefix(f' {prefix}') if suffix: spin.setSuffix(f' {suffix}') layout = QtWidgets.QHBoxLayout() widget = QtWidgets.QWidget() widget.setLayout(layout) if label: layout.addWidget(QtWidgets.QLabel(label)) layout.addWidget(spin) getattr(self.widget, f'{area}Layout').addWidget(widget) getattr(self, f'{area}_stretch').append(stretch) layout.setStretch(0, 0) layout.setStretch(1, 1) return spin # ---------------------------------------------------------------------- def add_combobox(self, label, items, editable=False, callback=None, area='top', stretch=0): """""" combo = QtWidgets.QComboBox() combo.addItems(items) combo.activated.connect(callback) combo.setEditable(editable) combo.setMinimumWidth(200) layout = QtWidgets.QHBoxLayout() widget = QtWidgets.QWidget() widget.setLayout(layout) if label: layout.addWidget(QtWidgets.QLabel(label)) layout.addWidget(combo) getattr(self.widget, f'{area}Layout').addWidget(widget) getattr(self, f'{area}_stretch').append(stretch) layout.setStretch(0, 0) layout.setStretch(1, 1) return combo # ---------------------------------------------------------------------- # @abstractmethod @property def pipeline_input(self): """""" if hasattr(self, '_previous_pipeline'): return self._previous_pipeline.pipeline_output elif hasattr(self, '_pipeline_input'): return self._pipeline_input else: logging.warning("'pipeline_input' does not exist yet.") # ---------------------------------------------------------------------- # @abstractmethod @pipeline_input.setter def pipeline_input(self, input_): """""" self._pipeline_input = input_ # ---------------------------------------------------------------------- # @abstractmethod @property def pipeline_output(self): """""" if hasattr(self, '_pipeline_output'): return self._pipeline_output # ---------------------------------------------------------------------- # @abstractmethod @pipeline_output.setter def pipeline_output(self, output_): """""" self._pipeline_output = output_ try: self.pipeline_output._original_markers = self.pipeline_output.markers except: pass self._pipeline_propagate() # ---------------------------------------------------------------------- # @abstractmethod @property def pipeline_tunned(self): """""" return getattr(self, '_pipeline_tunned', False) # ---------------------------------------------------------------------- # @abstractmethod @pipeline_tunned.setter def pipeline_tunned(self, value): """""" self._pipeline_tunned = value # ---------------------------------------------------------------------- def next_pipeline(self, pipe): """""" self._next_pipeline = pipe # self._next_pipeline._pipeline_input = self._pipeline_output # ---------------------------------------------------------------------- def previous_pipeline(self, pipe): """""" self._previous_pipeline = pipe # ---------------------------------------------------------------------- def set_pipeline_input(self, in_): """""" self._pipeline_input = in_ # ---------------------------------------------------------------------- # @abstractmethod def _pipeline_propagate(self): """""" if hasattr(self, '_next_pipeline'): if not self._next_pipeline.pipeline_tunned: return if next_pipeline := getattr(self, '_next_pipeline', False): next_pipeline.fit() # ---------------------------------------------------------------------- @abstractmethod def fit(self): """""" ######################################################################## class TimelockSeries(TimelockWidget): """""" # ---------------------------------------------------------------------- def __init__(self, height, *args, **kwargs): """Constructor""" super().__init__(height, *args, **kwargs) self.fill_opacity = 0.2 self.fill_color = os.environ.get( 'QTMATERIAL_PRIMARYCOLOR', '#ff0000') # ---------------------------------------------------------------------- def move_plot(self, value): """""" self.ax1.set_xlim(value / 1000, (value / 1000 + self.window_value)) self.ax2.collections.clear() self.ax2.fill_between([value / 1000, (value / 1000 + self.window_value)], *self.ax1.get_ylim(), color=self.fill_color, alpha=self.fill_opacity) self.draw() # ---------------------------------------------------------------------- def change_window(self): """""" self.window_value = self._get_seconds_from_human( self.combobox.currentText()) eeg = self.pipeline_output.eeg timestamp = self.pipeline_output.timestamp timestamp = np.linspace( 0, timestamp[0][-1], eeg.shape[1], endpoint=True) / 1000 self.scroll.setMaximum((timestamp[-1] - self.window_value) * 1000) self.scroll.setMinimum(0) self.scroll.setPageStep(self.window_value * 1000) self.ax1.set_xlim(self.scroll.value() / 1000, (self.scroll.value() / 1000 + self.window_value)) self.ax2.collections.clear() self.ax2.fill_between([self.scroll.value() / 1000, (self.scroll.value() + self.window_value) / 1000], *self.ax1.get_ylim(), color=self.fill_color, alpha=self.fill_opacity) self.draw() # ---------------------------------------------------------------------- def _get_seconds_from_human(self, human): """""" value = human.replace('milliseconds', '0.001') value = value.replace('second', '1') value = value.replace('minute', '60') value = value.replace('hour', '60 60') return np.prod(list(map(float, value.split()))) # ---------------------------------------------------------------------- def set_data(self, timestamp, eeg, labels, ylabel='', xlabel='', legend=True): """""" self.ax1.clear() self.ax2.clear() for i, ch in enumerate(eeg): self.ax1.plot(timestamp, eeg[i], label=labels[i]) self.ax2.plot(timestamp, eeg[i], alpha=0.5) self.ax1.grid(True, axis='x') if legend: self.ax1.legend(loc='upper center', ncol=8, bbox_to_anchor=(0.5, 1.4), **LEGEND_KWARGS) self.ax1.set_xlim(0, self.window_value) self.ax2.grid(True, axis='x') self.ax2.set_xlim(0, timestamp[-1]) self.ax2.fill_between([0, self.window_value], *self.ax1.get_ylim(), color=self.fill_color, alpha=self.fill_opacity) self.scroll.setMaximum((timestamp[-1] - self.window_value) * 1000) self.scroll.setMinimum(0) self.ax1.set_ylabel(ylabel) self.ax2.set_xlabel(xlabel) self.draw() # ---------------------------------------------------------------------- def set_window_width_options(self, options): """""" self.scroll = self.add_scroll( callback=self.move_plot, area='bottom', stretch=1) self.combobox = self.add_combobox('', options, callback=self.change_window, area='bottom', stretch=0) self.window_value = self._get_seconds_from_human(options[0]) ######################################################################## class Filters(TimelockWidget): """""" # ---------------------------------------------------------------------- def __init__(self, height, *args, **kwargs): """Constructor""" super().__init__(height, *args, **kwargs) self.title = 'Filter EEG' gs = self.figure.add_gridspec(1, 2) self.ax1 = gs.figure.add_subplot(gs[:, 0:-1]) self.ax2 = gs.figure.add_subplot(gs[:, -1]) # self.ax2.get_yaxis().set_visible(False) # self.ax1 = self.figure.add_subplot(111) self.figure.subplots_adjust(left=0.05, bottom=0.12, right=0.95, top=0.95, wspace=None, hspace=0.6) self.filters = {'Notch': 'none', 'Bandpass': 'none', } self.notchs = ('none', '50 Hz', '60 Hz') self.bandpass = ('none', 'delta', 'theta', 'alpha', 'beta', '0.01-20 Hz', '5-45 Hz', '3-30 Hz', '4-40 Hz', '2-45 Hz', '1-50 Hz', '7-13 Hz', '15-50 Hz', '1-100 Hz', '5-50 Hz') self.add_radios('Notch', self.notchs, callback=self.set_filters, area='top', stretch=0) self.add_radios('Bandpass', self.bandpass, callback=self.set_filters, area='top', stretch=0) self.scale = self.add_spin('Scale', 150, suffix='uv', min_=0, max_=1000, step=50, callback=self.fit, area='top', stretch=0) # ---------------------------------------------------------------------- @wait_for_it def fit(self): """""" eeg = self.pipeline_input.original_eeg timestamp = self.pipeline_input.timestamp for f in self.filters: if self.filters[f] != 'none': eeg = self.filters[f](eeg, fs=1000, axis=1) self.ax1.clear() self.ax2.clear() t = np.linspace(0, eeg.shape[1], eeg.shape[1], endpoint=True) / 1000 channels = eeg.shape[0] # threshold = max(eeg.max(axis=1) - eeg.min(axis=1)).round() # threshold = max(eeg.std(axis=1)).round() threshold = self.scale.value() # eeg_d = decimate(eeg, 15, axis=1) # timestamp = np.linspace( # 0, t[-1], eeg_d.shape[1], endpoint=True) for i, ch in enumerate(eeg): self.ax2.plot(t, ch + (threshold * i)) self.ax1.set_xlabel('Frequency [$Hz$]') self.ax1.set_ylabel('Amplitude') self.ax2.set_xlabel('Time [$s$]') self.ax2.set_yticks([threshold * i for i in range(channels)]) self.ax2.set_yticklabels( self.pipeline_input.header['channels'].values()) self.ax2.set_ylim(-threshold, threshold * channels) # self.output_signal = eeg w, spectrum = welch(eeg, fs=1000, axis=1, nperseg=1024, noverlap=256, average='median') # spectrum = decimate(spectrum, 15, axis=1) # w = np.linspace(0, w[-1], spectrum.shape[1]) for i, ch in enumerate(spectrum): self.ax1.fill_between(w, 0, ch, alpha=0.2, color=f'C{i}') self.ax1.plot(w, ch, linewidth=2, color=f'C{i}') self.ax1.set_xscale('log') self.ax1.set_xlim(0, w[-1]) self.ax2.set_xlim(0, t[-1]) self.ax1.grid(True, axis='y') self.ax2.grid(True, axis='x') self.draw() self.pipeline_tunned = True self._pipeline_output = self.pipeline_input self._pipeline_output.eeg = eeg.copy() self._pipeline_propagate() # ---------------------------------------------------------------------- def set_filters(self, group_name, filter_): """""" if filter_ == 'none': self.filters[group_name] = filter_ else: if group_name == 'Notch': filter_ = getattr(flt, f'notch{filter_.replace(" Hz", "")}') elif group_name == 'Bandpass': if filter_ in self.bandpass[1:5]: filter_ = getattr(flt, f'{filter_}') else: filter_ = getattr( flt, f'band{filter_.replace(" Hz", "").replace("-", "").replace(".", "")}') self.filters[group_name] = filter_ self.fit() # # ---------------------------------------------------------------------- # @property # def output(self): # """""" # if hasattr(self, 'output_signal'): # return self.output_signal ######################################################################## class LoadDatabase(TimelockSeries): """""" # ---------------------------------------------------------------------- def __init__(self, height=700, *args, **kwargs): """Constructor""" super().__init__(height, *args, **kwargs) self.title = 'Raw EEG signal' # Create grid plot gs = self.figure.add_gridspec(4, 4) self.ax1 = gs.figure.add_subplot(gs[0:-1, :]) self.ax2 = gs.figure.add_subplot(gs[-1, :]) self.ax2.get_yaxis().set_visible(False) self.figure.subplots_adjust(left=0.05, bottom=0.12, right=0.95, top=0.8, wspace=None, hspace=0.6) self.add_button('Load database', callback=self.load_database, area='top', stretch=0) self.add_spacer(area='top') self.set_window_width_options(['500 milliseconds']) self.window_options = ['500 milliseconds', '1 second', '5 second', '15 second', '30 second', '1 minute', '5 minute', '10 minute', '30 minute', '1 hour'] self.database_description = self.add_textarea( area='right', stretch=0) # ---------------------------------------------------------------------- def load_database(self): """""" self.datafile = Dialogs.load_database() # Set input manually self.pipeline_input = self.datafile flt.compile_filters( FS=self.pipeline_input.header['sample_rate'], N=2, Q=3) self.fit() # ---------------------------------------------------------------------- @wait_for_it def fit(self): """""" datafile = self.pipeline_input header = datafile.header eeg = datafile.eeg datafile.aux timestamp = datafile.timestamp self.database_description.setText(datafile.description) eeg = decimate(eeg, 15, axis=1) timestamp = np.linspace( 0, timestamp[0][-1], eeg.shape[1], endpoint=True) / 1000 eeg = eeg / 1000 options = [self._get_seconds_from_human( w) for w in self.window_options] l = len([o for o in options if o < timestamp[-1]]) self.combobox.clear() self.combobox.addItems(self.window_options[:l]) self.set_data(timestamp, eeg, labels=list(header['channels'].values()), ylabel='Millivolt [$mv$]', xlabel='Time [$s$]') datafile.close() self.pipeline_tunned = True self.pipeline_output = datafile ######################################################################## class EpochsVisualization(TimelockWidget): """""" # ---------------------------------------------------------------------- def __init__(self, height=700, *args, **kwargs): """Constructor""" super().__init__(height, *args, **kwargs) self.title = 'Visualize epochs' self.ax1 = self.figure.add_subplot(111) self.pipeline_tunned = True # ---------------------------------------------------------------------- def fit(self): """""" self.clear_widgets() markers = sorted(list(self.pipeline_input.markers.keys())) channels = list(self.pipeline_input.header['channels'].values()) self.tmin = self.add_spin('tmin', 0, suffix='s', min_=-99, max_=99, callback=self.get_epochs, area='top', stretch=0) self.tmax = self.add_spin( 'tmax', 1, suffix='s', min_=-99, max_=99, callback=self.get_epochs, area='top', stretch=0) self.method = self.add_combobox(label='Method', items=[ 'mean', 'median'], callback=self.get_epochs, area='top', stretch=0) self.add_spacer(area='top', fixed=50) self.reject = self.add_spin('Reject', 200, suffix='vpp', min_=0, max_=500, step=10, callback=self.get_epochs, area='top', stretch=0) self.flat = self.add_spin('Flat', 10, suffix='vpp', min_=0, max_=500, step=10, callback=self.get_epochs, area='top', stretch=0) self.add_spacer(area='top') self.checkbox = self.add_checkbox( 'Markers', markers, callback=self.get_epochs, area='bottom', stretch=1) self.add_spacer(area='bottom') self.channels = self.add_channels( 'Channels', channels, callback=self.get_epochs, area='right', stretch=1) self.add_spacer(area='right') # ---------------------------------------------------------------------- @wait_for_it def get_epochs(self, *args, **kwargs): """""" self.figure.clear() self.ax1 = self.figure.add_subplot(111) markers = sorted([ch.text() for ch in self.checkbox if ch.isChecked()]) channels = sorted([ch.text() for ch in self.channels if ch.isChecked()]) if not markers: return if not channels: return if self.reject.value() < self.flat.value(): return epochs = self.pipeline_input.epochs( tmin=self.tmin.value(), tmax=self.tmax.value(), markers=markers) reject = {'eeg': self.reject.value()} flat = {'eeg': self.flat.value()} epochs.drop_bad(reject, flat) evokeds = {} for mk in markers: erp = epochs[mk].average( method=self.method.currentText(), picks=channels) evokeds[mk] = erp try: mne.viz.plot_compare_evokeds(evokeds, axes=self.ax1, cmap=( 'Class', 'cool'), show=False, show_sensors=False, invert_y=True, styles={}, split_legend=False, legend='upper center') except: pass self.draw() self.pipeline_output = epochs ######################################################################## class AmplitudeAnalysis(TimelockWidget): """""" # ---------------------------------------------------------------------- def __init__(self, height, *args, **kwargs): """Constructor""" super().__init__(height, *args, **kwargs) self.title = 'Amplitude analysis' self.ax1 = self.figure.add_subplot(111) self.pipeline_tunned = True self.figure.subplots_adjust(left=0.05, bottom=0.12, right=0.95, top=0.95) # ---------------------------------------------------------------------- @wait_for_it def fit(self): """""" datafile = self.pipeline_input t = datafile.timestamp[0] / 1000 / 60 eeg = datafile.eeg eeg = eeg - eeg.mean(axis=1)[:, np.newaxis] mx = eeg.max(axis=0) mn = eeg.min(axis=0) m = eeg.mean(axis=0) self.ax1.clear() # dc = int(self.decimate.currentText()) dc = 1000 mxd = decimate(mx, dc, n=2) mnd = decimate(mn, dc, n=2) md = decimate(m, dc, n=2) td = decimate(t, dc, n=2) self.ax1.fill_between(td, mnd, mxd, color='k', alpha=0.3, linewidth=0) self.ax1.plot(td, md, color='C0') vpps = [100, 150, 200, 300, 500, 0] for i, vpp in enumerate(vpps): self.ax1.hlines( vpp / 2, 0, td[-1], linestyle='--', color=pyplot.cm.tab10(i)) if vpp: self.ax1.hlines(-vpp / 2, 0, td[-1], linestyle='--', color=pyplot.cm.tab10(i)) self.ax1.set_xlim(0, td[-1]) self.ax1.set_ylim(2 * mn.mean(), 2 * mx.mean()) ticks = sorted(vpps + [-v for v in vpps]) self.ax1.set_yticks([v / 2 for v in ticks]) self.ax1.set_yticklabels([f'{abs(v)} vpp' for v in ticks]) self.ax1.grid(True, axis='x') self.ax1.set_ylabel('Voltage [uv]') self.ax1.set_xlabel('Time [$s$]') self.draw() self.pipeline_output = self.pipeline_input ######################################################################## class AddMarkers(TimelockSeries): """""" # ---------------------------------------------------------------------- def __init__(self, height, *args, **kwargs): """Constructor""" super().__init__(height, *args, **kwargs) self.title = 'Add new markers' # Create grid plot gs = self.figure.add_gridspec(4, 1) self.ax1 = gs.figure.add_subplot(gs[0:-1, :]) self.ax2 = gs.figure.add_subplot(gs[-1, :]) self.ax2.get_yaxis().set_visible(False) self.figure.subplots_adjust(left=0.05, bottom=0.12, right=0.95, top=0.95, wspace=None, hspace=0.6) self.set_window_width_options( ['500 milliseconds', '1 second', '5 second', '15 second', '30 second', '1 minute', '5 minute', '10 minute', '30 minute', '1 hour']) self.markers = self.add_combobox('Marker', [], callback=None, editable=True, area='bottom2', stretch=3) self.add_button('Add marker', callback=self.add_marker, area='bottom2', stretch=0) self.add_spacer(area='bottom2', stretch=10) # self.database_description = self.add_textarea( # area='right', stretch=0) self.pipeline_tunned = True # ---------------------------------------------------------------------- def add_marker(self): """""" q = np.mean(self.ax1.get_xlim()) self.ax1.vlines(q, * self.ax1.get_ylim(), linestyle='--', color='red', linewidth=5, zorder=99) self.ax2.vlines(q, * self.ax2.get_ylim(), linestyle='--', color='red', linewidth=3, zorder=99) markers = self._pipeline_output.markers markers.setdefault(self.markers.currentText(), []).append(q) self._pipeline_output.markers = markers self._pipeline_propagate() self.draw() # ---------------------------------------------------------------------- @wait_for_it def fit(self): """""" datafile = self.pipeline_input markers = ['BAD', 'BLINK'] markers += sorted(list(datafile.markers.keys())) self.markers.clear() self.markers.addItems(markers) header = datafile.header eeg = datafile.eeg timestamp = datafile.timestamp eeg = decimate(eeg, 15, axis=1) timestamp = np.linspace( 0, timestamp[0][-1], eeg.shape[1], endpoint=True) / 1000 # eeg = eeg / 1000 self.threshold = 150 channels = eeg.shape[0] self.set_data(timestamp, eeg, labels=list(header['channels'].values()), ylabel='Millivolt [$mv$]', xlabel='Time [$s$]', legend=False, ) self.ax1.set_yticks([self.threshold * i for i in range(channels)]) self.ax1.set_yticklabels( self.pipeline_input.header['channels'].values()) self.ax1.set_ylim(-self.threshold, self.threshold * channels) self.ax2.set_ylim(-self.threshold, self.threshold * channels) self.vlines = self.ax1.vlines(np.mean(self.ax1.get_xlim()), * self.ax1.get_ylim(), linestyle='--', color='red', linewidth=2, zorder=99) self.draw() datafile.close() self.pipeline_tunned = True self.pipeline_output = self.pipeline_input # ---------------------------------------------------------------------- def set_data(self, timestamp, eeg, labels, ylabel='', xlabel='', legend=True): """""" self.ax1.clear() self.ax2.clear() for i, ch in enumerate(eeg): self.ax1.plot(timestamp, ch + self.threshold * i, label=labels[i]) self.ax2.plot(timestamp, ch + self.threshold * i, alpha=0.5) self.ax1.grid(True, axis='x') if legend: self.ax1.legend(loc='upper center', ncol=8, bbox_to_anchor=(0.5, 1.4), **LEGEND_KWARGS) self.ax1.set_xlim(0, self.window_value) self.ax2.grid(True, axis='x') self.ax2.set_xlim(0, timestamp[-1]) self.ax2.fill_between([0, self.window_value], *self.ax1.get_ylim(), color=self.fill_color, alpha=self.fill_opacity, label='AREA') self.scroll.setMaximum((timestamp[-1] - self.window_value) * 1000) self.scroll.setMinimum(0) self.ax1.set_ylabel(ylabel) self.ax2.set_xlabel(xlabel) # ---------------------------------------------------------------------- def move_plot(self, value): """""" self.ax1.set_xlim(value / 1000, (value / 1000 + self.window_value)) for area in [i for i, c in enumerate(self.ax2.collections) if c.get_label() == 'AREA'][::-1]: self.ax2.collections.pop(area) self.ax2.fill_between([value / 1000, (value / 1000 + self.window_value)], * self.ax1.get_ylim(), color=self.fill_color, alpha=self.fill_opacity, label='AREA') segments = self.vlines.get_segments() segments[0][:, 0] = [np.mean(self.ax1.get_xlim())] * 2 self.vlines.set_segments(segments) self.draw() # ---------------------------------------------------------------------- def change_window(self): """""" self.window_value = self._get_seconds_from_human( self.combobox.currentText()) eeg = self.pipeline_output.eeg timestamp = self.pipeline_output.timestamp timestamp = np.linspace( 0, timestamp[0][-1], eeg.shape[1], endpoint=True) / 1000 self.scroll.setMaximum((timestamp[-1] - self.window_value) * 1000) self.scroll.setMinimum(0) self.scroll.setPageStep(self.window_value * 1000) self.ax1.set_xlim(self.scroll.value() / 1000, (self.scroll.value() / 1000 + self.window_value)) self.draw() # ######################################################################## # class ConditionalCreateMarkers(ta.TimelockWidget): # """""" # # ---------------------------------------------------------------------- # def __init__(self, height, *args, **kwargs): # """Constructor""" # super().__init__(height=0, *args, **kwargs) # self.title = 'Create markers conditionally' # self.layout = QtWidgets.QVBoxLayout() # widget = QtWidgets.QWidget() # widget.setLayout(self.layout) # getattr(self.widget, 'topLayout').addWidget(widget) # getattr(self, 'top_stretch').append(1) # self.add_button('Add row', callback=self.add_row, # area='bottom', stretch=0) # self.add_spacer(area='bottom', fixed=None, stretch=1) # self.new_markers = {} # # ---------------------------------------------------------------------- # @wait_for_it # def fit(self): # """""" # # ---------------------------------------------------------------------- # def add_new_markers(self, n): # """""" # for k in self.new_markers: # c1, c2, tx = self.new_markers[k] # print(f'{c1()}, {c2()}, {tx()}') # print('#' * 10) # # ---------------------------------------------------------------------- # def add_row(self): # """""" # layout = QtWidgets.QHBoxLayout() # widget = QtWidgets.QWidget() # widget.setLayout(layout) # layout.addWidget(QtWidgets.QLabel( # 'Create new markers in the position of')) # combo1 = QtWidgets.QComboBox() # combo1.addItems(self.pipeline_input.markers.keys()) # layout.addWidget(combo1) # layout.addWidget(QtWidgets.QLabel('that have a closest')) # combo2 = QtWidgets.QComboBox() # combo2.addItems(self.pipeline_input.markers.keys()) # layout.addWidget(combo2) # layout.addWidget(QtWidgets.QLabel('as')) # edit = QtWidgets.QLineEdit() # self.new_markers[edit] = ( # combo1.currentText, combo2.currentText, edit.text) # edit.textChanged.connect(self.add_new_markers) # layout.addWidget(edit) # layout.setStretch(0, 0) # layout.setStretch(1, 0) # layout.setStretch(2, 0) # layout.setStretch(3, 0) # layout.setStretch(4, 0) # layout.setStretch(5, 1) # self.layout.addWidget(widget) ######################################################################## class MarkersSynchronization(TimelockWidget): """""" # ---------------------------------------------------------------------- def __init__(self, height, *args, **kwargs): """Constructor""" super().__init__(height, *args, **kwargs) self.title = 'Markers synchronization' # self.add_radios('Markers', self.notchs, callback=self.set_filters, # area='top', stretch=0) self.sync_channel = self.add_combobox('Channel', [], editable=False, callback=self.update_plot, area='top2', stretch=0) self.add_spacer(area='top2', fixed=None, stretch=1) self.upper = self.add_spin('Upper', 500, suffix='vpp', min_=0, max_=2000, step=10, callback=self.update_plot, area='right', stretch=0) self.lower = self.add_spin('Lower', 200, suffix='vpp', min_=0, max_=2000, step=10, callback=self.update_plot, area='right', stretch=0) self.pipeline_tunned = True gs = self.figure.add_gridspec(1, 3) self.ax1 = gs.figure.add_subplot(gs[:, 0:-1]) self.ax2 = gs.figure.add_subplot(gs[:, -1]) self.figure.subplots_adjust(left=0.05, bottom=0.12, right=0.95, top=0.8) # ---------------------------------------------------------------------- @wait_for_it def fit(self): """""" self.sync_channel.clear() self.sync_channel.addItems( f'AUX{c}' for c in range(self.pipeline_input.aux.shape[0])) self.clear_widgets(areas=['left']) self.marker_sync = self.add_checkbox('Markers', self.pipeline_input.markers.keys(), callback=self.update_plot, area='left', stretch=0, cols=1) self.add_spacer(stretch=1, area='left') # ---------------------------------------------------------------------- def update_plot(self, *args, **kwargs): """""" self.ax1.clear() self.ax2.clear() lower_val = self.lower.value() upper_val = self.upper.value() aux = self.pipeline_input.aux[self.sync_channel.currentIndex()] markers = self.pipeline_input.markers t = self.pipeline_input.aux_timestamp[0] rises = self.pipeline_input.get_rises( aux, t, lower=lower_val, upper=upper_val) mks = [] target_markers = [ch.text() for ch in self.marker_sync if ch.isChecked()] for k in target_markers: mks.extend(markers[k]) for i in mks: shape = aux[i - 2000:i + 2000] ts = np.linspace(-2000, 2000, shape.shape[0]) self.ax1.plot(ts, shape, color=pyplot.cm.tab10(7), alpha=0.5, linewidth=1) sh = shape.copy() if sh.size: sh = sh / (sh.max() - sh.min()) sh = sh - sh.min() sh[sh > 0.5] = 1 sh[sh <= 0.5] = 0 a = abs(np.diff(sh, prepend=0)) if r := np.argwhere(a == 1)[0][0]: self.ax1.vlines(ts[r], 200, 800, linestyle='--', color=pyplot.cm.tab10(3), alpha=0.5) self.ax1.grid(True) for rise in rises: i = np.argmin(abs(t - rise)) shape = aux[i - 50:i + 300] ts = np.linspace(-50, 300, shape.shape[0]) self.ax2.plot(ts, shape, color=pyplot.cm.tab10(7), alpha=0.1, linewidth=1) target = 100 * len(mks) / len(rises) self.ax2.plot(ts, shape, color=pyplot.cm.tab10(7), alpha=0.1, linewidth=1, label=f'{target:.2f}% of markers synchronized') self.ax2.grid(True) self.ax2.vlines(0, lower_val, upper_val, linestyle='--', color=pyplot.cm.tab10(3)) self.ax1.set_title('Original analog rises') self.ax2.set_title('Syncronized rises') self.ax1.set_xlabel('Time [s]') self.ax2.set_xlabel('Time [s]') self.ax1.set_ylabel('Amplitude [mV]') # self.ax1.legend(ncol=2, loc='upper center') if 90 < target < 110: self.ax2.legend(loc='lower right', facecolor=pyplot.cm.tab10( 0), framealpha=0.5, **LEGEND_KWARGS) else: self.ax2.legend(loc='lower right', facecolor=pyplot.cm.tab10( 3), framealpha=0.5, **LEGEND_KWARGS) self.ax1.set_ylim(lower_val, upper_val) self.ax2.set_ylim(lower_val, upper_val) # self.ax2.set_xlim(0, 20) self.pipeline_input.reset_markers() if target_markers: self.pipeline_input.fix_markers( target_markers, rises, range_=2000) # self.pipeline_tunned = True self.pipeline_output = self.pipeline_input self.draw() ######################################################################## class ScriptProcess(TimelockWidget): """""" # ---------------------------------------------------------------------- def __init__(self, height, *args, **kwargs): """Constructor""" super().__init__(0, *args, **kwargs) self.title = 'Script process' self.pipeline_tunned = True # # ---------------------------------------------------------------------- # def fit(self): # """""" # self.pipeline_output = self.process(self.pipeline_input) # # ---------------------------------------------------------------------- # def process(self, *args, **kwargs): # """""" # logging.warning('ERROR')