Coverage for contextualized/dags/lightning_modules.py: 95%
152 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-21 13:38 -0400
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-21 13:38 -0400
1import numpy as np
2import torch
3from torch import nn
4from torch.utils.data import DataLoader, TensorDataset
5import pytorch_lightning as pl
6from contextualized.functions import identity_link
7from contextualized.dags.graph_utils import (
8 project_to_dag_torch,
9 trim_params,
10 dag_pred,
11 dag_pred_with_factors,
12)
13from contextualized.dags.losses import (
14 dag_loss_notears,
15 dag_loss_dagma,
16 dag_loss_poly,
17 l1_loss,
18 mse_loss,
19 linear_sem_loss,
20 linear_sem_loss_with_factors,
21)
22from contextualized.modules import ENCODERS, Explainer
24DAG_LOSSES = {
25 "NOTEARS": dag_loss_notears,
26 "DAGMA": dag_loss_dagma,
27 "poly": dag_loss_poly,
28}
29DEFAULT_DAG_LOSS_TYPE = "NOTEARS"
30DEFAULT_DAG_LOSS_PARAMS = {
31 "NOTEARS": {
32 "alpha": 1e-1,
33 "rho": 1e-2,
34 "tol": 0.25,
35 "use_dynamic_alpha_rho": False,
36 },
37 "DAGMA": {"s": 1, "alpha": 1e0},
38 "poly": {},
39}
40DEFAULT_SS_PARAMS = {
41 "l1": 0.0,
42 "dag": {
43 "loss_type": "NOTEARS",
44 "params": {
45 "alpha": 1e-1,
46 "rho": 1e-2,
47 "h_old": 0.0,
48 "tol": 0.25,
49 "use_dynamic_alpha_rho": False,
50 },
51 },
52}
53DEFAULT_ARCH_PARAMS = {
54 "l1": 0.0,
55 "dag": {
56 "loss_type": "NOTEARS",
57 "params": {
58 "alpha": 0.0,
59 "rho": 0.0,
60 "h_old": 0.0,
61 "tol": 0.25,
62 "use_dynamic_alpha_rho": False,
63 },
64 },
65 "init_mat": None,
66 "num_factors": 0,
67 "factor_mat_l1": 0.0,
68 "num_archetypes": 4,
69}
70DEFAULT_ENCODER_KWARGS = {
71 "type": "mlp",
72 "params": {"width": 32, "layers": 2, "link_fn": identity_link},
73}
74DEFAULT_OPT_PARAMS = {
75 "learning_rate": 1e-3,
76 "step": 50,
77}
80class NOTMAD(pl.LightningModule):
81 """
82 NOTMAD model
83 """
85 def __init__(
86 self,
87 context_dim,
88 x_dim,
89 sample_specific_loss_params=DEFAULT_SS_PARAMS,
90 archetype_loss_params=DEFAULT_ARCH_PARAMS,
91 opt_params=DEFAULT_OPT_PARAMS,
92 encoder_kwargs=DEFAULT_ENCODER_KWARGS,
93 **kwargs,
94 ):
95 """Initialize NOTMAD.
97 Args:
98 context_dim (int): context dimensionality
99 x_dim (int): predictor dimensionality
101 Kwargs:
102 Explainer Kwargs
103 ----------------
104 init_mat (np.array): 3D Custom initial weights for each archetype. Defaults to None.
105 num_archetypes (int:4): Number of archetypes in explainer
107 Encoder Kwargs
108 ----------------
109 encoder_kwargs(dict): Dictionary of width, layers, and link_fn associated with encoder.
111 Optimization Kwargs
112 -------------------
113 learning_rate(float): Optimizer learning rate
114 opt_step(int): Optimizer step size
116 Loss Kwargs
117 -----------
118 sample_specific_loss_params (dict of str: int): Dict of params used by NOTEARS loss (l1, alpha, rho)
119 archetype_loss_params (dict of str: int): Dict of params used by Archetype loss (l1, alpha, rho)
121 """
122 super(NOTMAD, self).__init__()
124 # dataset params
125 self.context_dim = context_dim
126 self.x_dim = x_dim
127 self.num_archetypes = archetype_loss_params.get(
128 "num_archetypes", DEFAULT_ARCH_PARAMS["num_archetypes"]
129 )
130 num_factors = archetype_loss_params.pop("num_factors", 0)
131 if 0 < num_factors < self.x_dim:
132 self.latent_dim = num_factors
133 else:
134 if num_factors < 0:
135 print(
136 f"Requested num_factors={num_factors}, but this should be a positive integer."
137 )
138 if num_factors > self.x_dim:
139 print(
140 f"Requested num_factors={num_factors}, but this should be smaller than x_dim={self.x_dim}."
141 )
142 if num_factors == self.x_dim:
143 print(
144 f"Requested num_factors={num_factors}, but this equals x_dim={self.x_dim}, so ignoring."
145 )
146 self.latent_dim = self.x_dim
148 # DAG regularizers
149 self.ss_dag_params = sample_specific_loss_params["dag"].get(
150 "params",
151 DEFAULT_DAG_LOSS_PARAMS[
152 sample_specific_loss_params["dag"]["loss_type"]
153 ].copy(),
154 )
156 self.arch_dag_params = archetype_loss_params["dag"].get(
157 "params",
158 DEFAULT_DAG_LOSS_PARAMS[archetype_loss_params["dag"]["loss_type"]].copy(),
159 )
161 self.val_dag_loss_params = {"alpha": 1e0, "rho": 1e0}
162 self.ss_dag_loss = DAG_LOSSES[sample_specific_loss_params["dag"]["loss_type"]]
163 self.arch_dag_loss = DAG_LOSSES[archetype_loss_params["dag"]["loss_type"]]
165 # Sparsity regularizers
166 self.arch_l1 = archetype_loss_params.get("l1", 0.0)
167 self.ss_l1 = sample_specific_loss_params.get("l1", 0.0)
169 # Archetype params
170 self.init_mat = archetype_loss_params.get("init_mat", None)
171 self.factor_mat_l1 = archetype_loss_params.get("factor_mat_l1", 0.0)
173 # opt params
174 self.learning_rate = opt_params.get("learning_rate", 1e-3)
175 self.opt_step = opt_params.get("opt_step", 50)
176 # self.project_distance = 0.1
178 # layers
179 self.encoder = ENCODERS[encoder_kwargs["type"]](
180 context_dim,
181 self.num_archetypes,
182 **encoder_kwargs["params"],
183 )
184 self.register_buffer(
185 "diag_mask",
186 torch.ones(self.latent_dim, self.latent_dim) - torch.eye(self.latent_dim),
187 )
188 self.explainer = Explainer(
189 self.num_archetypes, (self.latent_dim, self.latent_dim)
190 )
191 self.explainer.set_archetypes(
192 self._mask(self.explainer.get_archetypes())
193 ) # intialize archetypes with 0 diagonal
194 if self.latent_dim != self.x_dim:
195 factor_mat_init = torch.rand([self.latent_dim, self.x_dim]) * 2e-2 - 1e-2
196 self.factor_mat_raw = nn.parameter.Parameter(
197 factor_mat_init, requires_grad=True
198 )
199 self.factor_softmax = nn.Softmax(
200 dim=0
201 ) # Sums to one along the latent factor axis, so each feature should only be projected to a single factor.
203 self.training_step_outputs = []
205 def forward(self, context):
206 subtype = self.encoder(context)
207 out = self.explainer(subtype)
208 return out
210 def configure_optimizers(self):
211 optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
212 sch = torch.optim.lr_scheduler.StepLR(
213 optimizer, step_size=self.opt_step, gamma=0.5
214 )
215 # learning rate scheduler
216 return {
217 "optimizer": optimizer,
218 "lr_scheduler": {
219 "scheduler": sch,
220 "monitor": "train_loss",
221 },
222 }
224 def _factor_mat(self):
225 return self.factor_softmax(self.factor_mat_raw)
227 def _batch_loss(self, batch, batch_idx):
228 _, x_true = batch
229 w_pred = self.predict_step(batch, batch_idx)
230 if self.latent_dim < self.x_dim:
231 mse_term = linear_sem_loss_with_factors(x_true, w_pred, self._factor_mat())
232 else:
233 mse_term = linear_sem_loss(x_true, w_pred)
234 l1_term = l1_loss(w_pred, self.ss_l1)
235 dag_term = self.ss_dag_loss(w_pred, **self.ss_dag_params)
236 notears = mse_term + l1_term + dag_term
237 W_arch = self.explainer.get_archetypes()
238 arch_l1_term = l1_loss(W_arch, self.arch_l1)
239 arch_dag_term = len(W_arch) * self.arch_dag_loss(W_arch, **self.arch_dag_params)
240 # todo: scale archetype loss?
241 if self.latent_dim < self.x_dim:
242 factor_mat_term = l1_loss(self.factor_mat_raw, self.factor_mat_l1)
243 loss = notears + arch_l1_term + arch_dag_term + factor_mat_term
244 return (
245 loss,
246 notears.detach(),
247 mse_term.detach(),
248 l1_term.detach(),
249 dag_term.detach(),
250 arch_l1_term.detach(),
251 arch_dag_term.detach(),
252 factor_mat_term.detach(),
253 )
254 else:
255 loss = notears + arch_l1_term + arch_dag_term
256 return (
257 loss,
258 notears.detach(),
259 mse_term.detach(),
260 l1_term.detach(),
261 dag_term.detach(),
262 arch_l1_term.detach(),
263 arch_dag_term.detach(),
264 0.0,
265 )
267 def training_step(self, batch, batch_idx):
268 (
269 loss,
270 notears,
271 mse_term,
272 l1_term,
273 dag_term,
274 arch_l1_term,
275 arch_dag_term,
276 factor_mat_term,
277 ) = self._batch_loss(batch, batch_idx)
278 ret = {
279 "loss": loss,
280 "train_loss": loss,
281 "train_mse_loss": mse_term,
282 "train_l1_loss": l1_term,
283 "train_dag_loss": dag_term,
284 "train_arch_l1_loss": arch_l1_term,
285 "train_arch_dag_loss": arch_dag_term,
286 "train_factor_l1_loss": factor_mat_term,
287 }
288 self.log_dict(ret)
289 ret.update(
290 {
291 "train_batch": batch,
292 "train_batch_idx": batch_idx,
293 }
294 )
295 self.training_step_outputs.append(ret)
296 return ret
298 def test_step(self, batch, batch_idx):
299 (
300 loss,
301 notears,
302 mse_term,
303 l1_term,
304 dag_term,
305 arch_l1_term,
306 arch_dag_term,
307 factor_mat_term,
308 ) = self._batch_loss(batch, batch_idx)
309 ret = {
310 "test_loss": loss,
311 "test_mse_loss": mse_term,
312 "test_l1_loss": l1_term,
313 "test_dag_loss": dag_term,
314 "test_arch_l1_loss": arch_l1_term,
315 "test_arch_dag_loss": arch_dag_term,
316 "test_factor_l1_loss": factor_mat_term,
317 }
318 self.log_dict(ret)
319 return ret
321 def validation_step(self, batch, batch_idx):
322 _, x_true = batch
323 w_pred = self.predict_step(batch, batch_idx)
324 if self.latent_dim < self.x_dim:
325 X_pred = dag_pred_with_factors(x_true, w_pred, self._factor_mat())
326 else:
327 X_pred = dag_pred(x_true, w_pred)
328 mse_term = 0.5 * x_true.shape[-1] * mse_loss(x_true, X_pred)
329 l1_term = l1_loss(w_pred, self.ss_l1).mean()
330 # ignore archetype loss, use constant alpha/rho upper bound for validation
331 dag_term = self.ss_dag_loss(w_pred, **self.val_dag_loss_params).mean()
332 if self.latent_dim < self.x_dim:
333 factor_mat_term = l1_loss(self.factor_mat_raw, self.factor_mat_l1)
334 loss = mse_term + l1_term + dag_term + factor_mat_term
335 ret = {
336 "val_loss": loss,
337 "val_mse_loss": mse_term,
338 "val_l1_loss": l1_term,
339 "val_dag_loss": dag_term,
340 "val_factor_l1_loss": factor_mat_term,
341 }
342 else:
343 loss = mse_term + l1_term + dag_term
344 ret = {
345 "val_loss": loss,
346 "val_mse_loss": mse_term,
347 "val_l1_loss": l1_term,
348 "val_dag_loss": dag_term,
349 "val_factor_l1_loss": 0.0,
350 }
351 self.log_dict(ret)
352 return ret
354 def predict_step(self, batch, batch_idx):
355 c, _ = batch
356 w_pred = self(c)
357 return self._mask(w_pred)
359 def _project_factor_graph_to_var(self, w_preds):
360 """
361 Projects the graphs in factor space to variable space.
362 w_preds: n x latent x latent
363 """
364 P_sums = self._factor_mat().sum(axis=1)
365 w_preds = np.tensordot(
366 w_preds,
367 (self._factor_mat().T.detach().numpy() / P_sums.detach().numpy()).T,
368 axes=1,
369 ) # n x latent x x_dims
370 w_preds = np.swapaxes(w_preds, 1, 2) # n x x_dims x latent
371 w_preds = np.tensordot(
372 w_preds, self._factor_mat().detach().numpy(), axes=1
373 ) # n x x_dims x x_dims
374 w_preds = np.swapaxes(w_preds, 1, 2) # n x x_dims x x_dims
375 return w_preds
377 def _format_params(self, w_preds, **kwargs):
378 """
379 Format the parameters to be returned by the model.
380 args:
381 w_preds: the predicted parameters
382 project_to_dag: whether to project the parameters to a DAG
383 threshold: the threshold to use for minimum edge weight magnitude
384 factors: whether to return the factor graph or the variable graph.
385 """
386 if 0 < self.latent_dim < self.x_dim and not kwargs.get("factors", False):
387 w_preds = self._project_factor_graph_to_var(w_preds)
388 if kwargs.get("project_to_dag", False):
389 try:
390 w_preds = np.array([project_to_dag_torch(w)[0] for w in w_preds])
391 except:
392 print("Error, couldn't project to dag. Returning normal predictions.")
393 return trim_params(w_preds, thresh=kwargs.get("threshold", 0.0))
395 def on_train_epoch_end(self, logs=None):
396 training_step_outputs = self.training_step_outputs
397 # update alpha/rho based on average end-of-epoch dag loss
398 epoch_samples = sum(
399 [len(ret["train_batch"][0]) for ret in training_step_outputs]
400 )
401 epoch_dag_loss = 0
402 for ret in training_step_outputs:
403 batch_dag_loss = self.ss_dag_loss(
404 self.predict_step(ret["train_batch"], ret["train_batch_idx"]),
405 **self.ss_dag_params,
406 ).detach()
407 epoch_dag_loss += (
408 len(ret["train_batch"][0]) / epoch_samples * batch_dag_loss
409 )
410 self.ss_dag_params = self._maybe_update_alpha_rho(
411 epoch_dag_loss, self.ss_dag_params
412 )
413 self.arch_dag_params = self._maybe_update_alpha_rho(
414 epoch_dag_loss, self.arch_dag_params
415 )
416 self.training_step_outputs.clear() # free memory
418 def _maybe_update_alpha_rho(self, epoch_dag_loss, dag_params):
419 """
420 Update alpha/rho use_dynamic_alpha_rho is True.
421 """
422 if (
423 dag_params.get("use_dynamic_alpha_rho", False)
424 and epoch_dag_loss
425 > dag_params.get("tol", 0.25) * dag_params.get("h_old", 0)
426 and dag_params["alpha"] < 1e12
427 and dag_params["rho"] < 1e12
428 ):
429 dag_params["alpha"] = (
430 dag_params["alpha"] + dag_params["rho"] * epoch_dag_loss
431 )
432 dag_params["rho"] *= dag_params.get("rho_mult", 1.1)
433 dag_params["h_old"] = epoch_dag_loss
434 return dag_params
436 # helpers
437 def _mask(self, W):
438 """
439 Mask out the diagonal of the adjacency matrix.
440 """
441 return torch.multiply(W, self.diag_mask)
443 def dataloader(self, C, X, **kwargs):
444 """
446 :param C:
447 :param X:
449 """
450 kwargs["num_workers"] = kwargs.get("num_workers", 0)
451 kwargs["batch_size"] = kwargs.get("batch_size", 32)
452 dataset = TensorDataset(
453 torch.tensor(C, dtype=torch.float),
454 torch.tensor(X, dtype=torch.float),
455 )
456 return DataLoader(dataset=dataset, shuffle=False, **kwargs)