Coverage for contextualized/easy/ContextualizedNetworks.py: 95%
97 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-21 13:49 -0400
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-21 13:49 -0400
1"""
2sklearn-like interface to Contextualized Networks.
3"""
5from typing import *
7import numpy as np
9from contextualized.easy.wrappers import SKLearnWrapper
10from contextualized.regression.trainers import CorrelationTrainer, MarkovTrainer
11from contextualized.regression.lightning_modules import (
12 ContextualizedCorrelation,
13 ContextualizedMarkovGraph,
14)
15from contextualized.dags.lightning_modules import (
16 NOTMAD,
17 DEFAULT_DAG_LOSS_TYPE,
18 DEFAULT_DAG_LOSS_PARAMS,
19)
20from contextualized.dags.trainers import GraphTrainer
21from contextualized.dags.graph_utils import dag_pred_np
24class ContextualizedNetworks(SKLearnWrapper):
25 """
26 sklearn-like interface to Contextualized Networks.
27 """
29 def _split_train_data(
30 self, C: np.ndarray, X: np.ndarray, **kwargs
31 ) -> Tuple[List[np.ndarray], List[np.ndarray]]:
32 """Splits data into train and test sets.
34 Args:
35 C (np.ndarray): Contextual features for each sample.
36 X (np.ndarray): The data matrix.
38 Returns:
39 Tuple[List[np.ndarray], List[np.ndarray]]: The train and test sets for C and X as ([C_train, X_train], [C_test, X_test]).
40 """
41 return super()._split_train_data(C, X, Y_required=False, **kwargs)
43 def predict_networks(
44 self,
45 C: np.ndarray,
46 with_offsets: bool = False,
47 individual_preds: bool = False,
48 **kwargs,
49 ) -> Union[
50 np.ndarray,
51 List[np.ndarray],
52 Tuple[np.ndarray, np.ndarray],
53 Tuple[List[np.ndarray], List[np.ndarray]],
54 ]:
55 """Predicts context-specific networks given contextual features.
57 Args:
58 C (np.ndarray): Contextual features for each sample (n_samples, n_context_features)
59 with_offsets (bool, optional): If True, returns both the network parameters and offsets. Defaults to False.
60 individual_preds (bool, optional): If True, returns the predictions for each bootstrap. Defaults to False.
62 Returns:
63 Union[np.ndarray, List[np.ndarray], Tuple[np.ndarray, np.ndarray], Tuple[List[np.ndarray], List[np.ndarray]]]: The predicted network parameters (and offsets if with_offsets is True). Returned as lists of individual bootstraps if individual_preds is True.
64 """
65 betas, mus = self.predict_params(
66 C, individual_preds=individual_preds, uses_y=False, **kwargs
67 )
68 if with_offsets:
69 return betas, mus
70 return betas
72 def predict_X(
73 self, C: np.ndarray, X: np.ndarray, individual_preds: bool = False, **kwargs
74 ) -> Union[np.ndarray, List[np.ndarray]]:
75 """Reconstructs the data matrix based on predicted contextualized networks and the true data matrix.
76 Useful for measuring reconstruction error or for imputation.
78 Args:
79 C (np.ndarray): Contextual features for each sample (n_samples, n_context_features)
80 X (np.ndarray): The data matrix (n_samples, n_features)
81 individual_preds (bool, optional): If True, returns the predictions for each bootstrap. Defaults to False.
82 **kwargs: Keyword arguments for the Lightning trainer's predict_y method.
84 Returns:
85 Union[np.ndarray, List[np.ndarray]]: The predicted data matrix, or matrices for each bootstrap if individual_preds is True (n_samples, n_features).
86 """
87 return self.predict(C, X, individual_preds=individual_preds, **kwargs)
90class ContextualizedCorrelationNetworks(ContextualizedNetworks):
91 """
92 Contextualized Correlation Networks reveal context-varying feature correlations, interaction strengths, dependencies in feature groups.
93 Uses the Contextualized Networks model, see the `paper <https://doi.org/10.1101/2023.12.01.569658>`__ for detailed estimation procedures.
95 Args:
96 n_bootstraps (int, optional): Number of bootstraps to use. Defaults to 1.
97 num_archetypes (int, optional): Number of archetypes to use. Defaults to 10. Always uses archetypes in the ContextualizedMetaModel.
98 encoder_type (str, optional): Type of encoder to use ("mlp", "ngam", "linear"). Defaults to "mlp".
99 alpha (float, optional): Regularization strength. Defaults to 0.0.
100 mu_ratio (float, optional): Float in range (0.0, 1.0), governs how much the regularization applies to context-specific parameters or context-specific offsets. Defaults to 0.0.
101 l1_ratio (float, optional): Float in range (0.0, 1.0), governs how much the regularization penalizes l1 vs l2 parameter norms. Defaults to 0.0.
102 """
104 def __init__(self, **kwargs):
105 super().__init__(
106 ContextualizedCorrelation, [], [], CorrelationTrainer, **kwargs
107 )
109 def predict_correlation(
110 self, C: np.ndarray, individual_preds: bool = True, squared: bool = True
111 ) -> Union[np.ndarray, List[np.ndarray]]:
112 """Predicts context-specific correlations between features.
114 Args:
115 C (Numpy ndarray): Contextual features for each sample (n_samples, n_context_features)
116 individual_preds (bool, optional): If True, returns the predictions for each bootstrap. Defaults to True.
117 squared (bool, optional): If True, returns the squared correlations. Defaults to True.
119 Returns:
120 Union[np.ndarray, List[np.ndarray]]: The predicted context-specific correlation matrices, or matrices for each bootstrap if individual_preds is True (n_samples, n_features, n_features).
121 """
122 get_dataloader = lambda i: self.models[i].dataloader(
123 C, np.zeros((len(C), self.x_dim))
124 )
125 rhos = np.array(
126 [
127 self.trainers[i].predict_params(self.models[i], get_dataloader(i))[0]
128 for i in range(len(self.models))
129 ]
130 )
131 if individual_preds:
132 if squared:
133 return np.square(rhos)
134 return rhos
135 else:
136 if squared:
137 return np.square(np.mean(rhos, axis=0))
138 return np.mean(rhos, axis=0)
140 def measure_mses(
141 self, C: np.ndarray, X: np.ndarray, individual_preds: bool = False
142 ) -> Union[np.ndarray, List[np.ndarray]]:
143 """Measures mean-squared errors.
145 Args:
146 C (np.ndarray): Contextual features for each sample (n_samples, n_context_features)
147 X (np.ndarray): The data matrix (n_samples, n_features)
148 individual_preds (bool, optional): If True, returns the predictions for each bootstrap. Defaults to False.
150 Returns:
151 Union[np.ndarray, List[np.ndarray]]: The mean-squared errors for each sample, or for each bootstrap if individual_preds is True (n_samples).
152 """
153 betas, mus = self.predict_networks(C, individual_preds=True, with_offsets=True)
154 mses = np.zeros((len(betas), len(C))) # n_bootstraps x n_samples
155 for i in range(X.shape[-1]):
156 for j in range(X.shape[-1]):
157 tiled_xi = np.array([X[:, i] for _ in range(len(betas))])
158 tiled_xj = np.array([X[:, j] for _ in range(len(betas))])
159 residuals = tiled_xi - betas[:, :, i, j] * tiled_xj - mus[:, :, i, j]
160 mses += residuals**2 / (X.shape[-1] ** 2)
161 if not individual_preds:
162 mses = np.mean(mses, axis=0)
163 return mses
166class ContextualizedMarkovNetworks(ContextualizedNetworks):
167 """
168 Contextualized Markov Networks reveal context-varying feature dependencies, cliques, and modules.
169 Implemented as Contextualized Gaussian Precision Matrices, directly interpretable as Markov Networks.
170 Uses the Contextualized Networks model, see the `paper <https://doi.org/10.1101/2023.12.01.569658>`__ for detailed estimation procedures.
172 Args:
173 n_bootstraps (int, optional): Number of bootstraps to use. Defaults to 1.
174 num_archetypes (int, optional): Number of archetypes to use. Defaults to 10. Always uses archetypes in the ContextualizedMetaModel.
175 encoder_type (str, optional): Type of encoder to use ("mlp", "ngam", "linear"). Defaults to "mlp".
176 alpha (float, optional): Regularization strength. Defaults to 0.0.
177 mu_ratio (float, optional): Float in range (0.0, 1.0), governs how much the regularization applies to context-specific parameters or context-specific offsets. Defaults to 0.0.
178 l1_ratio (float, optional): Float in range (0.0, 1.0), governs how much the regularization penalizes l1 vs l2 parameter norms. Defaults to 0.0.
179 """
181 def __init__(self, **kwargs):
182 super().__init__(ContextualizedMarkovGraph, [], [], MarkovTrainer, **kwargs)
184 def predict_precisions(
185 self, C: np.ndarray, individual_preds: bool = True
186 ) -> Union[np.ndarray, List[np.ndarray]]:
187 """Predicts context-specific precision matrices.
188 Can be converted to context-specific Markov networks by binarizing the networks and setting all non-zero entries to 1.
189 Can be converted to context-specific covariance matrices by taking the inverse.
191 Args:
192 C (np.ndarray): Contextual features for each sample (n_samples, n_context_features)
193 individual_preds (bool, optional): If True, returns the predictions for each bootstrap. Defaults to True.
195 Returns:
196 Union[np.ndarray, List[np.ndarray]]: The predicted context-specific Markov networks as precision matrices, or matrices for each bootstrap if individual_preds is True (n_samples, n_features, n_features).
197 """
198 get_dataloader = lambda i: self.models[i].dataloader(
199 C, np.zeros((len(C), self.x_dim))
200 )
201 precisions = np.array(
202 [
203 self.trainers[i].predict_precision(self.models[i], get_dataloader(i))
204 for i in range(len(self.models))
205 ]
206 )
207 if individual_preds:
208 return precisions
209 return np.mean(precisions, axis=0)
211 def measure_mses(
212 self, C: np.ndarray, X: np.ndarray, individual_preds: bool = False
213 ) -> Union[np.ndarray, List[np.ndarray]]:
214 """Measures mean-squared errors.
216 Args:
217 C (np.ndarray): Contextual features for each sample (n_samples, n_context_features)
218 X (np.ndarray): The data matrix (n_samples, n_features)
219 individual_preds (bool, optional): If True, returns the predictions for each bootstrap. Defaults to False.
221 Returns:
222 Union[np.ndarray, List[np.ndarray]]: The mean-squared errors for each sample, or for each bootstrap if individual_preds is True (n_samples).
223 """
224 betas, mus = self.predict_networks(C, individual_preds=True, with_offsets=True)
225 mses = np.zeros((len(betas), len(C))) # n_bootstraps x n_samples
226 for bootstrap in range(len(betas)):
227 for i in range(X.shape[-1]):
228 # betas are n_boostraps x n_samples x n_features x n_features
229 # preds[bootstrap, sample, i] = X[sample, :].dot(betas[bootstrap, sample, i, :])
230 preds = np.array(
231 [
232 X[j].dot(betas[bootstrap, j, i, :]) + mus[bootstrap, j, i]
233 for j in range(len(X))
234 ]
235 )
236 residuals = X[:, i] - preds
237 mses[bootstrap, :] += residuals**2 / (X.shape[-1])
238 if not individual_preds:
239 mses = np.mean(mses, axis=0)
240 return mses
243class ContextualizedBayesianNetworks(ContextualizedNetworks):
244 """
245 Contextualized Bayesian Networks and Directed Acyclic Graphs (DAGs) reveal context-dependent causal relationships, effect sizes, and variable ordering.
246 Uses the NOTMAD model, see the `paper <https://doi.org/10.48550/arXiv.2111.01104>`__ for detailed estimation procedures.
248 Args:
249 n_bootstraps (int, optional): Number of bootstraps to use. Defaults to 1.
250 num_archetypes (int, optional): Number of archetypes to use. Defaults to 16. Always uses archetypes in the ContextualizedMetaModel.
251 encoder_type (str, optional): Type of encoder to use ("mlp", "ngam", "linear"). Defaults to "mlp".
252 archetype_dag_loss_type (str, optional): The type of loss to use for the archetype loss. Defaults to "l1".
253 archetype_l1 (float, optional): The strength of the l1 regularization for the archetype loss. Defaults to 0.0.
254 archetype_dag_params (dict, optional): Parameters for the archetype loss. Defaults to {"loss_type": "l1", "params": {"alpha": 0.0, "rho": 0.0, "s": 0.0, "tol": 1e-4}}.
255 archetype_dag_loss_params (dict, optional): Parameters for the archetype loss. Defaults to {"alpha": 0.0, "rho": 0.0, "s": 0.0, "tol": 1e-4}.
256 archetype_alpha (float, optional): The strength of the alpha regularization for the archetype loss. Defaults to 0.0.
257 archetype_rho (float, optional): The strength of the rho regularization for the archetype loss. Defaults to 0.0.
258 archetype_s (float, optional): The strength of the s regularization for the archetype loss. Defaults to 0.0.
259 archetype_tol (float, optional): The tolerance for the archetype loss. Defaults to 1e-4.
260 archetype_use_dynamic_alpha_rho (bool, optional): Whether to use dynamic alpha and rho for the archetype loss. Defaults to False.
261 init_mat (np.ndarray, optional): The initial adjacency matrix for the archetype loss. Defaults to None.
262 num_factors (int, optional): The number of factors for the archetype loss. Defaults to 0.
263 factor_mat_l1 (float, optional): The strength of the l1 regularization for the factor matrix for the archetype loss. Defaults to 0.
264 sample_specific_dag_loss_type (str, optional): The type of loss to use for the sample-specific loss. Defaults to "l1".
265 sample_specific_alpha (float, optional): The strength of the alpha regularization for the sample-specific loss. Defaults to 0.0.
266 sample_specific_rho (float, optional): The strength of the rho regularization for the sample-specific loss. Defaults to 0.0.
267 sample_specific_s (float, optional): The strength of the s regularization for the sample-specific loss. Defaults to 0.0.
268 sample_specific_tol (float, optional): The tolerance for the sample-specific loss. Defaults to 1e-4.
269 sample_specific_use_dynamic_alpha_rho (bool, optional): Whether to use dynamic alpha and rho for the sample-specific loss. Defaults to False.
270 """
272 def _parse_private_init_kwargs(self, **kwargs):
273 """
274 Parses the kwargs for the NOTMAD model.
276 Args:
277 **kwargs: Keyword arguments for the NOTMAD model, including the encoder, archetype loss, sample-specific loss, and optimization parameters.
278 """
279 # Encoder Parameters
280 self._init_kwargs["model"]["encoder_kwargs"] = {
281 "type": kwargs.pop(
282 "encoder_type", self._init_kwargs["model"]["encoder_type"]
283 ),
284 "params": {
285 "width": self.constructor_kwargs["encoder_kwargs"]["width"],
286 "layers": self.constructor_kwargs["encoder_kwargs"]["layers"],
287 "link_fn": self.constructor_kwargs["encoder_kwargs"]["link_fn"],
288 },
289 }
291 # Archetype-specific parameters
292 archetype_dag_loss_type = kwargs.pop(
293 "archetype_dag_loss_type", DEFAULT_DAG_LOSS_TYPE
294 )
295 self._init_kwargs["model"]["archetype_loss_params"] = {
296 "l1": kwargs.get("archetype_l1", 0.0),
297 "dag": kwargs.get(
298 "archetype_dag_params",
299 {
300 "loss_type": archetype_dag_loss_type,
301 "params": kwargs.get(
302 "archetype_dag_loss_params",
303 DEFAULT_DAG_LOSS_PARAMS[archetype_dag_loss_type].copy(),
304 ),
305 },
306 ),
307 "init_mat": kwargs.pop("init_mat", None),
308 "num_factors": kwargs.pop("num_factors", 0),
309 "factor_mat_l1": kwargs.pop("factor_mat_l1", 0),
310 "num_archetypes": kwargs.pop("num_archetypes", 16),
311 }
313 if self._init_kwargs["model"]["archetype_loss_params"]["num_archetypes"] <= 0:
314 print(
315 "WARNING: num_archetypes is 0. NOTMAD requires archetypes. Setting num_archetypes to 16."
316 )
317 self._init_kwargs["model"]["archetype_loss_params"]["num_archetypes"] = 16
319 # Possibly update values with convenience parameters
320 for param, value in self._init_kwargs["model"]["archetype_loss_params"]["dag"][
321 "params"
322 ].items():
323 self._init_kwargs["model"]["archetype_loss_params"]["dag"]["params"][
324 param
325 ] = kwargs.pop(f"archetype_{param}", value)
326 sample_specific_dag_loss_type = kwargs.pop(
327 "sample_specific_dag_loss_type", DEFAULT_DAG_LOSS_TYPE
328 )
330 # Sample-specific parameters
331 self._init_kwargs["model"]["sample_specific_loss_params"] = {
332 "l1": kwargs.pop("sample_specific_l1", 0.0),
333 "dag": kwargs.pop(
334 "sample_specific_loss_params",
335 {
336 "loss_type": sample_specific_dag_loss_type,
337 "params": kwargs.pop(
338 "sample_specific_dag_loss_params",
339 DEFAULT_DAG_LOSS_PARAMS[sample_specific_dag_loss_type].copy(),
340 ),
341 },
342 ),
343 }
345 # Possibly update values with convenience parameters
346 for param, value in self._init_kwargs["model"]["sample_specific_loss_params"][
347 "dag"
348 ]["params"].items():
349 self._init_kwargs["model"]["sample_specific_loss_params"]["dag"]["params"][
350 param
351 ] = kwargs.pop(f"sample_specific_{param}", value)
353 # Optimization parameters
354 self._init_kwargs["model"]["opt_params"] = {
355 "learning_rate": kwargs.pop("learning_rate", 1e-3),
356 "step": kwargs.pop("step", 50),
357 }
359 return [
360 "archetype_dag_loss_type",
361 "archetype_l1",
362 "archetype_dag_params",
363 "archetype_dag_loss_params",
364 "archetype_dag_loss_type",
365 "archetype_alpha",
366 "archetype_rho",
367 "archetype_s",
368 "archetype_tol",
369 "archetype_loss_params",
370 "archetype_use_dynamic_alpha_rho",
371 "init_mat",
372 "num_factors",
373 "factor_mat_l1",
374 "sample_specific_dag_loss_type",
375 "sample_specific_alpha",
376 "sample_specific_rho",
377 "sample_specific_s",
378 "sample_specific_tol",
379 "sample_specific_loss_params",
380 "sample_specific_use_dynamic_alpha_rho",
381 ]
383 def __init__(self, **kwargs):
384 super().__init__(
385 NOTMAD,
386 extra_model_kwargs=[
387 "sample_specific_loss_params",
388 "archetype_loss_params",
389 "opt_params",
390 ],
391 extra_data_kwargs=[],
392 trainer_constructor=GraphTrainer,
393 remove_model_kwargs=[
394 "link_fn",
395 "univariate",
396 "loss_fn",
397 "model_regularizer",
398 ],
399 **kwargs,
400 )
402 def predict_params(
403 self, C: np.ndarray, **kwargs
404 ) -> Union[np.ndarray, List[np.ndarray]]:
405 """Predicts context-specific Bayesian network parameters as linear coefficients in a linear structural equation model (SEM).
407 Args:
408 C (np.ndarray): Contextual features for each sample (n_samples, n_context_features)
409 **kwargs: Keyword arguments for the contextualized.dags.GraphTrainer's predict_params method.
411 Returns:
412 Union[np.ndarray, List[np.ndarray]]: The linear coefficients of the predicted context-specific Bayesian network parameters (n_samples, n_features, n_features). Returned as lists of individual bootstraps if individual_preds is True.
413 """
414 # No mus for NOTMAD at present.
415 return super().predict_params(C, model_includes_mus=False, **kwargs)
417 def predict_networks(
418 self, C: np.ndarray, project_to_dag: bool = True, **kwargs
419 ) -> Union[np.ndarray, List[np.ndarray]]:
420 """Predicts context-specific Bayesian networks.
422 Args:
423 C (np.ndarray): Contextual features for each sample (n_samples, n_context_features)
424 project_to_dag (bool, optional): If True, guarantees returned graphs are DAGs by trimming edges until acyclicity is satisified. Defaults to True.
425 **kwargs: Keyword arguments for the contextualized.dags.GraphTrainer's predict_params method.
427 Returns:
428 Union[np.ndarray, List[np.ndarray]]: The linear coefficients of the predicted context-specific Bayesian network parameters (n_samples, n_features, n_features). Returned as lists of individual bootstraps if individual_preds is True.
429 """
430 if kwargs.pop("with_offsets", False):
431 print("No offsets can be returned by NOTMAD.")
432 betas = self.predict_params(
433 C, uses_y=False, project_to_dag=project_to_dag, **kwargs
434 )
435 return betas
437 def measure_mses(
438 self, C: np.ndarray, X: np.ndarray, individual_preds: bool = False, **kwargs
439 ) -> Union[np.ndarray, List[np.ndarray]]:
440 """Measures mean-squared errors.
442 Args:
443 C (np.ndarray): Contextual features for each sample (n_samples, n_context_features)
444 X (np.ndarray): The data matrix (n_samples, n_features)
445 individual_preds (bool, optional): If True, returns the predictions for each bootstrap. Defaults to False.
446 **kwargs: Keyword arguments for the contextualized.dags.GraphTrainer's predict_params method.
448 Returns:
449 Union[np.ndarray, List[np.ndarray]]: The mean-squared errors for each sample, or for each bootstrap if individual_preds is True (n_samples).
450 """
451 betas = self.predict_networks(C, individual_preds=True, **kwargs)
452 mses = np.zeros((len(betas), len(C))) # n_bootstraps x n_samples
453 for bootstrap in range(len(betas)):
454 X_pred = dag_pred_np(X, betas[bootstrap])
455 mses[bootstrap, :] = np.mean((X - X_pred) ** 2, axis=1)
456 if not individual_preds:
457 mses = np.mean(mses, axis=0)
458 return mses