Coverage for mldsutils\classification\_classifiers.py: 100%
13 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-24 21:06 +0200
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-24 21:06 +0200
1from sklearn.base import BaseEstimator, ClassifierMixin
2from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
3from sklearn.utils.validation import check_array
6class ShrinkageLDAClassifier(BaseEstimator, ClassifierMixin):
7 """
8 This is a wrapper class which wraps sklearn LDA Classifier with covariance matrix shrinkage (Ledoit-Wolf lemma).
9 Covariance matrix shrinkage is an effective regularization technique for LDA especially when the
10 input data has a large amount of variables (columns), relative to the number of data points (rows).
11 """
12 def __init__(self):
13 """
14 Construct a new 'LinearDiscriminantAnalysis' object with least squares
15 (performed better compared to eigenvalue decomposition) as solver
16 and covariance matrix shrinkage enabled.
17 """
18 self.estimator = LinearDiscriminantAnalysis(solver="lsqr", shrinkage="auto")
20 def fit(self, X, y):
21 self.estimator.fit(X, y)
23 # Return the classifier
24 return self
26 def predict(self, X):
27 # Input validation
28 X = check_array(X)
29 y_pred = self.estimator.predict(X)
30 return y_pred