Coverage for mldsutils\classification\_classifiers.py: 100%

13 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-06-24 21:06 +0200

1from sklearn.base import BaseEstimator, ClassifierMixin 

2from sklearn.discriminant_analysis import LinearDiscriminantAnalysis 

3from sklearn.utils.validation import check_array 

4 

5 

6class ShrinkageLDAClassifier(BaseEstimator, ClassifierMixin): 

7 """ 

8 This is a wrapper class which wraps sklearn LDA Classifier with covariance matrix shrinkage (Ledoit-Wolf lemma). 

9 Covariance matrix shrinkage is an effective regularization technique for LDA especially when the 

10 input data has a large amount of variables (columns), relative to the number of data points (rows). 

11 """ 

12 def __init__(self): 

13 """ 

14 Construct a new 'LinearDiscriminantAnalysis' object with least squares 

15 (performed better compared to eigenvalue decomposition) as solver 

16 and covariance matrix shrinkage enabled. 

17 """ 

18 self.estimator = LinearDiscriminantAnalysis(solver="lsqr", shrinkage="auto") 

19 

20 def fit(self, X, y): 

21 self.estimator.fit(X, y) 

22 

23 # Return the classifier 

24 return self 

25 

26 def predict(self, X): 

27 # Input validation 

28 X = check_array(X) 

29 y_pred = self.estimator.predict(X) 

30 return y_pred