Coverage for contextualized/regression/lightning_modules.py: 97%
303 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-21 13:49 -0400
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-21 13:49 -0400
1"""
2This class contains tools for solving context-specific regression problems:
4Y = g(beta(C)*X + mu(C))
6C: Context
7X: Explainable features
8Y: Outcome, aka response (regreession) or labels (classification)
9g: Link Function for contextualized generalized linear models.
11Implemented with PyTorch Lightning
12"""
14from abc import abstractmethod
15import numpy as np
16import torch
17from torch.utils.data import DataLoader
18import pytorch_lightning as pl
20from contextualized.regression.regularizers import REGULARIZERS
21from contextualized.regression.losses import MSE
22from contextualized.functions import LINK_FUNCTIONS
24from contextualized.regression.metamodels import (
25 NaiveMetamodel,
26 SubtypeMetamodel,
27 MultitaskMetamodel,
28 TasksplitMetamodel,
29 SINGLE_TASK_METAMODELS,
30 MULTITASK_METAMODELS,
31)
32from contextualized.regression.datasets import (
33 DataIterable,
34 MultivariateDataset,
35 UnivariateDataset,
36 MultitaskMultivariateDataset,
37 MultitaskUnivariateDataset,
38)
41class ContextualizedRegressionBase(pl.LightningModule):
42 """
43 Abstract class for Contextualized Regression.
44 """
46 def __init__(
47 self,
48 *args,
49 learning_rate=1e-3,
50 metamodel_type="subtype",
51 fit_intercept=True,
52 link_fn=LINK_FUNCTIONS["identity"],
53 loss_fn=MSE,
54 model_regularizer=REGULARIZERS["none"],
55 base_y_predictor=None,
56 base_param_predictor=None,
57 **kwargs,
58 ):
59 super().__init__()
60 self.learning_rate = learning_rate
61 self.metamodel_type = metamodel_type
62 self.fit_intercept = fit_intercept
63 self.link_fn = link_fn
64 self.loss_fn = loss_fn
65 self.model_regularizer = model_regularizer
66 self.base_y_predictor = base_y_predictor
67 self.base_param_predictor = base_param_predictor
68 self._build_metamodel(*args, **kwargs)
70 @abstractmethod
71 def _build_metamodel(self, *args, **kwargs):
72 """
74 :param *args:
75 :param **kwargs:
77 """
78 # builds the metamodel
79 kwargs["univariate"] = False
80 self.metamodel = SINGLE_TASK_METAMODELS[self.metamodel_type](*args, **kwargs)
82 @abstractmethod
83 def dataloader(self, C, X, Y, batch_size=32):
84 """
86 :param C:
87 :param X:
88 :param Y:
89 :param batch_size: (Default value = 32)
91 """
92 # returns the dataloader for this class
94 @abstractmethod
95 def _batch_loss(self, batch, batch_idx):
96 """
98 :param batch:
99 :param batch_idx:
101 """
102 # MSE loss by default
104 @abstractmethod
105 def predict_step(self, batch, batch_idx, dataloader_idx=0):
106 """
108 :param batch:
109 :param batch_idx:
110 :param dataload_idx:
112 """
113 # returns predicted params on the given batch
115 @abstractmethod
116 def _params_reshape(self, beta_preds, mu_preds, dataloader):
117 """
119 :param beta_preds:
120 :param mu_preds:
121 :param dataloader:
123 """
124 # reshapes the batch parameter predictions into beta (y_dim, x_dim)
126 @abstractmethod
127 def _y_reshape(self, y_preds, dataloader):
128 """
130 :param y_preds:
131 :param dataloader:
133 """
134 # reshapes the batch y predictions into a desirable format
136 def forward(self, *args, **kwargs):
137 """
139 :param *args:
141 """
142 beta, mu = self.metamodel(*args)
143 if not self.fit_intercept:
144 mu = torch.zeros_like(mu)
145 if self.base_param_predictor is not None:
146 base_beta, base_mu = self.base_param_predictor.predict_params(*args)
147 beta = beta + base_beta.to(beta.device)
148 mu = mu + base_mu.to(mu.device)
149 return beta, mu
151 def configure_optimizers(self):
152 """
153 Set up optimizer.
154 """
155 optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
156 return optimizer
158 def training_step(self, batch, batch_idx):
159 """
161 :param batch:
162 :param batch_idx:
164 """
165 loss = self._batch_loss(batch, batch_idx)
166 self.log_dict({"train_loss": loss})
167 return loss
169 def validation_step(self, batch, batch_idx):
170 """
172 :param batch:
173 :param batch_idx:
175 """
176 loss = self._batch_loss(batch, batch_idx)
177 self.log_dict({"val_loss": loss})
178 return loss
180 def test_step(self, batch, batch_idx):
181 """
183 :param batch:
184 :param batch_idx:
186 """
187 loss = self._batch_loss(batch, batch_idx)
188 self.log_dict({"test_loss": loss})
189 return loss
191 def _predict_from_models(self, X, beta_hat, mu_hat):
192 """
194 :param X:
195 :param beta_hat:
196 :param mu_hat:
198 """
199 return self.link_fn((beta_hat * X).sum(axis=-1).unsqueeze(-1) + mu_hat)
201 def _predict_y(self, C, X, beta_hat, mu_hat):
202 """
204 :param C:
205 :param X:
206 :param beta_hat:
207 :param mu_hat:
209 """
210 Y = self._predict_from_models(X, beta_hat, mu_hat)
211 if self.base_y_predictor is not None:
212 Y_base = self.base_y_predictor.predict_y(C, X)
213 Y = Y + Y_base.to(Y.device)
214 return Y
216 def _dataloader(self, C, X, Y, dataset_constructor, **kwargs):
217 """
219 :param C:
220 :param X:
221 :param Y:
222 :param dataset_constructor:
223 :param **kwargs:
225 """
226 kwargs["num_workers"] = kwargs.get("num_workers", 0)
227 kwargs["batch_size"] = kwargs.get("batch_size", 32)
228 return DataLoader(dataset=DataIterable(dataset_constructor(C, X, Y)), **kwargs)
231class NaiveContextualizedRegression(ContextualizedRegressionBase):
232 """See NaiveMetamodel"""
234 def _build_metamodel(self, *args, **kwargs):
235 """
237 :param *args:
238 :param **kwargs:
240 """
241 kwargs["univariate"] = False
242 self.metamodel = NaiveMetamodel(*args, **kwargs)
244 def _batch_loss(self, batch, batch_idx):
245 """
247 :param batch:
248 :param batch_idx:
250 """
251 C, X, Y, _ = batch
252 beta_hat, mu_hat = self.predict_step(batch, batch_idx)
253 pred_loss = self.loss_fn(Y, self._predict_y(C, X, beta_hat, mu_hat))
254 reg_loss = self.model_regularizer(beta_hat, mu_hat)
255 return pred_loss + reg_loss
257 def predict_step(self, batch, batch_idx):
258 """
260 :param batch:
261 :param batch_idx:
263 """
264 C, _, _, _ = batch
265 beta_hat, mu_hat = self(C)
266 return beta_hat, mu_hat
268 def _params_reshape(self, preds, dataloader):
269 """
271 :param preds:
272 :param dataloader:
274 """
275 ds = dataloader.dataset.dataset
276 betas = np.zeros((ds.n, ds.y_dim, ds.x_dim))
277 mus = np.zeros((ds.n, ds.y_dim))
278 for (beta_hats, mu_hats), data in zip(preds, dataloader):
279 _, _, _, n_idx = data
280 betas[n_idx] = beta_hats
281 mus[n_idx] = mu_hats.squeeze(-1)
282 return betas, mus
284 def _y_reshape(self, preds, dataloader):
285 """
287 :param preds:
288 :param dataloader:
290 """
291 ds = dataloader.dataset.dataset
292 ys = np.zeros((ds.n, ds.y_dim))
293 for (beta_hats, mu_hats), data in zip(preds, dataloader):
294 C, X, _, n_idx = data
295 ys[n_idx] = self._predict_y(C, X, beta_hats, mu_hats).squeeze(-1)
296 return ys
298 def dataloader(self, C, X, Y, **kwargs):
299 """
301 :param C:
302 :param X:
303 :param Y:
304 :param **kwargs:
306 """
307 return self._dataloader(C, X, Y, MultivariateDataset, **kwargs)
310class ContextualizedRegression(ContextualizedRegressionBase):
311 """Supports SubtypeMetamodel and NaiveMetamodel, see selected metamodel for docs"""
313 def _build_metamodel(self, *args, **kwargs):
314 """
316 :param *args:
317 :param **kwargs:
319 """
320 kwargs["univariate"] = False
321 self.metamodel = SINGLE_TASK_METAMODELS[self.metamodel_type](*args, **kwargs)
323 def _batch_loss(self, batch, batch_idx):
324 """
326 :param batch:
327 :param batch_idx:
329 """
330 (
331 C,
332 X,
333 Y,
334 _,
335 ) = batch
336 beta_hat, mu_hat = self.predict_step(batch, batch_idx)
337 pred_loss = self.loss_fn(Y, self._predict_y(C, X, beta_hat, mu_hat))
338 reg_loss = self.model_regularizer(beta_hat, mu_hat)
339 return pred_loss + reg_loss
341 def predict_step(self, batch, batch_idx):
342 """
344 :param batch:
345 :param batch_idx:
347 """
348 C, _, _, _ = batch
349 beta_hat, mu_hat = self(C)
350 return beta_hat, mu_hat
352 def _params_reshape(self, preds, dataloader):
353 """
355 :param preds:
356 :param dataloader:
358 """
359 ds = dataloader.dataset.dataset
360 betas = np.zeros((ds.n, ds.y_dim, ds.x_dim))
361 mus = np.zeros((ds.n, ds.y_dim))
362 for (beta_hats, mu_hats), data in zip(preds, dataloader):
363 _, _, _, n_idx = data
364 betas[n_idx] = beta_hats
365 mus[n_idx] = mu_hats.squeeze(-1)
366 return betas, mus
368 def _y_reshape(self, preds, dataloader):
369 """
371 :param preds:
372 :param dataloader:
374 """
375 ds = dataloader.dataset.dataset
376 ys = np.zeros((ds.n, ds.y_dim))
377 for (beta_hats, mu_hats), data in zip(preds, dataloader):
378 C, X, _, n_idx = data
379 ys[n_idx] = self._predict_y(C, X, beta_hats, mu_hats).squeeze(-1)
380 return ys
382 def dataloader(self, C, X, Y, **kwargs):
383 """
385 :param C:
386 :param X:
387 :param Y:
388 :param **kwargs:
390 """
391 return self._dataloader(C, X, Y, MultivariateDataset, **kwargs)
394class MultitaskContextualizedRegression(ContextualizedRegressionBase):
395 """See MultitaskMetamodel"""
397 def _build_metamodel(self, *args, **kwargs):
398 """
400 :param *args:
401 :param **kwargs:
403 """
404 kwargs["univariate"] = False
405 self.metamodel = MultitaskMetamodel(*args, **kwargs)
407 def _batch_loss(self, batch, batch_idx):
408 """
410 :param batch:
411 :param batch_idx:
413 """
414 C, T, X, Y, _, _ = batch
415 beta_hat, mu_hat = self.predict_step(batch, batch_idx)
416 pred_loss = self.loss_fn(Y, self._predict_y(C, X, beta_hat, mu_hat))
417 reg_loss = self.model_regularizer(beta_hat, mu_hat)
418 return pred_loss + reg_loss
420 def predict_step(self, batch, batch_idx):
421 """
423 :param batch:
424 :param batch_idx:
426 """
427 C, T, _, _, _, _ = batch
428 beta_hat, mu_hat = self(C, T)
429 return beta_hat, mu_hat
431 def _params_reshape(self, preds, dataloader):
432 """
434 :param preds:
435 :param dataloader:
437 """
438 ds = dataloader.dataset.dataset
439 betas = np.zeros((ds.n, ds.y_dim, ds.x_dim))
440 mus = np.zeros((ds.n, ds.y_dim))
441 for (beta_hats, mu_hats), data in zip(preds, dataloader):
442 _, _, _, _, n_idx, y_idx = data
443 betas[n_idx, y_idx] = beta_hats
444 mus[n_idx, y_idx] = mu_hats.squeeze(-1)
445 return betas, mus
447 def _y_reshape(self, preds, dataloader):
448 """
450 :param preds:
451 :param dataloader:
453 """
454 ds = dataloader.dataset.dataset
455 ys = np.zeros((ds.n, ds.y_dim))
456 for (beta_hats, mu_hats), data in zip(preds, dataloader):
457 C, _, X, _, n_idx, y_idx = data
458 ys[n_idx, y_idx] = self._predict_y(C, X, beta_hats, mu_hats).squeeze(-1)
459 return ys
461 def dataloader(self, C, X, Y, **kwargs):
462 """
464 :param C:
465 :param X:
466 :param Y:
467 :param **kwargs:
469 """
470 return self._dataloader(C, X, Y, MultitaskMultivariateDataset, **kwargs)
473class TasksplitContextualizedRegression(ContextualizedRegressionBase):
474 """See TasksplitMetamodel"""
476 def _build_metamodel(self, *args, **kwargs):
477 """
479 :param *args:
480 :param **kwargs:
482 """
483 kwargs["univariate"] = False
484 self.metamodel = TasksplitMetamodel(*args, **kwargs)
486 def _batch_loss(self, batch, batch_idx):
487 """
489 :param batch:
490 :param batch_idx:
492 """
493 C, T, X, Y, _, _ = batch
494 beta_hat, mu_hat = self.predict_step(batch, batch_idx)
495 pred_loss = self.loss_fn(Y, self._predict_y(C, X, beta_hat, mu_hat))
496 reg_loss = self.model_regularizer(beta_hat, mu_hat)
497 return pred_loss + reg_loss
499 def predict_step(self, batch, batch_idx):
500 """
502 :param batch:
503 :param batch_idx:
505 """
506 C, T, _, _, _, _ = batch
507 beta_hat, mu_hat = self(C, T)
508 return beta_hat, mu_hat
510 def _params_reshape(self, preds, dataloader):
511 """
513 :param preds:
514 :param dataloader:
516 """
517 ds = dataloader.dataset.dataset
518 betas = np.zeros((ds.n, ds.y_dim, ds.x_dim))
519 mus = np.zeros((ds.n, ds.y_dim))
520 for (beta_hats, mu_hats), data in zip(preds, dataloader):
521 _, _, _, _, n_idx, y_idx = data
522 betas[n_idx, y_idx] = beta_hats
523 mus[n_idx, y_idx] = mu_hats.squeeze(-1)
524 return betas, mus
526 def _y_reshape(self, preds, dataloader):
527 """
529 :param preds:
530 :param dataloader:
532 """
533 ds = dataloader.dataset.dataset
534 ys = np.zeros((ds.n, ds.y_dim))
535 for (beta_hats, mu_hats), data in zip(preds, dataloader):
536 C, _, X, _, n_idx, y_idx = data
537 ys[n_idx, y_idx] = self._predict_y(C, X, beta_hats, mu_hats).squeeze(-1)
538 return ys
540 def dataloader(self, C, X, Y, **kwargs):
541 """
543 :param C:
544 :param X:
545 :param Y:
546 :param **kwargs:
548 """
549 return self._dataloader(C, X, Y, MultitaskMultivariateDataset, **kwargs)
552class ContextualizedUnivariateRegression(ContextualizedRegression):
553 """Supports SubtypeMetamodel and NaiveMetamodel, see selected metamodel for docs"""
555 def _build_metamodel(self, *args, **kwargs):
556 """
558 :param *args:
559 :param **kwargs:
561 """
562 kwargs["univariate"] = True
563 self.metamodel = SINGLE_TASK_METAMODELS[self.metamodel_type](*args, **kwargs)
565 def _params_reshape(self, preds, dataloader):
566 """
568 :param preds:
569 :param dataloader:
571 """
572 ds = dataloader.dataset.dataset
573 betas = np.zeros((ds.n, ds.y_dim, ds.x_dim))
574 mus = np.zeros((ds.n, ds.y_dim, ds.x_dim))
575 for (beta_hats, mu_hats), data in zip(preds, dataloader):
576 _, _, _, n_idx = data
577 betas[n_idx] = beta_hats.squeeze(-1)
578 mus[n_idx] = mu_hats.squeeze(-1)
579 return betas, mus
581 def _y_reshape(self, preds, dataloader):
582 """
584 :param preds:
585 :param dataloader:
587 """
588 ds = dataloader.dataset.dataset
589 ys = np.zeros((ds.n, ds.y_dim, ds.x_dim))
590 for (beta_hats, mu_hats), data in zip(preds, dataloader):
591 C, X, _, n_idx = data
592 ys[n_idx] = self._predict_y(C, X, beta_hats, mu_hats).squeeze(-1)
593 return ys
595 def dataloader(self, C, X, Y, **kwargs):
596 """
598 :param C:
599 :param X:
600 :param Y:
601 :param **kwargs:
603 """
604 return self._dataloader(C, X, Y, UnivariateDataset, **kwargs)
607class TasksplitContextualizedUnivariateRegression(TasksplitContextualizedRegression):
608 """See TasksplitMetamodel"""
610 def _build_metamodel(self, *args, **kwargs):
611 """
613 :param *args:
614 :param **kwargs:
616 """
617 kwargs["univariate"] = True
618 self.metamodel = TasksplitMetamodel(*args, **kwargs)
620 def _batch_loss(self, batch, batch_idx):
621 """
623 :param batch:
624 :param batch_idx:
626 """
627 C, T, X, Y, _, _, _ = batch
628 beta_hat, mu_hat = self.predict_step(batch, batch_idx)
629 pred_loss = self.loss_fn(Y, self._predict_y(C, X, beta_hat, mu_hat))
630 reg_loss = self.model_regularizer(beta_hat, mu_hat)
631 return pred_loss + reg_loss
633 def predict_step(self, batch, batch_idx):
634 """
636 :param batch:
637 :param batch_idx:
639 """
640 C, T, _, _, _, _, _ = batch
641 beta_hat, mu_hat = self(C, T)
642 return beta_hat, mu_hat
644 def _params_reshape(self, preds, dataloader):
645 """
647 :param preds:
648 :param dataloader:
650 """
651 ds = dataloader.dataset.dataset
652 betas = np.zeros((ds.n, ds.y_dim, ds.x_dim))
653 mus = betas.copy()
654 for (beta_hats, mu_hats), data in zip(preds, dataloader):
655 _, _, _, _, n_idx, x_idx, y_idx = data
656 betas[n_idx, y_idx, x_idx] = beta_hats.squeeze(-1)
657 mus[n_idx, y_idx, x_idx] = mu_hats.squeeze(-1)
658 return betas, mus
660 def _y_reshape(self, preds, dataloader):
661 """
663 :param preds:
664 :param dataloader:
666 """
667 ds = dataloader.dataset.dataset
668 ys = np.zeros((ds.n, ds.y_dim, ds.x_dim))
669 for (beta_hats, mu_hats), data in zip(preds, dataloader):
670 C, _, X, _, n_idx, x_idx, y_idx = data
671 ys[n_idx, y_idx, x_idx] = self._predict_y(C, X, beta_hats, mu_hats).squeeze(
672 -1
673 )
674 return ys
676 def dataloader(self, C, X, Y, **kwargs):
677 """
679 :param C:
680 :param X:
681 :param Y:
682 :param **kwargs:
684 """
685 return self._dataloader(C, X, Y, MultitaskUnivariateDataset, **kwargs)
688class ContextualizedCorrelation(ContextualizedUnivariateRegression):
689 """Using univariate contextualized regression to estimate Pearson's correlation
690 See SubtypeMetamodel for assumptions and full docstring
693 """
695 def __init__(self, context_dim, x_dim, **kwargs):
696 if "y_dim" in kwargs:
697 del kwargs["y_dim"]
698 super().__init__(context_dim, x_dim, x_dim, **kwargs)
700 def dataloader(self, C, X, Y=None, **kwargs):
701 """
703 :param C:
704 :param X:
705 :param Y:
706 :param **kwargs:
708 """
709 if Y is not None:
710 print(
711 "Passed a Y, but this is self-correlation between X featuers. Ignoring Y."
712 )
713 return super().dataloader(C, X, X, **kwargs)
716class TasksplitContextualizedCorrelation(TasksplitContextualizedUnivariateRegression):
717 """Using multitask univariate contextualized regression to estimate Pearson's correlation
718 See TasksplitMetamodel for assumptions and full docstring
721 """
723 def __init__(self, context_dim, x_dim, **kwargs):
724 if "y_dim" in kwargs:
725 del kwargs["y_dim"]
726 super().__init__(context_dim, x_dim, x_dim, **kwargs)
728 def dataloader(self, C, X, Y=None, **kwargs):
729 """
731 :param C:
732 :param X:
733 :param Y:
734 :param **kwargs:
736 """
737 if Y is not None:
738 print(
739 "Passed a Y, but this is self-correlation between X featuers. Ignoring Y."
740 )
741 return super().dataloader(C, X, X, **kwargs)
744class ContextualizedNeighborhoodSelection(ContextualizedRegression):
745 """Using singletask multivariate contextualized regression to do edge-regression for
746 estimating conditional dependencies
747 See SubtypeMetamodel for assumptions and full docstring
750 """
752 def __init__(
753 self,
754 context_dim,
755 x_dim,
756 model_regularizer=REGULARIZERS["l1"](1e-3, mu_ratio=0),
757 **kwargs,
758 ):
759 if "y_dim" in kwargs:
760 del kwargs["y_dim"]
761 super().__init__(
762 context_dim, x_dim, x_dim, model_regularizer=model_regularizer, **kwargs
763 )
764 self.register_buffer("diag_mask", torch.ones(x_dim, x_dim) - torch.eye(x_dim))
766 def predict_step(self, batch, batch_idx):
767 """
769 :param batch:
770 :param batch_idx:
772 """
773 C, _, _, _ = batch
774 beta_hat, mu_hat = self(C)
775 beta_hat = beta_hat * self.diag_mask.expand(beta_hat.shape[0], -1, -1)
776 return beta_hat, mu_hat
778 def dataloader(self, C, X, Y=None, **kwargs):
779 """
781 :param C:
782 :param X:
783 :param Y:
784 :param **kwargs:
786 """
788 if Y is not None:
789 print(
790 "Passed a Y, but this is a Markov Graph between X featuers. Ignoring Y."
791 )
792 return super().dataloader(C, X, X, **kwargs)
795class ContextualizedMarkovGraph(ContextualizedRegression):
796 """Using singletask multivariate contextualized regression to do edge-regression for
797 estimating conditional dependencies
798 See SubtypeMetamodel for assumptions and full docstring
801 """
803 def __init__(self, context_dim, x_dim, **kwargs):
804 if "y_dim" in kwargs:
805 del kwargs["y_dim"]
806 super().__init__(context_dim, x_dim, x_dim, **kwargs)
807 self.register_buffer("diag_mask", torch.ones(x_dim, x_dim) - torch.eye(x_dim))
809 def predict_step(self, batch, batch_idx):
810 """
812 :param batch:
813 :param batch_idx:
815 """
816 C, _, _, _ = batch
817 beta_hat, mu_hat = self(C)
818 beta_hat = beta_hat + torch.transpose(
819 beta_hat, 1, 2
820 ) # hotfix to enforce symmetry
821 beta_hat = beta_hat * self.diag_mask.expand(beta_hat.shape[0], -1, -1)
822 return beta_hat, mu_hat
824 def dataloader(self, C, X, Y=None, **kwargs):
825 """
827 :param C:
828 :param X:
829 :param Y:
830 :param **kwargs:
832 """
834 if Y is not None:
835 print(
836 "Passed a Y, but this is a Markov Graph between X featuers. Ignoring Y."
837 )
838 return super().dataloader(C, X, X, **kwargs)