Coverage for contextualized/easy/ContextualizedClassifier.py: 100%
14 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-21 13:38 -0400
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-21 13:38 -0400
1"""
2sklearn-like interface to Contextualized Classifiers.
3"""
5import numpy as np
7from contextualized.functions import LINK_FUNCTIONS
8from contextualized.easy import ContextualizedRegressor
9from contextualized.regression import LOSSES
12class ContextualizedClassifier(ContextualizedRegressor):
13 """
14 Contextualized Logistic Regression reveals context-dependent decisions and decision boundaries.
15 Implemented as a ContextualizedRegressor with logistic link function and binary cross-entropy loss.
17 Args:
18 n_bootstraps (int, optional): Number of bootstraps to use. Defaults to 1.
19 num_archetypes (int, optional): Number of archetypes to use. Defaults to 0, which used the NaiveMetaModel. If > 0, uses archetypes in the ContextualizedMetaModel.
20 encoder_type (str, optional): Type of encoder to use ("mlp", "ngam", "linear"). Defaults to "mlp".
21 alpha (float, optional): Regularization strength. Defaults to 0.0.
22 mu_ratio (float, optional): Float in range (0.0, 1.0), governs how much the regularization applies to context-specific parameters or context-specific offsets. Defaults to 0.0.
23 l1_ratio (float, optional): Float in range (0.0, 1.0), governs how much the regularization penalizes l1 vs l2 parameter norms. Defaults to 0.0.
24 """
26 def __init__(self, **kwargs):
27 kwargs["link_fn"] = LINK_FUNCTIONS["logistic"]
28 kwargs["loss_fn"] = LOSSES["bceloss"]
29 super().__init__(**kwargs)
31 def predict(self, C, X, individual_preds=False, **kwargs):
32 """Predict binary outcomes from context C and predictors X.
34 Args:
35 C (np.ndarray): Context array of shape (n_samples, n_context_features)
36 X (np.ndarray): Predictor array of shape (N, n_features)
37 individual_preds (bool, optional): Whether to return individual predictions for each model. Defaults to False.
39 Returns:
40 Union[np.ndarray, List[np.ndarray]]: The binary outcomes predicted by the context-specific models (n_samples, y_dim). Returned as lists of individual bootstraps if individual_preds is True.
41 """
42 return np.round(super().predict(C, X, individual_preds, **kwargs))
44 def predict_proba(self, C, X, **kwargs):
45 """
46 Predict probabilities of outcomes from context C and predictors X.
48 Args:
49 C (np.ndarray): Context array of shape (n_samples, n_context_features)
50 X (np.ndarray): Predictor array of shape (N, n_features)
51 individual_preds (bool, optional): Whether to return individual predictions for each model. Defaults to False.
53 Returns:
54 Union[np.ndarray, List[np.ndarray]]: The outcome probabilities predicted by the context-specific models (n_samples, y_dim, 2). Returned as lists of individual bootstraps if individual_preds is True.
55 """
56 # Returns a np array of shape N samples, K outcomes, 2.
57 probs = super().predict(C, X, **kwargs)
58 return np.array([1 - probs, probs]).T.swapaxes(0, 1)