Coverage for contextualized/easy/tests/test_classifier.py: 97%
33 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-18 16:32 -0400
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-18 16:32 -0400
1""""
2Unit tests for Easy Classifier.
3"""
5import unittest
6import numpy as np
8from contextualized.easy import ContextualizedClassifier
11class TestEasyClassifier(unittest.TestCase):
12 """
13 Test Easy Classifier models.
14 """
16 def __init__(self, *args, **kwargs):
17 super().__init__(*args, **kwargs)
19 def _quicktest(self, model, C, X, Y, **kwargs):
20 print(f"{type(model)} quicktest")
21 model.fit(C, X, Y, max_epochs=0)
22 err_init = (Y != model.predict(C, X)).sum()
23 model.fit(C, X, Y, **kwargs)
24 beta_preds, mu_preds = model.predict_params(C)
25 assert beta_preds.shape == (X.shape[0], Y.shape[1], X.shape[1])
26 assert mu_preds.shape == (X.shape[0], Y.shape[1])
27 assert not np.any(np.isnan(beta_preds))
28 assert not np.any(np.isnan(mu_preds))
29 y_preds = model.predict(C, X)
30 assert y_preds.shape == Y.shape
31 err_trained = (Y != y_preds).sum()
32 assert err_trained < err_init
33 print(err_trained, err_init)
35 def test_classifier(self):
36 """Test Case for ContextualizedClassifier."""
38 n_samples = 1000
39 c_dim = 100
40 x_dim = 3
41 y_dim = 1
42 C = np.random.uniform(-1, 1, size=(n_samples, c_dim))
43 X = np.random.uniform(-1, 1, size=(n_samples, x_dim))
44 Y = np.random.binomial(1, 0.5, size=(n_samples, y_dim))
46 model = ContextualizedClassifier(alpha=1e-1, encoder_type="mlp")
47 self._quicktest(model, C, X, Y, max_epochs=10, es_patience=float("inf"))
50if __name__ == "__main__":
51 unittest.main()