Coverage for contextualized/regression/trainers.py: 100%
19 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-21 13:49 -0400
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-21 13:49 -0400
1"""
2PyTorch-Lightning trainers used for Contextualized regression.
3"""
5import numpy as np
6import pytorch_lightning as pl
9class RegressionTrainer(pl.Trainer):
10 """
11 Trains the contextualized.regression lightning_modules
12 """
14 def predict_params(self, model, dataloader):
15 """
16 Returns context-specific regression models
17 - beta (numpy.ndarray): (n, y_dim, x_dim)
18 - mu (numpy.ndarray): (n, y_dim, [1 if normal regression, x_dim if univariate])
19 """
20 preds = super().predict(model, dataloader)
21 return model._params_reshape(preds, dataloader)
23 def predict_y(self, model, dataloader):
24 """
25 Returns context-specific predictions of the response Y
26 - y_hat (numpy.ndarray): (n, y_dim, [1 if normal regression, x_dim if univariate])
27 """
28 preds = super().predict(model, dataloader)
29 return model._y_reshape(preds, dataloader)
32class CorrelationTrainer(RegressionTrainer):
33 """
34 Trains the contextualized.regression correlation lightning_modules
35 """
37 def predict_correlation(self, model, dataloader):
38 """
39 Returns context-specific correlation networks containing Pearson's correlation coefficient
40 - correlation (numpy.ndarray): (n, x_dim, x_dim)
41 """
42 betas, _ = super().predict_params(model, dataloader)
43 signs = np.sign(betas)
44 signs[signs != np.transpose(signs, (0, 2, 1))] = (
45 0 # remove asymmetric estimations
46 )
47 correlations = signs * np.sqrt(np.abs(betas * np.transpose(betas, (0, 2, 1))))
48 return correlations
51class MarkovTrainer(CorrelationTrainer):
52 """
53 Trains the contextualized.regression markov graph lightning_modules
54 """
56 def predict_precision(self, model, dataloader):
57 """
58 Returns context-specific precision matrix under a Gaussian graphical model
59 Assuming all diagonal precisions are equal and constant over context,
60 this is equivalent to the negative of the multivariate regression coefficient.
61 - precision (numpy.ndarray): (n, x_dim, x_dim)
62 """
63 # A trick in the markov lightning_module predict_step makes makes the predict_correlation
64 # output equivalent to negative precision values here.
65 return -super().predict_correlation(model, dataloader)