Coverage for contextualized/easy/ContextualizedRegressor.py: 94%

17 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-21 13:49 -0400

1""" 

2sklearn-like interface to Contextualized Regressors. 

3""" 

4 

5from contextualized.regression import ( 

6 NaiveContextualizedRegression, 

7 ContextualizedRegression, 

8) 

9from contextualized.easy.wrappers import SKLearnWrapper 

10from contextualized.regression import RegressionTrainer 

11 

12# TODO: Multitask metamodels 

13# TODO: Task-specific link functions. 

14 

15 

16class ContextualizedRegressor(SKLearnWrapper): 

17 """ 

18 Contextualized Linear Regression quantifies context-varying linear relationships. 

19 

20 Args: 

21 n_bootstraps (int, optional): Number of bootstraps to use. Defaults to 1. 

22 num_archetypes (int, optional): Number of archetypes to use. Defaults to 0, which used the NaiveMetaModel. If > 0, uses archetypes in the ContextualizedMetaModel. 

23 encoder_type (str, optional): Type of encoder to use ("mlp", "ngam", "linear"). Defaults to "mlp". 

24 loss_fn (torch.nn.Module, optional): Loss function. Defaults to LOSSES["mse"]. 

25 link_fn (torch.nn.Module, optional): Link function. Defaults to LINK_FUNCTIONS["identity"]. 

26 alpha (float, optional): Regularization strength. Defaults to 0.0. 

27 mu_ratio (float, optional): Float in range (0.0, 1.0), governs how much the regularization applies to context-specific parameters or context-specific offsets. Defaults to 0.0. 

28 l1_ratio (float, optional): Float in range (0.0, 1.0), governs how much the regularization penalizes l1 vs l2 parameter norms. Defaults to 0.0. 

29 """ 

30 

31 def __init__(self, **kwargs): 

32 self.num_archetypes = kwargs.get("num_archetypes", 0) 

33 if self.num_archetypes == 0: 

34 constructor = NaiveContextualizedRegression 

35 elif self.num_archetypes > 0: 

36 constructor = ContextualizedRegression 

37 else: 

38 print( 

39 f""" 

40 Was told to construct a ContextualizedRegressor with {self.num_archetypes} 

41 archetypes, but this should be a non-negative integer.""" 

42 ) 

43 

44 extra_model_kwargs = ["base_param_predictor", "base_y_predictor", "y_dim"] 

45 extra_data_kwargs = ["Y_val"] 

46 trainer_constructor = RegressionTrainer 

47 super().__init__( 

48 constructor, 

49 extra_model_kwargs, 

50 extra_data_kwargs, 

51 trainer_constructor, 

52 **kwargs, 

53 ) 

54 

55 def _split_train_data(self, C, X, Y=None, Y_required=False, **kwargs): 

56 return super()._split_train_data(C, X, Y, Y_required=True, **kwargs)