Coverage for contextualized/regression/tests.py: 99%

144 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-21 13:49 -0400

1""" 

2Unit tests for Contextualized Regression. 

3""" 

4 

5import unittest 

6import numpy as np 

7import torch 

8 

9# from contextualized.modules import NGAM, MLP, SoftSelect, Explainer 

10from contextualized.regression.lightning_modules import * 

11from contextualized.regression.trainers import * 

12from contextualized.functions import LINK_FUNCTIONS 

13from contextualized.utils import DummyParamPredictor, DummyYPredictor 

14 

15 

16class TestRegression(unittest.TestCase): 

17 """ 

18 Test regression models. 

19 """ 

20 

21 def __init__(self, *args, **kwargs): 

22 super().__init__(*args, **kwargs) 

23 

24 def setUp(self): 

25 """ 

26 Shared unit test setup code. 

27 """ 

28 np.random.seed(0) 

29 torch.manual_seed(0) 

30 n = 100 

31 c_dim = 4 

32 x_dim = 5 

33 y_dim = 3 

34 C = torch.rand((n, c_dim)) - 0.5 

35 W_1 = C.sum(axis=1).unsqueeze(-1) ** 2 

36 W_2 = -C.sum(axis=1).unsqueeze(-1) 

37 b_1 = C[:, 0].unsqueeze(-1) 

38 b_2 = C[:, 1].unsqueeze(-1) 

39 # W_full = torch.cat((W_1, W_2), axis=1) 

40 # b_full = b_1 + b_2 

41 X = torch.rand((n, x_dim)) - 0.5 

42 Y_1 = X[:, 0].unsqueeze(-1) * W_1 + b_1 

43 Y_2 = X[:, 1].unsqueeze(-1) * W_2 + b_2 

44 Y_3 = X.sum(axis=1).unsqueeze(-1) 

45 Y = torch.cat((Y_1, Y_2, Y_3), axis=1) 

46 

47 self.k = 10 

48 self.epochs = 1 

49 self.batch_size = 32 

50 self.c_dim, self.x_dim, self.y_dim = c_dim, x_dim, y_dim 

51 self.C, self.X, self.Y = C.numpy(), X.numpy(), Y.numpy() 

52 

53 def _quicktest(self, model, univariate=False, correlation=False, markov=False): 

54 """ 

55 

56 :param model: 

57 :param univariate: (Default value = False) 

58 :param correlation: (Default value = False) 

59 :param markov: (Default value = False) 

60 

61 """ 

62 print(f"\n{type(model)} quicktest") 

63 if correlation: 

64 dataloader = model.dataloader(self.C, self.X, batch_size=self.batch_size) 

65 trainer = CorrelationTrainer( 

66 max_epochs=self.epochs, enable_progress_bar=False 

67 ) 

68 y_true = np.tile(self.X[:, :, np.newaxis], (1, 1, self.X.shape[-1])) 

69 elif markov: 

70 dataloader = model.dataloader(self.C, self.X, batch_size=self.batch_size) 

71 trainer = MarkovTrainer(max_epochs=self.epochs, enable_progress_bar=False) 

72 y_true = self.X 

73 else: 

74 dataloader = model.dataloader( 

75 self.C, self.X, self.Y, batch_size=self.batch_size 

76 ) 

77 trainer = RegressionTrainer( 

78 max_epochs=self.epochs, enable_progress_bar=False 

79 ) 

80 if univariate: 

81 y_true = np.tile(self.Y[:, :, np.newaxis], (1, 1, self.X.shape[-1])) 

82 else: 

83 y_true = self.Y 

84 y_preds = trainer.predict_y(model, dataloader) 

85 err_init = ((y_true - y_preds) ** 2).mean() 

86 trainer.fit(model, dataloader) 

87 trainer.validate(model, dataloader) 

88 trainer.test(model, dataloader) 

89 beta_preds, mu_preds = trainer.predict_params(model, dataloader) 

90 if correlation: 

91 rhos = trainer.predict_correlation(model, dataloader) 

92 if markov: 

93 omegas = trainer.predict_precision(model, dataloader) 

94 y_preds = trainer.predict_y(model, dataloader) 

95 err_trained = ((y_true - y_preds) ** 2).mean() 

96 assert err_trained < err_init, "Model failed to converge" 

97 

98 def test_naive(self): 

99 """ 

100 Test Naive Multivariate regression. 

101 """ 

102 # Naive Multivariate 

103 model = NaiveContextualizedRegression( 

104 self.c_dim, 

105 self.x_dim, 

106 self.y_dim, 

107 encoder_kwargs={ 

108 "width": 25, 

109 "layers": 2, 

110 "link_fn": LINK_FUNCTIONS["identity"], 

111 }, 

112 link_fn=LINK_FUNCTIONS["identity"], 

113 ) 

114 self._quicktest(model) 

115 

116 model = NaiveContextualizedRegression( 

117 self.c_dim, 

118 self.x_dim, 

119 self.y_dim, 

120 encoder_type="ngam", 

121 encoder_kwargs={ 

122 "width": 25, 

123 "layers": 2, 

124 "link_fn": LINK_FUNCTIONS["identity"], 

125 }, 

126 link_fn=LINK_FUNCTIONS["identity"], 

127 ) 

128 self._quicktest(model) 

129 

130 model = NaiveContextualizedRegression( 

131 self.c_dim, 

132 self.x_dim, 

133 self.y_dim, 

134 encoder_kwargs={ 

135 "width": 25, 

136 "layers": 2, 

137 "link_fn": LINK_FUNCTIONS["softmax"], 

138 }, 

139 link_fn=LINK_FUNCTIONS["identity"], 

140 ) 

141 self._quicktest(model) 

142 

143 model = NaiveContextualizedRegression( 

144 self.c_dim, 

145 self.x_dim, 

146 self.y_dim, 

147 encoder_kwargs={ 

148 "width": 25, 

149 "layers": 2, 

150 "link_fn": LINK_FUNCTIONS["identity"], 

151 }, 

152 link_fn=LINK_FUNCTIONS["logistic"], 

153 ) 

154 self._quicktest(model) 

155 

156 model = NaiveContextualizedRegression( 

157 self.c_dim, 

158 self.x_dim, 

159 self.y_dim, 

160 encoder_kwargs={ 

161 "width": 25, 

162 "layers": 2, 

163 "link_fn": LINK_FUNCTIONS["softmax"], 

164 }, 

165 link_fn=LINK_FUNCTIONS["logistic"], 

166 ) 

167 self._quicktest(model) 

168 

169 parambase = DummyParamPredictor((self.y_dim, self.x_dim), (self.y_dim, 1)) 

170 model = NaiveContextualizedRegression( 

171 self.c_dim, 

172 self.x_dim, 

173 self.y_dim, 

174 encoder_kwargs={ 

175 "width": 25, 

176 "layers": 2, 

177 "link_fn": LINK_FUNCTIONS["softmax"], 

178 }, 

179 link_fn=LINK_FUNCTIONS["logistic"], 

180 base_param_predictor=parambase, 

181 ) 

182 self._quicktest(model) 

183 

184 ybase = DummyYPredictor((self.y_dim, 1)) 

185 model = NaiveContextualizedRegression( 

186 self.c_dim, 

187 self.x_dim, 

188 self.y_dim, 

189 encoder_kwargs={ 

190 "width": 25, 

191 "layers": 2, 

192 "link_fn": LINK_FUNCTIONS["softmax"], 

193 }, 

194 link_fn=LINK_FUNCTIONS["logistic"], 

195 base_y_predictor=ybase, 

196 ) 

197 self._quicktest(model) 

198 

199 def test_subtype(self): 

200 """ 

201 Test subtype multivariate regression. 

202 """ 

203 # Subtype Multivariate 

204 parambase = DummyParamPredictor((self.y_dim, self.x_dim), (self.y_dim, 1)) 

205 ybase = DummyYPredictor((self.y_dim, 1)) 

206 model = ContextualizedRegression( 

207 self.c_dim, 

208 self.x_dim, 

209 self.y_dim, 

210 base_param_predictor=parambase, 

211 base_y_predictor=ybase, 

212 ) 

213 self._quicktest(model) 

214 

215 def test_multitask(self): 

216 """ 

217 Test multitask multivariate regression. 

218 """ 

219 # Multitask Multivariate 

220 parambase = DummyParamPredictor((self.x_dim,), (1,)) 

221 ybase = DummyYPredictor((1,)) 

222 model = MultitaskContextualizedRegression( 

223 self.c_dim, 

224 self.x_dim, 

225 self.y_dim, 

226 base_param_predictor=parambase, 

227 base_y_predictor=ybase, 

228 ) 

229 self._quicktest(model) 

230 

231 def test_tasksplit(self): 

232 """ 

233 Test tasksplit multivariate regression. 

234 """ 

235 # Tasksplit Multivariate 

236 parambase = DummyParamPredictor((self.x_dim,), (1,)) 

237 ybase = DummyYPredictor((1,)) 

238 model = TasksplitContextualizedRegression( 

239 self.c_dim, 

240 self.x_dim, 

241 self.y_dim, 

242 base_param_predictor=parambase, 

243 base_y_predictor=ybase, 

244 ) 

245 self._quicktest(model) 

246 

247 def test_univariate_subtype(self): 

248 """ 

249 Test naive univariate regression. 

250 """ 

251 # Naive Univariate 

252 parambase = DummyParamPredictor( 

253 (self.y_dim, self.x_dim, 1), (self.y_dim, self.x_dim, 1) 

254 ) 

255 ybase = DummyYPredictor((self.y_dim, self.x_dim, 1)) 

256 model = ContextualizedUnivariateRegression( 

257 self.c_dim, 

258 self.x_dim, 

259 self.y_dim, 

260 base_param_predictor=parambase, 

261 base_y_predictor=ybase, 

262 ) 

263 self._quicktest(model, univariate=True) 

264 

265 def test_univariate_tasksplit(self): 

266 """ 

267 Test task-split univariate regression. 

268 """ 

269 # Tasksplit Univariate 

270 parambase = DummyParamPredictor((1,), (1,)) 

271 ybase = DummyYPredictor((1,)) 

272 model = TasksplitContextualizedUnivariateRegression( 

273 self.c_dim, 

274 self.x_dim, 

275 self.y_dim, 

276 base_param_predictor=parambase, 

277 base_y_predictor=ybase, 

278 ) 

279 self._quicktest(model, univariate=True) 

280 

281 def test_correlation_subtype(self): 

282 """ 

283 Test correlation. 

284 """ 

285 # Correlation 

286 parambase = DummyParamPredictor( 

287 (self.x_dim, self.x_dim, 1), (self.x_dim, self.x_dim, 1) 

288 ) 

289 ybase = DummyYPredictor((self.x_dim, self.x_dim, 1)) 

290 model = ContextualizedCorrelation( 

291 self.c_dim, 

292 self.x_dim, 

293 base_param_predictor=parambase, 

294 base_y_predictor=ybase, 

295 ) 

296 self._quicktest(model, correlation=True) 

297 

298 def test_correlation_tasksplit(self): 

299 """ 

300 Test task-split correlation. 

301 """ 

302 # Tasksplit Correlation 

303 parambase = DummyParamPredictor((1,), (1,)) 

304 ybase = DummyYPredictor((1,)) 

305 model = TasksplitContextualizedCorrelation( 

306 self.c_dim, 

307 self.x_dim, 

308 base_param_predictor=parambase, 

309 base_y_predictor=ybase, 

310 ) 

311 self._quicktest(model, correlation=True) 

312 

313 def test_markov_subtype(self): 

314 """ 

315 Test Markov Graphs. 

316 """ 

317 # Markov Graph 

318 parambase = DummyParamPredictor((self.x_dim, self.x_dim), (self.x_dim, 1)) 

319 ybase = DummyYPredictor((self.x_dim, 1)) 

320 model = ContextualizedMarkovGraph( 

321 self.c_dim, 

322 self.x_dim, 

323 base_param_predictor=parambase, 

324 base_y_predictor=ybase, 

325 ) 

326 self._quicktest(model, markov=True) 

327 

328 def test_neighborhood_subtype(self): 

329 """ 

330 Test Neighborhood Selection. 

331 """ 

332 parambase = DummyParamPredictor((self.x_dim, self.x_dim), (self.x_dim, 1)) 

333 ybase = DummyYPredictor((self.x_dim, 1)) 

334 model = ContextualizedNeighborhoodSelection( 

335 self.c_dim, 

336 self.x_dim, 

337 base_param_predictor=parambase, 

338 base_y_predictor=ybase, 

339 ) 

340 self._quicktest(model, markov=True) 

341 

342 def test_metamodel_switch(self): 

343 """ 

344 Test switching between meta-models. 

345 """ 

346 parambase = DummyParamPredictor((self.y_dim, self.x_dim), (self.y_dim, 1)) 

347 ybase = DummyYPredictor((self.y_dim, 1)) 

348 model = ContextualizedRegression( 

349 self.c_dim, 

350 self.x_dim, 

351 self.y_dim, 

352 metamodel_type="naive", 

353 base_param_predictor=parambase, 

354 base_y_predictor=ybase, 

355 ) 

356 self._quicktest(model) 

357 

358 parambase = DummyParamPredictor((self.y_dim, self.x_dim), (self.y_dim, 1)) 

359 ybase = DummyYPredictor((self.y_dim, 1)) 

360 model = ContextualizedRegression( 

361 self.c_dim, 

362 self.x_dim, 

363 self.y_dim, 

364 metamodel_type="subtype", 

365 base_param_predictor=parambase, 

366 base_y_predictor=ybase, 

367 ) 

368 self._quicktest(model) 

369 

370 def test_fit_intercept(self): 

371 """ 

372 Test switching between meta-models. 

373 """ 

374 parambase = DummyParamPredictor((self.y_dim, self.x_dim), (self.y_dim, 1)) 

375 ybase = DummyYPredictor((self.y_dim, 1)) 

376 model = ContextualizedRegression( 

377 self.c_dim, 

378 self.x_dim, 

379 self.y_dim, 

380 metamodel_type="naive", 

381 fit_intercept=False, 

382 base_param_predictor=parambase, 

383 base_y_predictor=ybase, 

384 ) 

385 dataloader = model.dataloader( 

386 self.C, self.X, self.Y, batch_size=self.batch_size 

387 ) 

388 trainer = RegressionTrainer(enable_progress_bar=False) 

389 beta_preds, mu_preds = trainer.predict_params(model, dataloader) 

390 assert (mu_preds == 0).all() 

391 self._quicktest(model) 

392 beta_preds, mu_preds = trainer.predict_params(model, dataloader) 

393 assert (mu_preds == 0).all() 

394 

395 

396if __name__ == "__main__": 

397 unittest.main()