Coverage for contextualized/regression/trainers.py: 100%

19 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-21 13:49 -0400

1""" 

2PyTorch-Lightning trainers used for Contextualized regression. 

3""" 

4 

5import numpy as np 

6import pytorch_lightning as pl 

7 

8 

9class RegressionTrainer(pl.Trainer): 

10 """ 

11 Trains the contextualized.regression lightning_modules 

12 """ 

13 

14 def predict_params(self, model, dataloader): 

15 """ 

16 Returns context-specific regression models 

17 - beta (numpy.ndarray): (n, y_dim, x_dim) 

18 - mu (numpy.ndarray): (n, y_dim, [1 if normal regression, x_dim if univariate]) 

19 """ 

20 preds = super().predict(model, dataloader) 

21 return model._params_reshape(preds, dataloader) 

22 

23 def predict_y(self, model, dataloader): 

24 """ 

25 Returns context-specific predictions of the response Y 

26 - y_hat (numpy.ndarray): (n, y_dim, [1 if normal regression, x_dim if univariate]) 

27 """ 

28 preds = super().predict(model, dataloader) 

29 return model._y_reshape(preds, dataloader) 

30 

31 

32class CorrelationTrainer(RegressionTrainer): 

33 """ 

34 Trains the contextualized.regression correlation lightning_modules 

35 """ 

36 

37 def predict_correlation(self, model, dataloader): 

38 """ 

39 Returns context-specific correlation networks containing Pearson's correlation coefficient 

40 - correlation (numpy.ndarray): (n, x_dim, x_dim) 

41 """ 

42 betas, _ = super().predict_params(model, dataloader) 

43 signs = np.sign(betas) 

44 signs[signs != np.transpose(signs, (0, 2, 1))] = ( 

45 0 # remove asymmetric estimations 

46 ) 

47 correlations = signs * np.sqrt(np.abs(betas * np.transpose(betas, (0, 2, 1)))) 

48 return correlations 

49 

50 

51class MarkovTrainer(CorrelationTrainer): 

52 """ 

53 Trains the contextualized.regression markov graph lightning_modules 

54 """ 

55 

56 def predict_precision(self, model, dataloader): 

57 """ 

58 Returns context-specific precision matrix under a Gaussian graphical model 

59 Assuming all diagonal precisions are equal and constant over context, 

60 this is equivalent to the negative of the multivariate regression coefficient. 

61 - precision (numpy.ndarray): (n, x_dim, x_dim) 

62 """ 

63 # A trick in the markov lightning_module predict_step makes makes the predict_correlation 

64 # output equivalent to negative precision values here. 

65 return -super().predict_correlation(model, dataloader)