Coverage for contextualized/analysis/embeddings.py: 11%
70 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-21 13:49 -0400
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-21 13:49 -0400
1"""
2Utilities for plotting embeddings of fitted Contextualized models.
3"""
5from typing import *
7import numpy as np
8import pandas as pd
9import matplotlib.pyplot as plt
10import matplotlib as mpl
13def convert_to_one_hot(col: Collection[Any]) -> Tuple[np.ndarray, List[Any]]:
14 """
15 Converts a categorical variable to a one-hot vector.
17 Args:
18 col (Collection[Any]): The categorical variable.
20 Returns:
21 Tuple[np.ndarray, List[Any]]: The one-hot vector and the possible values.
22 """
23 vals = list(set(col))
24 one_hot_vars = np.array([vals.index(x) for x in col], dtype=np.float32)
25 return one_hot_vars, vals
28def plot_embedding_for_all_covars(
29 reps: np.ndarray,
30 covars_df: pd.DataFrame,
31 covars_stds: np.ndarray = None,
32 covars_means: np.ndarray = None,
33 covars_encoders: List[Callable] = None,
34 **kwargs,
35) -> None:
36 """
37 Plot embeddings of representations for all covariates in a Pandas dataframe.
39 Args:
40 reps (np.ndarray): Embeddings of shape (n_samples, n_dims).
41 covars_df (pd.DataFrame): DataFrame of covariates.
42 covars_stds (np.ndarray, optional): Standard deviations of covariates. Defaults to None.
43 covars_means (np.ndarray, optional): Means of covariates. Defaults to None.
44 covars_encoders (List[LabelEncoder], optional): Encoders for covariates. Defaults to None.
45 kwargs: Keyword arguments for plotting.
47 Returns:
48 None
49 """
50 for i, covar in enumerate(covars_df.columns):
51 my_labels = covars_df.iloc[:, i].values
52 if covars_stds is not None:
53 my_labels *= covars_stds
54 if covars_means is not None:
55 my_labels += covars_means
56 if covars_encoders is not None:
57 my_labels = covars_encoders[i].inverse_transform(my_labels.astype(int))
58 if kwargs.get("dithering_pct", 0.0) > 0:
59 reps[:, 0] += np.random.normal(
60 0, kwargs["dithering_pct"] * np.std(reps[:, 0]), size=reps[:, 0].shape
61 )
62 reps[:, 1] += np.random.normal(
63 0, kwargs["dithering_pct"] * np.std(reps[:, 1]), size=reps[:, 1].shape
64 )
65 try:
66 plot_lowdim_rep(
67 reps[:, :2],
68 my_labels,
69 cbar_label=covar,
70 **kwargs,
71 )
72 except TypeError:
73 print(f"Error with covar {covar}")
76def plot_lowdim_rep(
77 low_dim: np.ndarray,
78 labels: np.ndarray,
79 **kwargs,
80):
81 """
82 Plot a low-dimensional representation of a dataset.
84 Args:
85 low_dim (np.ndarray): Low-dimensional representation of shape (n_samples, 2).
86 labels (np.ndarray): Labels of shape (n_samples,).
87 kwargs: Keyword arguments for plotting.
89 Returns:
90 None
91 """
93 if len(set(labels)) < kwargs.get("max_classes_for_discrete", 10): # discrete labels
94 discrete = True
95 cmap = plt.cm.jet
96 else:
97 discrete = False
98 tag = labels
99 norm = None
100 cmap = plt.cm.coolwarm
101 fig = plt.figure(figsize=kwargs.get("figsize", (12, 12)))
102 if discrete:
103 cmap = mpl.colors.LinearSegmentedColormap.from_list(
104 "Custom cmap", [cmap(i) for i in range(cmap.N)], cmap.N
105 )
106 tag, tag_names = convert_to_one_hot(labels)
107 order = np.argsort(tag_names)
108 tag_names = np.array(tag_names)[order]
109 tag = np.array([list(order).index(int(x)) for x in tag])
110 good_tags = [
111 np.sum(tag == i) > kwargs.get("min_samples", 0)
112 for i in range(len(tag_names))
113 ]
114 tag_names = np.array(tag_names)[good_tags]
115 good_idxs = np.array([good_tags[int(tag[i])] for i in range(len(tag))])
116 tag = tag[good_idxs]
117 tag, _ = convert_to_one_hot(tag)
118 bounds = np.linspace(0, len(tag_names), len(tag_names) + 1)
119 try:
120 norm = mpl.colors.BoundaryNorm(bounds, cmap.N)
121 except ValueError:
122 print(
123 "Not enough values for a colorbar (needs at least 2 values), quitting."
124 )
125 return
126 plt.scatter(
127 low_dim[good_idxs, 0],
128 low_dim[good_idxs, 1],
129 c=tag,
130 alpha=kwargs.get("alpha", 1.0),
131 s=100,
132 cmap=cmap,
133 norm=norm,
134 )
135 else:
136 plt.scatter(
137 low_dim[:, 0],
138 low_dim[:, 1],
139 c=labels,
140 alpha=kwargs.get("alpha", 1.0),
141 s=100,
142 cmap=cmap,
143 )
144 plt.xlabel(kwargs.get("xlabel", "X"), fontsize=kwargs.get("xlabel_fontsize", 48))
145 plt.ylabel(kwargs.get("ylabel", "Y"), fontsize=kwargs.get("ylabel_fontsize", 48))
146 plt.xticks([])
147 plt.yticks([])
148 plt.title(kwargs.get("title", ""), fontsize=kwargs.get("title_fontsize", 52))
150 # create a second axes for the colorbar
151 ax2 = fig.add_axes([0.95, 0.15, 0.03, 0.7])
152 if discrete:
153 color_bar = mpl.colorbar.ColorbarBase(
154 ax2,
155 cmap=cmap,
156 norm=norm,
157 spacing="proportional",
158 ticks=bounds[:-1] + 0.5, # boundaries=bounds,
159 format="%1i",
160 )
161 try:
162 color_bar.ax.set(yticks=bounds[:-1] + 0.5, yticklabels=np.round(tag_names))
163 except ValueError:
164 color_bar.ax.set(yticks=bounds[:-1] + 0.5, yticklabels=tag_names)
165 else:
166 color_bar = mpl.colorbar.ColorbarBase(ax2, cmap=cmap, format="%.1f")
167 if kwargs.get("cbar_label", None) is not None:
168 color_bar.ax.set_ylabel(
169 kwargs["cbar_label"], fontsize=kwargs.get("cbar_fontsize", 32)
170 )
171 if "figname" in kwargs:
172 plt.savefig(f"{kwargs['figname']}.pdf", dpi=300, bbox_inches="tight")