Coverage for contextualized/easy/ContextualGAM.py: 100%
9 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-21 13:38 -0400
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-21 13:38 -0400
1"""
2Contextual Generalized Additive Model.
3See https://www.sciencedirect.com/science/article/pii/S1532046422001022
4for more details.
5"""
7from contextualized.easy import ContextualizedClassifier, ContextualizedRegressor
10class ContextualGAMClassifier(ContextualizedClassifier):
11 """
12 The Contextual GAM Classifier separates and interprets the effect of context in context-varying decisions and classifiers, such as heterogeneous disease diagnoses.
13 Implemented as a Contextual Generalized Additive Model with a classifier on top.
14 Always uses a Neural Additive Model ("ngam") encoder for interpretability.
15 See `this paper <https://www.sciencedirect.com/science/article/pii/S1532046422001022>`__
16 for more details.
18 Args:
19 n_bootstraps (int, optional): Number of bootstraps to use. Defaults to 1.
20 num_archetypes (int, optional): Number of archetypes to use. Defaults to 0, which used the NaiveMetaModel. If > 0, uses archetypes in the ContextualizedMetaModel.
21 alpha (float, optional): Regularization strength. Defaults to 0.0.
22 mu_ratio (float, optional): Float in range (0.0, 1.0), governs how much the regularization applies to context-specific parameters or context-specific offsets. Defaults to 0.0.
23 l1_ratio (float, optional): Float in range (0.0, 1.0), governs how much the regularization penalizes l1 vs l2 parameter norms. Defaults to 0.0.
24 """
26 def __init__(self, **kwargs):
27 kwargs["encoder_type"] = "ngam"
28 super().__init__(**kwargs)
31class ContextualGAMRegressor(ContextualizedRegressor):
32 """
33 The Contextual GAM Regressor separates and interprets the effect of context in context-varying relationships, such as heterogeneous treatment effects.
34 Implemented as a Contextual Generalized Additive Model with a linear regressor on top.
35 Always uses a Neural Additive Model ("ngam") encoder for interpretability.
36 See `this paper <https://www.sciencedirect.com/science/article/pii/S1532046422001022>`__
37 for more details.
39 Args:
40 n_bootstraps (int, optional): Number of bootstraps to use. Defaults to 1.
41 num_archetypes (int, optional): Number of archetypes to use. Defaults to 0, which used the NaiveMetaModel. If > 0, uses archetypes in the ContextualizedMetaModel.
42 alpha (float, optional): Regularization strength. Defaults to 0.0.
43 mu_ratio (float, optional): Float in range (0.0, 1.0), governs how much the regularization applies to context-specific parameters or context-specific offsets. Defaults to 0.0.
44 l1_ratio (float, optional): Float in range (0.0, 1.0), governs how much the regularization penalizes l1 vs l2 parameter norms. Defaults to 0.0.
45 """
47 def __init__(self, **kwargs):
48 kwargs["encoder_type"] = "ngam"
49 super().__init__(**kwargs)