Coverage for contextualized/easy/ContextualizedClassifier.py: 100%

14 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-21 13:38 -0400

1""" 

2sklearn-like interface to Contextualized Classifiers. 

3""" 

4 

5import numpy as np 

6 

7from contextualized.functions import LINK_FUNCTIONS 

8from contextualized.easy import ContextualizedRegressor 

9from contextualized.regression import LOSSES 

10 

11 

12class ContextualizedClassifier(ContextualizedRegressor): 

13 """ 

14 Contextualized Logistic Regression reveals context-dependent decisions and decision boundaries. 

15 Implemented as a ContextualizedRegressor with logistic link function and binary cross-entropy loss. 

16 

17 Args: 

18 n_bootstraps (int, optional): Number of bootstraps to use. Defaults to 1. 

19 num_archetypes (int, optional): Number of archetypes to use. Defaults to 0, which used the NaiveMetaModel. If > 0, uses archetypes in the ContextualizedMetaModel. 

20 encoder_type (str, optional): Type of encoder to use ("mlp", "ngam", "linear"). Defaults to "mlp". 

21 alpha (float, optional): Regularization strength. Defaults to 0.0. 

22 mu_ratio (float, optional): Float in range (0.0, 1.0), governs how much the regularization applies to context-specific parameters or context-specific offsets. Defaults to 0.0. 

23 l1_ratio (float, optional): Float in range (0.0, 1.0), governs how much the regularization penalizes l1 vs l2 parameter norms. Defaults to 0.0. 

24 """ 

25 

26 def __init__(self, **kwargs): 

27 kwargs["link_fn"] = LINK_FUNCTIONS["logistic"] 

28 kwargs["loss_fn"] = LOSSES["bceloss"] 

29 super().__init__(**kwargs) 

30 

31 def predict(self, C, X, individual_preds=False, **kwargs): 

32 """Predict binary outcomes from context C and predictors X. 

33 

34 Args: 

35 C (np.ndarray): Context array of shape (n_samples, n_context_features) 

36 X (np.ndarray): Predictor array of shape (N, n_features) 

37 individual_preds (bool, optional): Whether to return individual predictions for each model. Defaults to False. 

38 

39 Returns: 

40 Union[np.ndarray, List[np.ndarray]]: The binary outcomes predicted by the context-specific models (n_samples, y_dim). Returned as lists of individual bootstraps if individual_preds is True. 

41 """ 

42 return np.round(super().predict(C, X, individual_preds, **kwargs)) 

43 

44 def predict_proba(self, C, X, **kwargs): 

45 """ 

46 Predict probabilities of outcomes from context C and predictors X. 

47 

48 Args: 

49 C (np.ndarray): Context array of shape (n_samples, n_context_features) 

50 X (np.ndarray): Predictor array of shape (N, n_features) 

51 individual_preds (bool, optional): Whether to return individual predictions for each model. Defaults to False. 

52 

53 Returns: 

54 Union[np.ndarray, List[np.ndarray]]: The outcome probabilities predicted by the context-specific models (n_samples, y_dim, 2). Returned as lists of individual bootstraps if individual_preds is True. 

55 """ 

56 # Returns a np array of shape N samples, K outcomes, 2. 

57 probs = super().predict(C, X, **kwargs) 

58 return np.array([1 - probs, probs]).T.swapaxes(0, 1)