Coverage for mldsutils\tests\test_shrinkage_lda_classifier.py: 100%
34 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-24 21:06 +0200
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-24 21:06 +0200
1from sklearn.datasets import make_classification
2from sklearn.model_selection import train_test_split
3from sklearn.metrics import accuracy_score
4from mldsutils.classification import ShrinkageLDAClassifier
7def test_shrinkage_lda_classifier():
8 # Test case 1: Basic functionality
10 # Generate a synthetic dataset
11 X, y = make_classification(n_samples=100, n_features=10, random_state=42)
13 # Split the data into training and testing sets
14 X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
16 # Initialize the ShrinkageLDAClassifier
17 classifier = ShrinkageLDAClassifier()
19 # Fit the classifier
20 classifier.fit(X_train, y_train)
22 # Predict the labels for the test data
23 y_pred = classifier.predict(X_test)
25 # Check if the predictions have the correct shape
26 assert y_pred.shape == y_test.shape
28 # Check if the accuracy score is reasonable
29 accuracy = accuracy_score(y_test, y_pred)
30 assert 0.0 <= accuracy <= 1.0
33# def test_shrinkage_lda_classifier_different_datasets():
34# # Test case 2: Behavior with different types of datasets
35#
36# # Generate a linearly separable dataset
37# X_linear, y_linear = make_classification(n_samples=100, n_features=10, random_state=42)
38#
39# # Generate a dataset with overlapping classes
40# X_overlap, y_overlap = make_classification(n_samples=100, n_features=10, n_informative=5, n_redundant=5,
41# random_state=42)
42#
43# # Initialize the ShrinkageLDAClassifier
44# classifier = ShrinkageLDAClassifier()
45#
46# # Fit and predict for the linearly separable dataset
47# classifier.fit(X_linear, y_linear)
48# y_pred_linear = classifier.predict(X_linear)
49#
50# # Fit and predict for the dataset with overlapping classes
51# classifier.fit(X_overlap, y_overlap)
52# y_pred_overlap = classifier.predict(X_overlap)
53#
54# # Check if the predictions have the correct shape
55# assert y_pred_linear.shape == y_linear.shape
56# assert y_pred_overlap.shape == y_overlap.shape
57#
58# # Check if the accuracy scores are reasonable
59# accuracy_linear = accuracy_score(y_linear, y_pred_linear)
60# assert accuracy_linear == 1.0
61#
62# accuracy_overlap = accuracy_score(y_overlap, y_pred_overlap)
63# assert accuracy_overlap >= 0.5
66def test_shrinkage_lda_classifier_edge_cases():
67 # Test case 3: Edge cases
69 # Generate a dataset with only one feature
70 X_single_feature, y_single_feature = make_classification(n_samples=100, n_features=1, n_clusters_per_class=1,
71 n_informative=1, n_redundant=0, random_state=42)
73 # Initialize the ShrinkageLDAClassifier
74 classifier = ShrinkageLDAClassifier()
76 # Fit and predict for the single feature dataset
77 classifier.fit(X_single_feature, y_single_feature)
78 y_pred_single_feature = classifier.predict(X_single_feature)
80 # Check if the predictions have the correct shape
81 assert y_pred_single_feature.shape == y_single_feature.shape
84def test_shrinkage_lda_classifier_comparison():
85 # Test case 4: Comparison with other classifiers
87 # Generate a synthetic dataset
88 X, y = make_classification(n_samples=100, n_features=10, random_state=42)
90 # Split the data into training and testing sets
91 X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
93 # Initialize the ShrinkageLDAClassifier
94 shrinkage_lda_classifier = ShrinkageLDAClassifier()
96 # Initialize another classifier for comparison (e.g., Logistic Regression)
97 from sklearn.linear_model import LogisticRegression
98 logistic_regression = LogisticRegression()
100 # Fit and predict using the ShrinkageLDAClassifier
101 shrinkage_lda_classifier.fit(X_train, y_train)
102 y_pred_shrinkage_lda = shrinkage_lda_classifier.predict(X_test)
104 # Fit and predict using the Logistic Regression classifier
105 logistic_regression.fit(X_train, y_train)
106 y_pred_logistic_regression = logistic_regression.predict(X_test)
108 # Check if the predictions have the correct shape
109 assert y_pred_shrinkage_lda.shape == y_test.shape
110 assert y_pred_logistic_regression.shape == y_test.shape
112 # Check if the accuracy score of ShrinkageLDAClassifier is higher than Logistic Regression
113 accuracy_shrinkage_lda = accuracy_score(y_test, y_pred_shrinkage_lda)
114 accuracy_logistic_regression = accuracy_score(y_test, y_pred_logistic_regression)
115 assert accuracy_shrinkage_lda >= accuracy_logistic_regression