Coverage for contextualized/dags/tests_fast.py: 98%
41 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-18 16:32 -0400
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-18 16:32 -0400
1"""
2Unit tests for DAG models.
3"""
4import unittest
5import numpy as np
6from pytorch_lightning import seed_everything
7from pytorch_lightning.callbacks import LearningRateFinder
10from contextualized.dags.lightning_modules import NOTMAD
11from contextualized.dags import graph_utils
12from contextualized.dags.trainers import GraphTrainer
13from contextualized.dags.losses import mse_loss as mse
16class TestNOTMAD(unittest.TestCase):
17 """Unit tests for NOTMAD."""
19 def __init__(self, *args, **kwargs):
20 super(TestNOTMAD, self).__init__(*args, **kwargs)
22 def setUp(self):
23 seed_everything(0)
24 self.n = 10
25 self.c_dim = 4
26 self.x_dim = 3
27 self.C = np.random.uniform(-1, 1, size=(self.n, self.c_dim))
28 self.X = np.random.uniform(-1, 1, size=(self.n, self.x_dim))
30 def _train(self, model_args, n_epochs):
31 k = 6
32 INIT_MAT = np.random.uniform(-0.1, 0.1, size=(k, 4, 4)) * np.tile(
33 1 - np.eye(4), (k, 1, 1)
34 )
35 model = NOTMAD(
36 self.C.shape[-1],
37 self.X.shape[-1],
38 archetype_params={
39 "l1": 0.0,
40 "dag": model_args.get(
41 "dag",
42 {
43 "loss_type": "NOTEARS",
44 "params": {
45 "alpha": 1e-1,
46 "rho": 1e-2,
47 "h_old": 0.0,
48 "tol": 0.25,
49 "use_dynamic_alpha_rho": True,
50 },
51 },
52 ),
53 "init_mat": INIT_MAT,
54 "num_factors": model_args.get("num_factors", 0),
55 "factor_mat_l1": 0.0,
56 "num_archetypes": model_args.get("num_archetypes", k),
57 },
58 )
59 dataloader = model.dataloader(
60 self.C, self.X, batch_size=1, num_workers=0
61 )
62 trainer = GraphTrainer(
63 max_epochs=n_epochs, deterministic=True, enable_progress_bar=False
64 )
65 predict_trainer = GraphTrainer(
66 deterministic=True, enable_progress_bar=False, devices=1
67 )
68 init_preds = predict_trainer.predict_params(
69 model, dataloader, project_to_dag=False,
70 )
71 assert init_preds.shape == (self.n, self.x_dim, self.x_dim)
72 init_mse = mse(graph_utils.dag_pred_np(self.X, init_preds), self.X)
73 trainer.fit(model, dataloader)
74 final_preds = predict_trainer.predict_params(
75 model, dataloader, project_to_dag=False
76 )
77 assert final_preds.shape == (self.n, self.x_dim, self.x_dim)
78 final_mse = mse(graph_utils.dag_pred_np(self.X, final_preds), self.X)
79 assert final_mse < init_mse
81 def test_notmad_dagma(self):
82 self._train(
83 {
84 "dag": {
85 "loss_type": "DAGMA",
86 "params": {
87 "alpha": 1.0,
88 },
89 }
90 },
91 1,
92 )
94 def test_notmad_notears(self):
95 self._train({}, 1)
97 def test_notmad_factor_graphs(self):
98 """
99 Unit tests for factor graph feature of NOTMAD.
100 """
101 self._train({"num_factors": 3}, 1)
104if __name__ == "__main__":
105 unittest.main()