Source code for bci_framework.extensions.timelock_analysis.timelock_analysis

import os
import sys
import logging
import copy
from abc import ABCMeta, abstractmethod

import mne
import numpy as np
# from scipy.fftpack import rfft, rfftfreq
from scipy.signal import welch, decimate
from scipy.signal import decimate, welch

from cycler import cycler
import matplotlib
from matplotlib import pyplot
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg
from matplotlib.figure import Figure

from PySide6.QtCore import Qt
from PySide6 import QtWidgets
from PySide6.QtUiTools import QUiLoader
from PySide6.QtWidgets import QSpacerItem, QSizePolicy

from gcpds.filters import frequency as flt
from gcpds.filters import frequency as flt
from bci_framework.extensions.timelock_analysis import timelock_analysis as ta
from bci_framework.framework.dialogs import Dialogs


# Set logger
logger = logging.getLogger("mne")
logger.setLevel(logging.CRITICAL)
logging.getLogger('matplotlib.font_manager').disabled = True
logging.getLogger().setLevel(logging.WARNING)
logging.root.name = "TimelockAnalysis"


########################################################################
[docs]class Canvas(FigureCanvasQTAgg): # ---------------------------------------------------------------------- def __init__(self, *args, **kwargs): """""" # Consigure matplotlib self.configure() self.figure = Figure(*args, **kwargs) super().__init__(self.figure) # self.figure.subplots_adjust(left=0.1, right=0.9, top=0.9, bottom=0.1) # ---------------------------------------------------------------------- def configure(self): """""" if ('light' in sys.argv) or ('light' in os.environ.get('QTMATERIAL_THEME', '')): pass else: pyplot.style.use('dark_background') try: q = matplotlib.cm.get_cmap('cool') matplotlib.rcParams['axes.prop_cycle'] = cycler( color=[q(m) for m in np.linspace(0, 1, 16)]) matplotlib.rcParams['figure.dpi'] = 70 matplotlib.rcParams['font.family'] = 'monospace' matplotlib.rcParams['font.size'] = 15 # matplotlib.rcParams['legend.facecolor'] = 'red' except: # 'rcParams' object does not support item assignment pass
######################################################################## class TimelockWidget(metaclass=ABCMeta): """""" # ---------------------------------------------------------------------- def __init__(self, height, *args, **kwargs): """Constructor""" # self.fill_opacity = 0.2 # self.fill_color = os.environ.get('QTMATERIAL_PRIMARYCOLOR', '#ff0000') self.bottom_stretch = [] self.bottom2_stretch = [] self.top_stretch = [] self.top2_stretch = [] self.right_stretch = [] self.left_stretch = [] self._pipeline_output = None ui = os.path.realpath(os.path.join( os.environ['BCISTREAM_ROOT'], 'framework', 'qtgui', 'locktime_widget.ui')) self.widget = QUiLoader().load(ui) # self.widget.setProperty('class', 'bottom_border') if height: self.widget.setMinimumHeight(height) self.canvas = Canvas(*args, **kwargs) self.figure = self.canvas.figure self.widget.gridLayout.addWidget(self.canvas) # ---------------------------------------------------------------------- def draw(self): """""" self.canvas.configure() self.canvas.draw() # ---------------------------------------------------------------------- def _add_spacers(self): """""" for i, s in enumerate(self.bottom_stretch): self.widget.bottomLayout.setStretch(i, s) for i, s in enumerate(self.top_stretch): self.widget.topLayout.setStretch(i, s) for i, s in enumerate(self.bottom2_stretch): self.widget.bottom2Layout.setStretch(i, s) for i, s in enumerate(self.top2_stretch): self.widget.top2Layout.setStretch(i, s) for i, s in enumerate(self.right_stretch): self.widget.rightLayout.setStretch(i, s) for i, s in enumerate(self.left_stretch): self.widget.leftLayout.setStretch(i, s) # ---------------------------------------------------------------------- def add_spacer(self, area='top', fixed=None, stretch=0): """""" if fixed: if area in ['left', 'right']: getattr(self.widget, f'{area}Layout').addItem(QSpacerItem( 20, fixed, QSizePolicy.Minimum, QSizePolicy.Minimum)) elif area in ['top', 'bottom', 'top2', 'bottom2']: getattr(self.widget, f'{area}Layout').addItem(QSpacerItem( fixed, 20, QSizePolicy.Minimum, QSizePolicy.Minimum)) else: if area in ['left', 'right']: getattr(self.widget, f'{area}Layout').addItem(QSpacerItem( 20, 20000, QSizePolicy.Minimum, QSizePolicy.Expanding)) elif area in ['top', 'bottom', 'top2', 'bottom2']: getattr(self.widget, f'{area}Layout').addItem(QSpacerItem( 20000, 20, QSizePolicy.Expanding, QSizePolicy.Minimum)) if stretch: getattr(self, f'{area}_stretch').append(stretch) # ---------------------------------------------------------------------- def clear_layout(self, layout): """""" for i in range(layout.count()): b = layout.itemAt(i) if b is None: continue if w := b.widget(): # widget w.deleteLater() if b.spacerItem(): # spacer layout.removeItem(b) if l := b.layout(): self.clear_layout(l) # layout.removeItem(layout.itemAt(i)) # b = layout.takeAt(2) # buttons.pop(2) # b.widget().deleteLater() # ---------------------------------------------------------------------- def clear_widgets(self): """""" for area in ['left', 'right', 'top', 'bottom', 'top2', 'bottom2']: layout = getattr(self.widget, f'{area}Layout') self.clear_layout(layout) # ---------------------------------------------------------------------- def add_textarea(self, content='', area='top', stretch=0): """""" textarea = QtWidgets.QTextEdit(content) textarea.setProperty('class', 'clear') textarea.setMinimumWidth(500) textarea.setReadOnly(True) # if callback: # button.clicked.connect(callback) getattr(self.widget, f'{area}Layout').addWidget(textarea) getattr(self, f'{area}_stretch').append(stretch) return textarea # ---------------------------------------------------------------------- def add_button(self, label, callback=None, area='top', stretch=0): """""" button = QtWidgets.QPushButton(label) if callback: button.clicked.connect(callback) getattr(self.widget, f'{area}Layout').addWidget(button) getattr(self, f'{area}_stretch').append(stretch) return button # ---------------------------------------------------------------------- def add_radios(self, group_name, radios, cols=None, callback=None, area='top', stretch=1): """""" group = QtWidgets.QGroupBox(group_name) group.setProperty('class', 'fill_background') vbox = QtWidgets.QVBoxLayout() group.setLayout(vbox) if cols is None: cols = len(radios) for i, radio in enumerate(radios): if (i % cols) == 0: hbox = QtWidgets.QHBoxLayout() vbox.addLayout(hbox) # group.setLayout(hbox) r = QtWidgets.QRadioButton() r.setText(radio) r.setChecked(i == 0) # ---------------------------------------------------------------------- def dec(*args): # ---------------------------------------------------------------------- def wrap(fn): return callback(*args) return wrap if callback: r.clicked.connect(dec(group_name, radio)) hbox.addWidget(r) getattr(self.widget, f'{area}Layout').addWidget(group) getattr(self, f'{area}_stretch').append(stretch) # ---------------------------------------------------------------------- def add_checkbox(self, group_name, radios, ncol=None, callback=None, area='top', stretch=1): """""" group = QtWidgets.QGroupBox(group_name) group.setProperty('class', 'fill_background') vbox = QtWidgets.QVBoxLayout() group.setLayout(vbox) if ncol is None: ncol = len(radios) list_radios = [] for i, radio in enumerate(radios): if (i % ncol) == 0: hbox = QtWidgets.QHBoxLayout() vbox.addLayout(hbox) # group.setLayout(hbox) r = QtWidgets.QCheckBox() r.setText(radio) r.setChecked(i == 0) list_radios.append(r) # ---------------------------------------------------------------------- def dec(*args): # ---------------------------------------------------------------------- def wrap(fn): return callback(*args) return wrap if callback: r.clicked.connect(dec(group_name, radio)) hbox.addWidget(r) getattr(self.widget, f'{area}Layout').addWidget(group) getattr(self, f'{area}_stretch').append(stretch) return list_radios # ---------------------------------------------------------------------- def add_channels(self, group_name, radios, callback=None, area='top', stretch=1): """""" group = QtWidgets.QGroupBox(group_name) group.setProperty('class', 'fill_background') vbox = QtWidgets.QHBoxLayout() group.setLayout(vbox) # ncol = len(radios) vbox_odd = QtWidgets.QVBoxLayout() vbox_z = QtWidgets.QVBoxLayout() vbox_even = QtWidgets.QVBoxLayout() vbox.addLayout(vbox_even) vbox.addLayout(vbox_z) vbox.addLayout(vbox_odd) list_radios = [] for radio in radios: r = QtWidgets.QCheckBox() r.setText(radio) r.setChecked(True) list_radios.append(r) if radio[-1].isnumeric() and int(radio[-1]) % 2 != 0: # odd vbox_even.addWidget(r) elif radio[-1].isnumeric() and int(radio[-1]) % 2 == 0: # even vbox_odd.addWidget(r) else: vbox_z.addWidget(r) def dec(*args): def wrap(fn): return callback(*args) return wrap if callback: r.clicked.connect(dec(group_name, radio)) getattr(self.widget, f'{area}Layout').addWidget(group) getattr(self, f'{area}_stretch').append(stretch) return list_radios # ---------------------------------------------------------------------- def add_scroll(self, callback=None, area='bottom', stretch=0): """""" scroll = QtWidgets.QScrollBar() scroll.setOrientation(Qt.Horizontal) # scroll.setMaximum(255) scroll.sliderMoved.connect(callback) scroll.setProperty('class', 'big') # scroll.setPageStep(1000) getattr(self.widget, f'{area}Layout').addWidget(scroll) getattr(self, f'{area}_stretch').append(stretch) return scroll # ---------------------------------------------------------------------- def add_slider(self, callback=None, area='bottom', stretch=0): """""" slider = QtWidgets.QSlider() slider.setOrientation(Qt.Horizontal) slider.setMaximum(0) slider.setMaximum(500) slider.setValue(500) slider.valueChanged.connect(callback) getattr(self.widget, f'{area}Layout').addWidget(slider) getattr(self, f'{area}_stretch').append(stretch) return slider # ---------------------------------------------------------------------- def add_spin(self, label, value, decimals=1, step=0.1, prefix='', suffix='', min_=0, max_=999, callback=None, area='top', stretch=0): """""" spin = QtWidgets.QDoubleSpinBox() spin.setDecimals(decimals) spin.setSingleStep(step) spin.setMinimum(min_) spin.setMaximum(max_) spin.setValue(value) if callback: spin.valueChanged.connect(callback) if prefix: spin.setPrefix(f' {prefix}') if suffix: spin.setSuffix(f' {suffix}') layout = QtWidgets.QHBoxLayout() widget = QtWidgets.QWidget() widget.setLayout(layout) if label: layout.addWidget(QtWidgets.QLabel(label)) layout.addWidget(spin) getattr(self.widget, f'{area}Layout').addWidget(widget) getattr(self, f'{area}_stretch').append(stretch) layout.setStretch(0, 0) layout.setStretch(1, 1) return spin # ---------------------------------------------------------------------- def add_combobox(self, label, items, editable=False, callback=None, area='top', stretch=0): """""" combo = QtWidgets.QComboBox() combo.addItems(items) combo.activated.connect(callback) combo.setEditable(editable) layout = QtWidgets.QHBoxLayout() widget = QtWidgets.QWidget() widget.setLayout(layout) if label: layout.addWidget(QtWidgets.QLabel(label)) layout.addWidget(combo) getattr(self.widget, f'{area}Layout').addWidget(widget) getattr(self, f'{area}_stretch').append(stretch) layout.setStretch(0, 0) layout.setStretch(1, 1) return combo # ---------------------------------------------------------------------- # @abstractmethod @property def pipeline_input(self): """""" if hasattr(self, '_previous_pipeline'): return self._previous_pipeline.pipeline_output elif hasattr(self, '_pipeline_input'): return self._pipeline_input else: logging.warning("'pipeline_input' does not exist yet.") # ---------------------------------------------------------------------- # @abstractmethod @pipeline_input.setter def pipeline_input(self, input_): """""" self._pipeline_input = input_ # ---------------------------------------------------------------------- # @abstractmethod @property def pipeline_output(self): """""" if hasattr(self, '_pipeline_output'): return self._pipeline_output # ---------------------------------------------------------------------- # @abstractmethod @pipeline_output.setter def pipeline_output(self, output_): """""" self._pipeline_output = output_ self._pipeline_propagate() # ---------------------------------------------------------------------- # @abstractmethod @property def pipeline_tunned(self): """""" return getattr(self, '_pipeline_tunned', False) # ---------------------------------------------------------------------- # @abstractmethod @pipeline_tunned.setter def pipeline_tunned(self, value): """""" self._pipeline_tunned = value # ---------------------------------------------------------------------- def next_pipeline(self, pipe): """""" self._next_pipeline = pipe # self._next_pipeline._pipeline_input = self._pipeline_output # ---------------------------------------------------------------------- def previous_pipeline(self, pipe): """""" self._previous_pipeline = pipe # ---------------------------------------------------------------------- def set_pipeline_input(self, in_): """""" self._pipeline_input = in_ # ---------------------------------------------------------------------- # @abstractmethod def _pipeline_propagate(self): """""" if hasattr(self, '_next_pipeline'): if not self._next_pipeline.pipeline_tunned: return if next_pipeline := getattr(self, '_next_pipeline', False): next_pipeline.fit() # ---------------------------------------------------------------------- @abstractmethod def fit(self): """""" ######################################################################## class TimelockSeries(TimelockWidget): """""" # ---------------------------------------------------------------------- def __init__(self, height, *args, **kwargs): """Constructor""" super().__init__(height, *args, **kwargs) self.fill_opacity = 0.2 self.fill_color = os.environ.get( 'QTMATERIAL_PRIMARYCOLOR', '#ff0000') # ---------------------------------------------------------------------- def move_plot(self, value): """""" self.ax1.set_xlim(value / 1000, (value / 1000 + self.window_value)) self.ax2.collections.clear() self.ax2.fill_between([value / 1000, (value / 1000 + self.window_value)], *self.ax1.get_ylim(), color=self.fill_color, alpha=self.fill_opacity) self.draw() # ---------------------------------------------------------------------- def change_window(self): """""" self.window_value = self._get_seconds_from_human( self.combobox.currentText()) eeg = self.pipeline_output.eeg timestamp = self.pipeline_output.timestamp timestamp = np.linspace( 0, timestamp[0][-1], eeg.shape[1], endpoint=True) / 1000 self.scroll.setMaximum((timestamp[-1] - self.window_value) * 1000) self.scroll.setMinimum(0) self.scroll.setPageStep(self.window_value * 1000) self.ax1.set_xlim(self.scroll.value() / 1000, (self.scroll.value() / 1000 + self.window_value)) self.ax2.collections.clear() self.ax2.fill_between([self.scroll.value() / 1000, (self.scroll.value() + self.window_value) / 1000], *self.ax1.get_ylim(), color=self.fill_color, alpha=self.fill_opacity) # paths = self.area.get_paths() # v = paths[0].vertices[:, 0] # m, n = v.min(), v.max() # v[v == n] = self.scroll.value() / 1000 # v[v == m] = self.scroll.value() / 1000 + self.window_value self.draw() # ---------------------------------------------------------------------- def _get_seconds_from_human(self, human): """""" value = human.replace('milliseconds', '0.001') value = value.replace('second', '1') value = value.replace('minute', '60') value = value.replace('hour', '60 60') return np.prod(list(map(float, value.split()))) # ---------------------------------------------------------------------- def set_data(self, timestamp, eeg, labels, ylabel='', xlabel=''): """""" self.ax1.clear() self.ax2.clear() for i, ch in enumerate(eeg): self.ax1.plot(timestamp, eeg[i], label=labels[i]) self.ax2.plot(timestamp, eeg[i], alpha=0.5) self.ax1.grid(True) self.ax1.legend(loc='upper center', ncol=8, labelcolor='k', bbox_to_anchor=(0.5, 1.4)) self.ax1.set_xlim(0, self.window_value) self.ax2.grid(True) self.ax2.set_xlim(0, timestamp[-1]) self.ax2.fill_between([0, self.window_value], *self.ax1.get_ylim(), color=self.fill_color, alpha=self.fill_opacity) self.scroll.setMaximum((timestamp[-1] - self.window_value) * 1000) self.scroll.setMinimum(0) self.ax1.set_ylabel(ylabel) self.ax2.set_xlabel(xlabel) self.draw() # ---------------------------------------------------------------------- def set_window_width_options(self, options): """""" self.scroll = self.add_scroll( callback=self.move_plot, area='bottom', stretch=1) self.combobox = self.add_combobox('', options, callback=self.change_window, area='bottom', stretch=0) self.window_value = self._get_seconds_from_human(options[0]) ######################################################################## class TimelockFilters(TimelockWidget): """""" # ---------------------------------------------------------------------- def __init__(self, height, *args, **kwargs): """Constructor""" super().__init__(height, *args, **kwargs) gs = self.figure.add_gridspec(1, 2) self.ax1 = gs.figure.add_subplot(gs[:, 0:-1]) self.ax2 = gs.figure.add_subplot(gs[:, -1]) # self.ax2.get_yaxis().set_visible(False) # self.ax1 = self.figure.add_subplot(111) self.figure.subplots_adjust(left=0.05, bottom=0.12, right=0.95, top=0.8, wspace=None, hspace=0.6) self.filters = {'Notch': 'none', 'Bandpass': 'none', } self.notchs = ('none', '50 Hz', '60 Hz') self.bandpass = ('none', 'delta', 'theta', 'alpha', 'beta', '0.01-20 Hz', '5-45 Hz', '3-30 Hz', '4-40 Hz', '2-45 Hz', '1-50 Hz', '7-13 Hz', '15-50 Hz', '1-100 Hz', '5-50 Hz') self.add_radios('Notch', self.notchs, callback=self.set_filters, area='top', stretch=0) self.add_radios('Bandpass', self.bandpass, callback=self.set_filters, area='top', stretch=1) self.scale = self.add_spin('Scale', 150, suffix='uv', min_=0, max_=1000, step=50, callback=self.fit, area='top', stretch=0) # ---------------------------------------------------------------------- def fit(self): """""" eeg = self.pipeline_input.original_eeg timestamp = self.pipeline_input.timestamp for f in self.filters: if self.filters[f] != 'none': eeg = self.filters[f](eeg, fs=1000, axis=1) self.ax1.clear() self.ax2.clear() t = np.linspace(0, eeg.shape[1], eeg.shape[1], endpoint=True) / 1000 channels = eeg.shape[0] # threshold = max(eeg.max(axis=1) - eeg.min(axis=1)).round() # threshold = max(eeg.std(axis=1)).round() threshold = self.scale.value() # eeg_d = decimate(eeg, 15, axis=1) # timestamp = np.linspace( # 0, t[-1], eeg_d.shape[1], endpoint=True) for i, ch in enumerate(eeg): self.ax2.plot(t, ch + (threshold * i)) self.ax1.set_xlabel('Frequency [$Hz$]') self.ax1.set_ylabel('Amplitude') self.ax2.set_xlabel('Time [$s$]') self.ax2.set_yticks([threshold * i for i in range(channels)]) self.ax2.set_yticklabels( self.pipeline_input.header['channels'].values()) self.ax2.set_ylim(-threshold, threshold * channels) # self.output_signal = eeg w, spectrum = welch(eeg, fs=1000, axis=1, nperseg=1024, noverlap=256, average='median') # spectrum = decimate(spectrum, 15, axis=1) # w = np.linspace(0, w[-1], spectrum.shape[1]) for i, ch in enumerate(spectrum): self.ax1.fill_between(w, 0, ch, alpha=0.2, color=f'C{i}') self.ax1.plot(w, ch, linewidth=2, color=f'C{i}') self.ax1.set_xscale('log') self.ax1.set_xlim(0, w[-1]) self.ax2.set_xlim(0, t[-1]) self.ax1.grid(True) self.ax2.grid(True) self.draw() self.pipeline_tunned = True self._pipeline_output = self.pipeline_input self._pipeline_output.eeg = eeg.copy() self._pipeline_propagate() # ---------------------------------------------------------------------- def set_filters(self, group_name, filter_): """""" if filter_ == 'none': self.filters[group_name] = filter_ else: if group_name == 'Notch': filter_ = getattr(flt, f'notch{filter_.replace(" Hz", "")}') elif group_name == 'Bandpass': if filter_ in self.bandpass[1:5]: filter_ = getattr(flt, f'{filter_}') else: filter_ = getattr( flt, f'band{filter_.replace(" Hz", "").replace("-", "").replace(".", "")}') self.filters[group_name] = filter_ self.fit() # # ---------------------------------------------------------------------- # @property # def output(self): # """""" # if hasattr(self, 'output_signal'): # return self.output_signal ######################################################################## class LoadDatabase(ta.TimelockSeries): """""" # ---------------------------------------------------------------------- def __init__(self, height=700, *args, **kwargs): """Constructor""" super().__init__(height, *args, **kwargs) # Create grid plot gs = self.figure.add_gridspec(4, 4) self.ax1 = gs.figure.add_subplot(gs[0:-1, :]) self.ax2 = gs.figure.add_subplot(gs[-1, :]) self.ax2.get_yaxis().set_visible(False) self.figure.subplots_adjust(left=0.05, bottom=0.12, right=0.95, top=0.8, wspace=None, hspace=0.6) self.add_button('Load database', callback=self.load_database, area='top', stretch=0) self.add_spacer(area='top') self.set_window_width_options(['500 milliseconds']) self.window_options = ['500 milliseconds', '1 second', '5 second', '15 second', '30 second', '1 minute', '5 minute', '10 minute', '30 minute', '1 hour'] self.database_description = self.add_textarea( area='right', stretch=0) # ---------------------------------------------------------------------- def load_database(self): """""" self.datafile = Dialogs.load_database() # Set input manually self.pipeline_input = self.datafile flt.compile_filters( FS=self.pipeline_input.header['sample_rate'], N=2, Q=3) self.fit() # ---------------------------------------------------------------------- def fit(self): """""" datafile = self.pipeline_input header = datafile.header eeg = datafile.eeg timestamp = datafile.timestamp self.database_description.setText(datafile.description) eeg = decimate(eeg, 15, axis=1) timestamp = np.linspace( 0, timestamp[0][-1], eeg.shape[1], endpoint=True) / 1000 eeg = eeg / 1000 options = [self._get_seconds_from_human( w) for w in self.window_options] l = len([o for o in options if o < timestamp[-1]]) self.combobox.clear() self.combobox.addItems(self.window_options[:l]) self.set_data(timestamp, eeg, labels=list(header['channels'].values()), ylabel='Millivolt [$mv$]', xlabel='Time [$s$]') datafile.close() self.pipeline_tunned = True self.pipeline_output = datafile ######################################################################## class EpochsVisualization(ta.TimelockWidget): """""" # ---------------------------------------------------------------------- def __init__(self, height=700, *args, **kwargs): """Constructor""" super().__init__(height, *args, **kwargs) self.ax1 = self.figure.add_subplot(111) self.pipeline_tunned = True # ---------------------------------------------------------------------- def fit(self): """""" self.clear_widgets() markers = list(self.pipeline_input.file.markers.keys()) channels = list(self.pipeline_input.header['channels'].values()) self.tmin = self.add_spin('tmin', 0, suffix='s', min_=-99, max_=99, callback=self.get_epochs, area='top', stretch=0) self.tmax = self.add_spin( 'tmax', 1, suffix='s', min_=-99, max_=99, callback=self.get_epochs, area='top', stretch=0) self.method = self.add_combobox(label='Method', items=[ 'mean', 'median'], callback=self.get_epochs, area='top', stretch=0) self.add_spacer(area='top', fixed=50) self.reject = self.add_spin('Reject', 200, suffix='vpp', min_=0, max_=500, step=10, callback=self.get_epochs, area='top', stretch=0) self.flat = self.add_spin('Flat', 10, suffix='vpp', min_=0, max_=500, step=10, callback=self.get_epochs, area='top', stretch=0) self.add_spacer(area='top') self.checkbox = self.add_checkbox( 'Markers', markers, callback=self.get_epochs, area='left', ncol=1, stretch=1) self.add_spacer(area='left') self.channels = self.add_channels( 'Channels', channels, callback=self.get_epochs, area='right', stretch=1) self.add_spacer(area='right') # ---------------------------------------------------------------------- def get_epochs(self, *args, **kwargs): """""" self.figure.clear() self.ax1 = self.figure.add_subplot(111) markers = sorted([ch.text() for ch in self.checkbox if ch.isChecked()]) channels = sorted([ch.text() for ch in self.channels if ch.isChecked()]) if not markers: return if not channels: return if self.reject.value() < self.flat.value(): return epochs = self.pipeline_input.epochs( tmin=self.tmin.value(), tmax=self.tmax.value(), markers=markers) reject = {'eeg': self.reject.value()} flat = {'eeg': self.flat.value()} epochs.drop_bad(reject, flat) evokeds = {} for mk in markers: erp = epochs[mk].average( method=self.method.currentText(), picks=channels) evokeds[mk] = erp try: mne.viz.plot_compare_evokeds(evokeds, axes=self.ax1, cmap=( 'Class', 'cool'), show=False, show_sensors=False, invert_y=True, styles={}, split_legend=False, legend='upper center') except: pass self.draw() self.pipeline_output = epochs ######################################################################## class TimelockAmplitudeAnalysis(ta.TimelockWidget): """""" # ---------------------------------------------------------------------- def __init__(self, height, *args, **kwargs): """Constructor""" super().__init__(height, *args, **kwargs) self.ax1 = self.figure.add_subplot(111) self.pipeline_tunned = True # decimates = '10 20 50 100 1000 2000 5000'.split() # self.decimate = self.add_combobox( # 'Decimate', decimates, callback=self.fit, area='top', stretch=0) # self.add_spacer(area='top') # ---------------------------------------------------------------------- def fit(self): """""" datafile = self.pipeline_input t = datafile.timestamp[0] / 1000 / 60 eeg = datafile.eeg eeg = eeg - eeg.mean(axis=1)[:, np.newaxis] mx = eeg.max(axis=0) mn = eeg.min(axis=0) m = eeg.mean(axis=0) self.ax1.clear() # dc = int(self.decimate.currentText()) dc = 1000 mxd = decimate(mx, dc, n=2) mnd = decimate(mn, dc, n=2) md = decimate(m, dc, n=2) td = decimate(t, dc, n=2) self.ax1.fill_between(td, mnd, mxd, color='k', alpha=0.3, linewidth=0) self.ax1.plot(td, md, color='C0') vpps = [100, 150, 200, 300, 500, 0] for i, vpp in enumerate(vpps): self.ax1.hlines( vpp / 2, 0, td[-1], linestyle='--', color=pyplot.cm.tab10(i)) if vpp: self.ax1.hlines(-vpp / 2, 0, td[-1], linestyle='--', color=pyplot.cm.tab10(i)) self.ax1.set_xlim(0, td[-1]) self.ax1.set_ylim(2 * mn.mean(), 2 * mx.mean()) ticks = sorted(vpps + [-v for v in vpps]) self.ax1.set_yticks([v / 2 for v in ticks]) self.ax1.set_yticklabels([f'{abs(v)} vpp' for v in ticks]) self.ax1.grid(True) self.ax1.set_ylabel('Voltage [uv]') self.ax1.set_xlabel('Time [$s$]') self.draw() self.pipeline_output = self.pipeline_input ######################################################################## class AddMarkers(ta.TimelockSeries): """""" # ---------------------------------------------------------------------- def __init__(self, height, *args, **kwargs): """Constructor""" super().__init__(height, *args, **kwargs) # Create grid plot gs = self.figure.add_gridspec(4, 1) self.ax1 = gs.figure.add_subplot(gs[0:-1, :]) self.ax2 = gs.figure.add_subplot(gs[-1, :]) self.ax2.get_yaxis().set_visible(False) # self.figure.subplots_adjust(left=0.05, # bottom=0.12, # right=0.95, # top=0.8, # wspace=None, # hspace=0.6) self.set_window_width_options( ['500 milliseconds', '1 second', '5 second', '15 second', '30 second', '1 minute', '5 minute', '10 minute', '30 minute', '1 hour']) self.markers = self.add_combobox('Marker', [], callback=None, editable=True, area='bottom2', stretch=3) self.add_button('Add marker', callback=self.add_marker, area='bottom2', stretch=0) self.add_spacer(area='bottom2', stretch=10) # self.database_description = self.add_textarea( # area='right', stretch=0) self.pipeline_tunned = True # ---------------------------------------------------------------------- def add_marker(self): """""" q = np.mean(self.ax1.get_xlim()) self.ax1.vlines(q, * self.ax1.get_ylim(), linestyle='--', color='red', linewidth=5, zorder=99) self.ax2.vlines(q, * self.ax2.get_ylim(), linestyle='--', color='red', linewidth=3, zorder=99) # self.ax1.fill_between([q - 1, q + 1], *self.ax1.get_ylim(), # linewidth=0, color='red', zorder=99, alpha=0.2) # self.ax2.fill_between([q - 1, q + 1], *self.ax2.get_ylim(), # linewidth=0, color='red', zorder=99, alpha=0.2) self.draw() # ---------------------------------------------------------------------- def fit(self): """""" datafile = self.pipeline_input markers = ['BAD', 'BLINK'] markers += sorted(list(datafile.markers.keys())) self.markers.clear() self.markers.addItems(markers) header = datafile.header eeg = datafile.eeg timestamp = datafile.timestamp eeg = decimate(eeg, 15, axis=1) timestamp = np.linspace( 0, timestamp[0][-1], eeg.shape[1], endpoint=True) / 1000 # eeg = eeg / 1000 self.threshold = 150 channels = eeg.shape[0] self.set_data(timestamp, eeg, labels=list(header['channels'].values()), ylabel='Millivolt [$mv$]', xlabel='Time [$s$]') self.ax1.set_yticks([self.threshold * i for i in range(channels)]) self.ax1.set_yticklabels( self.pipeline_input.header['channels'].values()) self.ax1.set_ylim(-self.threshold, self.threshold * channels) self.ax2.set_ylim(-self.threshold, self.threshold * channels) self.vlines = self.ax1.vlines(np.mean(self.ax1.get_xlim()), * self.ax1.get_ylim(), linestyle='--', color='red', linewidth=2, zorder=99) self.draw() datafile.close() self.pipeline_tunned = True self.pipeline_output = self.pipeline_input # ---------------------------------------------------------------------- def set_data(self, timestamp, eeg, labels, ylabel='', xlabel=''): """""" self.ax1.clear() self.ax2.clear() for i, ch in enumerate(eeg): self.ax1.plot(timestamp, ch + self.threshold * i, label=labels[i]) self.ax2.plot(timestamp, ch + self.threshold * i, alpha=0.5) self.ax1.grid(True) self.ax1.legend(loc='upper center', ncol=8, labelcolor='k', bbox_to_anchor=(0.5, 1.4)) self.ax1.set_xlim(0, self.window_value) self.ax2.grid(True) self.ax2.set_xlim(0, timestamp[-1]) self.area = self.ax2.fill_between([0, self.window_value], *self.ax1.get_ylim(), color=self.fill_color, alpha=self.fill_opacity, label='AREA') self.scroll.setMaximum((timestamp[-1] - self.window_value) * 1000) self.scroll.setMinimum(0) self.ax1.set_ylabel(ylabel) self.ax2.set_xlabel(xlabel) # ---------------------------------------------------------------------- def move_plot(self, value): """""" self.ax1.set_xlim(value / 1000, (value / 1000 + self.window_value)) # self.ax2.collections.clear() # [c for c in self.ax2.collections if c.get_label() == 'AREA'].clear() # [self.ax2.collections.pop(j) for j in [len(self.ax2.collections) - 1 - # i for i, c in enumerate(self.ax2.collections) if c.get_label() == 'AREA']] # self.ax2.collections.pop(self.ax2.collections.index(self.area)) paths = self.area.get_paths() v = paths[0].vertices[:, 0] m, n = v.min(), v.max() v[v == n] = value / 1000 v[v == m] = value / 1000 + self.window_value # self.area.set_paths(paths) segments = self.vlines.get_segments() segments[0][:, 0] = [np.mean(self.ax1.get_xlim())] * 2 self.vlines.set_segments(segments) self.draw() # ---------------------------------------------------------------------- def change_window(self): """""" self.window_value = self._get_seconds_from_human( self.combobox.currentText()) eeg = self.pipeline_output.eeg timestamp = self.pipeline_output.timestamp timestamp = np.linspace( 0, timestamp[0][-1], eeg.shape[1], endpoint=True) / 1000 self.scroll.setMaximum((timestamp[-1] - self.window_value) * 1000) self.scroll.setMinimum(0) self.scroll.setPageStep(self.window_value * 1000) self.ax1.set_xlim(self.scroll.value() / 1000, (self.scroll.value() / 1000 + self.window_value)) # self.ax2.collections.clear() # self.ax2.fill_between([self.scroll.value() / 1000, (self.scroll.value() + self.window_value) / 1000], # *self.ax1.get_ylim(), # color=self.fill_color, # alpha=self.fill_opacity) paths = self.area.get_paths() v = paths[0].vertices[:, 0] m, n = v.min(), v.max() v[v == n] = self.scroll.value() / 1000 v[v == m] = self.scroll.value() / 1000 + self.window_value self.draw()