Coverage for contextualized/analysis/accuracy_split.py: 15%

39 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-21 13:49 -0400

1""" 

2Utilities for post-hoc analysis of trained Contextualized models. 

3""" 

4 

5from typing import * 

6 

7import numpy as np 

8import pandas as pd 

9from sklearn.metrics import roc_auc_score as roc 

10 

11 

12def get_roc(Y_true: np.ndarray, Y_pred: np.ndarray) -> float: 

13 """Measures ROC. Return np.nan if no valid ROC value.""" 

14 try: 

15 return roc(Y_true, Y_pred) 

16 except (IndexError, ValueError): 

17 return np.nan 

18 

19 

20def print_acc_by_covars( 

21 Y_true: np.ndarray, Y_pred: np.ndarray, covar_df: pd.DataFrame, **kwargs 

22) -> None: 

23 """ 

24 Prints AUROC for each class for different covariate splits. Should only be used with ContextualizedClassifier. 

25 

26 Args: 

27 Y_true (np.ndarray): True labels. 

28 Y_pred (np.ndarray): Predicted labels. 

29 covar_df (pd.DataFrame): DataFrame of covariates. 

30 max_classes (int, optional): Maximum number of classes to print. Defaults to 20. 

31 covar_stds (np.ndarray, optional): Standard deviations of covariates. Defaults to None. 

32 covar_means (np.ndarray, optional): Means of covariates. Defaults to None. 

33 covar_encoders (List[LabelEncoder], optional): Encoders for covariates. Defaults to None. 

34 train_idx (np.ndarray, optional): Boolean array indicating training data. Defaults to None. 

35 test_idx (np.ndarray, optional): Boolean array indicating testing data. Defaults to None. 

36 

37 Returns: 

38 None 

39 """ 

40 Y_true = np.squeeze(Y_true) 

41 Y_pred = np.squeeze(Y_pred) 

42 for i, covar in enumerate(covar_df.columns): 

43 my_labels = covar_df.values[:, i] 

44 if len(set(my_labels)) > kwargs.get("max_classes", 20): 

45 continue 

46 if kwargs.get("covar_stds", None) is not None: 

47 my_labels *= kwargs["covar_stds"][i] 

48 if kwargs.get("covar_means", None) is not None: 

49 my_labels += kwargs["covar_means"][i] 

50 if kwargs.get("covar_encoders", None) is not None: 

51 try: 

52 my_labels = kwargs["covar_encoders"][i].inverse_transform( 

53 my_labels.astype(int) 

54 ) 

55 except (AttributeError, TypeError, ValueError): 

56 pass 

57 print("=" * 20) 

58 print(covar) 

59 print("-" * 10) 

60 

61 for label in sorted(set(my_labels)): 

62 label_idxs = my_labels == label 

63 if ( 

64 kwargs.get("train_idx", None) is not None 

65 and kwargs.get("test_idx", None) is not None 

66 ): 

67 my_train_idx = np.logical_and(label_idxs, kwargs["train_idx"]) 

68 my_test_idx = np.logical_and(label_idxs, kwargs["test_idx"]) 

69 train_roc = get_roc(Y_true[my_train_idx], Y_pred[my_train_idx]) 

70 test_roc = get_roc(Y_true[my_test_idx], Y_pred[my_test_idx]) 

71 print( 

72 f"{label}:\t Train ROC: {train_roc:.2f}, Test ROC: {test_roc:.2f}" 

73 ) 

74 else: 

75 overall_roc = get_roc(Y_true[label_idxs], Y_pred[label_idxs]) 

76 print(f"{label}:\t ROC: {overall_roc:.2f}") 

77 print("=" * 20)