Coverage for contextualized/easy/tests/test_classifier.py: 97%

33 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-18 16:32 -0400

1"""" 

2Unit tests for Easy Classifier. 

3""" 

4 

5import unittest 

6import numpy as np 

7 

8from contextualized.easy import ContextualizedClassifier 

9 

10 

11class TestEasyClassifier(unittest.TestCase): 

12 """ 

13 Test Easy Classifier models. 

14 """ 

15 

16 def __init__(self, *args, **kwargs): 

17 super().__init__(*args, **kwargs) 

18 

19 def _quicktest(self, model, C, X, Y, **kwargs): 

20 print(f"{type(model)} quicktest") 

21 model.fit(C, X, Y, max_epochs=0) 

22 err_init = (Y != model.predict(C, X)).sum() 

23 model.fit(C, X, Y, **kwargs) 

24 beta_preds, mu_preds = model.predict_params(C) 

25 assert beta_preds.shape == (X.shape[0], Y.shape[1], X.shape[1]) 

26 assert mu_preds.shape == (X.shape[0], Y.shape[1]) 

27 assert not np.any(np.isnan(beta_preds)) 

28 assert not np.any(np.isnan(mu_preds)) 

29 y_preds = model.predict(C, X) 

30 assert y_preds.shape == Y.shape 

31 err_trained = (Y != y_preds).sum() 

32 assert err_trained < err_init 

33 print(err_trained, err_init) 

34 

35 def test_classifier(self): 

36 """Test Case for ContextualizedClassifier.""" 

37 

38 n_samples = 1000 

39 c_dim = 100 

40 x_dim = 3 

41 y_dim = 1 

42 C = np.random.uniform(-1, 1, size=(n_samples, c_dim)) 

43 X = np.random.uniform(-1, 1, size=(n_samples, x_dim)) 

44 Y = np.random.binomial(1, 0.5, size=(n_samples, y_dim)) 

45 

46 model = ContextualizedClassifier(alpha=1e-1, encoder_type="mlp") 

47 self._quicktest(model, C, X, Y, max_epochs=10, es_patience=float("inf")) 

48 

49 

50if __name__ == "__main__": 

51 unittest.main()