Coverage for contextualized/analysis/embeddings.py: 11%

70 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-21 13:49 -0400

1""" 

2Utilities for plotting embeddings of fitted Contextualized models. 

3""" 

4 

5from typing import * 

6 

7import numpy as np 

8import pandas as pd 

9import matplotlib.pyplot as plt 

10import matplotlib as mpl 

11 

12 

13def convert_to_one_hot(col: Collection[Any]) -> Tuple[np.ndarray, List[Any]]: 

14 """ 

15 Converts a categorical variable to a one-hot vector. 

16 

17 Args: 

18 col (Collection[Any]): The categorical variable. 

19 

20 Returns: 

21 Tuple[np.ndarray, List[Any]]: The one-hot vector and the possible values. 

22 """ 

23 vals = list(set(col)) 

24 one_hot_vars = np.array([vals.index(x) for x in col], dtype=np.float32) 

25 return one_hot_vars, vals 

26 

27 

28def plot_embedding_for_all_covars( 

29 reps: np.ndarray, 

30 covars_df: pd.DataFrame, 

31 covars_stds: np.ndarray = None, 

32 covars_means: np.ndarray = None, 

33 covars_encoders: List[Callable] = None, 

34 **kwargs, 

35) -> None: 

36 """ 

37 Plot embeddings of representations for all covariates in a Pandas dataframe. 

38 

39 Args: 

40 reps (np.ndarray): Embeddings of shape (n_samples, n_dims). 

41 covars_df (pd.DataFrame): DataFrame of covariates. 

42 covars_stds (np.ndarray, optional): Standard deviations of covariates. Defaults to None. 

43 covars_means (np.ndarray, optional): Means of covariates. Defaults to None. 

44 covars_encoders (List[LabelEncoder], optional): Encoders for covariates. Defaults to None. 

45 kwargs: Keyword arguments for plotting. 

46 

47 Returns: 

48 None 

49 """ 

50 for i, covar in enumerate(covars_df.columns): 

51 my_labels = covars_df.iloc[:, i].values 

52 if covars_stds is not None: 

53 my_labels *= covars_stds 

54 if covars_means is not None: 

55 my_labels += covars_means 

56 if covars_encoders is not None: 

57 my_labels = covars_encoders[i].inverse_transform(my_labels.astype(int)) 

58 if kwargs.get("dithering_pct", 0.0) > 0: 

59 reps[:, 0] += np.random.normal( 

60 0, kwargs["dithering_pct"] * np.std(reps[:, 0]), size=reps[:, 0].shape 

61 ) 

62 reps[:, 1] += np.random.normal( 

63 0, kwargs["dithering_pct"] * np.std(reps[:, 1]), size=reps[:, 1].shape 

64 ) 

65 try: 

66 plot_lowdim_rep( 

67 reps[:, :2], 

68 my_labels, 

69 cbar_label=covar, 

70 **kwargs, 

71 ) 

72 except TypeError: 

73 print(f"Error with covar {covar}") 

74 

75 

76def plot_lowdim_rep( 

77 low_dim: np.ndarray, 

78 labels: np.ndarray, 

79 **kwargs, 

80): 

81 """ 

82 Plot a low-dimensional representation of a dataset. 

83 

84 Args: 

85 low_dim (np.ndarray): Low-dimensional representation of shape (n_samples, 2). 

86 labels (np.ndarray): Labels of shape (n_samples,). 

87 kwargs: Keyword arguments for plotting. 

88 

89 Returns: 

90 None 

91 """ 

92 

93 if len(set(labels)) < kwargs.get("max_classes_for_discrete", 10): # discrete labels 

94 discrete = True 

95 cmap = plt.cm.jet 

96 else: 

97 discrete = False 

98 tag = labels 

99 norm = None 

100 cmap = plt.cm.coolwarm 

101 fig = plt.figure(figsize=kwargs.get("figsize", (12, 12))) 

102 if discrete: 

103 cmap = mpl.colors.LinearSegmentedColormap.from_list( 

104 "Custom cmap", [cmap(i) for i in range(cmap.N)], cmap.N 

105 ) 

106 tag, tag_names = convert_to_one_hot(labels) 

107 order = np.argsort(tag_names) 

108 tag_names = np.array(tag_names)[order] 

109 tag = np.array([list(order).index(int(x)) for x in tag]) 

110 good_tags = [ 

111 np.sum(tag == i) > kwargs.get("min_samples", 0) 

112 for i in range(len(tag_names)) 

113 ] 

114 tag_names = np.array(tag_names)[good_tags] 

115 good_idxs = np.array([good_tags[int(tag[i])] for i in range(len(tag))]) 

116 tag = tag[good_idxs] 

117 tag, _ = convert_to_one_hot(tag) 

118 bounds = np.linspace(0, len(tag_names), len(tag_names) + 1) 

119 try: 

120 norm = mpl.colors.BoundaryNorm(bounds, cmap.N) 

121 except ValueError: 

122 print( 

123 "Not enough values for a colorbar (needs at least 2 values), quitting." 

124 ) 

125 return 

126 plt.scatter( 

127 low_dim[good_idxs, 0], 

128 low_dim[good_idxs, 1], 

129 c=tag, 

130 alpha=kwargs.get("alpha", 1.0), 

131 s=100, 

132 cmap=cmap, 

133 norm=norm, 

134 ) 

135 else: 

136 plt.scatter( 

137 low_dim[:, 0], 

138 low_dim[:, 1], 

139 c=labels, 

140 alpha=kwargs.get("alpha", 1.0), 

141 s=100, 

142 cmap=cmap, 

143 ) 

144 plt.xlabel(kwargs.get("xlabel", "X"), fontsize=kwargs.get("xlabel_fontsize", 48)) 

145 plt.ylabel(kwargs.get("ylabel", "Y"), fontsize=kwargs.get("ylabel_fontsize", 48)) 

146 plt.xticks([]) 

147 plt.yticks([]) 

148 plt.title(kwargs.get("title", ""), fontsize=kwargs.get("title_fontsize", 52)) 

149 

150 # create a second axes for the colorbar 

151 ax2 = fig.add_axes([0.95, 0.15, 0.03, 0.7]) 

152 if discrete: 

153 color_bar = mpl.colorbar.ColorbarBase( 

154 ax2, 

155 cmap=cmap, 

156 norm=norm, 

157 spacing="proportional", 

158 ticks=bounds[:-1] + 0.5, # boundaries=bounds, 

159 format="%1i", 

160 ) 

161 try: 

162 color_bar.ax.set(yticks=bounds[:-1] + 0.5, yticklabels=np.round(tag_names)) 

163 except ValueError: 

164 color_bar.ax.set(yticks=bounds[:-1] + 0.5, yticklabels=tag_names) 

165 else: 

166 color_bar = mpl.colorbar.ColorbarBase(ax2, cmap=cmap, format="%.1f") 

167 if kwargs.get("cbar_label", None) is not None: 

168 color_bar.ax.set_ylabel( 

169 kwargs["cbar_label"], fontsize=kwargs.get("cbar_fontsize", 32) 

170 ) 

171 if "figname" in kwargs: 

172 plt.savefig(f"{kwargs['figname']}.pdf", dpi=300, bbox_inches="tight")