TrainningΒΆ

[4]:
import os

import numpy as np
import mne
import joblib

from scipy.signal import filtfilt, butter
from mne.decoding import CSP

from sklearn.svm import SVC
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.model_selection import GridSearchCV, train_test_split

from openbci_stream.utils.hdf5 import HDF5Reader
from gcpds.mi import CSPEpochsFilterExtractor
[6]:
with HDF5Reader(os.path.join('databases', 'motor_imagery-01_19_22-20_07_03.h5')) as reader:
    print(reader)

    eeg = reader.eeg
    aux = reader.aux
    timestamp = reader.timestamp
    aux_timestamp = reader.aux_timestamp
    header = reader.header
    markers = reader.markers
    fs = reader.header['sample_rate']
==========
databases/motor_imagery-01_19_22-20_07_03.h5
2022-01-19 20:07:03.725021
==========
MARKERS: ['Up', 'Right', 'Left', 'Bottom']
SAMPLE_RATE: 1000
STREAMING_SAMPLE_RATE: 1000
DATETIME: 1642640823.725021
MONTAGE: standard_1020
CHANNELS: {1: 'Fp1', 2: 'Fp2', 3: 'F7', 4: 'F3', 5: 'Fz', 6: 'F4', 7: 'F8', 8: 'T7', 9: 'C3', 10: 'Cz', 11: 'C4', 12: 'T8', 13: 'P7', 14: 'P3', 15: 'P8', 16: 'P4'}
CHANNELS_BY_BOARD: [16]
SHAPE: [16, 1145997]
==========
[7]:
data, classes = reader.get_data(tmax=4, tmin=-0.5, markers=['Right', 'Left', 'Bottom', 'Up'])
data.shape  # trials, channels, time
[7]:
(180, 16, 4500)
[8]:
Xtrain, Xtest, ytrain, ytest = train_test_split(data, classes, test_size=0.3, random_state=123)
Xtrain.shape
[8]:
(126, 16, 4500)
[11]:
# arreglo con bandas de interes alpha y beta
f_frec = np.array([[1, 4], [4, 8], [8, 12], [12, 30]])
csp = CSPEpochsFilterExtractor(fs=fs, f_frec=f_frec, ncomp=data.shape[1], reg='empirical')  # construir modelo csp

# se crean listas de steps (pipeline)
steps = [('c_', csp),
         ('nor', StandardScaler()),
         ('cla', SVC(kernel='rbf')),
         ]  # Pipeline1

parameters = {'cla__C': [1, 10, 1e2, 1e3, 1e4],
              # Pipeline1 - hyperparametros
              'cla__gamma': [1e-3, 1e-2, 1e-1, 1, 10]
              }

label_models = ['StandarScaler_SVCrbf']

grid_search = GridSearchCV(Pipeline(steps), parameters, n_jobs=-1, cv=10, scoring='accuracy', refit=True, verbose=10)
grid_search.fit(Xtrain, ytrain)
Fitting 10 folds for each of 25 candidates, totalling 250 fits
[11]:
GridSearchCV(cv=10,
             estimator=Pipeline(steps=[('c_',
                                        CSPEpochsFilterExtractor(f_frec=array([[ 1,  4],
       [ 4,  8],
       [ 8, 12],
       [12, 30]]),
                                                                 fs=1000,
                                                                 ncomp=16)),
                                       ('nor', StandardScaler()),
                                       ('cla', SVC())]),
             n_jobs=-1,
             param_grid={'cla__C': [1, 10, 100.0, 1000.0, 10000.0],
                         'cla__gamma': [0.001, 0.01, 0.1, 1, 10]},
             scoring='accuracy', verbose=10)
[14]:
joblib.dump(grid_search, os.path.join('models', 'model_yn.pkl'))
[14]:
['models/model_yn.pkl']
[16]:
#test

model = joblib.load(os.path.join('models', 'model_yn.pkl'))
input_ = np.random.normal(0, 2, size=(1, 16, 4500))#.reshape(1, -1)
input_.shape
[16]:
(1, 16, 4500)
[17]:
 model.predict(input_) # 1, 2, 3, 4
[17]:
array([4])
[CV 9/10; 1/25] START cla__C=1, cla__gamma=0.001................................
[CV 9/10; 1/25] END .cla__C=1, cla__gamma=0.001;, score=0.250 total time=  24.8s
[CV 4/10; 3/25] START cla__C=1, cla__gamma=0.1..................................
[CV 4/10; 3/25] END ...cla__C=1, cla__gamma=0.1;, score=0.308 total time=  23.6s
[CV 5/10; 4/25] START cla__C=1, cla__gamma=1....................................
[CV 5/10; 4/25] END .....cla__C=1, cla__gamma=1;, score=0.308 total time=  23.8s
[CV 7/10; 5/25] START cla__C=1, cla__gamma=10...................................
[CV 7/10; 5/25] END ....cla__C=1, cla__gamma=10;, score=0.250 total time=  24.3s
[CV 9/10; 6/25] START cla__C=10, cla__gamma=0.001...............................
[CV 9/10; 6/25] END cla__C=10, cla__gamma=0.001;, score=0.250 total time=  24.9s
[CV 2/10; 8/25] START cla__C=10, cla__gamma=0.1.................................
[CV 2/10; 8/25] END ..cla__C=10, cla__gamma=0.1;, score=0.308 total time=  23.0s
[CV 3/10; 9/25] START cla__C=10, cla__gamma=1...................................
[CV 3/10; 9/25] END ....cla__C=10, cla__gamma=1;, score=0.308 total time=  23.0s
[CV 5/10; 10/25] START cla__C=10, cla__gamma=10.................................
[CV 5/10; 10/25] END ..cla__C=10, cla__gamma=10;, score=0.308 total time=  22.6s
[CV 7/10; 11/25] START cla__C=100.0, cla__gamma=0.001...........................
[CV 7/10; 11/25] END cla__C=100.0, cla__gamma=0.001;, score=0.333 total time=  23.5s
[CV 9/10; 12/25] START cla__C=100.0, cla__gamma=0.01............................
[CV 9/10; 12/25] END cla__C=100.0, cla__gamma=0.01;, score=0.333 total time=  24.9s
[CV 1/10; 14/25] START cla__C=100.0, cla__gamma=1...............................
[CV 1/10; 14/25] END cla__C=100.0, cla__gamma=1;, score=0.308 total time=  23.8s
[CV 3/10; 15/25] START cla__C=100.0, cla__gamma=10..............................
[CV 3/10; 15/25] END cla__C=100.0, cla__gamma=10;, score=0.308 total time=  24.1s
[CV 5/10; 16/25] START cla__C=1000.0, cla__gamma=0.001..........................
[CV 5/10; 16/25] END cla__C=1000.0, cla__gamma=0.001;, score=0.308 total time=  24.4s
[CV 7/10; 17/25] START cla__C=1000.0, cla__gamma=0.01...........................
[CV 7/10; 17/25] END cla__C=1000.0, cla__gamma=0.01;, score=0.417 total time=  25.2s
[CV 9/10; 18/25] START cla__C=1000.0, cla__gamma=0.1............................
[CV 9/10; 18/25] END cla__C=1000.0, cla__gamma=0.1;, score=0.250 total time=  24.6s
[CV 1/10; 20/25] START cla__C=1000.0, cla__gamma=10.............................
[CV 1/10; 20/25] END cla__C=1000.0, cla__gamma=10;, score=0.308 total time=  24.5s
[CV 3/10; 21/25] START cla__C=10000.0, cla__gamma=0.001.........................
[CV 3/10; 21/25] END cla__C=10000.0, cla__gamma=0.001;, score=0.231 total time=  24.5s
[CV 5/10; 22/25] START cla__C=10000.0, cla__gamma=0.01..........................
[CV 5/10; 22/25] END cla__C=10000.0, cla__gamma=0.01;, score=0.385 total time=  24.9s
[CV 7/10; 23/25] START cla__C=10000.0, cla__gamma=0.1...........................
[CV 7/10; 23/25] END cla__C=10000.0, cla__gamma=0.1;, score=0.250 total time=  24.5s
[CV 9/10; 24/25] START cla__C=10000.0, cla__gamma=1.............................
[CV 9/10; 24/25] END cla__C=10000.0, cla__gamma=1;, score=0.250 total time=  23.9s
[CV 7/10; 1/25] START cla__C=1, cla__gamma=0.001................................
[CV 7/10; 1/25] END .cla__C=1, cla__gamma=0.001;, score=0.250 total time=  24.9s
[CV 3/10; 3/25] START cla__C=1, cla__gamma=0.1..................................
[CV 3/10; 3/25] END ...cla__C=1, cla__gamma=0.1;, score=0.308 total time=  23.7s
[CV 6/10; 4/25] START cla__C=1, cla__gamma=1....................................
[CV 6/10; 4/25] END .....cla__C=1, cla__gamma=1;, score=0.308 total time=  23.9s
[CV 8/10; 5/25] START cla__C=1, cla__gamma=10...................................
[CV 8/10; 5/25] END ....cla__C=1, cla__gamma=10;, score=0.250 total time=  24.2s
[CV 10/10; 6/25] START cla__C=10, cla__gamma=0.001..............................
[CV 10/10; 6/25] END cla__C=10, cla__gamma=0.001;, score=0.000 total time=  24.8s
[CV 1/10; 8/25] START cla__C=10, cla__gamma=0.1.................................
[CV 1/10; 8/25] END ..cla__C=10, cla__gamma=0.1;, score=0.308 total time=  23.6s
[CV 4/10; 9/25] START cla__C=10, cla__gamma=1...................................
[CV 4/10; 9/25] END ....cla__C=10, cla__gamma=1;, score=0.308 total time=  22.6s
[CV 6/10; 10/25] START cla__C=10, cla__gamma=10.................................
[CV 6/10; 10/25] END ..cla__C=10, cla__gamma=10;, score=0.308 total time=  23.9s
[CV 8/10; 11/25] START cla__C=100.0, cla__gamma=0.001...........................
[CV 8/10; 11/25] END cla__C=100.0, cla__gamma=0.001;, score=0.333 total time=  24.5s
[CV 10/10; 12/25] START cla__C=100.0, cla__gamma=0.01...........................
[CV 10/10; 12/25] END cla__C=100.0, cla__gamma=0.01;, score=0.000 total time=  24.4s
[CV 2/10; 14/25] START cla__C=100.0, cla__gamma=1...............................
[CV 2/10; 14/25] END cla__C=100.0, cla__gamma=1;, score=0.308 total time=  24.6s
[CV 4/10; 15/25] START cla__C=100.0, cla__gamma=10..............................
[CV 4/10; 15/25] END cla__C=100.0, cla__gamma=10;, score=0.308 total time=  24.8s
[CV 6/10; 16/25] START cla__C=1000.0, cla__gamma=0.001..........................
[CV 6/10; 16/25] END cla__C=1000.0, cla__gamma=0.001;, score=0.231 total time=  24.6s
[CV 8/10; 17/25] START cla__C=1000.0, cla__gamma=0.01...........................
[CV 8/10; 17/25] END cla__C=1000.0, cla__gamma=0.01;, score=0.417 total time=  25.0s
[CV 10/10; 18/25] START cla__C=1000.0, cla__gamma=0.1...........................
[CV 10/10; 18/25] END cla__C=1000.0, cla__gamma=0.1;, score=0.250 total time=  25.0s
[CV 2/10; 20/25] START cla__C=1000.0, cla__gamma=10.............................
[CV 2/10; 20/25] END cla__C=1000.0, cla__gamma=10;, score=0.308 total time=  24.9s
[CV 4/10; 21/25] START cla__C=10000.0, cla__gamma=0.001.........................
[CV 4/10; 21/25] END cla__C=10000.0, cla__gamma=0.001;, score=0.538 total time=  24.3s
[CV 6/10; 22/25] START cla__C=10000.0, cla__gamma=0.01..........................
[CV 6/10; 22/25] END cla__C=10000.0, cla__gamma=0.01;, score=0.308 total time=  25.2s
[CV 8/10; 23/25] START cla__C=10000.0, cla__gamma=0.1...........................
[CV 8/10; 23/25] END cla__C=10000.0, cla__gamma=0.1;, score=0.250 total time=  24.7s
[CV 10/10; 24/25] START cla__C=10000.0, cla__gamma=1............................
[CV 10/10; 24/25] END cla__C=10000.0, cla__gamma=1;, score=0.250 total time=  24.0s