Coverage for contextualized/baselines/tests.py: 97%

36 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-21 13:38 -0400

1""" 

2Unit tests for baseline models. 

3""" 

4 

5import unittest 

6import numpy as np 

7 

8from contextualized.baselines import ( 

9 CorrelationNetwork, 

10 MarkovNetwork, 

11 BayesianNetwork, 

12 GroupedNetworks, 

13) 

14 

15 

16class TestBaselineNetworks(unittest.TestCase): 

17 """ 

18 Test that the baseline networks can be fit and predict the correct shape. 

19 """ 

20 

21 def __init__(self, *args, **kwargs): 

22 super().__init__(*args, **kwargs) 

23 

24 def setUp(self): 

25 """ 

26 Shared data setup. 

27 """ 

28 self.n_samples, self.x_dim = 100, 20 

29 self.labels = np.random.randint(0, 5, (self.n_samples,)) 

30 self.X = np.random.uniform(-1, 1, (self.n_samples, self.x_dim)) 

31 

32 def test_correlation_network(self): 

33 """ 

34 Test that the correlation network can be fit and predicts the correct shape. 

35 """ 

36 corr = CorrelationNetwork().fit(self.X) 

37 assert corr.predict(self.n_samples).shape == ( 

38 self.n_samples, 

39 self.x_dim, 

40 self.x_dim, 

41 ) 

42 assert corr.measure_mses(self.X).mean() < 1.0 

43 

44 def test_grouped_corr_network(self): 

45 """ 

46 Test that the grouped correlation network can be fit and predicts the correct shape. 

47 """ 

48 grouped_corr = GroupedNetworks(CorrelationNetwork).fit(self.X, self.labels) 

49 assert grouped_corr.predict(self.labels).shape == ( 

50 self.n_samples, 

51 self.x_dim, 

52 self.x_dim, 

53 ) 

54 assert grouped_corr.measure_mses(self.X, self.labels).mean() < 1.0 

55 

56 def test_markov_network(self): 

57 """ 

58 Test that the markov network can be fit and predicts the correct shape. 

59 """ 

60 mark = MarkovNetwork().fit(self.X) 

61 assert mark.predict(self.n_samples).shape == ( 

62 self.n_samples, 

63 self.x_dim, 

64 self.x_dim, 

65 ) 

66 assert mark.measure_mses(self.X).mean() < 1.0 

67 

68 def test_grouped_markov_network(self): 

69 """ 

70 Test that the grouped markov network can be fit and predicts the correct shape. 

71 """ 

72 grouped_mark = GroupedNetworks(MarkovNetwork).fit(self.X, self.labels) 

73 grouped_mark.predict(self.labels) 

74 assert grouped_mark.measure_mses(self.X, self.labels).mean() < 1.0 

75 

76 def test_bayesian_network(self): 

77 """ 

78 Test that the bayesian network can be fit and predicts the correct shape. 

79 """ 

80 dag = BayesianNetwork().fit(self.X) 

81 assert dag.predict(self.n_samples).shape == ( 

82 self.n_samples, 

83 self.x_dim, 

84 self.x_dim, 

85 ) 

86 assert dag.measure_mses(self.X).mean() < 1.0 

87 

88 def test_grouped_bayesian_network(self): 

89 """ 

90 Test that the grouped bayesian network can be fit and predicts the correct shape. 

91 """ 

92 grouped_dag = GroupedNetworks(BayesianNetwork).fit(self.X, self.labels) 

93 assert grouped_dag.predict(self.labels).shape == ( 

94 self.n_samples, 

95 self.x_dim, 

96 self.x_dim, 

97 ) 

98 assert grouped_dag.measure_mses(self.X, self.labels).mean() < 1.0 

99 

100 

101if __name__ == "__main__": 

102 unittest.main()