Coverage for contextualized/tests.py: 99%
165 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-21 13:49 -0400
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-21 13:49 -0400
1import os
2import unittest
3import numpy as np
4import torch
5from contextualized.modules import SoftSelect, Explainer, MLP, NGAM, Linear
6from contextualized.easy import (
7 ContextualizedRegressor,
8 ContextualizedBayesianNetworks,
9 ContextualizedCorrelationNetworks,
10)
11from contextualized.baselines import BayesianNetwork, CorrelationNetwork
12from contextualized.utils import save, load
15class TestModules(unittest.TestCase):
17 def __init__(self, *args, **kwargs):
18 super().__init__(*args, **kwargs)
20 def setUp(self):
21 """
22 Shared data setup.
23 """
24 self.N_SAMPLES = 100
25 self.X_DIM = 10
26 self.Y_DIM = 5
27 self.K_ARCHETYPES = 3
28 self.WIDTH = 50
29 self.LAYERS = 5
30 self.X_data = torch.rand((self.N_SAMPLES, self.X_DIM))
31 self.IN_DIMS = (3, 4)
32 self.OUT_SHAPE = (5, 6)
33 self.Z1 = torch.randn(self.N_SAMPLES, self.IN_DIMS[0])
34 self.Z2 = torch.randn(self.N_SAMPLES, self.IN_DIMS[1])
36 def test_mlp(self):
37 """
38 Test that the output shape of the MLP is as expected.
39 """
40 mlp = MLP(self.X_DIM, self.Y_DIM, self.WIDTH, self.LAYERS)
41 assert mlp(self.X_data).shape == (self.N_SAMPLES, self.Y_DIM)
43 def test_ngam(self):
44 """
45 Test that the output shape of the NGAM is as expected.
46 """
47 ngam = NGAM(self.X_DIM, self.Y_DIM, self.WIDTH, self.LAYERS)
48 assert ngam(self.X_data).shape == (self.N_SAMPLES, self.Y_DIM)
50 def test_softselect(self):
51 """
52 Test that the output shape of the SoftSelect is as expected.
53 """
54 softselect = SoftSelect(self.IN_DIMS, self.OUT_SHAPE)
55 assert softselect(self.Z1, self.Z2).shape == (self.N_SAMPLES, *self.OUT_SHAPE)
57 precycle_vals = softselect.archetypes
58 assert precycle_vals.shape == (*self.OUT_SHAPE, *self.IN_DIMS)
59 postcycle_vals = softselect.get_archetypes()
60 assert postcycle_vals.shape == (*self.IN_DIMS, *self.OUT_SHAPE)
61 softselect.set_archetypes(torch.randn(*self.IN_DIMS, *self.OUT_SHAPE))
62 assert (softselect.archetypes != precycle_vals).any()
63 softselect.set_archetypes(postcycle_vals)
64 assert (softselect.archetypes == precycle_vals).all()
66 def test_explainer(self):
67 explainer = Explainer(self.IN_DIMS[0], self.OUT_SHAPE)
68 ret = explainer(self.Z1)
70 precycle_vals = explainer.archetypes
71 assert precycle_vals.shape == (*self.OUT_SHAPE, self.IN_DIMS[0])
72 postcycle_vals = explainer.get_archetypes()
73 assert postcycle_vals.shape == (self.IN_DIMS[0], *self.OUT_SHAPE)
74 explainer.set_archetypes(torch.randn(self.IN_DIMS[0], *self.OUT_SHAPE))
75 assert (explainer.archetypes != precycle_vals).any()
76 explainer.set_archetypes(postcycle_vals)
77 assert (explainer.archetypes == precycle_vals).all()
79 def test_linear(self):
80 linear_encoder = Linear(self.X_DIM, self.Y_DIM)
81 linear_output = linear_encoder(self.X_data)
82 assert linear_output.shape == (self.N_SAMPLES, self.Y_DIM)
85class TestSaveLoad(unittest.TestCase):
87 def __init__(self, *args, **kwargs):
88 super().__init__(*args, **kwargs)
90 def test_save_load(self):
91 """
92 Test saving and loading of contextualized objects
93 """
94 C = np.random.uniform(0, 1, size=(100, 2))
95 X = np.random.uniform(0, 1, size=(100, 2))
96 Y = np.random.uniform(0, 1, size=(100, 2))
97 C2 = np.random.uniform(0, 1, size=(100, 2))
98 X2 = np.random.uniform(0, 1, size=(100, 2))
99 Y2 = np.random.uniform(0, 1, size=(100, 2))
100 mlp = MLP(2, 2, 50, 5)
101 Y_pred = mlp(torch.Tensor(X)).detach().numpy()
102 save(mlp, "unittest_model.pt")
103 del mlp
104 mlp_loaded = load("unittest_model.pt")
105 Y_pred_loaded = mlp_loaded(torch.Tensor(X)).detach().numpy()
106 assert np.all(Y_pred == Y_pred_loaded)
107 os.remove("unittest_model.pt")
109 model = ContextualizedRegressor()
110 model.fit(C, X, Y)
111 Y_pred = model.predict(C, X)
112 save(model, "unittest_model.pt")
113 del model
114 model_loaded = load("unittest_model.pt")
115 Y_pred_loaded = model_loaded.predict(C, X)
116 assert np.all(Y_pred == Y_pred_loaded)
117 os.remove("unittest_model.pt")
118 model_loaded.fit(C2, X2, Y2)
119 Y_pred2 = model_loaded.predict(C2, X2)
120 assert not np.all(Y_pred_loaded == Y_pred2)
121 save(model_loaded, "unittest_model.pt")
122 del model_loaded
123 model_loaded2 = load("unittest_model.pt")
124 Y_pred_loaded2 = model_loaded2.predict(C2, X2)
125 assert np.all(Y_pred2 == Y_pred_loaded2)
126 os.remove("unittest_model.pt")
128 model = ContextualizedBayesianNetworks()
129 model.fit(C, X)
130 pred = model.predict_networks(C)
131 save(model, "unittest_model.pt")
132 del model
133 model_loaded = load("unittest_model.pt")
134 pred_loaded = model_loaded.predict_networks(C)
135 assert np.all(np.array(pred) == np.array(pred_loaded))
136 os.remove("unittest_model.pt")
137 model_loaded.fit(C2, X2)
138 pred2 = model_loaded.predict_networks(C2)
139 assert not np.all(np.array(pred_loaded) == np.array(pred2))
140 save(model_loaded, "unittest_model.pt")
141 del model_loaded
142 model_loaded2 = load("unittest_model.pt")
143 pred_loaded2 = model_loaded2.predict_networks(C2)
144 assert np.all(np.array(pred2) == np.array(pred_loaded2))
145 os.remove("unittest_model.pt")
147 model = ContextualizedCorrelationNetworks()
148 model.fit(C, X)
149 pred = model.predict_correlation(C)
150 save(model, "unittest_model.pt")
151 del model
152 model_loaded = load("unittest_model.pt")
153 pred_loaded = model_loaded.predict_correlation(C)
154 assert np.all(np.array(pred) == np.array(pred_loaded))
155 os.remove("unittest_model.pt")
156 model_loaded.fit(C2, X2)
157 pred2 = model_loaded.predict_correlation(C2)
158 assert not np.all(np.array(pred_loaded) == np.array(pred2))
159 save(model_loaded, "unittest_model.pt")
160 del model_loaded
161 model_loaded2 = load("unittest_model.pt")
162 pred_loaded2 = model_loaded2.predict_correlation(C2)
163 assert np.all(np.array(pred2) == np.array(pred_loaded2))
164 os.remove("unittest_model.pt")
166 model = BayesianNetwork()
167 model.fit(X)
168 pred = model.measure_mses(X)
169 save(model, "unittest_model.pt")
170 del model
171 model_loaded = load("unittest_model.pt")
172 pred_loaded = model_loaded.measure_mses(X)
173 assert np.all(np.array(pred) == np.array(pred_loaded))
174 os.remove("unittest_model.pt")
175 model_loaded.fit(X2)
176 pred2 = model_loaded.measure_mses(X2)
177 assert not np.all(np.array(pred_loaded) == np.array(pred2))
178 save(model_loaded, "unittest_model.pt")
179 del model_loaded
180 model_loaded2 = load("unittest_model.pt")
181 pred_loaded2 = model_loaded2.measure_mses(X2)
182 assert np.all(np.array(pred2) == np.array(pred_loaded2))
183 os.remove("unittest_model.pt")
185 model = CorrelationNetwork()
186 model.fit(X)
187 pred = model.measure_mses(X)
188 save(model, "unittest_model.pt")
189 del model
190 model_loaded = load("unittest_model.pt")
191 pred_loaded = model_loaded.measure_mses(X)
192 assert np.all(np.array(pred) == np.array(pred_loaded))
193 os.remove("unittest_model.pt")
194 model_loaded.fit(X2)
195 pred2 = model_loaded.measure_mses(X2)
196 assert not np.all(np.array(pred_loaded) == np.array(pred2))
197 save(model_loaded, "unittest_model.pt")
198 del model_loaded
199 model_loaded2 = load("unittest_model.pt")
200 pred_loaded2 = model_loaded2.measure_mses(X2)
201 assert np.all(np.array(pred2) == np.array(pred_loaded2))
202 os.remove("unittest_model.pt")
205if __name__ == "__main__":
206 unittest.main()