Coverage for contextualized/baselines/networks.py: 94%

171 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-21 13:38 -0400

1""" 

2Baseline network models for learning the causal structure of data. 

3Includes: 

4 - CorrelationNetwork: learns the correlation matrix of the data 

5 - MarkovNetwork: learns the Markov blanket of each variable 

6 - BayesianNetwork: learns the DAG structure of the data 

7 

8 - GroupedNetworks: learns a separate model for each group of data. This is a wrapper around the above models. 

9""" 

10 

11import numpy as np 

12from sklearn.linear_model import LinearRegression 

13import torch 

14from torch import nn 

15from torch.utils.data import DataLoader 

16import pytorch_lightning as pl 

17from contextualized.dags.graph_utils import project_to_dag_torch, dag_pred 

18from contextualized.dags.losses import dag_loss_notears, mse_loss, l1_loss 

19 

20 

21class NOTEARSTrainer(pl.Trainer): 

22 def predict(self, model, dataloader): 

23 preds = super().predict(model, dataloader) 

24 return torch.cat(preds) 

25 

26 def predict_w(self, model, dataloader, project_to_dag=True): 

27 preds = self.predict(model, dataloader) 

28 W = model.W.detach() * model.diag_mask 

29 if project_to_dag: 

30 W = torch.tensor(project_to_dag_torch(W.numpy(force=True))[0]) 

31 W_batch = W.unsqueeze(0).expand(len(preds), -1, -1) 

32 return W_batch.numpy() 

33 

34 

35class NOTEARS(pl.LightningModule): 

36 """ 

37 NOTEARS model for learning the DAG structure of data. 

38 """ 

39 

40 def __init__(self, x_dim, l1=1e-3, alpha=1e-8, rho=1e-8, learning_rate=1e-2): 

41 super().__init__() 

42 self.learning_rate = learning_rate 

43 self.l1 = l1 

44 self.alpha = alpha 

45 self.rho = rho 

46 diag_mask = torch.ones(x_dim, x_dim) - torch.eye(x_dim) 

47 self.register_buffer("diag_mask", diag_mask) 

48 init_mat = (torch.rand(x_dim, x_dim) * 2e-2 - 1e-2) * diag_mask 

49 self.W = nn.parameter.Parameter(init_mat, requires_grad=True) 

50 self.tolerance = 0.25 

51 self.prev_dag = 0.0 

52 

53 def forward(self, X): 

54 W = self.W * self.diag_mask 

55 return dag_pred(X, W) 

56 

57 def _batch_loss(self, batch, batch_idx): 

58 x_true = batch 

59 x_pred = self(x_true) 

60 mse = mse_loss(x_true, x_pred) 

61 l1 = l1_loss(self.W, self.l1) 

62 dag = dag_loss_notears(self.W, alpha=self.alpha, rho=self.rho) 

63 return 0.5 * mse + l1 + dag 

64 

65 def configure_optimizers(self): 

66 return torch.optim.Adam(self.parameters(), lr=self.learning_rate) 

67 

68 def training_step(self, batch, batch_idx): 

69 loss = self._batch_loss(batch, batch_idx) 

70 self.log_dict({"train_loss": loss}) 

71 return loss 

72 

73 def validation_step(self, batch, batch_idx): 

74 x_true = batch 

75 x_pred = self(x_true).detach() 

76 mse = mse_loss(x_true, x_pred) 

77 dag = dag_loss_notears(self.W, alpha=1e12, rho=1e12).detach() 

78 loss = mse + dag 

79 self.log_dict({"val_loss": loss}) 

80 return loss 

81 

82 def test_step(self, batch, batch_idx): 

83 loss = self._batch_loss(batch, batch_idx) 

84 self.log_dict({"test_loss": loss}) 

85 return loss 

86 

87 def on_train_epoch_end(self, *args, **kwargs): 

88 dag = dag_loss_notears(self.W, alpha=self.alpha, rho=self.rho).item() 

89 if ( 

90 dag > self.tolerance * self.prev_dag 

91 and self.alpha < 1e12 

92 and self.rho < 1e12 

93 ): 

94 self.alpha = self.alpha + self.rho * dag 

95 self.rho = self.rho * 10 

96 self.prev_dag = dag 

97 

98 def dataloader(self, X, **kwargs): 

99 kwargs["batch_size"] = kwargs.get("batch_size", 32) 

100 X_tensor = torch.Tensor(X).to(torch.float32) 

101 return DataLoader(dataset=X_tensor, **kwargs) 

102 

103 

104class CorrelationNetwork: 

105 """ 

106 Standard correlation network fit with linear regression. 

107 """ 

108 

109 def fit(self, X): 

110 self.p = X.shape[-1] 

111 self.regs = [[LinearRegression() for _ in range(self.p)] for _ in range(self.p)] 

112 for i in range(self.p): 

113 for j in range(self.p): 

114 self.regs[i][j].fit(X[:, j, np.newaxis], X[:, i, np.newaxis]) 

115 return self 

116 

117 def predict(self, n): 

118 betas = np.zeros((self.p, self.p)) 

119 for i in range(self.p): 

120 for j in range(self.p): 

121 betas[i, j] = self.regs[i][j].coef_.squeeze() 

122 corrs = betas * betas.T 

123 return np.tile(np.expand_dims(corrs, axis=0), (n, 1, 1)) 

124 

125 def measure_mses(self, X): 

126 mses = np.zeros(len(X)) 

127 for i in range(self.p): 

128 for j in range(self.p): 

129 residual = ( 

130 self.regs[i][j].predict(X[:, j, np.newaxis]) - X[:, i, np.newaxis] 

131 ) 

132 residual = residual[:, 0] 

133 mses += (residual**2) / self.p**2 

134 return mses 

135 

136 

137class MarkovNetwork: 

138 """ 

139 Standard Markov Network fit with neighborhood regression. 

140 """ 

141 

142 def __init__(self, alpha=1e-3): 

143 self.alpha = alpha 

144 self.p = -1 

145 self.regs = [] 

146 

147 def fit(self, X): 

148 self.p = X.shape[-1] 

149 self.regs = [LinearRegression() for _ in range(self.p)] 

150 for i in range(self.p): 

151 mask = np.ones_like(X) 

152 mask[:, i] = 0 

153 self.regs[i].fit(X * mask, X[:, i, np.newaxis]) 

154 return self 

155 

156 def predict(self, n): 

157 betas = np.zeros((self.p, self.p)) 

158 for i in range(self.p): 

159 betas[i] = self.regs[i].coef_.squeeze() 

160 betas[i, i] = 0 

161 precision = -np.sign(betas) * np.sqrt(np.abs(betas * betas.T)) 

162 return np.tile(np.expand_dims(precision, axis=0), (n, 1, 1)) 

163 

164 def measure_mses(self, X): 

165 mses = np.zeros(len(X)) 

166 for i in range(self.p): 

167 mask = np.ones_like(X) 

168 mask[:, i] = 0 

169 residual = self.regs[i].predict(X * mask) - X[:, i, np.newaxis] 

170 residual = residual[:, 0] 

171 mses += (residual**2) / self.p 

172 return mses 

173 

174 

175class BayesianNetwork: 

176 """ 

177 A standard Bayesian Network fit with NOTEARS loss. 

178 """ 

179 

180 def __init__(self, **kwargs): 

181 self.p = -1 

182 self.model = None 

183 self.trainer = None 

184 self.l1 = kwargs.get("l1", 1e-3) 

185 self.alpha = kwargs.get("alpha", 1e-8) 

186 self.rho = kwargs.get("rho", 1e-8) 

187 self.learning_rate = kwargs.get("learning_rate", 1e-2) 

188 

189 def fit(self, X, max_epochs=50): 

190 self.p = X.shape[-1] 

191 self.model = NOTEARS( 

192 self.p, 

193 l1=self.l1, 

194 alpha=self.alpha, 

195 rho=self.rho, 

196 learning_rate=self.learning_rate, 

197 ) 

198 dataset = self.model.dataloader(X) 

199 accelerator = "gpu" if torch.cuda.is_available() else "cpu" 

200 self.trainer = NOTEARSTrainer( 

201 max_epochs=max_epochs, accelerator=accelerator, devices=1 

202 ) 

203 self.trainer.fit(self.model, dataset) 

204 return self 

205 

206 def predict(self, n): 

207 dummy_X = np.zeros((n, self.p)) 

208 dummy_dataset = self.model.dataloader(dummy_X) 

209 return self.trainer.predict_w(self.model, dummy_dataset) 

210 

211 def measure_mses(self, X): 

212 mses = np.zeros(len(X)) 

213 W_pred = self.model.W.detach() 

214 X_preds = ( 

215 dag_pred(torch.tensor(X, dtype=torch.float32), W_pred).detach().numpy() 

216 ) 

217 return ((X_preds - X) ** 2).mean(axis=1) 

218 

219 

220class GroupedNetworks: 

221 """ 

222 Fit a separate network for each group. 

223 Wrapper around CorrelationNetwork, MarkovNetwork, or BayesianNetwork. 

224 Assumes that the labels are 0-indexed integers and already learned. 

225 """ 

226 

227 def __init__(self, model_class): 

228 self.model_class = model_class 

229 self.models = {} 

230 self.p = -1 

231 

232 def fit(self, X, labels): 

233 self.p = X.shape[-1] 

234 for label in np.unique(labels): 

235 model = self.model_class().fit(X[labels == label]) 

236 self.models[label] = model 

237 return self 

238 

239 def predict(self, labels): 

240 networks = np.zeros((len(labels), self.p, self.p)) 

241 for label in np.unique(labels): 

242 label_idx = labels == label 

243 networks[label_idx] = self.models[label].predict(label_idx.sum()) 

244 return networks 

245 

246 def measure_mses(self, X, labels): 

247 mses = np.zeros(len(X)) 

248 for label in np.unique(labels): 

249 label_idx = labels == label 

250 mses[label_idx] = self.models[label].measure_mses(X[label_idx]) 

251 return mses