import os
import sys
import logging
import copy
from abc import ABCMeta, abstractmethod
import mne
import numpy as np
# from scipy.fftpack import rfft, rfftfreq
from scipy.signal import welch, decimate
from scipy.signal import decimate, welch
from cycler import cycler
import matplotlib
from matplotlib import pyplot
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg
from matplotlib.figure import Figure
from PySide6.QtCore import Qt
from PySide6 import QtWidgets
from PySide6.QtUiTools import QUiLoader
from PySide6.QtWidgets import QSpacerItem, QSizePolicy
from gcpds.filters import frequency as flt
from gcpds.filters import frequency as flt
from bci_framework.extensions.timelock_analysis import timelock_analysis as ta
from bci_framework.framework.dialogs import Dialogs
# Set logger
logger = logging.getLogger("mne")
logger.setLevel(logging.CRITICAL)
logging.getLogger('matplotlib.font_manager').disabled = True
logging.getLogger().setLevel(logging.WARNING)
logging.root.name = "TimelockAnalysis"
########################################################################
[docs]class Canvas(FigureCanvasQTAgg):
# ----------------------------------------------------------------------
def __init__(self, *args, **kwargs):
""""""
# Consigure matplotlib
self.configure()
self.figure = Figure(*args, **kwargs)
super().__init__(self.figure)
# self.figure.subplots_adjust(left=0.1, right=0.9, top=0.9, bottom=0.1)
# ----------------------------------------------------------------------
def configure(self):
""""""
if ('light' in sys.argv) or ('light' in os.environ.get('QTMATERIAL_THEME', '')):
pass
else:
pyplot.style.use('dark_background')
try:
q = matplotlib.cm.get_cmap('cool')
matplotlib.rcParams['axes.prop_cycle'] = cycler(
color=[q(m) for m in np.linspace(0, 1, 16)])
matplotlib.rcParams['figure.dpi'] = 70
matplotlib.rcParams['font.family'] = 'monospace'
matplotlib.rcParams['font.size'] = 15
# matplotlib.rcParams['legend.facecolor'] = 'red'
except:
# 'rcParams' object does not support item assignment
pass
########################################################################
class TimelockWidget(metaclass=ABCMeta):
""""""
# ----------------------------------------------------------------------
def __init__(self, height, *args, **kwargs):
"""Constructor"""
# self.fill_opacity = 0.2
# self.fill_color = os.environ.get('QTMATERIAL_PRIMARYCOLOR', '#ff0000')
self.bottom_stretch = []
self.bottom2_stretch = []
self.top_stretch = []
self.top2_stretch = []
self.right_stretch = []
self.left_stretch = []
self._pipeline_output = None
ui = os.path.realpath(os.path.join(
os.environ['BCISTREAM_ROOT'], 'framework', 'qtgui', 'locktime_widget.ui'))
self.widget = QUiLoader().load(ui)
# self.widget.setProperty('class', 'bottom_border')
if height:
self.widget.setMinimumHeight(height)
self.canvas = Canvas(*args, **kwargs)
self.figure = self.canvas.figure
self.widget.gridLayout.addWidget(self.canvas)
# ----------------------------------------------------------------------
def draw(self):
""""""
self.canvas.configure()
self.canvas.draw()
# ----------------------------------------------------------------------
def _add_spacers(self):
""""""
for i, s in enumerate(self.bottom_stretch):
self.widget.bottomLayout.setStretch(i, s)
for i, s in enumerate(self.top_stretch):
self.widget.topLayout.setStretch(i, s)
for i, s in enumerate(self.bottom2_stretch):
self.widget.bottom2Layout.setStretch(i, s)
for i, s in enumerate(self.top2_stretch):
self.widget.top2Layout.setStretch(i, s)
for i, s in enumerate(self.right_stretch):
self.widget.rightLayout.setStretch(i, s)
for i, s in enumerate(self.left_stretch):
self.widget.leftLayout.setStretch(i, s)
# ----------------------------------------------------------------------
def add_spacer(self, area='top', fixed=None, stretch=0):
""""""
if fixed:
if area in ['left', 'right']:
getattr(self.widget, f'{area}Layout').addItem(QSpacerItem(
20, fixed, QSizePolicy.Minimum, QSizePolicy.Minimum))
elif area in ['top', 'bottom', 'top2', 'bottom2']:
getattr(self.widget, f'{area}Layout').addItem(QSpacerItem(
fixed, 20, QSizePolicy.Minimum, QSizePolicy.Minimum))
else:
if area in ['left', 'right']:
getattr(self.widget, f'{area}Layout').addItem(QSpacerItem(
20, 20000, QSizePolicy.Minimum, QSizePolicy.Expanding))
elif area in ['top', 'bottom', 'top2', 'bottom2']:
getattr(self.widget, f'{area}Layout').addItem(QSpacerItem(
20000, 20, QSizePolicy.Expanding, QSizePolicy.Minimum))
if stretch:
getattr(self, f'{area}_stretch').append(stretch)
# ----------------------------------------------------------------------
def clear_layout(self, layout):
""""""
for i in range(layout.count()):
b = layout.itemAt(i)
if b is None:
continue
if w := b.widget(): # widget
w.deleteLater()
if b.spacerItem(): # spacer
layout.removeItem(b)
if l := b.layout():
self.clear_layout(l)
# layout.removeItem(layout.itemAt(i))
# b = layout.takeAt(2)
# buttons.pop(2)
# b.widget().deleteLater()
# ----------------------------------------------------------------------
def clear_widgets(self):
""""""
for area in ['left', 'right', 'top', 'bottom', 'top2', 'bottom2']:
layout = getattr(self.widget, f'{area}Layout')
self.clear_layout(layout)
# ----------------------------------------------------------------------
def add_textarea(self, content='', area='top', stretch=0):
""""""
textarea = QtWidgets.QTextEdit(content)
textarea.setProperty('class', 'clear')
textarea.setMinimumWidth(500)
textarea.setReadOnly(True)
# if callback:
# button.clicked.connect(callback)
getattr(self.widget, f'{area}Layout').addWidget(textarea)
getattr(self, f'{area}_stretch').append(stretch)
return textarea
# ----------------------------------------------------------------------
def add_button(self, label, callback=None, area='top', stretch=0):
""""""
button = QtWidgets.QPushButton(label)
if callback:
button.clicked.connect(callback)
getattr(self.widget, f'{area}Layout').addWidget(button)
getattr(self, f'{area}_stretch').append(stretch)
return button
# ----------------------------------------------------------------------
def add_radios(self, group_name, radios, cols=None, callback=None, area='top', stretch=1):
""""""
group = QtWidgets.QGroupBox(group_name)
group.setProperty('class', 'fill_background')
vbox = QtWidgets.QVBoxLayout()
group.setLayout(vbox)
if cols is None:
cols = len(radios)
for i, radio in enumerate(radios):
if (i % cols) == 0:
hbox = QtWidgets.QHBoxLayout()
vbox.addLayout(hbox)
# group.setLayout(hbox)
r = QtWidgets.QRadioButton()
r.setText(radio)
r.setChecked(i == 0)
# ----------------------------------------------------------------------
def dec(*args):
# ----------------------------------------------------------------------
def wrap(fn):
return callback(*args)
return wrap
if callback:
r.clicked.connect(dec(group_name, radio))
hbox.addWidget(r)
getattr(self.widget, f'{area}Layout').addWidget(group)
getattr(self, f'{area}_stretch').append(stretch)
# ----------------------------------------------------------------------
def add_checkbox(self, group_name, radios, ncol=None, callback=None, area='top', stretch=1):
""""""
group = QtWidgets.QGroupBox(group_name)
group.setProperty('class', 'fill_background')
vbox = QtWidgets.QVBoxLayout()
group.setLayout(vbox)
if ncol is None:
ncol = len(radios)
list_radios = []
for i, radio in enumerate(radios):
if (i % ncol) == 0:
hbox = QtWidgets.QHBoxLayout()
vbox.addLayout(hbox)
# group.setLayout(hbox)
r = QtWidgets.QCheckBox()
r.setText(radio)
r.setChecked(i == 0)
list_radios.append(r)
# ----------------------------------------------------------------------
def dec(*args):
# ----------------------------------------------------------------------
def wrap(fn):
return callback(*args)
return wrap
if callback:
r.clicked.connect(dec(group_name, radio))
hbox.addWidget(r)
getattr(self.widget, f'{area}Layout').addWidget(group)
getattr(self, f'{area}_stretch').append(stretch)
return list_radios
# ----------------------------------------------------------------------
def add_channels(self, group_name, radios, callback=None, area='top', stretch=1):
""""""
group = QtWidgets.QGroupBox(group_name)
group.setProperty('class', 'fill_background')
vbox = QtWidgets.QHBoxLayout()
group.setLayout(vbox)
# ncol = len(radios)
vbox_odd = QtWidgets.QVBoxLayout()
vbox_z = QtWidgets.QVBoxLayout()
vbox_even = QtWidgets.QVBoxLayout()
vbox.addLayout(vbox_even)
vbox.addLayout(vbox_z)
vbox.addLayout(vbox_odd)
list_radios = []
for radio in radios:
r = QtWidgets.QCheckBox()
r.setText(radio)
r.setChecked(True)
list_radios.append(r)
if radio[-1].isnumeric() and int(radio[-1]) % 2 != 0: # odd
vbox_even.addWidget(r)
elif radio[-1].isnumeric() and int(radio[-1]) % 2 == 0: # even
vbox_odd.addWidget(r)
else:
vbox_z.addWidget(r)
def dec(*args):
def wrap(fn):
return callback(*args)
return wrap
if callback:
r.clicked.connect(dec(group_name, radio))
getattr(self.widget, f'{area}Layout').addWidget(group)
getattr(self, f'{area}_stretch').append(stretch)
return list_radios
# ----------------------------------------------------------------------
def add_scroll(self, callback=None, area='bottom', stretch=0):
""""""
scroll = QtWidgets.QScrollBar()
scroll.setOrientation(Qt.Horizontal)
# scroll.setMaximum(255)
scroll.sliderMoved.connect(callback)
scroll.setProperty('class', 'big')
# scroll.setPageStep(1000)
getattr(self.widget, f'{area}Layout').addWidget(scroll)
getattr(self, f'{area}_stretch').append(stretch)
return scroll
# ----------------------------------------------------------------------
def add_slider(self, callback=None, area='bottom', stretch=0):
""""""
slider = QtWidgets.QSlider()
slider.setOrientation(Qt.Horizontal)
slider.setMaximum(0)
slider.setMaximum(500)
slider.setValue(500)
slider.valueChanged.connect(callback)
getattr(self.widget, f'{area}Layout').addWidget(slider)
getattr(self, f'{area}_stretch').append(stretch)
return slider
# ----------------------------------------------------------------------
def add_spin(self, label, value, decimals=1, step=0.1, prefix='', suffix='', min_=0, max_=999, callback=None, area='top', stretch=0):
""""""
spin = QtWidgets.QDoubleSpinBox()
spin.setDecimals(decimals)
spin.setSingleStep(step)
spin.setMinimum(min_)
spin.setMaximum(max_)
spin.setValue(value)
if callback:
spin.valueChanged.connect(callback)
if prefix:
spin.setPrefix(f' {prefix}')
if suffix:
spin.setSuffix(f' {suffix}')
layout = QtWidgets.QHBoxLayout()
widget = QtWidgets.QWidget()
widget.setLayout(layout)
if label:
layout.addWidget(QtWidgets.QLabel(label))
layout.addWidget(spin)
getattr(self.widget, f'{area}Layout').addWidget(widget)
getattr(self, f'{area}_stretch').append(stretch)
layout.setStretch(0, 0)
layout.setStretch(1, 1)
return spin
# ----------------------------------------------------------------------
def add_combobox(self, label, items, editable=False, callback=None, area='top', stretch=0):
""""""
combo = QtWidgets.QComboBox()
combo.addItems(items)
combo.activated.connect(callback)
combo.setEditable(editable)
layout = QtWidgets.QHBoxLayout()
widget = QtWidgets.QWidget()
widget.setLayout(layout)
if label:
layout.addWidget(QtWidgets.QLabel(label))
layout.addWidget(combo)
getattr(self.widget, f'{area}Layout').addWidget(widget)
getattr(self, f'{area}_stretch').append(stretch)
layout.setStretch(0, 0)
layout.setStretch(1, 1)
return combo
# ----------------------------------------------------------------------
# @abstractmethod
@property
def pipeline_input(self):
""""""
if hasattr(self, '_previous_pipeline'):
return self._previous_pipeline.pipeline_output
elif hasattr(self, '_pipeline_input'):
return self._pipeline_input
else:
logging.warning("'pipeline_input' does not exist yet.")
# ----------------------------------------------------------------------
# @abstractmethod
@pipeline_input.setter
def pipeline_input(self, input_):
""""""
self._pipeline_input = input_
# ----------------------------------------------------------------------
# @abstractmethod
@property
def pipeline_output(self):
""""""
if hasattr(self, '_pipeline_output'):
return self._pipeline_output
# ----------------------------------------------------------------------
# @abstractmethod
@pipeline_output.setter
def pipeline_output(self, output_):
""""""
self._pipeline_output = output_
self._pipeline_propagate()
# ----------------------------------------------------------------------
# @abstractmethod
@property
def pipeline_tunned(self):
""""""
return getattr(self, '_pipeline_tunned', False)
# ----------------------------------------------------------------------
# @abstractmethod
@pipeline_tunned.setter
def pipeline_tunned(self, value):
""""""
self._pipeline_tunned = value
# ----------------------------------------------------------------------
def next_pipeline(self, pipe):
""""""
self._next_pipeline = pipe
# self._next_pipeline._pipeline_input = self._pipeline_output
# ----------------------------------------------------------------------
def previous_pipeline(self, pipe):
""""""
self._previous_pipeline = pipe
# ----------------------------------------------------------------------
def set_pipeline_input(self, in_):
""""""
self._pipeline_input = in_
# ----------------------------------------------------------------------
# @abstractmethod
def _pipeline_propagate(self):
""""""
if hasattr(self, '_next_pipeline'):
if not self._next_pipeline.pipeline_tunned:
return
if next_pipeline := getattr(self, '_next_pipeline', False):
next_pipeline.fit()
# ----------------------------------------------------------------------
@abstractmethod
def fit(self):
""""""
########################################################################
class TimelockSeries(TimelockWidget):
""""""
# ----------------------------------------------------------------------
def __init__(self, height, *args, **kwargs):
"""Constructor"""
super().__init__(height, *args, **kwargs)
self.fill_opacity = 0.2
self.fill_color = os.environ.get(
'QTMATERIAL_PRIMARYCOLOR', '#ff0000')
# ----------------------------------------------------------------------
def move_plot(self, value):
""""""
self.ax1.set_xlim(value / 1000, (value / 1000 + self.window_value))
self.ax2.collections.clear()
self.ax2.fill_between([value / 1000, (value / 1000 + self.window_value)],
*self.ax1.get_ylim(), color=self.fill_color, alpha=self.fill_opacity)
self.draw()
# ----------------------------------------------------------------------
def change_window(self):
""""""
self.window_value = self._get_seconds_from_human(
self.combobox.currentText())
eeg = self.pipeline_output.eeg
timestamp = self.pipeline_output.timestamp
timestamp = np.linspace(
0, timestamp[0][-1], eeg.shape[1], endpoint=True) / 1000
self.scroll.setMaximum((timestamp[-1] - self.window_value) * 1000)
self.scroll.setMinimum(0)
self.scroll.setPageStep(self.window_value * 1000)
self.ax1.set_xlim(self.scroll.value() / 1000,
(self.scroll.value() / 1000 + self.window_value))
self.ax2.collections.clear()
self.ax2.fill_between([self.scroll.value() / 1000, (self.scroll.value() + self.window_value) / 1000],
*self.ax1.get_ylim(),
color=self.fill_color,
alpha=self.fill_opacity)
# paths = self.area.get_paths()
# v = paths[0].vertices[:, 0]
# m, n = v.min(), v.max()
# v[v == n] = self.scroll.value() / 1000
# v[v == m] = self.scroll.value() / 1000 + self.window_value
self.draw()
# ----------------------------------------------------------------------
def _get_seconds_from_human(self, human):
""""""
value = human.replace('milliseconds', '0.001')
value = value.replace('second', '1')
value = value.replace('minute', '60')
value = value.replace('hour', '60 60')
return np.prod(list(map(float, value.split())))
# ----------------------------------------------------------------------
def set_data(self, timestamp, eeg, labels, ylabel='', xlabel=''):
""""""
self.ax1.clear()
self.ax2.clear()
for i, ch in enumerate(eeg):
self.ax1.plot(timestamp, eeg[i], label=labels[i])
self.ax2.plot(timestamp, eeg[i], alpha=0.5)
self.ax1.grid(True)
self.ax1.legend(loc='upper center', ncol=8,
labelcolor='k', bbox_to_anchor=(0.5, 1.4))
self.ax1.set_xlim(0, self.window_value)
self.ax2.grid(True)
self.ax2.set_xlim(0, timestamp[-1])
self.ax2.fill_between([0, self.window_value], *self.ax1.get_ylim(),
color=self.fill_color, alpha=self.fill_opacity)
self.scroll.setMaximum((timestamp[-1] - self.window_value) * 1000)
self.scroll.setMinimum(0)
self.ax1.set_ylabel(ylabel)
self.ax2.set_xlabel(xlabel)
self.draw()
# ----------------------------------------------------------------------
def set_window_width_options(self, options):
""""""
self.scroll = self.add_scroll(
callback=self.move_plot, area='bottom', stretch=1)
self.combobox = self.add_combobox('', options,
callback=self.change_window,
area='bottom',
stretch=0)
self.window_value = self._get_seconds_from_human(options[0])
########################################################################
class TimelockFilters(TimelockWidget):
""""""
# ----------------------------------------------------------------------
def __init__(self, height, *args, **kwargs):
"""Constructor"""
super().__init__(height, *args, **kwargs)
gs = self.figure.add_gridspec(1, 2)
self.ax1 = gs.figure.add_subplot(gs[:, 0:-1])
self.ax2 = gs.figure.add_subplot(gs[:, -1])
# self.ax2.get_yaxis().set_visible(False)
# self.ax1 = self.figure.add_subplot(111)
self.figure.subplots_adjust(left=0.05,
bottom=0.12,
right=0.95,
top=0.8,
wspace=None,
hspace=0.6)
self.filters = {'Notch': 'none',
'Bandpass': 'none',
}
self.notchs = ('none', '50 Hz', '60 Hz')
self.bandpass = ('none', 'delta', 'theta', 'alpha', 'beta',
'0.01-20 Hz',
'5-45 Hz', '3-30 Hz', '4-40 Hz', '2-45 Hz', '1-50 Hz',
'7-13 Hz', '15-50 Hz', '1-100 Hz', '5-50 Hz')
self.add_radios('Notch', self.notchs, callback=self.set_filters,
area='top', stretch=0)
self.add_radios('Bandpass', self.bandpass, callback=self.set_filters,
area='top', stretch=1)
self.scale = self.add_spin('Scale', 150, suffix='uv', min_=0,
max_=1000, step=50, callback=self.fit, area='top',
stretch=0)
# ----------------------------------------------------------------------
def fit(self):
""""""
eeg = self.pipeline_input.original_eeg
timestamp = self.pipeline_input.timestamp
for f in self.filters:
if self.filters[f] != 'none':
eeg = self.filters[f](eeg, fs=1000, axis=1)
self.ax1.clear()
self.ax2.clear()
t = np.linspace(0, eeg.shape[1], eeg.shape[1], endpoint=True) / 1000
channels = eeg.shape[0]
# threshold = max(eeg.max(axis=1) - eeg.min(axis=1)).round()
# threshold = max(eeg.std(axis=1)).round()
threshold = self.scale.value()
# eeg_d = decimate(eeg, 15, axis=1)
# timestamp = np.linspace(
# 0, t[-1], eeg_d.shape[1], endpoint=True)
for i, ch in enumerate(eeg):
self.ax2.plot(t, ch + (threshold * i))
self.ax1.set_xlabel('Frequency [$Hz$]')
self.ax1.set_ylabel('Amplitude')
self.ax2.set_xlabel('Time [$s$]')
self.ax2.set_yticks([threshold * i for i in range(channels)])
self.ax2.set_yticklabels(
self.pipeline_input.header['channels'].values())
self.ax2.set_ylim(-threshold, threshold * channels)
# self.output_signal = eeg
w, spectrum = welch(eeg, fs=1000, axis=1,
nperseg=1024, noverlap=256, average='median')
# spectrum = decimate(spectrum, 15, axis=1)
# w = np.linspace(0, w[-1], spectrum.shape[1])
for i, ch in enumerate(spectrum):
self.ax1.fill_between(w, 0, ch, alpha=0.2, color=f'C{i}')
self.ax1.plot(w, ch, linewidth=2, color=f'C{i}')
self.ax1.set_xscale('log')
self.ax1.set_xlim(0, w[-1])
self.ax2.set_xlim(0, t[-1])
self.ax1.grid(True)
self.ax2.grid(True)
self.draw()
self.pipeline_tunned = True
self._pipeline_output = self.pipeline_input
self._pipeline_output.eeg = eeg.copy()
self._pipeline_propagate()
# ----------------------------------------------------------------------
def set_filters(self, group_name, filter_):
""""""
if filter_ == 'none':
self.filters[group_name] = filter_
else:
if group_name == 'Notch':
filter_ = getattr(flt, f'notch{filter_.replace(" Hz", "")}')
elif group_name == 'Bandpass':
if filter_ in self.bandpass[1:5]:
filter_ = getattr(flt, f'{filter_}')
else:
filter_ = getattr(
flt, f'band{filter_.replace(" Hz", "").replace("-", "").replace(".", "")}')
self.filters[group_name] = filter_
self.fit()
# # ----------------------------------------------------------------------
# @property
# def output(self):
# """"""
# if hasattr(self, 'output_signal'):
# return self.output_signal
########################################################################
class LoadDatabase(ta.TimelockSeries):
""""""
# ----------------------------------------------------------------------
def __init__(self, height=700, *args, **kwargs):
"""Constructor"""
super().__init__(height, *args, **kwargs)
# Create grid plot
gs = self.figure.add_gridspec(4, 4)
self.ax1 = gs.figure.add_subplot(gs[0:-1, :])
self.ax2 = gs.figure.add_subplot(gs[-1, :])
self.ax2.get_yaxis().set_visible(False)
self.figure.subplots_adjust(left=0.05,
bottom=0.12,
right=0.95,
top=0.8,
wspace=None,
hspace=0.6)
self.add_button('Load database',
callback=self.load_database, area='top', stretch=0)
self.add_spacer(area='top')
self.set_window_width_options(['500 milliseconds'])
self.window_options = ['500 milliseconds',
'1 second',
'5 second',
'15 second',
'30 second',
'1 minute',
'5 minute',
'10 minute',
'30 minute',
'1 hour']
self.database_description = self.add_textarea(
area='right', stretch=0)
# ----------------------------------------------------------------------
def load_database(self):
""""""
self.datafile = Dialogs.load_database()
# Set input manually
self.pipeline_input = self.datafile
flt.compile_filters(
FS=self.pipeline_input.header['sample_rate'], N=2, Q=3)
self.fit()
# ----------------------------------------------------------------------
def fit(self):
""""""
datafile = self.pipeline_input
header = datafile.header
eeg = datafile.eeg
timestamp = datafile.timestamp
self.database_description.setText(datafile.description)
eeg = decimate(eeg, 15, axis=1)
timestamp = np.linspace(
0, timestamp[0][-1], eeg.shape[1], endpoint=True) / 1000
eeg = eeg / 1000
options = [self._get_seconds_from_human(
w) for w in self.window_options]
l = len([o for o in options if o < timestamp[-1]])
self.combobox.clear()
self.combobox.addItems(self.window_options[:l])
self.set_data(timestamp, eeg,
labels=list(header['channels'].values()),
ylabel='Millivolt [$mv$]',
xlabel='Time [$s$]')
datafile.close()
self.pipeline_tunned = True
self.pipeline_output = datafile
########################################################################
class EpochsVisualization(ta.TimelockWidget):
""""""
# ----------------------------------------------------------------------
def __init__(self, height=700, *args, **kwargs):
"""Constructor"""
super().__init__(height, *args, **kwargs)
self.ax1 = self.figure.add_subplot(111)
self.pipeline_tunned = True
# ----------------------------------------------------------------------
def fit(self):
""""""
self.clear_widgets()
markers = list(self.pipeline_input.file.markers.keys())
channels = list(self.pipeline_input.header['channels'].values())
self.tmin = self.add_spin('tmin', 0, suffix='s', min_=-99,
max_=99, callback=self.get_epochs, area='top', stretch=0)
self.tmax = self.add_spin(
'tmax', 1, suffix='s', min_=-99, max_=99, callback=self.get_epochs, area='top', stretch=0)
self.method = self.add_combobox(label='Method', items=[
'mean', 'median'], callback=self.get_epochs, area='top', stretch=0)
self.add_spacer(area='top', fixed=50)
self.reject = self.add_spin('Reject', 200, suffix='vpp', min_=0,
max_=500, step=10, callback=self.get_epochs, area='top', stretch=0)
self.flat = self.add_spin('Flat', 10, suffix='vpp', min_=0, max_=500,
step=10, callback=self.get_epochs, area='top', stretch=0)
self.add_spacer(area='top')
self.checkbox = self.add_checkbox(
'Markers', markers, callback=self.get_epochs, area='left', ncol=1, stretch=1)
self.add_spacer(area='left')
self.channels = self.add_channels(
'Channels', channels, callback=self.get_epochs, area='right', stretch=1)
self.add_spacer(area='right')
# ----------------------------------------------------------------------
def get_epochs(self, *args, **kwargs):
""""""
self.figure.clear()
self.ax1 = self.figure.add_subplot(111)
markers = sorted([ch.text()
for ch in self.checkbox if ch.isChecked()])
channels = sorted([ch.text()
for ch in self.channels if ch.isChecked()])
if not markers:
return
if not channels:
return
if self.reject.value() < self.flat.value():
return
epochs = self.pipeline_input.epochs(
tmin=self.tmin.value(), tmax=self.tmax.value(), markers=markers)
reject = {'eeg': self.reject.value()}
flat = {'eeg': self.flat.value()}
epochs.drop_bad(reject, flat)
evokeds = {}
for mk in markers:
erp = epochs[mk].average(
method=self.method.currentText(), picks=channels)
evokeds[mk] = erp
try:
mne.viz.plot_compare_evokeds(evokeds, axes=self.ax1, cmap=(
'Class', 'cool'), show=False, show_sensors=False, invert_y=True, styles={}, split_legend=False, legend='upper center')
except:
pass
self.draw()
self.pipeline_output = epochs
########################################################################
class TimelockAmplitudeAnalysis(ta.TimelockWidget):
""""""
# ----------------------------------------------------------------------
def __init__(self, height, *args, **kwargs):
"""Constructor"""
super().__init__(height, *args, **kwargs)
self.ax1 = self.figure.add_subplot(111)
self.pipeline_tunned = True
# decimates = '10 20 50 100 1000 2000 5000'.split()
# self.decimate = self.add_combobox(
# 'Decimate', decimates, callback=self.fit, area='top', stretch=0)
# self.add_spacer(area='top')
# ----------------------------------------------------------------------
def fit(self):
""""""
datafile = self.pipeline_input
t = datafile.timestamp[0] / 1000 / 60
eeg = datafile.eeg
eeg = eeg - eeg.mean(axis=1)[:, np.newaxis]
mx = eeg.max(axis=0)
mn = eeg.min(axis=0)
m = eeg.mean(axis=0)
self.ax1.clear()
# dc = int(self.decimate.currentText())
dc = 1000
mxd = decimate(mx, dc, n=2)
mnd = decimate(mn, dc, n=2)
md = decimate(m, dc, n=2)
td = decimate(t, dc, n=2)
self.ax1.fill_between(td, mnd, mxd, color='k',
alpha=0.3, linewidth=0)
self.ax1.plot(td, md, color='C0')
vpps = [100, 150, 200, 300, 500, 0]
for i, vpp in enumerate(vpps):
self.ax1.hlines(
vpp / 2, 0, td[-1], linestyle='--', color=pyplot.cm.tab10(i))
if vpp:
self.ax1.hlines(-vpp / 2, 0,
td[-1], linestyle='--', color=pyplot.cm.tab10(i))
self.ax1.set_xlim(0, td[-1])
self.ax1.set_ylim(2 * mn.mean(), 2 * mx.mean())
ticks = sorted(vpps + [-v for v in vpps])
self.ax1.set_yticks([v / 2 for v in ticks])
self.ax1.set_yticklabels([f'{abs(v)} vpp' for v in ticks])
self.ax1.grid(True)
self.ax1.set_ylabel('Voltage [uv]')
self.ax1.set_xlabel('Time [$s$]')
self.draw()
self.pipeline_output = self.pipeline_input
########################################################################
class AddMarkers(ta.TimelockSeries):
""""""
# ----------------------------------------------------------------------
def __init__(self, height, *args, **kwargs):
"""Constructor"""
super().__init__(height, *args, **kwargs)
# Create grid plot
gs = self.figure.add_gridspec(4, 1)
self.ax1 = gs.figure.add_subplot(gs[0:-1, :])
self.ax2 = gs.figure.add_subplot(gs[-1, :])
self.ax2.get_yaxis().set_visible(False)
# self.figure.subplots_adjust(left=0.05,
# bottom=0.12,
# right=0.95,
# top=0.8,
# wspace=None,
# hspace=0.6)
self.set_window_width_options(
['500 milliseconds',
'1 second',
'5 second',
'15 second',
'30 second',
'1 minute',
'5 minute',
'10 minute',
'30 minute',
'1 hour'])
self.markers = self.add_combobox('Marker', [], callback=None, editable=True,
area='bottom2', stretch=3)
self.add_button('Add marker', callback=self.add_marker,
area='bottom2', stretch=0)
self.add_spacer(area='bottom2', stretch=10)
# self.database_description = self.add_textarea(
# area='right', stretch=0)
self.pipeline_tunned = True
# ----------------------------------------------------------------------
def add_marker(self):
""""""
q = np.mean(self.ax1.get_xlim())
self.ax1.vlines(q, * self.ax1.get_ylim(),
linestyle='--', color='red', linewidth=5, zorder=99)
self.ax2.vlines(q, * self.ax2.get_ylim(),
linestyle='--', color='red', linewidth=3, zorder=99)
# self.ax1.fill_between([q - 1, q + 1], *self.ax1.get_ylim(),
# linewidth=0, color='red', zorder=99, alpha=0.2)
# self.ax2.fill_between([q - 1, q + 1], *self.ax2.get_ylim(),
# linewidth=0, color='red', zorder=99, alpha=0.2)
self.draw()
# ----------------------------------------------------------------------
def fit(self):
""""""
datafile = self.pipeline_input
markers = ['BAD', 'BLINK']
markers += sorted(list(datafile.markers.keys()))
self.markers.clear()
self.markers.addItems(markers)
header = datafile.header
eeg = datafile.eeg
timestamp = datafile.timestamp
eeg = decimate(eeg, 15, axis=1)
timestamp = np.linspace(
0, timestamp[0][-1], eeg.shape[1], endpoint=True) / 1000
# eeg = eeg / 1000
self.threshold = 150
channels = eeg.shape[0]
self.set_data(timestamp, eeg,
labels=list(header['channels'].values()),
ylabel='Millivolt [$mv$]',
xlabel='Time [$s$]')
self.ax1.set_yticks([self.threshold * i for i in range(channels)])
self.ax1.set_yticklabels(
self.pipeline_input.header['channels'].values())
self.ax1.set_ylim(-self.threshold, self.threshold * channels)
self.ax2.set_ylim(-self.threshold, self.threshold * channels)
self.vlines = self.ax1.vlines(np.mean(self.ax1.get_xlim()),
* self.ax1.get_ylim(), linestyle='--', color='red', linewidth=2, zorder=99)
self.draw()
datafile.close()
self.pipeline_tunned = True
self.pipeline_output = self.pipeline_input
# ----------------------------------------------------------------------
def set_data(self, timestamp, eeg, labels, ylabel='', xlabel=''):
""""""
self.ax1.clear()
self.ax2.clear()
for i, ch in enumerate(eeg):
self.ax1.plot(timestamp, ch + self.threshold
* i, label=labels[i])
self.ax2.plot(timestamp, ch + self.threshold * i, alpha=0.5)
self.ax1.grid(True)
self.ax1.legend(loc='upper center', ncol=8,
labelcolor='k', bbox_to_anchor=(0.5, 1.4))
self.ax1.set_xlim(0, self.window_value)
self.ax2.grid(True)
self.ax2.set_xlim(0, timestamp[-1])
self.area = self.ax2.fill_between([0, self.window_value], *self.ax1.get_ylim(),
color=self.fill_color, alpha=self.fill_opacity, label='AREA')
self.scroll.setMaximum((timestamp[-1] - self.window_value) * 1000)
self.scroll.setMinimum(0)
self.ax1.set_ylabel(ylabel)
self.ax2.set_xlabel(xlabel)
# ----------------------------------------------------------------------
def move_plot(self, value):
""""""
self.ax1.set_xlim(value / 1000, (value / 1000 + self.window_value))
# self.ax2.collections.clear()
# [c for c in self.ax2.collections if c.get_label() == 'AREA'].clear()
# [self.ax2.collections.pop(j) for j in [len(self.ax2.collections) - 1 -
# i for i, c in enumerate(self.ax2.collections) if c.get_label() == 'AREA']]
# self.ax2.collections.pop(self.ax2.collections.index(self.area))
paths = self.area.get_paths()
v = paths[0].vertices[:, 0]
m, n = v.min(), v.max()
v[v == n] = value / 1000
v[v == m] = value / 1000 + self.window_value
# self.area.set_paths(paths)
segments = self.vlines.get_segments()
segments[0][:, 0] = [np.mean(self.ax1.get_xlim())] * 2
self.vlines.set_segments(segments)
self.draw()
# ----------------------------------------------------------------------
def change_window(self):
""""""
self.window_value = self._get_seconds_from_human(
self.combobox.currentText())
eeg = self.pipeline_output.eeg
timestamp = self.pipeline_output.timestamp
timestamp = np.linspace(
0, timestamp[0][-1], eeg.shape[1], endpoint=True) / 1000
self.scroll.setMaximum((timestamp[-1] - self.window_value) * 1000)
self.scroll.setMinimum(0)
self.scroll.setPageStep(self.window_value * 1000)
self.ax1.set_xlim(self.scroll.value() / 1000,
(self.scroll.value() / 1000 + self.window_value))
# self.ax2.collections.clear()
# self.ax2.fill_between([self.scroll.value() / 1000, (self.scroll.value() + self.window_value) / 1000],
# *self.ax1.get_ylim(),
# color=self.fill_color,
# alpha=self.fill_opacity)
paths = self.area.get_paths()
v = paths[0].vertices[:, 0]
m, n = v.min(), v.max()
v[v == n] = self.scroll.value() / 1000
v[v == m] = self.scroll.value() / 1000 + self.window_value
self.draw()