Coverage for contextualized/regression/tests.py: 99%
144 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-21 13:49 -0400
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-21 13:49 -0400
1"""
2Unit tests for Contextualized Regression.
3"""
5import unittest
6import numpy as np
7import torch
9# from contextualized.modules import NGAM, MLP, SoftSelect, Explainer
10from contextualized.regression.lightning_modules import *
11from contextualized.regression.trainers import *
12from contextualized.functions import LINK_FUNCTIONS
13from contextualized.utils import DummyParamPredictor, DummyYPredictor
16class TestRegression(unittest.TestCase):
17 """
18 Test regression models.
19 """
21 def __init__(self, *args, **kwargs):
22 super().__init__(*args, **kwargs)
24 def setUp(self):
25 """
26 Shared unit test setup code.
27 """
28 np.random.seed(0)
29 torch.manual_seed(0)
30 n = 100
31 c_dim = 4
32 x_dim = 5
33 y_dim = 3
34 C = torch.rand((n, c_dim)) - 0.5
35 W_1 = C.sum(axis=1).unsqueeze(-1) ** 2
36 W_2 = -C.sum(axis=1).unsqueeze(-1)
37 b_1 = C[:, 0].unsqueeze(-1)
38 b_2 = C[:, 1].unsqueeze(-1)
39 # W_full = torch.cat((W_1, W_2), axis=1)
40 # b_full = b_1 + b_2
41 X = torch.rand((n, x_dim)) - 0.5
42 Y_1 = X[:, 0].unsqueeze(-1) * W_1 + b_1
43 Y_2 = X[:, 1].unsqueeze(-1) * W_2 + b_2
44 Y_3 = X.sum(axis=1).unsqueeze(-1)
45 Y = torch.cat((Y_1, Y_2, Y_3), axis=1)
47 self.k = 10
48 self.epochs = 1
49 self.batch_size = 32
50 self.c_dim, self.x_dim, self.y_dim = c_dim, x_dim, y_dim
51 self.C, self.X, self.Y = C.numpy(), X.numpy(), Y.numpy()
53 def _quicktest(self, model, univariate=False, correlation=False, markov=False):
54 """
56 :param model:
57 :param univariate: (Default value = False)
58 :param correlation: (Default value = False)
59 :param markov: (Default value = False)
61 """
62 print(f"\n{type(model)} quicktest")
63 if correlation:
64 dataloader = model.dataloader(self.C, self.X, batch_size=self.batch_size)
65 trainer = CorrelationTrainer(
66 max_epochs=self.epochs, enable_progress_bar=False
67 )
68 y_true = np.tile(self.X[:, :, np.newaxis], (1, 1, self.X.shape[-1]))
69 elif markov:
70 dataloader = model.dataloader(self.C, self.X, batch_size=self.batch_size)
71 trainer = MarkovTrainer(max_epochs=self.epochs, enable_progress_bar=False)
72 y_true = self.X
73 else:
74 dataloader = model.dataloader(
75 self.C, self.X, self.Y, batch_size=self.batch_size
76 )
77 trainer = RegressionTrainer(
78 max_epochs=self.epochs, enable_progress_bar=False
79 )
80 if univariate:
81 y_true = np.tile(self.Y[:, :, np.newaxis], (1, 1, self.X.shape[-1]))
82 else:
83 y_true = self.Y
84 y_preds = trainer.predict_y(model, dataloader)
85 err_init = ((y_true - y_preds) ** 2).mean()
86 trainer.fit(model, dataloader)
87 trainer.validate(model, dataloader)
88 trainer.test(model, dataloader)
89 beta_preds, mu_preds = trainer.predict_params(model, dataloader)
90 if correlation:
91 rhos = trainer.predict_correlation(model, dataloader)
92 if markov:
93 omegas = trainer.predict_precision(model, dataloader)
94 y_preds = trainer.predict_y(model, dataloader)
95 err_trained = ((y_true - y_preds) ** 2).mean()
96 assert err_trained < err_init, "Model failed to converge"
98 def test_naive(self):
99 """
100 Test Naive Multivariate regression.
101 """
102 # Naive Multivariate
103 model = NaiveContextualizedRegression(
104 self.c_dim,
105 self.x_dim,
106 self.y_dim,
107 encoder_kwargs={
108 "width": 25,
109 "layers": 2,
110 "link_fn": LINK_FUNCTIONS["identity"],
111 },
112 link_fn=LINK_FUNCTIONS["identity"],
113 )
114 self._quicktest(model)
116 model = NaiveContextualizedRegression(
117 self.c_dim,
118 self.x_dim,
119 self.y_dim,
120 encoder_type="ngam",
121 encoder_kwargs={
122 "width": 25,
123 "layers": 2,
124 "link_fn": LINK_FUNCTIONS["identity"],
125 },
126 link_fn=LINK_FUNCTIONS["identity"],
127 )
128 self._quicktest(model)
130 model = NaiveContextualizedRegression(
131 self.c_dim,
132 self.x_dim,
133 self.y_dim,
134 encoder_kwargs={
135 "width": 25,
136 "layers": 2,
137 "link_fn": LINK_FUNCTIONS["softmax"],
138 },
139 link_fn=LINK_FUNCTIONS["identity"],
140 )
141 self._quicktest(model)
143 model = NaiveContextualizedRegression(
144 self.c_dim,
145 self.x_dim,
146 self.y_dim,
147 encoder_kwargs={
148 "width": 25,
149 "layers": 2,
150 "link_fn": LINK_FUNCTIONS["identity"],
151 },
152 link_fn=LINK_FUNCTIONS["logistic"],
153 )
154 self._quicktest(model)
156 model = NaiveContextualizedRegression(
157 self.c_dim,
158 self.x_dim,
159 self.y_dim,
160 encoder_kwargs={
161 "width": 25,
162 "layers": 2,
163 "link_fn": LINK_FUNCTIONS["softmax"],
164 },
165 link_fn=LINK_FUNCTIONS["logistic"],
166 )
167 self._quicktest(model)
169 parambase = DummyParamPredictor((self.y_dim, self.x_dim), (self.y_dim, 1))
170 model = NaiveContextualizedRegression(
171 self.c_dim,
172 self.x_dim,
173 self.y_dim,
174 encoder_kwargs={
175 "width": 25,
176 "layers": 2,
177 "link_fn": LINK_FUNCTIONS["softmax"],
178 },
179 link_fn=LINK_FUNCTIONS["logistic"],
180 base_param_predictor=parambase,
181 )
182 self._quicktest(model)
184 ybase = DummyYPredictor((self.y_dim, 1))
185 model = NaiveContextualizedRegression(
186 self.c_dim,
187 self.x_dim,
188 self.y_dim,
189 encoder_kwargs={
190 "width": 25,
191 "layers": 2,
192 "link_fn": LINK_FUNCTIONS["softmax"],
193 },
194 link_fn=LINK_FUNCTIONS["logistic"],
195 base_y_predictor=ybase,
196 )
197 self._quicktest(model)
199 def test_subtype(self):
200 """
201 Test subtype multivariate regression.
202 """
203 # Subtype Multivariate
204 parambase = DummyParamPredictor((self.y_dim, self.x_dim), (self.y_dim, 1))
205 ybase = DummyYPredictor((self.y_dim, 1))
206 model = ContextualizedRegression(
207 self.c_dim,
208 self.x_dim,
209 self.y_dim,
210 base_param_predictor=parambase,
211 base_y_predictor=ybase,
212 )
213 self._quicktest(model)
215 def test_multitask(self):
216 """
217 Test multitask multivariate regression.
218 """
219 # Multitask Multivariate
220 parambase = DummyParamPredictor((self.x_dim,), (1,))
221 ybase = DummyYPredictor((1,))
222 model = MultitaskContextualizedRegression(
223 self.c_dim,
224 self.x_dim,
225 self.y_dim,
226 base_param_predictor=parambase,
227 base_y_predictor=ybase,
228 )
229 self._quicktest(model)
231 def test_tasksplit(self):
232 """
233 Test tasksplit multivariate regression.
234 """
235 # Tasksplit Multivariate
236 parambase = DummyParamPredictor((self.x_dim,), (1,))
237 ybase = DummyYPredictor((1,))
238 model = TasksplitContextualizedRegression(
239 self.c_dim,
240 self.x_dim,
241 self.y_dim,
242 base_param_predictor=parambase,
243 base_y_predictor=ybase,
244 )
245 self._quicktest(model)
247 def test_univariate_subtype(self):
248 """
249 Test naive univariate regression.
250 """
251 # Naive Univariate
252 parambase = DummyParamPredictor(
253 (self.y_dim, self.x_dim, 1), (self.y_dim, self.x_dim, 1)
254 )
255 ybase = DummyYPredictor((self.y_dim, self.x_dim, 1))
256 model = ContextualizedUnivariateRegression(
257 self.c_dim,
258 self.x_dim,
259 self.y_dim,
260 base_param_predictor=parambase,
261 base_y_predictor=ybase,
262 )
263 self._quicktest(model, univariate=True)
265 def test_univariate_tasksplit(self):
266 """
267 Test task-split univariate regression.
268 """
269 # Tasksplit Univariate
270 parambase = DummyParamPredictor((1,), (1,))
271 ybase = DummyYPredictor((1,))
272 model = TasksplitContextualizedUnivariateRegression(
273 self.c_dim,
274 self.x_dim,
275 self.y_dim,
276 base_param_predictor=parambase,
277 base_y_predictor=ybase,
278 )
279 self._quicktest(model, univariate=True)
281 def test_correlation_subtype(self):
282 """
283 Test correlation.
284 """
285 # Correlation
286 parambase = DummyParamPredictor(
287 (self.x_dim, self.x_dim, 1), (self.x_dim, self.x_dim, 1)
288 )
289 ybase = DummyYPredictor((self.x_dim, self.x_dim, 1))
290 model = ContextualizedCorrelation(
291 self.c_dim,
292 self.x_dim,
293 base_param_predictor=parambase,
294 base_y_predictor=ybase,
295 )
296 self._quicktest(model, correlation=True)
298 def test_correlation_tasksplit(self):
299 """
300 Test task-split correlation.
301 """
302 # Tasksplit Correlation
303 parambase = DummyParamPredictor((1,), (1,))
304 ybase = DummyYPredictor((1,))
305 model = TasksplitContextualizedCorrelation(
306 self.c_dim,
307 self.x_dim,
308 base_param_predictor=parambase,
309 base_y_predictor=ybase,
310 )
311 self._quicktest(model, correlation=True)
313 def test_markov_subtype(self):
314 """
315 Test Markov Graphs.
316 """
317 # Markov Graph
318 parambase = DummyParamPredictor((self.x_dim, self.x_dim), (self.x_dim, 1))
319 ybase = DummyYPredictor((self.x_dim, 1))
320 model = ContextualizedMarkovGraph(
321 self.c_dim,
322 self.x_dim,
323 base_param_predictor=parambase,
324 base_y_predictor=ybase,
325 )
326 self._quicktest(model, markov=True)
328 def test_neighborhood_subtype(self):
329 """
330 Test Neighborhood Selection.
331 """
332 parambase = DummyParamPredictor((self.x_dim, self.x_dim), (self.x_dim, 1))
333 ybase = DummyYPredictor((self.x_dim, 1))
334 model = ContextualizedNeighborhoodSelection(
335 self.c_dim,
336 self.x_dim,
337 base_param_predictor=parambase,
338 base_y_predictor=ybase,
339 )
340 self._quicktest(model, markov=True)
342 def test_metamodel_switch(self):
343 """
344 Test switching between meta-models.
345 """
346 parambase = DummyParamPredictor((self.y_dim, self.x_dim), (self.y_dim, 1))
347 ybase = DummyYPredictor((self.y_dim, 1))
348 model = ContextualizedRegression(
349 self.c_dim,
350 self.x_dim,
351 self.y_dim,
352 metamodel_type="naive",
353 base_param_predictor=parambase,
354 base_y_predictor=ybase,
355 )
356 self._quicktest(model)
358 parambase = DummyParamPredictor((self.y_dim, self.x_dim), (self.y_dim, 1))
359 ybase = DummyYPredictor((self.y_dim, 1))
360 model = ContextualizedRegression(
361 self.c_dim,
362 self.x_dim,
363 self.y_dim,
364 metamodel_type="subtype",
365 base_param_predictor=parambase,
366 base_y_predictor=ybase,
367 )
368 self._quicktest(model)
370 def test_fit_intercept(self):
371 """
372 Test switching between meta-models.
373 """
374 parambase = DummyParamPredictor((self.y_dim, self.x_dim), (self.y_dim, 1))
375 ybase = DummyYPredictor((self.y_dim, 1))
376 model = ContextualizedRegression(
377 self.c_dim,
378 self.x_dim,
379 self.y_dim,
380 metamodel_type="naive",
381 fit_intercept=False,
382 base_param_predictor=parambase,
383 base_y_predictor=ybase,
384 )
385 dataloader = model.dataloader(
386 self.C, self.X, self.Y, batch_size=self.batch_size
387 )
388 trainer = RegressionTrainer(enable_progress_bar=False)
389 beta_preds, mu_preds = trainer.predict_params(model, dataloader)
390 assert (mu_preds == 0).all()
391 self._quicktest(model)
392 beta_preds, mu_preds = trainer.predict_params(model, dataloader)
393 assert (mu_preds == 0).all()
396if __name__ == "__main__":
397 unittest.main()