Source code for bciflow.modules.sf.csp
import numpy as np
import scipy as sp
[docs]
class csp:
'''
Attributes
----------
n_electrodes :
(int) The number of electrodes.
m_pairs :
(int) The number of pairs of spatial filters to extract (default is 2).
W :
(np.ndarray) The spatial filters.
bands :
(int) The number of bands used.
Methods
-------
fit(eegdata: dict) -> np.ndarray
Fits the CSP filter to the input data, calculating the spatial filters.
transform(eegdata: dict) -> dict
Applies the learned spatial filters to the input data.
fit_transform(eegdata: dict) -> dict
Combines fitting and transforming into a single step.
Example
-------
>>> from bciflow.modules.sf.csp import csp
>>> import numpy as np
>>> csp_filter = csp(m_pairs=2)
>>> eegdata = {
'X': np.random.rand(100, 5, 64), # 100 samples, 5 bands, 64 electrodes
'y': np.random.randint(0, 2, size=100) # Binary classes
}
>>> csp_filter.fit(eegdata)
>>> transformed_data = csp_filter.transform(eegdata)
>>> print(transformed_data['X'].shape)
'''
n_electrodes: int = None
m_pairs: int = 2
W: np.ndarray = None
bands: int = None
def __init__(self, m_pairs: int = 2):
if type(m_pairs) != int or m_pairs <= 0:
raise ValueError("Must be a positive integer")
else:
self.m_pairs = m_pairs
[docs]
def fit(self, eegdata: dict) -> np.ndarray:
'''
Fits the CSP filter to the input data, calculating the spatial filters.
Parameters
----------
eegdata : dict
The input data containing 'X' (features) and 'y' (labels).
Returns
-------
self : csp
The fitted CSP object with spatial filters stored in W.
Raises
------
ValueError
If any of the input parameters are invalid
'''
X = None
y = None
if type(eegdata['X']) != np.ndarray:
raise ValueError("data must be an array-like object")
else:
X = eegdata['X'].copy()
if type(eegdata['y']) != np.ndarray:
raise ValueError("data must be an array-like object")
else:
y = eegdata['y'].copy()
self.bands = X.shape[1]
self.n_electrodes = X.shape[2]
self.W = np.zeros((self.bands, self.n_electrodes, self.n_electrodes))
# unique values of y
y_unique = np.unique(y)
if len(y_unique) != 2:
raise ValueError("y must have exactly two unique classes.")
sigma = np.zeros((len(y_unique), self.bands, self.n_electrodes, self.n_electrodes))
for i in range(len(y)):
for band_ in range(self.bands):
if y[i] == y_unique[0]:
sigma[0, band_] += (X[i, band_] @ X[i, band_].T)
else:
sigma[1, band_] += (X[i, band_] @ X[i, band_].T)
for band_ in range(self.bands):
sigma[0, band_] /= np.sum(y == y_unique[0])
sigma[1, band_] /= np.sum(y == y_unique[1])
sigma_tot = np.zeros((self.bands, self.n_electrodes, self.n_electrodes))
for band_ in range(self.bands):
sigma_tot[band_] = sigma[0, band_] + sigma[1, band_]
W = np.zeros((self.bands, self.n_electrodes, self.n_electrodes))
for band_ in range(self.bands):
try:
_, W[band_] = sp.linalg.eigh(sigma[0, band_], sigma_tot[band_])
except:
W[band_] = np.eye(self.n_electrodes)
W = np.array(W)
first_aux = W[:, :, :self.m_pairs]
last_aux = W[:, :, -self.m_pairs:]
self.W = np.concatenate((first_aux, last_aux), axis=2)
return self
[docs]
def transform(self, eegdata: dict) -> dict:
'''
Applies the learned spatial filters to the input data.
Parameters
----------
eegdata : dict
The input data containing 'X' (features).
Returns
-------
eegdata : dict
The transformed data with 'X' containing the filtered features.
Raises
------
ValueError
If the input data does not match the expected format or dimensions.
'''
X = None
y = None
if type(eegdata['X']) != np.ndarray:
raise ValueError("data must be an array-like object")
else:
X = eegdata['X'].copy()
if len(X.shape) == 3:
X = np.expand_dims(X, axis=1)
if X.shape[1] != self.bands:
raise ValueError("The number of bands in the input data is different from the number of bands in the fitted data.")
X = [np.transpose(self.W[band_]) @ X[:, band_] for band_ in range(self.bands)]
X = np.swapaxes(np.array(X), 0, 1)
eegdata['X'] = X
return eegdata
#X_ = [np.transpose(self.W) @ X_[i] for i in range(len(X_))]
[docs]
def fit_transform(self, eegdata: dict) -> dict:
'''
Combines fitting and transforming into a single step.
Parameters
----------
eegdata : dict
The input data containing 'X' (features) and 'y' (labels).
Returns
-------
eegdata : dict
The transformed data with 'X' containing the filtered features.
Raises
------
ValueError
If the input data does not match the expected format or dimensions.
'''
X = None
y = None
if type(eegdata['X']) != np.ndarray:
raise ValueError("data must be an array-like object")
else:
X = eegdata['X'].copy()
if type(eegdata['y']) != np.ndarray:
raise ValueError("data must be an array-like object")
else:
y = eegdata['y'].copy()
return self.fit(eegdata).transform(eegdata)