Coverage for contextualized/dags/lightning_modules.py: 95%

152 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-21 13:38 -0400

1import numpy as np 

2import torch 

3from torch import nn 

4from torch.utils.data import DataLoader, TensorDataset 

5import pytorch_lightning as pl 

6from contextualized.functions import identity_link 

7from contextualized.dags.graph_utils import ( 

8 project_to_dag_torch, 

9 trim_params, 

10 dag_pred, 

11 dag_pred_with_factors, 

12) 

13from contextualized.dags.losses import ( 

14 dag_loss_notears, 

15 dag_loss_dagma, 

16 dag_loss_poly, 

17 l1_loss, 

18 mse_loss, 

19 linear_sem_loss, 

20 linear_sem_loss_with_factors, 

21) 

22from contextualized.modules import ENCODERS, Explainer 

23 

24DAG_LOSSES = { 

25 "NOTEARS": dag_loss_notears, 

26 "DAGMA": dag_loss_dagma, 

27 "poly": dag_loss_poly, 

28} 

29DEFAULT_DAG_LOSS_TYPE = "NOTEARS" 

30DEFAULT_DAG_LOSS_PARAMS = { 

31 "NOTEARS": { 

32 "alpha": 1e-1, 

33 "rho": 1e-2, 

34 "tol": 0.25, 

35 "use_dynamic_alpha_rho": False, 

36 }, 

37 "DAGMA": {"s": 1, "alpha": 1e0}, 

38 "poly": {}, 

39} 

40DEFAULT_SS_PARAMS = { 

41 "l1": 0.0, 

42 "dag": { 

43 "loss_type": "NOTEARS", 

44 "params": { 

45 "alpha": 1e-1, 

46 "rho": 1e-2, 

47 "h_old": 0.0, 

48 "tol": 0.25, 

49 "use_dynamic_alpha_rho": False, 

50 }, 

51 }, 

52} 

53DEFAULT_ARCH_PARAMS = { 

54 "l1": 0.0, 

55 "dag": { 

56 "loss_type": "NOTEARS", 

57 "params": { 

58 "alpha": 0.0, 

59 "rho": 0.0, 

60 "h_old": 0.0, 

61 "tol": 0.25, 

62 "use_dynamic_alpha_rho": False, 

63 }, 

64 }, 

65 "init_mat": None, 

66 "num_factors": 0, 

67 "factor_mat_l1": 0.0, 

68 "num_archetypes": 4, 

69} 

70DEFAULT_ENCODER_KWARGS = { 

71 "type": "mlp", 

72 "params": {"width": 32, "layers": 2, "link_fn": identity_link}, 

73} 

74DEFAULT_OPT_PARAMS = { 

75 "learning_rate": 1e-3, 

76 "step": 50, 

77} 

78 

79 

80class NOTMAD(pl.LightningModule): 

81 """ 

82 NOTMAD model 

83 """ 

84 

85 def __init__( 

86 self, 

87 context_dim, 

88 x_dim, 

89 sample_specific_loss_params=DEFAULT_SS_PARAMS, 

90 archetype_loss_params=DEFAULT_ARCH_PARAMS, 

91 opt_params=DEFAULT_OPT_PARAMS, 

92 encoder_kwargs=DEFAULT_ENCODER_KWARGS, 

93 **kwargs, 

94 ): 

95 """Initialize NOTMAD. 

96 

97 Args: 

98 context_dim (int): context dimensionality 

99 x_dim (int): predictor dimensionality 

100 

101 Kwargs: 

102 Explainer Kwargs 

103 ---------------- 

104 init_mat (np.array): 3D Custom initial weights for each archetype. Defaults to None. 

105 num_archetypes (int:4): Number of archetypes in explainer 

106 

107 Encoder Kwargs 

108 ---------------- 

109 encoder_kwargs(dict): Dictionary of width, layers, and link_fn associated with encoder. 

110 

111 Optimization Kwargs 

112 ------------------- 

113 learning_rate(float): Optimizer learning rate 

114 opt_step(int): Optimizer step size 

115 

116 Loss Kwargs 

117 ----------- 

118 sample_specific_loss_params (dict of str: int): Dict of params used by NOTEARS loss (l1, alpha, rho) 

119 archetype_loss_params (dict of str: int): Dict of params used by Archetype loss (l1, alpha, rho) 

120 

121 """ 

122 super(NOTMAD, self).__init__() 

123 

124 # dataset params 

125 self.context_dim = context_dim 

126 self.x_dim = x_dim 

127 self.num_archetypes = archetype_loss_params.get( 

128 "num_archetypes", DEFAULT_ARCH_PARAMS["num_archetypes"] 

129 ) 

130 num_factors = archetype_loss_params.pop("num_factors", 0) 

131 if 0 < num_factors < self.x_dim: 

132 self.latent_dim = num_factors 

133 else: 

134 if num_factors < 0: 

135 print( 

136 f"Requested num_factors={num_factors}, but this should be a positive integer." 

137 ) 

138 if num_factors > self.x_dim: 

139 print( 

140 f"Requested num_factors={num_factors}, but this should be smaller than x_dim={self.x_dim}." 

141 ) 

142 if num_factors == self.x_dim: 

143 print( 

144 f"Requested num_factors={num_factors}, but this equals x_dim={self.x_dim}, so ignoring." 

145 ) 

146 self.latent_dim = self.x_dim 

147 

148 # DAG regularizers 

149 self.ss_dag_params = sample_specific_loss_params["dag"].get( 

150 "params", 

151 DEFAULT_DAG_LOSS_PARAMS[ 

152 sample_specific_loss_params["dag"]["loss_type"] 

153 ].copy(), 

154 ) 

155 

156 self.arch_dag_params = archetype_loss_params["dag"].get( 

157 "params", 

158 DEFAULT_DAG_LOSS_PARAMS[archetype_loss_params["dag"]["loss_type"]].copy(), 

159 ) 

160 

161 self.val_dag_loss_params = {"alpha": 1e0, "rho": 1e0} 

162 self.ss_dag_loss = DAG_LOSSES[sample_specific_loss_params["dag"]["loss_type"]] 

163 self.arch_dag_loss = DAG_LOSSES[archetype_loss_params["dag"]["loss_type"]] 

164 

165 # Sparsity regularizers 

166 self.arch_l1 = archetype_loss_params.get("l1", 0.0) 

167 self.ss_l1 = sample_specific_loss_params.get("l1", 0.0) 

168 

169 # Archetype params 

170 self.init_mat = archetype_loss_params.get("init_mat", None) 

171 self.factor_mat_l1 = archetype_loss_params.get("factor_mat_l1", 0.0) 

172 

173 # opt params 

174 self.learning_rate = opt_params.get("learning_rate", 1e-3) 

175 self.opt_step = opt_params.get("opt_step", 50) 

176 # self.project_distance = 0.1 

177 

178 # layers 

179 self.encoder = ENCODERS[encoder_kwargs["type"]]( 

180 context_dim, 

181 self.num_archetypes, 

182 **encoder_kwargs["params"], 

183 ) 

184 self.register_buffer( 

185 "diag_mask", 

186 torch.ones(self.latent_dim, self.latent_dim) - torch.eye(self.latent_dim), 

187 ) 

188 self.explainer = Explainer( 

189 self.num_archetypes, (self.latent_dim, self.latent_dim) 

190 ) 

191 self.explainer.set_archetypes( 

192 self._mask(self.explainer.get_archetypes()) 

193 ) # intialize archetypes with 0 diagonal 

194 if self.latent_dim != self.x_dim: 

195 factor_mat_init = torch.rand([self.latent_dim, self.x_dim]) * 2e-2 - 1e-2 

196 self.factor_mat_raw = nn.parameter.Parameter( 

197 factor_mat_init, requires_grad=True 

198 ) 

199 self.factor_softmax = nn.Softmax( 

200 dim=0 

201 ) # Sums to one along the latent factor axis, so each feature should only be projected to a single factor. 

202 

203 self.training_step_outputs = [] 

204 

205 def forward(self, context): 

206 subtype = self.encoder(context) 

207 out = self.explainer(subtype) 

208 return out 

209 

210 def configure_optimizers(self): 

211 optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate) 

212 sch = torch.optim.lr_scheduler.StepLR( 

213 optimizer, step_size=self.opt_step, gamma=0.5 

214 ) 

215 # learning rate scheduler 

216 return { 

217 "optimizer": optimizer, 

218 "lr_scheduler": { 

219 "scheduler": sch, 

220 "monitor": "train_loss", 

221 }, 

222 } 

223 

224 def _factor_mat(self): 

225 return self.factor_softmax(self.factor_mat_raw) 

226 

227 def _batch_loss(self, batch, batch_idx): 

228 _, x_true = batch 

229 w_pred = self.predict_step(batch, batch_idx) 

230 if self.latent_dim < self.x_dim: 

231 mse_term = linear_sem_loss_with_factors(x_true, w_pred, self._factor_mat()) 

232 else: 

233 mse_term = linear_sem_loss(x_true, w_pred) 

234 l1_term = l1_loss(w_pred, self.ss_l1) 

235 dag_term = self.ss_dag_loss(w_pred, **self.ss_dag_params) 

236 notears = mse_term + l1_term + dag_term 

237 W_arch = self.explainer.get_archetypes() 

238 arch_l1_term = l1_loss(W_arch, self.arch_l1) 

239 arch_dag_term = len(W_arch) * self.arch_dag_loss(W_arch, **self.arch_dag_params) 

240 # todo: scale archetype loss? 

241 if self.latent_dim < self.x_dim: 

242 factor_mat_term = l1_loss(self.factor_mat_raw, self.factor_mat_l1) 

243 loss = notears + arch_l1_term + arch_dag_term + factor_mat_term 

244 return ( 

245 loss, 

246 notears.detach(), 

247 mse_term.detach(), 

248 l1_term.detach(), 

249 dag_term.detach(), 

250 arch_l1_term.detach(), 

251 arch_dag_term.detach(), 

252 factor_mat_term.detach(), 

253 ) 

254 else: 

255 loss = notears + arch_l1_term + arch_dag_term 

256 return ( 

257 loss, 

258 notears.detach(), 

259 mse_term.detach(), 

260 l1_term.detach(), 

261 dag_term.detach(), 

262 arch_l1_term.detach(), 

263 arch_dag_term.detach(), 

264 0.0, 

265 ) 

266 

267 def training_step(self, batch, batch_idx): 

268 ( 

269 loss, 

270 notears, 

271 mse_term, 

272 l1_term, 

273 dag_term, 

274 arch_l1_term, 

275 arch_dag_term, 

276 factor_mat_term, 

277 ) = self._batch_loss(batch, batch_idx) 

278 ret = { 

279 "loss": loss, 

280 "train_loss": loss, 

281 "train_mse_loss": mse_term, 

282 "train_l1_loss": l1_term, 

283 "train_dag_loss": dag_term, 

284 "train_arch_l1_loss": arch_l1_term, 

285 "train_arch_dag_loss": arch_dag_term, 

286 "train_factor_l1_loss": factor_mat_term, 

287 } 

288 self.log_dict(ret) 

289 ret.update( 

290 { 

291 "train_batch": batch, 

292 "train_batch_idx": batch_idx, 

293 } 

294 ) 

295 self.training_step_outputs.append(ret) 

296 return ret 

297 

298 def test_step(self, batch, batch_idx): 

299 ( 

300 loss, 

301 notears, 

302 mse_term, 

303 l1_term, 

304 dag_term, 

305 arch_l1_term, 

306 arch_dag_term, 

307 factor_mat_term, 

308 ) = self._batch_loss(batch, batch_idx) 

309 ret = { 

310 "test_loss": loss, 

311 "test_mse_loss": mse_term, 

312 "test_l1_loss": l1_term, 

313 "test_dag_loss": dag_term, 

314 "test_arch_l1_loss": arch_l1_term, 

315 "test_arch_dag_loss": arch_dag_term, 

316 "test_factor_l1_loss": factor_mat_term, 

317 } 

318 self.log_dict(ret) 

319 return ret 

320 

321 def validation_step(self, batch, batch_idx): 

322 _, x_true = batch 

323 w_pred = self.predict_step(batch, batch_idx) 

324 if self.latent_dim < self.x_dim: 

325 X_pred = dag_pred_with_factors(x_true, w_pred, self._factor_mat()) 

326 else: 

327 X_pred = dag_pred(x_true, w_pred) 

328 mse_term = 0.5 * x_true.shape[-1] * mse_loss(x_true, X_pred) 

329 l1_term = l1_loss(w_pred, self.ss_l1).mean() 

330 # ignore archetype loss, use constant alpha/rho upper bound for validation 

331 dag_term = self.ss_dag_loss(w_pred, **self.val_dag_loss_params).mean() 

332 if self.latent_dim < self.x_dim: 

333 factor_mat_term = l1_loss(self.factor_mat_raw, self.factor_mat_l1) 

334 loss = mse_term + l1_term + dag_term + factor_mat_term 

335 ret = { 

336 "val_loss": loss, 

337 "val_mse_loss": mse_term, 

338 "val_l1_loss": l1_term, 

339 "val_dag_loss": dag_term, 

340 "val_factor_l1_loss": factor_mat_term, 

341 } 

342 else: 

343 loss = mse_term + l1_term + dag_term 

344 ret = { 

345 "val_loss": loss, 

346 "val_mse_loss": mse_term, 

347 "val_l1_loss": l1_term, 

348 "val_dag_loss": dag_term, 

349 "val_factor_l1_loss": 0.0, 

350 } 

351 self.log_dict(ret) 

352 return ret 

353 

354 def predict_step(self, batch, batch_idx): 

355 c, _ = batch 

356 w_pred = self(c) 

357 return self._mask(w_pred) 

358 

359 def _project_factor_graph_to_var(self, w_preds): 

360 """ 

361 Projects the graphs in factor space to variable space. 

362 w_preds: n x latent x latent 

363 """ 

364 P_sums = self._factor_mat().sum(axis=1) 

365 w_preds = np.tensordot( 

366 w_preds, 

367 (self._factor_mat().T.detach().numpy() / P_sums.detach().numpy()).T, 

368 axes=1, 

369 ) # n x latent x x_dims 

370 w_preds = np.swapaxes(w_preds, 1, 2) # n x x_dims x latent 

371 w_preds = np.tensordot( 

372 w_preds, self._factor_mat().detach().numpy(), axes=1 

373 ) # n x x_dims x x_dims 

374 w_preds = np.swapaxes(w_preds, 1, 2) # n x x_dims x x_dims 

375 return w_preds 

376 

377 def _format_params(self, w_preds, **kwargs): 

378 """ 

379 Format the parameters to be returned by the model. 

380 args: 

381 w_preds: the predicted parameters 

382 project_to_dag: whether to project the parameters to a DAG 

383 threshold: the threshold to use for minimum edge weight magnitude 

384 factors: whether to return the factor graph or the variable graph. 

385 """ 

386 if 0 < self.latent_dim < self.x_dim and not kwargs.get("factors", False): 

387 w_preds = self._project_factor_graph_to_var(w_preds) 

388 if kwargs.get("project_to_dag", False): 

389 try: 

390 w_preds = np.array([project_to_dag_torch(w)[0] for w in w_preds]) 

391 except: 

392 print("Error, couldn't project to dag. Returning normal predictions.") 

393 return trim_params(w_preds, thresh=kwargs.get("threshold", 0.0)) 

394 

395 def on_train_epoch_end(self, logs=None): 

396 training_step_outputs = self.training_step_outputs 

397 # update alpha/rho based on average end-of-epoch dag loss 

398 epoch_samples = sum( 

399 [len(ret["train_batch"][0]) for ret in training_step_outputs] 

400 ) 

401 epoch_dag_loss = 0 

402 for ret in training_step_outputs: 

403 batch_dag_loss = self.ss_dag_loss( 

404 self.predict_step(ret["train_batch"], ret["train_batch_idx"]), 

405 **self.ss_dag_params, 

406 ).detach() 

407 epoch_dag_loss += ( 

408 len(ret["train_batch"][0]) / epoch_samples * batch_dag_loss 

409 ) 

410 self.ss_dag_params = self._maybe_update_alpha_rho( 

411 epoch_dag_loss, self.ss_dag_params 

412 ) 

413 self.arch_dag_params = self._maybe_update_alpha_rho( 

414 epoch_dag_loss, self.arch_dag_params 

415 ) 

416 self.training_step_outputs.clear() # free memory 

417 

418 def _maybe_update_alpha_rho(self, epoch_dag_loss, dag_params): 

419 """ 

420 Update alpha/rho use_dynamic_alpha_rho is True. 

421 """ 

422 if ( 

423 dag_params.get("use_dynamic_alpha_rho", False) 

424 and epoch_dag_loss 

425 > dag_params.get("tol", 0.25) * dag_params.get("h_old", 0) 

426 and dag_params["alpha"] < 1e12 

427 and dag_params["rho"] < 1e12 

428 ): 

429 dag_params["alpha"] = ( 

430 dag_params["alpha"] + dag_params["rho"] * epoch_dag_loss 

431 ) 

432 dag_params["rho"] *= dag_params.get("rho_mult", 1.1) 

433 dag_params["h_old"] = epoch_dag_loss 

434 return dag_params 

435 

436 # helpers 

437 def _mask(self, W): 

438 """ 

439 Mask out the diagonal of the adjacency matrix. 

440 """ 

441 return torch.multiply(W, self.diag_mask) 

442 

443 def dataloader(self, C, X, **kwargs): 

444 """ 

445 

446 :param C: 

447 :param X: 

448 

449 """ 

450 kwargs["num_workers"] = kwargs.get("num_workers", 0) 

451 kwargs["batch_size"] = kwargs.get("batch_size", 32) 

452 dataset = TensorDataset( 

453 torch.tensor(C, dtype=torch.float), 

454 torch.tensor(X, dtype=torch.float), 

455 ) 

456 return DataLoader(dataset=dataset, shuffle=False, **kwargs)