Coverage for contextualized/easy/ContextualizedNetworks.py: 95%

97 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-21 13:49 -0400

1""" 

2sklearn-like interface to Contextualized Networks. 

3""" 

4 

5from typing import * 

6 

7import numpy as np 

8 

9from contextualized.easy.wrappers import SKLearnWrapper 

10from contextualized.regression.trainers import CorrelationTrainer, MarkovTrainer 

11from contextualized.regression.lightning_modules import ( 

12 ContextualizedCorrelation, 

13 ContextualizedMarkovGraph, 

14) 

15from contextualized.dags.lightning_modules import ( 

16 NOTMAD, 

17 DEFAULT_DAG_LOSS_TYPE, 

18 DEFAULT_DAG_LOSS_PARAMS, 

19) 

20from contextualized.dags.trainers import GraphTrainer 

21from contextualized.dags.graph_utils import dag_pred_np 

22 

23 

24class ContextualizedNetworks(SKLearnWrapper): 

25 """ 

26 sklearn-like interface to Contextualized Networks. 

27 """ 

28 

29 def _split_train_data( 

30 self, C: np.ndarray, X: np.ndarray, **kwargs 

31 ) -> Tuple[List[np.ndarray], List[np.ndarray]]: 

32 """Splits data into train and test sets. 

33 

34 Args: 

35 C (np.ndarray): Contextual features for each sample. 

36 X (np.ndarray): The data matrix. 

37 

38 Returns: 

39 Tuple[List[np.ndarray], List[np.ndarray]]: The train and test sets for C and X as ([C_train, X_train], [C_test, X_test]). 

40 """ 

41 return super()._split_train_data(C, X, Y_required=False, **kwargs) 

42 

43 def predict_networks( 

44 self, 

45 C: np.ndarray, 

46 with_offsets: bool = False, 

47 individual_preds: bool = False, 

48 **kwargs, 

49 ) -> Union[ 

50 np.ndarray, 

51 List[np.ndarray], 

52 Tuple[np.ndarray, np.ndarray], 

53 Tuple[List[np.ndarray], List[np.ndarray]], 

54 ]: 

55 """Predicts context-specific networks given contextual features. 

56 

57 Args: 

58 C (np.ndarray): Contextual features for each sample (n_samples, n_context_features) 

59 with_offsets (bool, optional): If True, returns both the network parameters and offsets. Defaults to False. 

60 individual_preds (bool, optional): If True, returns the predictions for each bootstrap. Defaults to False. 

61 

62 Returns: 

63 Union[np.ndarray, List[np.ndarray], Tuple[np.ndarray, np.ndarray], Tuple[List[np.ndarray], List[np.ndarray]]]: The predicted network parameters (and offsets if with_offsets is True). Returned as lists of individual bootstraps if individual_preds is True. 

64 """ 

65 betas, mus = self.predict_params( 

66 C, individual_preds=individual_preds, uses_y=False, **kwargs 

67 ) 

68 if with_offsets: 

69 return betas, mus 

70 return betas 

71 

72 def predict_X( 

73 self, C: np.ndarray, X: np.ndarray, individual_preds: bool = False, **kwargs 

74 ) -> Union[np.ndarray, List[np.ndarray]]: 

75 """Reconstructs the data matrix based on predicted contextualized networks and the true data matrix. 

76 Useful for measuring reconstruction error or for imputation. 

77 

78 Args: 

79 C (np.ndarray): Contextual features for each sample (n_samples, n_context_features) 

80 X (np.ndarray): The data matrix (n_samples, n_features) 

81 individual_preds (bool, optional): If True, returns the predictions for each bootstrap. Defaults to False. 

82 **kwargs: Keyword arguments for the Lightning trainer's predict_y method. 

83 

84 Returns: 

85 Union[np.ndarray, List[np.ndarray]]: The predicted data matrix, or matrices for each bootstrap if individual_preds is True (n_samples, n_features). 

86 """ 

87 return self.predict(C, X, individual_preds=individual_preds, **kwargs) 

88 

89 

90class ContextualizedCorrelationNetworks(ContextualizedNetworks): 

91 """ 

92 Contextualized Correlation Networks reveal context-varying feature correlations, interaction strengths, dependencies in feature groups. 

93 Uses the Contextualized Networks model, see the `paper <https://doi.org/10.1101/2023.12.01.569658>`__ for detailed estimation procedures. 

94 

95 Args: 

96 n_bootstraps (int, optional): Number of bootstraps to use. Defaults to 1. 

97 num_archetypes (int, optional): Number of archetypes to use. Defaults to 10. Always uses archetypes in the ContextualizedMetaModel. 

98 encoder_type (str, optional): Type of encoder to use ("mlp", "ngam", "linear"). Defaults to "mlp". 

99 alpha (float, optional): Regularization strength. Defaults to 0.0. 

100 mu_ratio (float, optional): Float in range (0.0, 1.0), governs how much the regularization applies to context-specific parameters or context-specific offsets. Defaults to 0.0. 

101 l1_ratio (float, optional): Float in range (0.0, 1.0), governs how much the regularization penalizes l1 vs l2 parameter norms. Defaults to 0.0. 

102 """ 

103 

104 def __init__(self, **kwargs): 

105 super().__init__( 

106 ContextualizedCorrelation, [], [], CorrelationTrainer, **kwargs 

107 ) 

108 

109 def predict_correlation( 

110 self, C: np.ndarray, individual_preds: bool = True, squared: bool = True 

111 ) -> Union[np.ndarray, List[np.ndarray]]: 

112 """Predicts context-specific correlations between features. 

113 

114 Args: 

115 C (Numpy ndarray): Contextual features for each sample (n_samples, n_context_features) 

116 individual_preds (bool, optional): If True, returns the predictions for each bootstrap. Defaults to True. 

117 squared (bool, optional): If True, returns the squared correlations. Defaults to True. 

118 

119 Returns: 

120 Union[np.ndarray, List[np.ndarray]]: The predicted context-specific correlation matrices, or matrices for each bootstrap if individual_preds is True (n_samples, n_features, n_features). 

121 """ 

122 get_dataloader = lambda i: self.models[i].dataloader( 

123 C, np.zeros((len(C), self.x_dim)) 

124 ) 

125 rhos = np.array( 

126 [ 

127 self.trainers[i].predict_params(self.models[i], get_dataloader(i))[0] 

128 for i in range(len(self.models)) 

129 ] 

130 ) 

131 if individual_preds: 

132 if squared: 

133 return np.square(rhos) 

134 return rhos 

135 else: 

136 if squared: 

137 return np.square(np.mean(rhos, axis=0)) 

138 return np.mean(rhos, axis=0) 

139 

140 def measure_mses( 

141 self, C: np.ndarray, X: np.ndarray, individual_preds: bool = False 

142 ) -> Union[np.ndarray, List[np.ndarray]]: 

143 """Measures mean-squared errors. 

144 

145 Args: 

146 C (np.ndarray): Contextual features for each sample (n_samples, n_context_features) 

147 X (np.ndarray): The data matrix (n_samples, n_features) 

148 individual_preds (bool, optional): If True, returns the predictions for each bootstrap. Defaults to False. 

149 

150 Returns: 

151 Union[np.ndarray, List[np.ndarray]]: The mean-squared errors for each sample, or for each bootstrap if individual_preds is True (n_samples). 

152 """ 

153 betas, mus = self.predict_networks(C, individual_preds=True, with_offsets=True) 

154 mses = np.zeros((len(betas), len(C))) # n_bootstraps x n_samples 

155 for i in range(X.shape[-1]): 

156 for j in range(X.shape[-1]): 

157 tiled_xi = np.array([X[:, i] for _ in range(len(betas))]) 

158 tiled_xj = np.array([X[:, j] for _ in range(len(betas))]) 

159 residuals = tiled_xi - betas[:, :, i, j] * tiled_xj - mus[:, :, i, j] 

160 mses += residuals**2 / (X.shape[-1] ** 2) 

161 if not individual_preds: 

162 mses = np.mean(mses, axis=0) 

163 return mses 

164 

165 

166class ContextualizedMarkovNetworks(ContextualizedNetworks): 

167 """ 

168 Contextualized Markov Networks reveal context-varying feature dependencies, cliques, and modules. 

169 Implemented as Contextualized Gaussian Precision Matrices, directly interpretable as Markov Networks. 

170 Uses the Contextualized Networks model, see the `paper <https://doi.org/10.1101/2023.12.01.569658>`__ for detailed estimation procedures. 

171 

172 Args: 

173 n_bootstraps (int, optional): Number of bootstraps to use. Defaults to 1. 

174 num_archetypes (int, optional): Number of archetypes to use. Defaults to 10. Always uses archetypes in the ContextualizedMetaModel. 

175 encoder_type (str, optional): Type of encoder to use ("mlp", "ngam", "linear"). Defaults to "mlp". 

176 alpha (float, optional): Regularization strength. Defaults to 0.0. 

177 mu_ratio (float, optional): Float in range (0.0, 1.0), governs how much the regularization applies to context-specific parameters or context-specific offsets. Defaults to 0.0. 

178 l1_ratio (float, optional): Float in range (0.0, 1.0), governs how much the regularization penalizes l1 vs l2 parameter norms. Defaults to 0.0. 

179 """ 

180 

181 def __init__(self, **kwargs): 

182 super().__init__(ContextualizedMarkovGraph, [], [], MarkovTrainer, **kwargs) 

183 

184 def predict_precisions( 

185 self, C: np.ndarray, individual_preds: bool = True 

186 ) -> Union[np.ndarray, List[np.ndarray]]: 

187 """Predicts context-specific precision matrices. 

188 Can be converted to context-specific Markov networks by binarizing the networks and setting all non-zero entries to 1. 

189 Can be converted to context-specific covariance matrices by taking the inverse. 

190 

191 Args: 

192 C (np.ndarray): Contextual features for each sample (n_samples, n_context_features) 

193 individual_preds (bool, optional): If True, returns the predictions for each bootstrap. Defaults to True. 

194 

195 Returns: 

196 Union[np.ndarray, List[np.ndarray]]: The predicted context-specific Markov networks as precision matrices, or matrices for each bootstrap if individual_preds is True (n_samples, n_features, n_features). 

197 """ 

198 get_dataloader = lambda i: self.models[i].dataloader( 

199 C, np.zeros((len(C), self.x_dim)) 

200 ) 

201 precisions = np.array( 

202 [ 

203 self.trainers[i].predict_precision(self.models[i], get_dataloader(i)) 

204 for i in range(len(self.models)) 

205 ] 

206 ) 

207 if individual_preds: 

208 return precisions 

209 return np.mean(precisions, axis=0) 

210 

211 def measure_mses( 

212 self, C: np.ndarray, X: np.ndarray, individual_preds: bool = False 

213 ) -> Union[np.ndarray, List[np.ndarray]]: 

214 """Measures mean-squared errors. 

215 

216 Args: 

217 C (np.ndarray): Contextual features for each sample (n_samples, n_context_features) 

218 X (np.ndarray): The data matrix (n_samples, n_features) 

219 individual_preds (bool, optional): If True, returns the predictions for each bootstrap. Defaults to False. 

220 

221 Returns: 

222 Union[np.ndarray, List[np.ndarray]]: The mean-squared errors for each sample, or for each bootstrap if individual_preds is True (n_samples). 

223 """ 

224 betas, mus = self.predict_networks(C, individual_preds=True, with_offsets=True) 

225 mses = np.zeros((len(betas), len(C))) # n_bootstraps x n_samples 

226 for bootstrap in range(len(betas)): 

227 for i in range(X.shape[-1]): 

228 # betas are n_boostraps x n_samples x n_features x n_features 

229 # preds[bootstrap, sample, i] = X[sample, :].dot(betas[bootstrap, sample, i, :]) 

230 preds = np.array( 

231 [ 

232 X[j].dot(betas[bootstrap, j, i, :]) + mus[bootstrap, j, i] 

233 for j in range(len(X)) 

234 ] 

235 ) 

236 residuals = X[:, i] - preds 

237 mses[bootstrap, :] += residuals**2 / (X.shape[-1]) 

238 if not individual_preds: 

239 mses = np.mean(mses, axis=0) 

240 return mses 

241 

242 

243class ContextualizedBayesianNetworks(ContextualizedNetworks): 

244 """ 

245 Contextualized Bayesian Networks and Directed Acyclic Graphs (DAGs) reveal context-dependent causal relationships, effect sizes, and variable ordering. 

246 Uses the NOTMAD model, see the `paper <https://doi.org/10.48550/arXiv.2111.01104>`__ for detailed estimation procedures. 

247 

248 Args: 

249 n_bootstraps (int, optional): Number of bootstraps to use. Defaults to 1. 

250 num_archetypes (int, optional): Number of archetypes to use. Defaults to 16. Always uses archetypes in the ContextualizedMetaModel. 

251 encoder_type (str, optional): Type of encoder to use ("mlp", "ngam", "linear"). Defaults to "mlp". 

252 archetype_dag_loss_type (str, optional): The type of loss to use for the archetype loss. Defaults to "l1". 

253 archetype_l1 (float, optional): The strength of the l1 regularization for the archetype loss. Defaults to 0.0. 

254 archetype_dag_params (dict, optional): Parameters for the archetype loss. Defaults to {"loss_type": "l1", "params": {"alpha": 0.0, "rho": 0.0, "s": 0.0, "tol": 1e-4}}. 

255 archetype_dag_loss_params (dict, optional): Parameters for the archetype loss. Defaults to {"alpha": 0.0, "rho": 0.0, "s": 0.0, "tol": 1e-4}. 

256 archetype_alpha (float, optional): The strength of the alpha regularization for the archetype loss. Defaults to 0.0. 

257 archetype_rho (float, optional): The strength of the rho regularization for the archetype loss. Defaults to 0.0. 

258 archetype_s (float, optional): The strength of the s regularization for the archetype loss. Defaults to 0.0. 

259 archetype_tol (float, optional): The tolerance for the archetype loss. Defaults to 1e-4. 

260 archetype_use_dynamic_alpha_rho (bool, optional): Whether to use dynamic alpha and rho for the archetype loss. Defaults to False. 

261 init_mat (np.ndarray, optional): The initial adjacency matrix for the archetype loss. Defaults to None. 

262 num_factors (int, optional): The number of factors for the archetype loss. Defaults to 0. 

263 factor_mat_l1 (float, optional): The strength of the l1 regularization for the factor matrix for the archetype loss. Defaults to 0. 

264 sample_specific_dag_loss_type (str, optional): The type of loss to use for the sample-specific loss. Defaults to "l1". 

265 sample_specific_alpha (float, optional): The strength of the alpha regularization for the sample-specific loss. Defaults to 0.0. 

266 sample_specific_rho (float, optional): The strength of the rho regularization for the sample-specific loss. Defaults to 0.0. 

267 sample_specific_s (float, optional): The strength of the s regularization for the sample-specific loss. Defaults to 0.0. 

268 sample_specific_tol (float, optional): The tolerance for the sample-specific loss. Defaults to 1e-4. 

269 sample_specific_use_dynamic_alpha_rho (bool, optional): Whether to use dynamic alpha and rho for the sample-specific loss. Defaults to False. 

270 """ 

271 

272 def _parse_private_init_kwargs(self, **kwargs): 

273 """ 

274 Parses the kwargs for the NOTMAD model. 

275 

276 Args: 

277 **kwargs: Keyword arguments for the NOTMAD model, including the encoder, archetype loss, sample-specific loss, and optimization parameters. 

278 """ 

279 # Encoder Parameters 

280 self._init_kwargs["model"]["encoder_kwargs"] = { 

281 "type": kwargs.pop( 

282 "encoder_type", self._init_kwargs["model"]["encoder_type"] 

283 ), 

284 "params": { 

285 "width": self.constructor_kwargs["encoder_kwargs"]["width"], 

286 "layers": self.constructor_kwargs["encoder_kwargs"]["layers"], 

287 "link_fn": self.constructor_kwargs["encoder_kwargs"]["link_fn"], 

288 }, 

289 } 

290 

291 # Archetype-specific parameters 

292 archetype_dag_loss_type = kwargs.pop( 

293 "archetype_dag_loss_type", DEFAULT_DAG_LOSS_TYPE 

294 ) 

295 self._init_kwargs["model"]["archetype_loss_params"] = { 

296 "l1": kwargs.get("archetype_l1", 0.0), 

297 "dag": kwargs.get( 

298 "archetype_dag_params", 

299 { 

300 "loss_type": archetype_dag_loss_type, 

301 "params": kwargs.get( 

302 "archetype_dag_loss_params", 

303 DEFAULT_DAG_LOSS_PARAMS[archetype_dag_loss_type].copy(), 

304 ), 

305 }, 

306 ), 

307 "init_mat": kwargs.pop("init_mat", None), 

308 "num_factors": kwargs.pop("num_factors", 0), 

309 "factor_mat_l1": kwargs.pop("factor_mat_l1", 0), 

310 "num_archetypes": kwargs.pop("num_archetypes", 16), 

311 } 

312 

313 if self._init_kwargs["model"]["archetype_loss_params"]["num_archetypes"] <= 0: 

314 print( 

315 "WARNING: num_archetypes is 0. NOTMAD requires archetypes. Setting num_archetypes to 16." 

316 ) 

317 self._init_kwargs["model"]["archetype_loss_params"]["num_archetypes"] = 16 

318 

319 # Possibly update values with convenience parameters 

320 for param, value in self._init_kwargs["model"]["archetype_loss_params"]["dag"][ 

321 "params" 

322 ].items(): 

323 self._init_kwargs["model"]["archetype_loss_params"]["dag"]["params"][ 

324 param 

325 ] = kwargs.pop(f"archetype_{param}", value) 

326 sample_specific_dag_loss_type = kwargs.pop( 

327 "sample_specific_dag_loss_type", DEFAULT_DAG_LOSS_TYPE 

328 ) 

329 

330 # Sample-specific parameters 

331 self._init_kwargs["model"]["sample_specific_loss_params"] = { 

332 "l1": kwargs.pop("sample_specific_l1", 0.0), 

333 "dag": kwargs.pop( 

334 "sample_specific_loss_params", 

335 { 

336 "loss_type": sample_specific_dag_loss_type, 

337 "params": kwargs.pop( 

338 "sample_specific_dag_loss_params", 

339 DEFAULT_DAG_LOSS_PARAMS[sample_specific_dag_loss_type].copy(), 

340 ), 

341 }, 

342 ), 

343 } 

344 

345 # Possibly update values with convenience parameters 

346 for param, value in self._init_kwargs["model"]["sample_specific_loss_params"][ 

347 "dag" 

348 ]["params"].items(): 

349 self._init_kwargs["model"]["sample_specific_loss_params"]["dag"]["params"][ 

350 param 

351 ] = kwargs.pop(f"sample_specific_{param}", value) 

352 

353 # Optimization parameters 

354 self._init_kwargs["model"]["opt_params"] = { 

355 "learning_rate": kwargs.pop("learning_rate", 1e-3), 

356 "step": kwargs.pop("step", 50), 

357 } 

358 

359 return [ 

360 "archetype_dag_loss_type", 

361 "archetype_l1", 

362 "archetype_dag_params", 

363 "archetype_dag_loss_params", 

364 "archetype_dag_loss_type", 

365 "archetype_alpha", 

366 "archetype_rho", 

367 "archetype_s", 

368 "archetype_tol", 

369 "archetype_loss_params", 

370 "archetype_use_dynamic_alpha_rho", 

371 "init_mat", 

372 "num_factors", 

373 "factor_mat_l1", 

374 "sample_specific_dag_loss_type", 

375 "sample_specific_alpha", 

376 "sample_specific_rho", 

377 "sample_specific_s", 

378 "sample_specific_tol", 

379 "sample_specific_loss_params", 

380 "sample_specific_use_dynamic_alpha_rho", 

381 ] 

382 

383 def __init__(self, **kwargs): 

384 super().__init__( 

385 NOTMAD, 

386 extra_model_kwargs=[ 

387 "sample_specific_loss_params", 

388 "archetype_loss_params", 

389 "opt_params", 

390 ], 

391 extra_data_kwargs=[], 

392 trainer_constructor=GraphTrainer, 

393 remove_model_kwargs=[ 

394 "link_fn", 

395 "univariate", 

396 "loss_fn", 

397 "model_regularizer", 

398 ], 

399 **kwargs, 

400 ) 

401 

402 def predict_params( 

403 self, C: np.ndarray, **kwargs 

404 ) -> Union[np.ndarray, List[np.ndarray]]: 

405 """Predicts context-specific Bayesian network parameters as linear coefficients in a linear structural equation model (SEM). 

406 

407 Args: 

408 C (np.ndarray): Contextual features for each sample (n_samples, n_context_features) 

409 **kwargs: Keyword arguments for the contextualized.dags.GraphTrainer's predict_params method. 

410 

411 Returns: 

412 Union[np.ndarray, List[np.ndarray]]: The linear coefficients of the predicted context-specific Bayesian network parameters (n_samples, n_features, n_features). Returned as lists of individual bootstraps if individual_preds is True. 

413 """ 

414 # No mus for NOTMAD at present. 

415 return super().predict_params(C, model_includes_mus=False, **kwargs) 

416 

417 def predict_networks( 

418 self, C: np.ndarray, project_to_dag: bool = True, **kwargs 

419 ) -> Union[np.ndarray, List[np.ndarray]]: 

420 """Predicts context-specific Bayesian networks. 

421 

422 Args: 

423 C (np.ndarray): Contextual features for each sample (n_samples, n_context_features) 

424 project_to_dag (bool, optional): If True, guarantees returned graphs are DAGs by trimming edges until acyclicity is satisified. Defaults to True. 

425 **kwargs: Keyword arguments for the contextualized.dags.GraphTrainer's predict_params method. 

426 

427 Returns: 

428 Union[np.ndarray, List[np.ndarray]]: The linear coefficients of the predicted context-specific Bayesian network parameters (n_samples, n_features, n_features). Returned as lists of individual bootstraps if individual_preds is True. 

429 """ 

430 if kwargs.pop("with_offsets", False): 

431 print("No offsets can be returned by NOTMAD.") 

432 betas = self.predict_params( 

433 C, uses_y=False, project_to_dag=project_to_dag, **kwargs 

434 ) 

435 return betas 

436 

437 def measure_mses( 

438 self, C: np.ndarray, X: np.ndarray, individual_preds: bool = False, **kwargs 

439 ) -> Union[np.ndarray, List[np.ndarray]]: 

440 """Measures mean-squared errors. 

441 

442 Args: 

443 C (np.ndarray): Contextual features for each sample (n_samples, n_context_features) 

444 X (np.ndarray): The data matrix (n_samples, n_features) 

445 individual_preds (bool, optional): If True, returns the predictions for each bootstrap. Defaults to False. 

446 **kwargs: Keyword arguments for the contextualized.dags.GraphTrainer's predict_params method. 

447 

448 Returns: 

449 Union[np.ndarray, List[np.ndarray]]: The mean-squared errors for each sample, or for each bootstrap if individual_preds is True (n_samples). 

450 """ 

451 betas = self.predict_networks(C, individual_preds=True, **kwargs) 

452 mses = np.zeros((len(betas), len(C))) # n_bootstraps x n_samples 

453 for bootstrap in range(len(betas)): 

454 X_pred = dag_pred_np(X, betas[bootstrap]) 

455 mses[bootstrap, :] = np.mean((X - X_pred) ** 2, axis=1) 

456 if not individual_preds: 

457 mses = np.mean(mses, axis=0) 

458 return mses