Coverage for contextualized/baselines/tests.py: 97%
36 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-21 13:38 -0400
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-21 13:38 -0400
1"""
2Unit tests for baseline models.
3"""
5import unittest
6import numpy as np
8from contextualized.baselines import (
9 CorrelationNetwork,
10 MarkovNetwork,
11 BayesianNetwork,
12 GroupedNetworks,
13)
16class TestBaselineNetworks(unittest.TestCase):
17 """
18 Test that the baseline networks can be fit and predict the correct shape.
19 """
21 def __init__(self, *args, **kwargs):
22 super().__init__(*args, **kwargs)
24 def setUp(self):
25 """
26 Shared data setup.
27 """
28 self.n_samples, self.x_dim = 100, 20
29 self.labels = np.random.randint(0, 5, (self.n_samples,))
30 self.X = np.random.uniform(-1, 1, (self.n_samples, self.x_dim))
32 def test_correlation_network(self):
33 """
34 Test that the correlation network can be fit and predicts the correct shape.
35 """
36 corr = CorrelationNetwork().fit(self.X)
37 assert corr.predict(self.n_samples).shape == (
38 self.n_samples,
39 self.x_dim,
40 self.x_dim,
41 )
42 assert corr.measure_mses(self.X).mean() < 1.0
44 def test_grouped_corr_network(self):
45 """
46 Test that the grouped correlation network can be fit and predicts the correct shape.
47 """
48 grouped_corr = GroupedNetworks(CorrelationNetwork).fit(self.X, self.labels)
49 assert grouped_corr.predict(self.labels).shape == (
50 self.n_samples,
51 self.x_dim,
52 self.x_dim,
53 )
54 assert grouped_corr.measure_mses(self.X, self.labels).mean() < 1.0
56 def test_markov_network(self):
57 """
58 Test that the markov network can be fit and predicts the correct shape.
59 """
60 mark = MarkovNetwork().fit(self.X)
61 assert mark.predict(self.n_samples).shape == (
62 self.n_samples,
63 self.x_dim,
64 self.x_dim,
65 )
66 assert mark.measure_mses(self.X).mean() < 1.0
68 def test_grouped_markov_network(self):
69 """
70 Test that the grouped markov network can be fit and predicts the correct shape.
71 """
72 grouped_mark = GroupedNetworks(MarkovNetwork).fit(self.X, self.labels)
73 grouped_mark.predict(self.labels)
74 assert grouped_mark.measure_mses(self.X, self.labels).mean() < 1.0
76 def test_bayesian_network(self):
77 """
78 Test that the bayesian network can be fit and predicts the correct shape.
79 """
80 dag = BayesianNetwork().fit(self.X)
81 assert dag.predict(self.n_samples).shape == (
82 self.n_samples,
83 self.x_dim,
84 self.x_dim,
85 )
86 assert dag.measure_mses(self.X).mean() < 1.0
88 def test_grouped_bayesian_network(self):
89 """
90 Test that the grouped bayesian network can be fit and predicts the correct shape.
91 """
92 grouped_dag = GroupedNetworks(BayesianNetwork).fit(self.X, self.labels)
93 assert grouped_dag.predict(self.labels).shape == (
94 self.n_samples,
95 self.x_dim,
96 self.x_dim,
97 )
98 assert grouped_dag.measure_mses(self.X, self.labels).mean() < 1.0
101if __name__ == "__main__":
102 unittest.main()