Coverage for contextualized/easy/tests/test_regressor.py: 98%
56 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-18 16:32 -0400
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-18 16:32 -0400
1""""
2Unit tests for Easy Regressor.
3"""
5import unittest
6import numpy as np
7import torch
9from contextualized.easy import ContextualizedRegressor
10from contextualized.utils import DummyParamPredictor, DummyYPredictor
13class TestEasyRegression(unittest.TestCase):
14 """
15 Test Easy Regression models.
16 """
18 def __init__(self, *args, **kwargs):
19 super().__init__(*args, **kwargs)
21 def _quicktest(self, model, C, X, Y, **kwargs):
22 print(f"{type(model)} quicktest")
23 model.fit(C, X, Y, max_epochs=0)
24 err_init = np.linalg.norm(Y - model.predict(C, X), ord=2)
25 model.fit(C, X, Y, **kwargs)
26 beta_preds, mu_preds = model.predict_params(C)
27 try:
28 y_dim = Y.shape[1]
29 except IndexError:
30 y_dim = 1
31 assert beta_preds.shape == (X.shape[0], y_dim, X.shape[1])
32 assert mu_preds.shape == (X.shape[0], y_dim)
33 y_preds = model.predict(C, X)
34 assert y_preds.shape == (len(Y), y_dim)
35 err_trained = np.linalg.norm(Y - np.squeeze(y_preds), ord=2)
36 assert err_trained < err_init
37 print(err_trained, err_init)
39 def test_regressor(self):
40 """Test Case for ContextualizedRegressor."""
41 n_samples = 1000
42 c_dim = 2
43 x_dim = 3
44 y_dim = 2
45 C = torch.rand((n_samples, c_dim)) - 0.5
46 beta_1 = C.sum(axis=1).unsqueeze(-1) ** 2
47 beta_2 = -C.sum(axis=1).unsqueeze(-1)
48 b_1 = C[:, 0].unsqueeze(-1)
49 b_2 = C[:, 1].unsqueeze(-1)
50 X = torch.rand((n_samples, x_dim)) - 0.5
51 outcome_1 = X[:, 0].unsqueeze(-1) * beta_1 + b_1
52 outcome_2 = X[:, 1].unsqueeze(-1) * beta_2 + b_2
53 Y = torch.cat((outcome_1, outcome_2), axis=1)
55 C, X, Y = C.numpy(), X.numpy(), Y.numpy()
57 # Naive Multivariate
58 parambase = DummyParamPredictor((y_dim, x_dim), (y_dim, 1))
59 ybase = DummyYPredictor((y_dim, 1))
60 model = ContextualizedRegressor(
61 base_param_predictor=parambase, base_y_predictor=ybase
62 )
63 self._quicktest(
64 model, C, X, Y, max_epochs=10, learning_rate=1e-3, es_patience=float("inf")
65 )
67 model = ContextualizedRegressor(num_archetypes=0)
68 self._quicktest(model, C, X, Y, max_epochs=10, es_patience=float("inf"))
70 model = ContextualizedRegressor(num_archetypes=4)
71 self._quicktest(model, C, X, Y, max_epochs=10, es_patience=float("inf"))
73 # With regularization
74 model = ContextualizedRegressor(
75 num_archetypes=4, alpha=1e-1, l1_ratio=0.5, mu_ratio=0.1
76 )
77 self._quicktest(model, C, X, Y, max_epochs=10, es_patience=float("inf"))
79 # With bootstrap
80 model = ContextualizedRegressor(
81 num_archetypes=4, alpha=1e-1, l1_ratio=0.5, mu_ratio=0.1
82 )
83 self._quicktest(
84 model,
85 C,
86 X,
87 Y,
88 max_epochs=10,
89 n_bootstraps=2,
90 learning_rate=1e-3,
91 es_patience=float("inf"),
92 )
94 # Check smaller Y.
95 model = ContextualizedRegressor(
96 num_archetypes=4, alpha=1e-1, l1_ratio=0.5, mu_ratio=0.1
97 )
98 self._quicktest(
99 model,
100 C,
101 X,
102 Y[:, 0],
103 max_epochs=10,
104 n_bootstraps=2,
105 learning_rate=1e-3,
106 es_patience=float("inf"),
107 )
110if __name__ == "__main__":
111 unittest.main()