Coverage for contextualized/dags/tests.py: 99%
131 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-21 13:49 -0400
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-21 13:49 -0400
1"""
2Unit tests for DAG models.
3"""
5import unittest
6import numpy as np
7import igraph as ig
8from pytorch_lightning import seed_everything
9from pytorch_lightning.callbacks import LearningRateFinder
12from contextualized.dags.lightning_modules import NOTMAD
13from contextualized.dags import graph_utils
14from contextualized.dags.trainers import GraphTrainer
15from contextualized.dags.losses import mse_loss as mse
18class TestNOTMADFast(unittest.TestCase):
19 """Unit tests for NOTMAD."""
21 def __init__(self, *args, **kwargs):
22 super(TestNOTMADFast, self).__init__(*args, **kwargs)
24 def setUp(self):
25 seed_everything(0)
26 self.n = 10
27 self.c_dim = 4
28 self.x_dim = 3
29 self.C = np.random.uniform(-1, 1, size=(self.n, self.c_dim))
30 self.X = np.random.uniform(-1, 1, size=(self.n, self.x_dim))
32 def _train(self, model_args, n_epochs):
33 k = 6
34 INIT_MAT = np.random.uniform(-0.1, 0.1, size=(k, 4, 4)) * np.tile(
35 1 - np.eye(4), (k, 1, 1)
36 )
37 model = NOTMAD(
38 self.C.shape[-1],
39 self.X.shape[-1],
40 archetype_params={
41 "l1": 0.0,
42 "dag": model_args.get(
43 "dag",
44 {
45 "loss_type": "NOTEARS",
46 "params": {
47 "alpha": 1e-1,
48 "rho": 1e-2,
49 "h_old": 0.0,
50 "tol": 0.25,
51 "use_dynamic_alpha_rho": True,
52 },
53 },
54 ),
55 "init_mat": INIT_MAT,
56 "num_factors": model_args.get("num_factors", 0),
57 "factor_mat_l1": 0.0,
58 "num_archetypes": model_args.get("num_archetypes", k),
59 },
60 )
61 dataloader = model.dataloader(self.C, self.X, batch_size=1, num_workers=0)
62 trainer = GraphTrainer(
63 max_epochs=n_epochs, deterministic=True, enable_progress_bar=False
64 )
65 predict_trainer = GraphTrainer(
66 deterministic=True, enable_progress_bar=False, devices=1
67 )
68 init_preds = predict_trainer.predict_params(
69 model,
70 dataloader,
71 project_to_dag=False,
72 )
73 assert init_preds.shape == (self.n, self.x_dim, self.x_dim)
74 init_mse = mse(graph_utils.dag_pred_np(self.X, init_preds), self.X)
75 trainer.fit(model, dataloader)
76 final_preds = predict_trainer.predict_params(
77 model, dataloader, project_to_dag=False
78 )
79 assert final_preds.shape == (self.n, self.x_dim, self.x_dim)
80 final_mse = mse(graph_utils.dag_pred_np(self.X, final_preds), self.X)
81 assert final_mse < init_mse
83 def test_notmad_dagma(self):
84 self._train(
85 {
86 "dag": {
87 "loss_type": "DAGMA",
88 "params": {
89 "alpha": 1.0,
90 },
91 }
92 },
93 1,
94 )
96 def test_notmad_notears(self):
97 self._train({}, 1)
99 def test_notmad_factor_graphs(self):
100 """
101 Unit tests for factor graph feature of NOTMAD.
102 """
103 self._train({"num_factors": 3}, 1)
106class TestNOTMAD(unittest.TestCase):
107 """Unit tests for NOTMAD."""
109 def __init__(self, *args, **kwargs):
110 super(TestNOTMAD, self).__init__(*args, **kwargs)
112 def setUp(self):
113 (
114 self.C,
115 self.X,
116 self.W,
117 self.train_idx,
118 self.test_idx,
119 self.val_idx,
120 ) = self._create_cwx_dataset()
121 self.C_train, self.C_test, self.C_val = (
122 self.C[self.train_idx],
123 self.C[self.test_idx],
124 self.C[self.val_idx],
125 )
126 self.X_train, self.X_test, self.X_val = (
127 self.X[self.train_idx],
128 self.X[self.test_idx],
129 self.X[self.val_idx],
130 )
131 self.W_train, self.W_test, self.W_val = (
132 self.W[self.train_idx],
133 self.W[self.test_idx],
134 self.W[self.val_idx],
135 )
137 def _create_cwx_dataset(self, n=500):
138 np.random.seed(0)
139 C = np.linspace(1, 2, n).reshape((n, 1))
140 W = np.zeros((4, 4, n, 1))
141 W[0, 1] = C - 2
142 W[2, 1] = 2 * C
143 W[3, 1] = C**2
144 W[3, 2] = C
145 W = np.squeeze(W)
147 W = np.transpose(W, (2, 0, 1))
148 X = np.zeros((n, 4))
149 for i, w in enumerate(W):
150 x = graph_utils.simulate_linear_sem(w, 1, "uniform", noise_scale=0.1)[0]
151 X[i] = x
152 train_idx = np.argwhere(np.logical_or(C < 1.7, C >= 1.9)[:, 0])[:, 0]
153 np.random.shuffle(train_idx)
154 test_idx = np.argwhere(np.logical_and(C >= 1.8, C < 1.9)[:, 0])[:, 0]
155 val_idx = np.argwhere(np.logical_and(C >= 1.7, C < 1.8)[:, 0])[:, 0]
156 return (
157 C,
158 X,
159 W,
160 train_idx,
161 test_idx,
162 val_idx,
163 )
165 def _evaluate(self, train_preds, test_preds, val_preds):
166 return (
167 mse(train_preds, self.W_train),
168 mse(test_preds, self.W_test),
169 mse(val_preds, self.W_val),
170 mse(graph_utils.dag_pred_np(self.X_train, train_preds), self.X_train),
171 mse(graph_utils.dag_pred_np(self.X_test, test_preds), self.X_test),
172 mse(graph_utils.dag_pred_np(self.X_val, val_preds), self.X_val),
173 )
175 def _train(self, model_args, n_epochs):
176 seed_everything(0)
177 k = 6
178 INIT_MAT = np.random.uniform(-0.1, 0.1, size=(k, 4, 4)) * np.tile(
179 1 - np.eye(4), (k, 1, 1)
180 )
181 model = NOTMAD(
182 self.C.shape[-1],
183 self.X.shape[-1],
184 archetype_params={
185 "l1": 0.0,
186 "dag": model_args.get(
187 "dag",
188 {
189 "loss_type": "NOTEARS",
190 "params": {
191 "alpha": 1e-1,
192 "rho": 1e-2,
193 "h_old": 0.0,
194 "tol": 0.25,
195 "use_dynamic_alpha_rho": True,
196 },
197 },
198 ),
199 "init_mat": INIT_MAT,
200 "num_factors": model_args.get("num_factors", 0),
201 "factor_mat_l1": 0.0,
202 "num_archetypes": model_args.get("num_archetypes", k),
203 },
204 )
205 train_dataloader = model.dataloader(
206 self.C_train, self.X_train, batch_size=1, num_workers=0
207 )
208 test_dataloader = model.dataloader(
209 self.C_test, self.X_test, batch_size=10, num_workers=0
210 )
211 val_dataloader = model.dataloader(
212 self.C_val, self.X_val, batch_size=10, num_workers=0
213 )
214 trainer = GraphTrainer(
215 max_epochs=n_epochs, deterministic=True, enable_progress_bar=False
216 )
217 predict_trainer = GraphTrainer(
218 deterministic=True, enable_progress_bar=False, devices=1
219 )
220 preds_train = predict_trainer.predict_params(
221 model, train_dataloader, project_to_dag=True
222 )
223 preds_test = predict_trainer.predict_params(
224 model, test_dataloader, project_to_dag=True
225 )
226 preds_val = predict_trainer.predict_params(
227 model, val_dataloader, project_to_dag=True
228 )
229 init_train_l2, init_test_l2, init_val_l2, init_train_mse, _, _ = self._evaluate(
230 preds_train, preds_test, preds_val
231 )
232 trainer.fit(model, train_dataloader)
233 trainer.validate(model, val_dataloader)
234 trainer.test(model, test_dataloader)
236 # Evaluate results
237 preds_train = predict_trainer.predict_params(
238 model, train_dataloader, project_to_dag=True
239 )
240 preds_test = predict_trainer.predict_params(
241 model, test_dataloader, project_to_dag=True
242 )
243 preds_val = predict_trainer.predict_params(
244 model, val_dataloader, project_to_dag=True
245 )
247 return (
248 preds_train,
249 preds_test,
250 preds_val,
251 init_train_l2,
252 init_test_l2,
253 init_val_l2,
254 )
256 def test_notmad_dagma(self):
257 train_preds, test_preds, val_preds, _, _, _ = self._train(
258 {
259 "dag": {
260 "loss_type": "DAGMA",
261 "params": {
262 "alpha": 1.0,
263 },
264 }
265 },
266 10,
267 )
268 train_l2, test_l2, val_l2, train_mse, test_mse, val_mse = self._evaluate(
269 train_preds, test_preds, val_preds
270 )
271 print(f"Train L2: {train_l2}")
272 print(f"Test L2: {test_l2}")
273 print(f"Val L2: {val_l2}")
274 print(f"Train mse: {train_mse}")
275 print(f"Test mse: {test_mse}")
276 print(f"Val mse: {val_mse}")
277 assert train_l2 < 1e-1
278 assert test_l2 < 1e-1
279 assert val_l2 < 1e-1
280 assert train_mse < 1e-2
281 assert test_mse < 1e-2
282 assert val_mse < 1e-2
284 def test_notmad_notears(self):
285 train_preds, test_preds, val_preds, _, _, _ = self._train({}, 10)
286 train_l2, test_l2, val_l2, train_mse, test_mse, val_mse = self._evaluate(
287 train_preds, test_preds, val_preds
288 )
289 print(f"Train L2: {train_l2}")
290 print(f"Test L2: {test_l2}")
291 print(f"Val L2: {val_l2}")
292 print(f"Train mse: {train_mse}")
293 print(f"Test mse: {test_mse}")
294 print(f"Val mse: {val_mse}")
295 assert train_l2 < 1e-1
296 assert test_l2 < 1e-1
297 assert val_l2 < 1e-1
298 assert train_mse < 1e-2
299 assert test_mse < 1e-2
300 assert val_mse < 1e-2
302 def test_notmad_factor_graphs(self):
303 """
304 Unit tests for factor graph feature of NOTMAD.
305 """
306 (
307 train_preds,
308 test_preds,
309 val_preds,
310 init_train_l2,
311 init_test_l2,
312 init_val_l2,
313 ) = self._train({"num_factors": 3}, 10)
314 train_l2, test_l2, val_l2, _, _, _ = self._evaluate(
315 train_preds, test_preds, val_preds
316 )
317 assert train_preds.shape == self.W_train.shape
318 assert val_preds.shape == self.W_val.shape
319 assert test_preds.shape == self.W_test.shape
320 assert train_l2 < init_train_l2
321 assert test_l2 < init_test_l2
322 assert val_l2 < init_val_l2
325if __name__ == "__main__":
326 unittest.main()