Coverage for contextualized/regression/lightning_modules.py: 97%

303 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-21 13:49 -0400

1""" 

2This class contains tools for solving context-specific regression problems: 

3 

4Y = g(beta(C)*X + mu(C)) 

5 

6C: Context 

7X: Explainable features 

8Y: Outcome, aka response (regreession) or labels (classification) 

9g: Link Function for contextualized generalized linear models. 

10 

11Implemented with PyTorch Lightning 

12""" 

13 

14from abc import abstractmethod 

15import numpy as np 

16import torch 

17from torch.utils.data import DataLoader 

18import pytorch_lightning as pl 

19 

20from contextualized.regression.regularizers import REGULARIZERS 

21from contextualized.regression.losses import MSE 

22from contextualized.functions import LINK_FUNCTIONS 

23 

24from contextualized.regression.metamodels import ( 

25 NaiveMetamodel, 

26 SubtypeMetamodel, 

27 MultitaskMetamodel, 

28 TasksplitMetamodel, 

29 SINGLE_TASK_METAMODELS, 

30 MULTITASK_METAMODELS, 

31) 

32from contextualized.regression.datasets import ( 

33 DataIterable, 

34 MultivariateDataset, 

35 UnivariateDataset, 

36 MultitaskMultivariateDataset, 

37 MultitaskUnivariateDataset, 

38) 

39 

40 

41class ContextualizedRegressionBase(pl.LightningModule): 

42 """ 

43 Abstract class for Contextualized Regression. 

44 """ 

45 

46 def __init__( 

47 self, 

48 *args, 

49 learning_rate=1e-3, 

50 metamodel_type="subtype", 

51 fit_intercept=True, 

52 link_fn=LINK_FUNCTIONS["identity"], 

53 loss_fn=MSE, 

54 model_regularizer=REGULARIZERS["none"], 

55 base_y_predictor=None, 

56 base_param_predictor=None, 

57 **kwargs, 

58 ): 

59 super().__init__() 

60 self.learning_rate = learning_rate 

61 self.metamodel_type = metamodel_type 

62 self.fit_intercept = fit_intercept 

63 self.link_fn = link_fn 

64 self.loss_fn = loss_fn 

65 self.model_regularizer = model_regularizer 

66 self.base_y_predictor = base_y_predictor 

67 self.base_param_predictor = base_param_predictor 

68 self._build_metamodel(*args, **kwargs) 

69 

70 @abstractmethod 

71 def _build_metamodel(self, *args, **kwargs): 

72 """ 

73 

74 :param *args: 

75 :param **kwargs: 

76 

77 """ 

78 # builds the metamodel 

79 kwargs["univariate"] = False 

80 self.metamodel = SINGLE_TASK_METAMODELS[self.metamodel_type](*args, **kwargs) 

81 

82 @abstractmethod 

83 def dataloader(self, C, X, Y, batch_size=32): 

84 """ 

85 

86 :param C: 

87 :param X: 

88 :param Y: 

89 :param batch_size: (Default value = 32) 

90 

91 """ 

92 # returns the dataloader for this class 

93 

94 @abstractmethod 

95 def _batch_loss(self, batch, batch_idx): 

96 """ 

97 

98 :param batch: 

99 :param batch_idx: 

100 

101 """ 

102 # MSE loss by default 

103 

104 @abstractmethod 

105 def predict_step(self, batch, batch_idx, dataloader_idx=0): 

106 """ 

107 

108 :param batch: 

109 :param batch_idx: 

110 :param dataload_idx: 

111 

112 """ 

113 # returns predicted params on the given batch 

114 

115 @abstractmethod 

116 def _params_reshape(self, beta_preds, mu_preds, dataloader): 

117 """ 

118 

119 :param beta_preds: 

120 :param mu_preds: 

121 :param dataloader: 

122 

123 """ 

124 # reshapes the batch parameter predictions into beta (y_dim, x_dim) 

125 

126 @abstractmethod 

127 def _y_reshape(self, y_preds, dataloader): 

128 """ 

129 

130 :param y_preds: 

131 :param dataloader: 

132 

133 """ 

134 # reshapes the batch y predictions into a desirable format 

135 

136 def forward(self, *args, **kwargs): 

137 """ 

138 

139 :param *args: 

140 

141 """ 

142 beta, mu = self.metamodel(*args) 

143 if not self.fit_intercept: 

144 mu = torch.zeros_like(mu) 

145 if self.base_param_predictor is not None: 

146 base_beta, base_mu = self.base_param_predictor.predict_params(*args) 

147 beta = beta + base_beta.to(beta.device) 

148 mu = mu + base_mu.to(mu.device) 

149 return beta, mu 

150 

151 def configure_optimizers(self): 

152 """ 

153 Set up optimizer. 

154 """ 

155 optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate) 

156 return optimizer 

157 

158 def training_step(self, batch, batch_idx): 

159 """ 

160 

161 :param batch: 

162 :param batch_idx: 

163 

164 """ 

165 loss = self._batch_loss(batch, batch_idx) 

166 self.log_dict({"train_loss": loss}) 

167 return loss 

168 

169 def validation_step(self, batch, batch_idx): 

170 """ 

171 

172 :param batch: 

173 :param batch_idx: 

174 

175 """ 

176 loss = self._batch_loss(batch, batch_idx) 

177 self.log_dict({"val_loss": loss}) 

178 return loss 

179 

180 def test_step(self, batch, batch_idx): 

181 """ 

182 

183 :param batch: 

184 :param batch_idx: 

185 

186 """ 

187 loss = self._batch_loss(batch, batch_idx) 

188 self.log_dict({"test_loss": loss}) 

189 return loss 

190 

191 def _predict_from_models(self, X, beta_hat, mu_hat): 

192 """ 

193 

194 :param X: 

195 :param beta_hat: 

196 :param mu_hat: 

197 

198 """ 

199 return self.link_fn((beta_hat * X).sum(axis=-1).unsqueeze(-1) + mu_hat) 

200 

201 def _predict_y(self, C, X, beta_hat, mu_hat): 

202 """ 

203 

204 :param C: 

205 :param X: 

206 :param beta_hat: 

207 :param mu_hat: 

208 

209 """ 

210 Y = self._predict_from_models(X, beta_hat, mu_hat) 

211 if self.base_y_predictor is not None: 

212 Y_base = self.base_y_predictor.predict_y(C, X) 

213 Y = Y + Y_base.to(Y.device) 

214 return Y 

215 

216 def _dataloader(self, C, X, Y, dataset_constructor, **kwargs): 

217 """ 

218 

219 :param C: 

220 :param X: 

221 :param Y: 

222 :param dataset_constructor: 

223 :param **kwargs: 

224 

225 """ 

226 kwargs["num_workers"] = kwargs.get("num_workers", 0) 

227 kwargs["batch_size"] = kwargs.get("batch_size", 32) 

228 return DataLoader(dataset=DataIterable(dataset_constructor(C, X, Y)), **kwargs) 

229 

230 

231class NaiveContextualizedRegression(ContextualizedRegressionBase): 

232 """See NaiveMetamodel""" 

233 

234 def _build_metamodel(self, *args, **kwargs): 

235 """ 

236 

237 :param *args: 

238 :param **kwargs: 

239 

240 """ 

241 kwargs["univariate"] = False 

242 self.metamodel = NaiveMetamodel(*args, **kwargs) 

243 

244 def _batch_loss(self, batch, batch_idx): 

245 """ 

246 

247 :param batch: 

248 :param batch_idx: 

249 

250 """ 

251 C, X, Y, _ = batch 

252 beta_hat, mu_hat = self.predict_step(batch, batch_idx) 

253 pred_loss = self.loss_fn(Y, self._predict_y(C, X, beta_hat, mu_hat)) 

254 reg_loss = self.model_regularizer(beta_hat, mu_hat) 

255 return pred_loss + reg_loss 

256 

257 def predict_step(self, batch, batch_idx): 

258 """ 

259 

260 :param batch: 

261 :param batch_idx: 

262 

263 """ 

264 C, _, _, _ = batch 

265 beta_hat, mu_hat = self(C) 

266 return beta_hat, mu_hat 

267 

268 def _params_reshape(self, preds, dataloader): 

269 """ 

270 

271 :param preds: 

272 :param dataloader: 

273 

274 """ 

275 ds = dataloader.dataset.dataset 

276 betas = np.zeros((ds.n, ds.y_dim, ds.x_dim)) 

277 mus = np.zeros((ds.n, ds.y_dim)) 

278 for (beta_hats, mu_hats), data in zip(preds, dataloader): 

279 _, _, _, n_idx = data 

280 betas[n_idx] = beta_hats 

281 mus[n_idx] = mu_hats.squeeze(-1) 

282 return betas, mus 

283 

284 def _y_reshape(self, preds, dataloader): 

285 """ 

286 

287 :param preds: 

288 :param dataloader: 

289 

290 """ 

291 ds = dataloader.dataset.dataset 

292 ys = np.zeros((ds.n, ds.y_dim)) 

293 for (beta_hats, mu_hats), data in zip(preds, dataloader): 

294 C, X, _, n_idx = data 

295 ys[n_idx] = self._predict_y(C, X, beta_hats, mu_hats).squeeze(-1) 

296 return ys 

297 

298 def dataloader(self, C, X, Y, **kwargs): 

299 """ 

300 

301 :param C: 

302 :param X: 

303 :param Y: 

304 :param **kwargs: 

305 

306 """ 

307 return self._dataloader(C, X, Y, MultivariateDataset, **kwargs) 

308 

309 

310class ContextualizedRegression(ContextualizedRegressionBase): 

311 """Supports SubtypeMetamodel and NaiveMetamodel, see selected metamodel for docs""" 

312 

313 def _build_metamodel(self, *args, **kwargs): 

314 """ 

315 

316 :param *args: 

317 :param **kwargs: 

318 

319 """ 

320 kwargs["univariate"] = False 

321 self.metamodel = SINGLE_TASK_METAMODELS[self.metamodel_type](*args, **kwargs) 

322 

323 def _batch_loss(self, batch, batch_idx): 

324 """ 

325 

326 :param batch: 

327 :param batch_idx: 

328 

329 """ 

330 ( 

331 C, 

332 X, 

333 Y, 

334 _, 

335 ) = batch 

336 beta_hat, mu_hat = self.predict_step(batch, batch_idx) 

337 pred_loss = self.loss_fn(Y, self._predict_y(C, X, beta_hat, mu_hat)) 

338 reg_loss = self.model_regularizer(beta_hat, mu_hat) 

339 return pred_loss + reg_loss 

340 

341 def predict_step(self, batch, batch_idx): 

342 """ 

343 

344 :param batch: 

345 :param batch_idx: 

346 

347 """ 

348 C, _, _, _ = batch 

349 beta_hat, mu_hat = self(C) 

350 return beta_hat, mu_hat 

351 

352 def _params_reshape(self, preds, dataloader): 

353 """ 

354 

355 :param preds: 

356 :param dataloader: 

357 

358 """ 

359 ds = dataloader.dataset.dataset 

360 betas = np.zeros((ds.n, ds.y_dim, ds.x_dim)) 

361 mus = np.zeros((ds.n, ds.y_dim)) 

362 for (beta_hats, mu_hats), data in zip(preds, dataloader): 

363 _, _, _, n_idx = data 

364 betas[n_idx] = beta_hats 

365 mus[n_idx] = mu_hats.squeeze(-1) 

366 return betas, mus 

367 

368 def _y_reshape(self, preds, dataloader): 

369 """ 

370 

371 :param preds: 

372 :param dataloader: 

373 

374 """ 

375 ds = dataloader.dataset.dataset 

376 ys = np.zeros((ds.n, ds.y_dim)) 

377 for (beta_hats, mu_hats), data in zip(preds, dataloader): 

378 C, X, _, n_idx = data 

379 ys[n_idx] = self._predict_y(C, X, beta_hats, mu_hats).squeeze(-1) 

380 return ys 

381 

382 def dataloader(self, C, X, Y, **kwargs): 

383 """ 

384 

385 :param C: 

386 :param X: 

387 :param Y: 

388 :param **kwargs: 

389 

390 """ 

391 return self._dataloader(C, X, Y, MultivariateDataset, **kwargs) 

392 

393 

394class MultitaskContextualizedRegression(ContextualizedRegressionBase): 

395 """See MultitaskMetamodel""" 

396 

397 def _build_metamodel(self, *args, **kwargs): 

398 """ 

399 

400 :param *args: 

401 :param **kwargs: 

402 

403 """ 

404 kwargs["univariate"] = False 

405 self.metamodel = MultitaskMetamodel(*args, **kwargs) 

406 

407 def _batch_loss(self, batch, batch_idx): 

408 """ 

409 

410 :param batch: 

411 :param batch_idx: 

412 

413 """ 

414 C, T, X, Y, _, _ = batch 

415 beta_hat, mu_hat = self.predict_step(batch, batch_idx) 

416 pred_loss = self.loss_fn(Y, self._predict_y(C, X, beta_hat, mu_hat)) 

417 reg_loss = self.model_regularizer(beta_hat, mu_hat) 

418 return pred_loss + reg_loss 

419 

420 def predict_step(self, batch, batch_idx): 

421 """ 

422 

423 :param batch: 

424 :param batch_idx: 

425 

426 """ 

427 C, T, _, _, _, _ = batch 

428 beta_hat, mu_hat = self(C, T) 

429 return beta_hat, mu_hat 

430 

431 def _params_reshape(self, preds, dataloader): 

432 """ 

433 

434 :param preds: 

435 :param dataloader: 

436 

437 """ 

438 ds = dataloader.dataset.dataset 

439 betas = np.zeros((ds.n, ds.y_dim, ds.x_dim)) 

440 mus = np.zeros((ds.n, ds.y_dim)) 

441 for (beta_hats, mu_hats), data in zip(preds, dataloader): 

442 _, _, _, _, n_idx, y_idx = data 

443 betas[n_idx, y_idx] = beta_hats 

444 mus[n_idx, y_idx] = mu_hats.squeeze(-1) 

445 return betas, mus 

446 

447 def _y_reshape(self, preds, dataloader): 

448 """ 

449 

450 :param preds: 

451 :param dataloader: 

452 

453 """ 

454 ds = dataloader.dataset.dataset 

455 ys = np.zeros((ds.n, ds.y_dim)) 

456 for (beta_hats, mu_hats), data in zip(preds, dataloader): 

457 C, _, X, _, n_idx, y_idx = data 

458 ys[n_idx, y_idx] = self._predict_y(C, X, beta_hats, mu_hats).squeeze(-1) 

459 return ys 

460 

461 def dataloader(self, C, X, Y, **kwargs): 

462 """ 

463 

464 :param C: 

465 :param X: 

466 :param Y: 

467 :param **kwargs: 

468 

469 """ 

470 return self._dataloader(C, X, Y, MultitaskMultivariateDataset, **kwargs) 

471 

472 

473class TasksplitContextualizedRegression(ContextualizedRegressionBase): 

474 """See TasksplitMetamodel""" 

475 

476 def _build_metamodel(self, *args, **kwargs): 

477 """ 

478 

479 :param *args: 

480 :param **kwargs: 

481 

482 """ 

483 kwargs["univariate"] = False 

484 self.metamodel = TasksplitMetamodel(*args, **kwargs) 

485 

486 def _batch_loss(self, batch, batch_idx): 

487 """ 

488 

489 :param batch: 

490 :param batch_idx: 

491 

492 """ 

493 C, T, X, Y, _, _ = batch 

494 beta_hat, mu_hat = self.predict_step(batch, batch_idx) 

495 pred_loss = self.loss_fn(Y, self._predict_y(C, X, beta_hat, mu_hat)) 

496 reg_loss = self.model_regularizer(beta_hat, mu_hat) 

497 return pred_loss + reg_loss 

498 

499 def predict_step(self, batch, batch_idx): 

500 """ 

501 

502 :param batch: 

503 :param batch_idx: 

504 

505 """ 

506 C, T, _, _, _, _ = batch 

507 beta_hat, mu_hat = self(C, T) 

508 return beta_hat, mu_hat 

509 

510 def _params_reshape(self, preds, dataloader): 

511 """ 

512 

513 :param preds: 

514 :param dataloader: 

515 

516 """ 

517 ds = dataloader.dataset.dataset 

518 betas = np.zeros((ds.n, ds.y_dim, ds.x_dim)) 

519 mus = np.zeros((ds.n, ds.y_dim)) 

520 for (beta_hats, mu_hats), data in zip(preds, dataloader): 

521 _, _, _, _, n_idx, y_idx = data 

522 betas[n_idx, y_idx] = beta_hats 

523 mus[n_idx, y_idx] = mu_hats.squeeze(-1) 

524 return betas, mus 

525 

526 def _y_reshape(self, preds, dataloader): 

527 """ 

528 

529 :param preds: 

530 :param dataloader: 

531 

532 """ 

533 ds = dataloader.dataset.dataset 

534 ys = np.zeros((ds.n, ds.y_dim)) 

535 for (beta_hats, mu_hats), data in zip(preds, dataloader): 

536 C, _, X, _, n_idx, y_idx = data 

537 ys[n_idx, y_idx] = self._predict_y(C, X, beta_hats, mu_hats).squeeze(-1) 

538 return ys 

539 

540 def dataloader(self, C, X, Y, **kwargs): 

541 """ 

542 

543 :param C: 

544 :param X: 

545 :param Y: 

546 :param **kwargs: 

547 

548 """ 

549 return self._dataloader(C, X, Y, MultitaskMultivariateDataset, **kwargs) 

550 

551 

552class ContextualizedUnivariateRegression(ContextualizedRegression): 

553 """Supports SubtypeMetamodel and NaiveMetamodel, see selected metamodel for docs""" 

554 

555 def _build_metamodel(self, *args, **kwargs): 

556 """ 

557 

558 :param *args: 

559 :param **kwargs: 

560 

561 """ 

562 kwargs["univariate"] = True 

563 self.metamodel = SINGLE_TASK_METAMODELS[self.metamodel_type](*args, **kwargs) 

564 

565 def _params_reshape(self, preds, dataloader): 

566 """ 

567 

568 :param preds: 

569 :param dataloader: 

570 

571 """ 

572 ds = dataloader.dataset.dataset 

573 betas = np.zeros((ds.n, ds.y_dim, ds.x_dim)) 

574 mus = np.zeros((ds.n, ds.y_dim, ds.x_dim)) 

575 for (beta_hats, mu_hats), data in zip(preds, dataloader): 

576 _, _, _, n_idx = data 

577 betas[n_idx] = beta_hats.squeeze(-1) 

578 mus[n_idx] = mu_hats.squeeze(-1) 

579 return betas, mus 

580 

581 def _y_reshape(self, preds, dataloader): 

582 """ 

583 

584 :param preds: 

585 :param dataloader: 

586 

587 """ 

588 ds = dataloader.dataset.dataset 

589 ys = np.zeros((ds.n, ds.y_dim, ds.x_dim)) 

590 for (beta_hats, mu_hats), data in zip(preds, dataloader): 

591 C, X, _, n_idx = data 

592 ys[n_idx] = self._predict_y(C, X, beta_hats, mu_hats).squeeze(-1) 

593 return ys 

594 

595 def dataloader(self, C, X, Y, **kwargs): 

596 """ 

597 

598 :param C: 

599 :param X: 

600 :param Y: 

601 :param **kwargs: 

602 

603 """ 

604 return self._dataloader(C, X, Y, UnivariateDataset, **kwargs) 

605 

606 

607class TasksplitContextualizedUnivariateRegression(TasksplitContextualizedRegression): 

608 """See TasksplitMetamodel""" 

609 

610 def _build_metamodel(self, *args, **kwargs): 

611 """ 

612 

613 :param *args: 

614 :param **kwargs: 

615 

616 """ 

617 kwargs["univariate"] = True 

618 self.metamodel = TasksplitMetamodel(*args, **kwargs) 

619 

620 def _batch_loss(self, batch, batch_idx): 

621 """ 

622 

623 :param batch: 

624 :param batch_idx: 

625 

626 """ 

627 C, T, X, Y, _, _, _ = batch 

628 beta_hat, mu_hat = self.predict_step(batch, batch_idx) 

629 pred_loss = self.loss_fn(Y, self._predict_y(C, X, beta_hat, mu_hat)) 

630 reg_loss = self.model_regularizer(beta_hat, mu_hat) 

631 return pred_loss + reg_loss 

632 

633 def predict_step(self, batch, batch_idx): 

634 """ 

635 

636 :param batch: 

637 :param batch_idx: 

638 

639 """ 

640 C, T, _, _, _, _, _ = batch 

641 beta_hat, mu_hat = self(C, T) 

642 return beta_hat, mu_hat 

643 

644 def _params_reshape(self, preds, dataloader): 

645 """ 

646 

647 :param preds: 

648 :param dataloader: 

649 

650 """ 

651 ds = dataloader.dataset.dataset 

652 betas = np.zeros((ds.n, ds.y_dim, ds.x_dim)) 

653 mus = betas.copy() 

654 for (beta_hats, mu_hats), data in zip(preds, dataloader): 

655 _, _, _, _, n_idx, x_idx, y_idx = data 

656 betas[n_idx, y_idx, x_idx] = beta_hats.squeeze(-1) 

657 mus[n_idx, y_idx, x_idx] = mu_hats.squeeze(-1) 

658 return betas, mus 

659 

660 def _y_reshape(self, preds, dataloader): 

661 """ 

662 

663 :param preds: 

664 :param dataloader: 

665 

666 """ 

667 ds = dataloader.dataset.dataset 

668 ys = np.zeros((ds.n, ds.y_dim, ds.x_dim)) 

669 for (beta_hats, mu_hats), data in zip(preds, dataloader): 

670 C, _, X, _, n_idx, x_idx, y_idx = data 

671 ys[n_idx, y_idx, x_idx] = self._predict_y(C, X, beta_hats, mu_hats).squeeze( 

672 -1 

673 ) 

674 return ys 

675 

676 def dataloader(self, C, X, Y, **kwargs): 

677 """ 

678 

679 :param C: 

680 :param X: 

681 :param Y: 

682 :param **kwargs: 

683 

684 """ 

685 return self._dataloader(C, X, Y, MultitaskUnivariateDataset, **kwargs) 

686 

687 

688class ContextualizedCorrelation(ContextualizedUnivariateRegression): 

689 """Using univariate contextualized regression to estimate Pearson's correlation 

690 See SubtypeMetamodel for assumptions and full docstring 

691 

692 

693 """ 

694 

695 def __init__(self, context_dim, x_dim, **kwargs): 

696 if "y_dim" in kwargs: 

697 del kwargs["y_dim"] 

698 super().__init__(context_dim, x_dim, x_dim, **kwargs) 

699 

700 def dataloader(self, C, X, Y=None, **kwargs): 

701 """ 

702 

703 :param C: 

704 :param X: 

705 :param Y: 

706 :param **kwargs: 

707 

708 """ 

709 if Y is not None: 

710 print( 

711 "Passed a Y, but this is self-correlation between X featuers. Ignoring Y." 

712 ) 

713 return super().dataloader(C, X, X, **kwargs) 

714 

715 

716class TasksplitContextualizedCorrelation(TasksplitContextualizedUnivariateRegression): 

717 """Using multitask univariate contextualized regression to estimate Pearson's correlation 

718 See TasksplitMetamodel for assumptions and full docstring 

719 

720 

721 """ 

722 

723 def __init__(self, context_dim, x_dim, **kwargs): 

724 if "y_dim" in kwargs: 

725 del kwargs["y_dim"] 

726 super().__init__(context_dim, x_dim, x_dim, **kwargs) 

727 

728 def dataloader(self, C, X, Y=None, **kwargs): 

729 """ 

730 

731 :param C: 

732 :param X: 

733 :param Y: 

734 :param **kwargs: 

735 

736 """ 

737 if Y is not None: 

738 print( 

739 "Passed a Y, but this is self-correlation between X featuers. Ignoring Y." 

740 ) 

741 return super().dataloader(C, X, X, **kwargs) 

742 

743 

744class ContextualizedNeighborhoodSelection(ContextualizedRegression): 

745 """Using singletask multivariate contextualized regression to do edge-regression for 

746 estimating conditional dependencies 

747 See SubtypeMetamodel for assumptions and full docstring 

748 

749 

750 """ 

751 

752 def __init__( 

753 self, 

754 context_dim, 

755 x_dim, 

756 model_regularizer=REGULARIZERS["l1"](1e-3, mu_ratio=0), 

757 **kwargs, 

758 ): 

759 if "y_dim" in kwargs: 

760 del kwargs["y_dim"] 

761 super().__init__( 

762 context_dim, x_dim, x_dim, model_regularizer=model_regularizer, **kwargs 

763 ) 

764 self.register_buffer("diag_mask", torch.ones(x_dim, x_dim) - torch.eye(x_dim)) 

765 

766 def predict_step(self, batch, batch_idx): 

767 """ 

768 

769 :param batch: 

770 :param batch_idx: 

771 

772 """ 

773 C, _, _, _ = batch 

774 beta_hat, mu_hat = self(C) 

775 beta_hat = beta_hat * self.diag_mask.expand(beta_hat.shape[0], -1, -1) 

776 return beta_hat, mu_hat 

777 

778 def dataloader(self, C, X, Y=None, **kwargs): 

779 """ 

780 

781 :param C: 

782 :param X: 

783 :param Y: 

784 :param **kwargs: 

785 

786 """ 

787 

788 if Y is not None: 

789 print( 

790 "Passed a Y, but this is a Markov Graph between X featuers. Ignoring Y." 

791 ) 

792 return super().dataloader(C, X, X, **kwargs) 

793 

794 

795class ContextualizedMarkovGraph(ContextualizedRegression): 

796 """Using singletask multivariate contextualized regression to do edge-regression for 

797 estimating conditional dependencies 

798 See SubtypeMetamodel for assumptions and full docstring 

799 

800 

801 """ 

802 

803 def __init__(self, context_dim, x_dim, **kwargs): 

804 if "y_dim" in kwargs: 

805 del kwargs["y_dim"] 

806 super().__init__(context_dim, x_dim, x_dim, **kwargs) 

807 self.register_buffer("diag_mask", torch.ones(x_dim, x_dim) - torch.eye(x_dim)) 

808 

809 def predict_step(self, batch, batch_idx): 

810 """ 

811 

812 :param batch: 

813 :param batch_idx: 

814 

815 """ 

816 C, _, _, _ = batch 

817 beta_hat, mu_hat = self(C) 

818 beta_hat = beta_hat + torch.transpose( 

819 beta_hat, 1, 2 

820 ) # hotfix to enforce symmetry 

821 beta_hat = beta_hat * self.diag_mask.expand(beta_hat.shape[0], -1, -1) 

822 return beta_hat, mu_hat 

823 

824 def dataloader(self, C, X, Y=None, **kwargs): 

825 """ 

826 

827 :param C: 

828 :param X: 

829 :param Y: 

830 :param **kwargs: 

831 

832 """ 

833 

834 if Y is not None: 

835 print( 

836 "Passed a Y, but this is a Markov Graph between X featuers. Ignoring Y." 

837 ) 

838 return super().dataloader(C, X, X, **kwargs)