Coverage for contextualized/easy/ContextualizedRegressor.py: 94%
17 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-21 13:49 -0400
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-21 13:49 -0400
1"""
2sklearn-like interface to Contextualized Regressors.
3"""
5from contextualized.regression import (
6 NaiveContextualizedRegression,
7 ContextualizedRegression,
8)
9from contextualized.easy.wrappers import SKLearnWrapper
10from contextualized.regression import RegressionTrainer
12# TODO: Multitask metamodels
13# TODO: Task-specific link functions.
16class ContextualizedRegressor(SKLearnWrapper):
17 """
18 Contextualized Linear Regression quantifies context-varying linear relationships.
20 Args:
21 n_bootstraps (int, optional): Number of bootstraps to use. Defaults to 1.
22 num_archetypes (int, optional): Number of archetypes to use. Defaults to 0, which used the NaiveMetaModel. If > 0, uses archetypes in the ContextualizedMetaModel.
23 encoder_type (str, optional): Type of encoder to use ("mlp", "ngam", "linear"). Defaults to "mlp".
24 loss_fn (torch.nn.Module, optional): Loss function. Defaults to LOSSES["mse"].
25 link_fn (torch.nn.Module, optional): Link function. Defaults to LINK_FUNCTIONS["identity"].
26 alpha (float, optional): Regularization strength. Defaults to 0.0.
27 mu_ratio (float, optional): Float in range (0.0, 1.0), governs how much the regularization applies to context-specific parameters or context-specific offsets. Defaults to 0.0.
28 l1_ratio (float, optional): Float in range (0.0, 1.0), governs how much the regularization penalizes l1 vs l2 parameter norms. Defaults to 0.0.
29 """
31 def __init__(self, **kwargs):
32 self.num_archetypes = kwargs.get("num_archetypes", 0)
33 if self.num_archetypes == 0:
34 constructor = NaiveContextualizedRegression
35 elif self.num_archetypes > 0:
36 constructor = ContextualizedRegression
37 else:
38 print(
39 f"""
40 Was told to construct a ContextualizedRegressor with {self.num_archetypes}
41 archetypes, but this should be a non-negative integer."""
42 )
44 extra_model_kwargs = ["base_param_predictor", "base_y_predictor", "y_dim"]
45 extra_data_kwargs = ["Y_val"]
46 trainer_constructor = RegressionTrainer
47 super().__init__(
48 constructor,
49 extra_model_kwargs,
50 extra_data_kwargs,
51 trainer_constructor,
52 **kwargs,
53 )
55 def _split_train_data(self, C, X, Y=None, Y_required=False, **kwargs):
56 return super()._split_train_data(C, X, Y, Y_required=True, **kwargs)