Coverage for contextualized/baselines/networks.py: 94%
171 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-21 13:38 -0400
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-21 13:38 -0400
1"""
2Baseline network models for learning the causal structure of data.
3Includes:
4 - CorrelationNetwork: learns the correlation matrix of the data
5 - MarkovNetwork: learns the Markov blanket of each variable
6 - BayesianNetwork: learns the DAG structure of the data
8 - GroupedNetworks: learns a separate model for each group of data. This is a wrapper around the above models.
9"""
11import numpy as np
12from sklearn.linear_model import LinearRegression
13import torch
14from torch import nn
15from torch.utils.data import DataLoader
16import pytorch_lightning as pl
17from contextualized.dags.graph_utils import project_to_dag_torch, dag_pred
18from contextualized.dags.losses import dag_loss_notears, mse_loss, l1_loss
21class NOTEARSTrainer(pl.Trainer):
22 def predict(self, model, dataloader):
23 preds = super().predict(model, dataloader)
24 return torch.cat(preds)
26 def predict_w(self, model, dataloader, project_to_dag=True):
27 preds = self.predict(model, dataloader)
28 W = model.W.detach() * model.diag_mask
29 if project_to_dag:
30 W = torch.tensor(project_to_dag_torch(W.numpy(force=True))[0])
31 W_batch = W.unsqueeze(0).expand(len(preds), -1, -1)
32 return W_batch.numpy()
35class NOTEARS(pl.LightningModule):
36 """
37 NOTEARS model for learning the DAG structure of data.
38 """
40 def __init__(self, x_dim, l1=1e-3, alpha=1e-8, rho=1e-8, learning_rate=1e-2):
41 super().__init__()
42 self.learning_rate = learning_rate
43 self.l1 = l1
44 self.alpha = alpha
45 self.rho = rho
46 diag_mask = torch.ones(x_dim, x_dim) - torch.eye(x_dim)
47 self.register_buffer("diag_mask", diag_mask)
48 init_mat = (torch.rand(x_dim, x_dim) * 2e-2 - 1e-2) * diag_mask
49 self.W = nn.parameter.Parameter(init_mat, requires_grad=True)
50 self.tolerance = 0.25
51 self.prev_dag = 0.0
53 def forward(self, X):
54 W = self.W * self.diag_mask
55 return dag_pred(X, W)
57 def _batch_loss(self, batch, batch_idx):
58 x_true = batch
59 x_pred = self(x_true)
60 mse = mse_loss(x_true, x_pred)
61 l1 = l1_loss(self.W, self.l1)
62 dag = dag_loss_notears(self.W, alpha=self.alpha, rho=self.rho)
63 return 0.5 * mse + l1 + dag
65 def configure_optimizers(self):
66 return torch.optim.Adam(self.parameters(), lr=self.learning_rate)
68 def training_step(self, batch, batch_idx):
69 loss = self._batch_loss(batch, batch_idx)
70 self.log_dict({"train_loss": loss})
71 return loss
73 def validation_step(self, batch, batch_idx):
74 x_true = batch
75 x_pred = self(x_true).detach()
76 mse = mse_loss(x_true, x_pred)
77 dag = dag_loss_notears(self.W, alpha=1e12, rho=1e12).detach()
78 loss = mse + dag
79 self.log_dict({"val_loss": loss})
80 return loss
82 def test_step(self, batch, batch_idx):
83 loss = self._batch_loss(batch, batch_idx)
84 self.log_dict({"test_loss": loss})
85 return loss
87 def on_train_epoch_end(self, *args, **kwargs):
88 dag = dag_loss_notears(self.W, alpha=self.alpha, rho=self.rho).item()
89 if (
90 dag > self.tolerance * self.prev_dag
91 and self.alpha < 1e12
92 and self.rho < 1e12
93 ):
94 self.alpha = self.alpha + self.rho * dag
95 self.rho = self.rho * 10
96 self.prev_dag = dag
98 def dataloader(self, X, **kwargs):
99 kwargs["batch_size"] = kwargs.get("batch_size", 32)
100 X_tensor = torch.Tensor(X).to(torch.float32)
101 return DataLoader(dataset=X_tensor, **kwargs)
104class CorrelationNetwork:
105 """
106 Standard correlation network fit with linear regression.
107 """
109 def fit(self, X):
110 self.p = X.shape[-1]
111 self.regs = [[LinearRegression() for _ in range(self.p)] for _ in range(self.p)]
112 for i in range(self.p):
113 for j in range(self.p):
114 self.regs[i][j].fit(X[:, j, np.newaxis], X[:, i, np.newaxis])
115 return self
117 def predict(self, n):
118 betas = np.zeros((self.p, self.p))
119 for i in range(self.p):
120 for j in range(self.p):
121 betas[i, j] = self.regs[i][j].coef_.squeeze()
122 corrs = betas * betas.T
123 return np.tile(np.expand_dims(corrs, axis=0), (n, 1, 1))
125 def measure_mses(self, X):
126 mses = np.zeros(len(X))
127 for i in range(self.p):
128 for j in range(self.p):
129 residual = (
130 self.regs[i][j].predict(X[:, j, np.newaxis]) - X[:, i, np.newaxis]
131 )
132 residual = residual[:, 0]
133 mses += (residual**2) / self.p**2
134 return mses
137class MarkovNetwork:
138 """
139 Standard Markov Network fit with neighborhood regression.
140 """
142 def __init__(self, alpha=1e-3):
143 self.alpha = alpha
144 self.p = -1
145 self.regs = []
147 def fit(self, X):
148 self.p = X.shape[-1]
149 self.regs = [LinearRegression() for _ in range(self.p)]
150 for i in range(self.p):
151 mask = np.ones_like(X)
152 mask[:, i] = 0
153 self.regs[i].fit(X * mask, X[:, i, np.newaxis])
154 return self
156 def predict(self, n):
157 betas = np.zeros((self.p, self.p))
158 for i in range(self.p):
159 betas[i] = self.regs[i].coef_.squeeze()
160 betas[i, i] = 0
161 precision = -np.sign(betas) * np.sqrt(np.abs(betas * betas.T))
162 return np.tile(np.expand_dims(precision, axis=0), (n, 1, 1))
164 def measure_mses(self, X):
165 mses = np.zeros(len(X))
166 for i in range(self.p):
167 mask = np.ones_like(X)
168 mask[:, i] = 0
169 residual = self.regs[i].predict(X * mask) - X[:, i, np.newaxis]
170 residual = residual[:, 0]
171 mses += (residual**2) / self.p
172 return mses
175class BayesianNetwork:
176 """
177 A standard Bayesian Network fit with NOTEARS loss.
178 """
180 def __init__(self, **kwargs):
181 self.p = -1
182 self.model = None
183 self.trainer = None
184 self.l1 = kwargs.get("l1", 1e-3)
185 self.alpha = kwargs.get("alpha", 1e-8)
186 self.rho = kwargs.get("rho", 1e-8)
187 self.learning_rate = kwargs.get("learning_rate", 1e-2)
189 def fit(self, X, max_epochs=50):
190 self.p = X.shape[-1]
191 self.model = NOTEARS(
192 self.p,
193 l1=self.l1,
194 alpha=self.alpha,
195 rho=self.rho,
196 learning_rate=self.learning_rate,
197 )
198 dataset = self.model.dataloader(X)
199 accelerator = "gpu" if torch.cuda.is_available() else "cpu"
200 self.trainer = NOTEARSTrainer(
201 max_epochs=max_epochs, accelerator=accelerator, devices=1
202 )
203 self.trainer.fit(self.model, dataset)
204 return self
206 def predict(self, n):
207 dummy_X = np.zeros((n, self.p))
208 dummy_dataset = self.model.dataloader(dummy_X)
209 return self.trainer.predict_w(self.model, dummy_dataset)
211 def measure_mses(self, X):
212 mses = np.zeros(len(X))
213 W_pred = self.model.W.detach()
214 X_preds = (
215 dag_pred(torch.tensor(X, dtype=torch.float32), W_pred).detach().numpy()
216 )
217 return ((X_preds - X) ** 2).mean(axis=1)
220class GroupedNetworks:
221 """
222 Fit a separate network for each group.
223 Wrapper around CorrelationNetwork, MarkovNetwork, or BayesianNetwork.
224 Assumes that the labels are 0-indexed integers and already learned.
225 """
227 def __init__(self, model_class):
228 self.model_class = model_class
229 self.models = {}
230 self.p = -1
232 def fit(self, X, labels):
233 self.p = X.shape[-1]
234 for label in np.unique(labels):
235 model = self.model_class().fit(X[labels == label])
236 self.models[label] = model
237 return self
239 def predict(self, labels):
240 networks = np.zeros((len(labels), self.p, self.p))
241 for label in np.unique(labels):
242 label_idx = labels == label
243 networks[label_idx] = self.models[label].predict(label_idx.sum())
244 return networks
246 def measure_mses(self, X, labels):
247 mses = np.zeros(len(X))
248 for label in np.unique(labels):
249 label_idx = labels == label
250 mses[label_idx] = self.models[label].measure_mses(X[label_idx])
251 return mses