Coverage for contextualized/easy/ContextualGAM.py: 100%

9 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-21 13:38 -0400

1""" 

2Contextual Generalized Additive Model. 

3See https://www.sciencedirect.com/science/article/pii/S1532046422001022 

4for more details. 

5""" 

6 

7from contextualized.easy import ContextualizedClassifier, ContextualizedRegressor 

8 

9 

10class ContextualGAMClassifier(ContextualizedClassifier): 

11 """ 

12 The Contextual GAM Classifier separates and interprets the effect of context in context-varying decisions and classifiers, such as heterogeneous disease diagnoses. 

13 Implemented as a Contextual Generalized Additive Model with a classifier on top. 

14 Always uses a Neural Additive Model ("ngam") encoder for interpretability. 

15 See `this paper <https://www.sciencedirect.com/science/article/pii/S1532046422001022>`__ 

16 for more details. 

17 

18 Args: 

19 n_bootstraps (int, optional): Number of bootstraps to use. Defaults to 1. 

20 num_archetypes (int, optional): Number of archetypes to use. Defaults to 0, which used the NaiveMetaModel. If > 0, uses archetypes in the ContextualizedMetaModel. 

21 alpha (float, optional): Regularization strength. Defaults to 0.0. 

22 mu_ratio (float, optional): Float in range (0.0, 1.0), governs how much the regularization applies to context-specific parameters or context-specific offsets. Defaults to 0.0. 

23 l1_ratio (float, optional): Float in range (0.0, 1.0), governs how much the regularization penalizes l1 vs l2 parameter norms. Defaults to 0.0. 

24 """ 

25 

26 def __init__(self, **kwargs): 

27 kwargs["encoder_type"] = "ngam" 

28 super().__init__(**kwargs) 

29 

30 

31class ContextualGAMRegressor(ContextualizedRegressor): 

32 """ 

33 The Contextual GAM Regressor separates and interprets the effect of context in context-varying relationships, such as heterogeneous treatment effects. 

34 Implemented as a Contextual Generalized Additive Model with a linear regressor on top. 

35 Always uses a Neural Additive Model ("ngam") encoder for interpretability. 

36 See `this paper <https://www.sciencedirect.com/science/article/pii/S1532046422001022>`__ 

37 for more details. 

38 

39 Args: 

40 n_bootstraps (int, optional): Number of bootstraps to use. Defaults to 1. 

41 num_archetypes (int, optional): Number of archetypes to use. Defaults to 0, which used the NaiveMetaModel. If > 0, uses archetypes in the ContextualizedMetaModel. 

42 alpha (float, optional): Regularization strength. Defaults to 0.0. 

43 mu_ratio (float, optional): Float in range (0.0, 1.0), governs how much the regularization applies to context-specific parameters or context-specific offsets. Defaults to 0.0. 

44 l1_ratio (float, optional): Float in range (0.0, 1.0), governs how much the regularization penalizes l1 vs l2 parameter norms. Defaults to 0.0. 

45 """ 

46 

47 def __init__(self, **kwargs): 

48 kwargs["encoder_type"] = "ngam" 

49 super().__init__(**kwargs)