Coverage for contextualized/dags/tests.py: 99%

131 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-21 13:49 -0400

1""" 

2Unit tests for DAG models. 

3""" 

4 

5import unittest 

6import numpy as np 

7import igraph as ig 

8from pytorch_lightning import seed_everything 

9from pytorch_lightning.callbacks import LearningRateFinder 

10 

11 

12from contextualized.dags.lightning_modules import NOTMAD 

13from contextualized.dags import graph_utils 

14from contextualized.dags.trainers import GraphTrainer 

15from contextualized.dags.losses import mse_loss as mse 

16 

17 

18class TestNOTMADFast(unittest.TestCase): 

19 """Unit tests for NOTMAD.""" 

20 

21 def __init__(self, *args, **kwargs): 

22 super(TestNOTMADFast, self).__init__(*args, **kwargs) 

23 

24 def setUp(self): 

25 seed_everything(0) 

26 self.n = 10 

27 self.c_dim = 4 

28 self.x_dim = 3 

29 self.C = np.random.uniform(-1, 1, size=(self.n, self.c_dim)) 

30 self.X = np.random.uniform(-1, 1, size=(self.n, self.x_dim)) 

31 

32 def _train(self, model_args, n_epochs): 

33 k = 6 

34 INIT_MAT = np.random.uniform(-0.1, 0.1, size=(k, 4, 4)) * np.tile( 

35 1 - np.eye(4), (k, 1, 1) 

36 ) 

37 model = NOTMAD( 

38 self.C.shape[-1], 

39 self.X.shape[-1], 

40 archetype_params={ 

41 "l1": 0.0, 

42 "dag": model_args.get( 

43 "dag", 

44 { 

45 "loss_type": "NOTEARS", 

46 "params": { 

47 "alpha": 1e-1, 

48 "rho": 1e-2, 

49 "h_old": 0.0, 

50 "tol": 0.25, 

51 "use_dynamic_alpha_rho": True, 

52 }, 

53 }, 

54 ), 

55 "init_mat": INIT_MAT, 

56 "num_factors": model_args.get("num_factors", 0), 

57 "factor_mat_l1": 0.0, 

58 "num_archetypes": model_args.get("num_archetypes", k), 

59 }, 

60 ) 

61 dataloader = model.dataloader(self.C, self.X, batch_size=1, num_workers=0) 

62 trainer = GraphTrainer( 

63 max_epochs=n_epochs, deterministic=True, enable_progress_bar=False 

64 ) 

65 predict_trainer = GraphTrainer( 

66 deterministic=True, enable_progress_bar=False, devices=1 

67 ) 

68 init_preds = predict_trainer.predict_params( 

69 model, 

70 dataloader, 

71 project_to_dag=False, 

72 ) 

73 assert init_preds.shape == (self.n, self.x_dim, self.x_dim) 

74 init_mse = mse(graph_utils.dag_pred_np(self.X, init_preds), self.X) 

75 trainer.fit(model, dataloader) 

76 final_preds = predict_trainer.predict_params( 

77 model, dataloader, project_to_dag=False 

78 ) 

79 assert final_preds.shape == (self.n, self.x_dim, self.x_dim) 

80 final_mse = mse(graph_utils.dag_pred_np(self.X, final_preds), self.X) 

81 assert final_mse < init_mse 

82 

83 def test_notmad_dagma(self): 

84 self._train( 

85 { 

86 "dag": { 

87 "loss_type": "DAGMA", 

88 "params": { 

89 "alpha": 1.0, 

90 }, 

91 } 

92 }, 

93 1, 

94 ) 

95 

96 def test_notmad_notears(self): 

97 self._train({}, 1) 

98 

99 def test_notmad_factor_graphs(self): 

100 """ 

101 Unit tests for factor graph feature of NOTMAD. 

102 """ 

103 self._train({"num_factors": 3}, 1) 

104 

105 

106class TestNOTMAD(unittest.TestCase): 

107 """Unit tests for NOTMAD.""" 

108 

109 def __init__(self, *args, **kwargs): 

110 super(TestNOTMAD, self).__init__(*args, **kwargs) 

111 

112 def setUp(self): 

113 ( 

114 self.C, 

115 self.X, 

116 self.W, 

117 self.train_idx, 

118 self.test_idx, 

119 self.val_idx, 

120 ) = self._create_cwx_dataset() 

121 self.C_train, self.C_test, self.C_val = ( 

122 self.C[self.train_idx], 

123 self.C[self.test_idx], 

124 self.C[self.val_idx], 

125 ) 

126 self.X_train, self.X_test, self.X_val = ( 

127 self.X[self.train_idx], 

128 self.X[self.test_idx], 

129 self.X[self.val_idx], 

130 ) 

131 self.W_train, self.W_test, self.W_val = ( 

132 self.W[self.train_idx], 

133 self.W[self.test_idx], 

134 self.W[self.val_idx], 

135 ) 

136 

137 def _create_cwx_dataset(self, n=500): 

138 np.random.seed(0) 

139 C = np.linspace(1, 2, n).reshape((n, 1)) 

140 W = np.zeros((4, 4, n, 1)) 

141 W[0, 1] = C - 2 

142 W[2, 1] = 2 * C 

143 W[3, 1] = C**2 

144 W[3, 2] = C 

145 W = np.squeeze(W) 

146 

147 W = np.transpose(W, (2, 0, 1)) 

148 X = np.zeros((n, 4)) 

149 for i, w in enumerate(W): 

150 x = graph_utils.simulate_linear_sem(w, 1, "uniform", noise_scale=0.1)[0] 

151 X[i] = x 

152 train_idx = np.argwhere(np.logical_or(C < 1.7, C >= 1.9)[:, 0])[:, 0] 

153 np.random.shuffle(train_idx) 

154 test_idx = np.argwhere(np.logical_and(C >= 1.8, C < 1.9)[:, 0])[:, 0] 

155 val_idx = np.argwhere(np.logical_and(C >= 1.7, C < 1.8)[:, 0])[:, 0] 

156 return ( 

157 C, 

158 X, 

159 W, 

160 train_idx, 

161 test_idx, 

162 val_idx, 

163 ) 

164 

165 def _evaluate(self, train_preds, test_preds, val_preds): 

166 return ( 

167 mse(train_preds, self.W_train), 

168 mse(test_preds, self.W_test), 

169 mse(val_preds, self.W_val), 

170 mse(graph_utils.dag_pred_np(self.X_train, train_preds), self.X_train), 

171 mse(graph_utils.dag_pred_np(self.X_test, test_preds), self.X_test), 

172 mse(graph_utils.dag_pred_np(self.X_val, val_preds), self.X_val), 

173 ) 

174 

175 def _train(self, model_args, n_epochs): 

176 seed_everything(0) 

177 k = 6 

178 INIT_MAT = np.random.uniform(-0.1, 0.1, size=(k, 4, 4)) * np.tile( 

179 1 - np.eye(4), (k, 1, 1) 

180 ) 

181 model = NOTMAD( 

182 self.C.shape[-1], 

183 self.X.shape[-1], 

184 archetype_params={ 

185 "l1": 0.0, 

186 "dag": model_args.get( 

187 "dag", 

188 { 

189 "loss_type": "NOTEARS", 

190 "params": { 

191 "alpha": 1e-1, 

192 "rho": 1e-2, 

193 "h_old": 0.0, 

194 "tol": 0.25, 

195 "use_dynamic_alpha_rho": True, 

196 }, 

197 }, 

198 ), 

199 "init_mat": INIT_MAT, 

200 "num_factors": model_args.get("num_factors", 0), 

201 "factor_mat_l1": 0.0, 

202 "num_archetypes": model_args.get("num_archetypes", k), 

203 }, 

204 ) 

205 train_dataloader = model.dataloader( 

206 self.C_train, self.X_train, batch_size=1, num_workers=0 

207 ) 

208 test_dataloader = model.dataloader( 

209 self.C_test, self.X_test, batch_size=10, num_workers=0 

210 ) 

211 val_dataloader = model.dataloader( 

212 self.C_val, self.X_val, batch_size=10, num_workers=0 

213 ) 

214 trainer = GraphTrainer( 

215 max_epochs=n_epochs, deterministic=True, enable_progress_bar=False 

216 ) 

217 predict_trainer = GraphTrainer( 

218 deterministic=True, enable_progress_bar=False, devices=1 

219 ) 

220 preds_train = predict_trainer.predict_params( 

221 model, train_dataloader, project_to_dag=True 

222 ) 

223 preds_test = predict_trainer.predict_params( 

224 model, test_dataloader, project_to_dag=True 

225 ) 

226 preds_val = predict_trainer.predict_params( 

227 model, val_dataloader, project_to_dag=True 

228 ) 

229 init_train_l2, init_test_l2, init_val_l2, init_train_mse, _, _ = self._evaluate( 

230 preds_train, preds_test, preds_val 

231 ) 

232 trainer.fit(model, train_dataloader) 

233 trainer.validate(model, val_dataloader) 

234 trainer.test(model, test_dataloader) 

235 

236 # Evaluate results 

237 preds_train = predict_trainer.predict_params( 

238 model, train_dataloader, project_to_dag=True 

239 ) 

240 preds_test = predict_trainer.predict_params( 

241 model, test_dataloader, project_to_dag=True 

242 ) 

243 preds_val = predict_trainer.predict_params( 

244 model, val_dataloader, project_to_dag=True 

245 ) 

246 

247 return ( 

248 preds_train, 

249 preds_test, 

250 preds_val, 

251 init_train_l2, 

252 init_test_l2, 

253 init_val_l2, 

254 ) 

255 

256 def test_notmad_dagma(self): 

257 train_preds, test_preds, val_preds, _, _, _ = self._train( 

258 { 

259 "dag": { 

260 "loss_type": "DAGMA", 

261 "params": { 

262 "alpha": 1.0, 

263 }, 

264 } 

265 }, 

266 10, 

267 ) 

268 train_l2, test_l2, val_l2, train_mse, test_mse, val_mse = self._evaluate( 

269 train_preds, test_preds, val_preds 

270 ) 

271 print(f"Train L2: {train_l2}") 

272 print(f"Test L2: {test_l2}") 

273 print(f"Val L2: {val_l2}") 

274 print(f"Train mse: {train_mse}") 

275 print(f"Test mse: {test_mse}") 

276 print(f"Val mse: {val_mse}") 

277 assert train_l2 < 1e-1 

278 assert test_l2 < 1e-1 

279 assert val_l2 < 1e-1 

280 assert train_mse < 1e-2 

281 assert test_mse < 1e-2 

282 assert val_mse < 1e-2 

283 

284 def test_notmad_notears(self): 

285 train_preds, test_preds, val_preds, _, _, _ = self._train({}, 10) 

286 train_l2, test_l2, val_l2, train_mse, test_mse, val_mse = self._evaluate( 

287 train_preds, test_preds, val_preds 

288 ) 

289 print(f"Train L2: {train_l2}") 

290 print(f"Test L2: {test_l2}") 

291 print(f"Val L2: {val_l2}") 

292 print(f"Train mse: {train_mse}") 

293 print(f"Test mse: {test_mse}") 

294 print(f"Val mse: {val_mse}") 

295 assert train_l2 < 1e-1 

296 assert test_l2 < 1e-1 

297 assert val_l2 < 1e-1 

298 assert train_mse < 1e-2 

299 assert test_mse < 1e-2 

300 assert val_mse < 1e-2 

301 

302 def test_notmad_factor_graphs(self): 

303 """ 

304 Unit tests for factor graph feature of NOTMAD. 

305 """ 

306 ( 

307 train_preds, 

308 test_preds, 

309 val_preds, 

310 init_train_l2, 

311 init_test_l2, 

312 init_val_l2, 

313 ) = self._train({"num_factors": 3}, 10) 

314 train_l2, test_l2, val_l2, _, _, _ = self._evaluate( 

315 train_preds, test_preds, val_preds 

316 ) 

317 assert train_preds.shape == self.W_train.shape 

318 assert val_preds.shape == self.W_val.shape 

319 assert test_preds.shape == self.W_test.shape 

320 assert train_l2 < init_train_l2 

321 assert test_l2 < init_test_l2 

322 assert val_l2 < init_val_l2 

323 

324 

325if __name__ == "__main__": 

326 unittest.main()