Coverage for src / autoencodix / visualize / _general_visualizer.py: 11%
288 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
1from dataclasses import field
2from typing import Any, Dict, Optional, Union, Literal, no_type_check
3import warnings
5import matplotlib.figure
6import numpy as np
7import pandas as pd
8import seaborn as sns # type: ignore
9from matplotlib import pyplot as plt
10from umap import UMAP # type: ignore
12from autoencodix.base._base_visualizer import BaseVisualizer
13from autoencodix.utils._result import Result
14from autoencodix.utils._utils import nested_dict, show_figure
15from autoencodix.configs.default_config import DefaultConfig
18class GeneralVisualizer(BaseVisualizer):
19 plots: Dict[str, Any] = field(
20 default_factory=nested_dict
21 ) ## Nested dictionary of plots as figure handles
23 def __init__(self):
24 self.plots = nested_dict()
26 def __setitem__(self, key, elem):
27 self.plots[key] = elem
29 def visualize(self, result: Result, config: DefaultConfig) -> Result:
30 ## Make Model Weights plot
31 if result.model.input_dim <= 3000:
32 self.plots["ModelWeights"] = self._plot_model_weights(model=result.model)
33 else:
34 warnings.warn(
35 f"Model weights plot is skipped since input dimension {result.model.input_dim} is larger than 3000 and heatmap would be too large."
36 )
38 ## Make long format of losses
39 try:
40 loss_df_melt = self._make_loss_format(result=result, config=config)
42 ## Make plot loss absolute
43 self.plots["loss_absolute"] = self._make_loss_plot(
44 df_plot=loss_df_melt, plot_type="absolute"
45 )
46 ## Make plot loss relative
47 self.plots["loss_relative"] = self._make_loss_plot(
48 df_plot=loss_df_melt, plot_type="relative"
49 )
50 except Exception as e:
51 warnings.warn(
52 f"We could not create visualizations for the loss plots.\n"
53 f"This usually happens if you try to visualize after saving and loading "
54 f"the pipeline object with `save_all=False`. This memory-efficient saving mode "
55 f"does not retain past training loss data.\n\n"
56 # f"Original error message: {e}"
57 )
59 return result
61 ## Plotting methods ##
62 @no_type_check
63 def show_latent_space(
64 self,
65 result: Result,
66 plot_type: Literal[
67 "2D-scatter", "Ridgeline", "Coverage-Correlation"
68 ] = "2D-scatter",
69 labels: Optional[Union[list, pd.Series, None]] = None,
70 focus_labels: Optional[Union[list, None]] = None,
71 param: Optional[Union[list, str]] = None,
72 epoch: Optional[Union[int, None]] = None,
73 split: str = "all",
74 n_downsample: Optional[int] = 10000,
75 **kwargs,
76 ) -> None:
77 """Visualizes the latent space of the given result using different types of plots.
79 Args:
80 result: The result object containing latent spaces and losses.
81 plot_type: The type of plot to generate. Options are "2D-scatter", "Ridgeline", and "Coverage-Correlation". Default is "2D-scatter".
82 labels: List of labels for the data points in the latent space. Default is None.
83 focus_labels: List of labels which should be considered for coloring. All other labels are set to 'other'. Defaults to None where all labels are considered.
84 param: List of parameters provided and stored as metadata. Strings must match column names. If not a list, string "all" is expected for convenient way to make plots for all parameters available. Default is None where no colored labels are plotted.
85 epoch: The epoch number to visualize. If None, the last epoch is inferred from the losses. Default is None.
86 split: The data split to visualize. Options are "train", "valid", "test", and "all". Default is "all".
87 n_downsample: If provided, downsample the data to this number of samples for faster visualization. Default is 10000. Set to None to disable downsampling.
88 **kwargs: additional arguments.
90 """
91 plt.ioff()
92 if plot_type == "Coverage-Correlation":
93 if "Coverage-Correlation" in self.plots:
94 fig = self.plots["Coverage-Correlation"]
95 show_figure(fig)
96 plt.show()
97 else:
98 results = []
99 for epoch in range(
100 result.model.config.checkpoint_interval,
101 result.model.config.epochs + 1,
102 result.model.config.checkpoint_interval,
103 ):
104 for split in ["train", "valid"]:
105 latent_df = result.get_latent_df(epoch=epoch - 1, split=split)
106 tc = self._total_correlation(latent_df)
107 cov = self._coverage_calc(latent_df)
108 results.append(
109 {
110 "epoch": epoch,
111 "split": split,
112 "total_correlation": tc,
113 "coverage": cov,
114 }
115 )
117 df_metrics = pd.DataFrame(results)
119 fig, axes = plt.subplots(1, 2, figsize=(12, 5))
121 # Total Correlation plot
122 _ = sns.lineplot(
123 data=df_metrics,
124 x="epoch",
125 y="total_correlation",
126 hue="split",
127 ax=axes[0],
128 )
129 axes[0].set_title("Total Correlation")
130 axes[0].set_xlabel("Epoch")
131 axes[0].set_ylabel("Total Correlation")
133 # Coverage plot
134 _ = sns.lineplot(
135 data=df_metrics, x="epoch", y="coverage", hue="split", ax=axes[1]
136 )
137 axes[1].set_title("Coverage")
138 axes[1].set_xlabel("Epoch")
139 axes[1].set_ylabel("Coverage")
141 plt.tight_layout()
142 self.plots["Coverage-Correlation"] = fig
143 show_figure(fig)
144 plt.show()
146 else:
147 # Set Defaults
148 if epoch is None:
149 epoch = result.model.config.epochs - 1
151 # ## Getting clin_data
152 clin_data = self._collect_all_metadata(result=result)
153 # if hasattr(result.datasets.train, "metadata"):
154 # # Check if metadata is a dictionary and contains 'paired'
155 # if isinstance(result.datasets.train.metadata, dict):
156 # if "paired" in result.datasets.train.metadata:
157 # clin_data = result.datasets.train.metadata["paired"]
158 # if hasattr(result.datasets, "test"):
159 # clin_data = pd.concat(
160 # [
161 # clin_data,
162 # result.datasets.test.metadata[ # ty: ignore
163 # "paired"
164 # ], # ty: ignore
165 # ], # ty: ignore
166 # axis=0,
167 # )
168 # if hasattr(result.datasets, "valid"):
169 # clin_data = pd.concat(
170 # [
171 # clin_data,
172 # result.datasets.valid.metadata[ # ty: ignore
173 # "paired"
174 # ], # ty: ignore
175 # ], # ty: ignore
176 # axis=0,
177 # )
178 # else:
179 # # Iterate over all splits and keys, concatenate if DataFrame
180 # clin_data = pd.DataFrame()
181 # for split_name in ["train", "test", "valid"]:
182 # split_temp = getattr(result.datasets, split_name, None)
183 # if split_temp is not None and hasattr(
184 # split_temp, "metadata"
185 # ):
186 # for key in split_temp.metadata.keys():
187 # if isinstance(
188 # split_temp.metadata[key], pd.DataFrame
189 # ):
190 # clin_data = pd.concat(
191 # [
192 # clin_data,
193 # split_temp.metadata[key],
194 # ],
195 # axis=0,
196 # )
197 # # remove duplicate rows
198 # clin_data = clin_data[~clin_data.index.duplicated(keep="first")]
199 # # if clin_data.empty:
200 # # # Raise error no annotation given
201 # # raise ValueError(
202 # # "Please provide paired annotation data with key 'paired' in metadata dictionary."
203 # # )
204 # elif isinstance(result.datasets.train.metadata, pd.DataFrame):
205 # clin_data = result.datasets.train.metadata
206 # if hasattr(result.datasets, "test"):
207 # clin_data = pd.concat(
208 # [clin_data, result.datasets.test.metadata], # ty: ignore
209 # axis=0,
210 # )
211 # if hasattr(result.datasets, "valid"):
212 # clin_data = pd.concat(
213 # [clin_data, result.datasets.valid.metadata], # ty: ignore
214 # axis=0,
215 # )
216 # else:
217 # # Raise error no annotation given
218 # raise ValueError(
219 # "Metadata is not a dictionary or DataFrame. Please provide a valid annotation data type."
220 # )
221 # else:
222 # # Iterate over all splits and keys, concatenate if DataFrame
223 # clin_data = pd.DataFrame()
224 # for split_name in ["train", "test", "valid"]:
225 # split_temp = getattr(result.datasets, split_name, None)
226 # if split_temp is not None:
227 # for key in split_temp.datasets.keys():
228 # if isinstance(
229 # split_temp.datasets[key].metadata, pd.DataFrame
230 # ):
231 # clin_data = pd.concat(
232 # [
233 # clin_data,
234 # split_temp.datasets[key].metadata,
235 # ],
236 # axis=0,
237 # )
238 # if len(clin_data) == 0: ## New predict case
239 # for split_name in ["train", "test", "valid"]:
240 # split_temp = getattr(result.new_datasets, split_name, None)
241 # if split_temp is not None:
242 # if len(split_temp.datasets.keys()) > 0:
243 # for key in split_temp.datasets.keys():
244 # if isinstance(
245 # split_temp.datasets[key].metadata, pd.DataFrame
246 # ):
247 # clin_data = pd.concat(
248 # [
249 # clin_data,
250 # split_temp.datasets[key].metadata,
251 # ],
252 # axis=0,
253 # )
254 # else:
255 # if isinstance(
256 # split_temp.metadata, pd.DataFrame
257 # ):
258 # clin_data = pd.concat(
259 # [
260 # clin_data,
261 # split_temp.metadata,
262 # ],
263 # axis=0,
264 # )
265 # # remove duplicate rows
266 # clin_data = clin_data[~clin_data.index.duplicated(keep="first")]
268 # # Raise error no annotation given
269 # raise ValueError(
270 # "No annotation data found. Please provide a valid annotation data type."
271 # )
273 if split == "all":
274 df_latent = pd.concat(
275 [
276 result.get_latent_df(epoch=epoch, split="train"),
277 result.get_latent_df(epoch=epoch, split="valid"),
278 result.get_latent_df(epoch=-1, split="test"),
279 ]
280 )
281 else:
282 if split == "test":
283 df_latent = result.get_latent_df(epoch=-1, split=split)
284 else:
285 df_latent = result.get_latent_df(epoch=epoch, split=split)
287 ## Label options
288 if labels is None and param is None:
289 labels = ["all"] * df_latent.shape[0]
291 if labels is None and isinstance(param, str):
292 if param == "all":
293 param = list(clin_data.columns)
294 else:
295 raise ValueError(
296 "Please provide parameter to plot as a list not as string. If you want to plot all parameters, set param to 'all' and labels to None."
297 )
299 if labels is not None and param is not None:
300 raise ValueError(
301 "Please provide either labels or param, not both. If you want to plot all parameters, set param to 'all' and labels to None."
302 )
304 if labels is not None and param is None:
305 if isinstance(labels, pd.Series):
306 param = [labels.name]
307 # Order by index of df_latent first, fill missing with "unknown"
308 labels = labels.reindex(
309 df_latent.index, fill_value="unknown"
310 ).tolist()
311 else:
312 param = ["user_label"] # Default label if none provided
313 if not isinstance(param, list):
314 raise TypeError("Param needs to be converted to a list")
315 for p in param:
316 if p in clin_data.columns:
317 labels = clin_data.loc[df_latent.index, p].tolist() # ty: ignore
319 if n_downsample is not None:
320 if df_latent.shape[0] > n_downsample:
321 sample_idx = np.random.choice(
322 df_latent.shape[0], n_downsample, replace=False
323 )
324 df_latent = df_latent.iloc[sample_idx]
325 if labels is not None:
326 labels = [labels[i] for i in sample_idx]
328 if plot_type == "2D-scatter":
329 ## Make 2D Embedding with UMAP
330 if df_latent.shape[1] > 2:
331 reducer = UMAP(n_components=2)
332 embedding = pd.DataFrame(reducer.fit_transform(df_latent))
333 else:
334 embedding = df_latent
336 fig = self._plot_2D(
337 embedding=embedding,
338 labels=labels,
339 focus_labels=focus_labels,
340 param=p,
341 layer=f"2D latent space (epoch {epoch+1})", # we start counting epochs at 0, so add 1 for display
342 figsize=(12, 8),
343 center=True,
344 )
345 if focus_labels is None:
346 self.plots["2D-scatter"][epoch][split][p] = fig
347 else:
348 focus_group = "group_" + str(
349 len(
350 self.plots["2D-scatter"][epoch][split][
351 p + "_focus"
352 ].keys()
353 )
354 + 1
355 )
356 self.plots["2D-scatter"][epoch][split][p + "_focus"][
357 focus_group
358 ] = fig
359 show_figure(fig)
360 plt.show()
362 if plot_type == "Ridgeline":
363 ## Make ridgeline plot
365 fig = self._plot_latent_ridge(
366 lat_space=df_latent,
367 labels=labels,
368 focus_labels=focus_labels,
369 param=p,
370 )
371 if focus_labels is None:
372 self.plots["Ridgeline"][epoch][split][p] = fig
373 else:
374 focus_group = "group_" + str(
375 len(
376 self.plots["Ridgeline"][epoch][split][
377 p + "_focus"
378 ].keys()
379 )
380 + 1
381 )
382 self.plots["Ridgeline"][epoch][split][p + "_focus"][
383 focus_group
384 ] = fig
385 show_figure(fig.figure)
386 plt.show()
388 if plot_type == "Clustermap":
389 ## Make clustermap plot
391 fig = self._plot_latent_clustermap(
392 lat_space=df_latent,
393 labels=labels,
394 focus_labels=focus_labels,
395 param=p,
396 )
397 if focus_labels is None:
398 self.plots["Clustermap"][epoch][split][p] = fig
399 else:
400 focus_group = "group_" + str(
401 len(
402 self.plots["Clustermap"][epoch][split][
403 p + "_focus"
404 ].keys()
405 )
406 + 1
407 )
408 self.plots["Clustermap"][epoch][split][p + "_focus"][
409 focus_group
410 ] = fig
411 show_figure(fig)
412 plt.show()
414 def show_weights(self) -> None:
415 """Display the model weights plot if it exists in the plots dictionary."""
417 if "ModelWeights" not in self.plots.keys():
418 print("Model weights not found in the plots dictionary")
419 print("You need to run visualize() method first")
420 else:
421 fig = self.plots["ModelWeights"]
422 show_figure(fig)
423 plt.show()
425 ### Moved to Base
426 # def show_evaluation(
427 # self,
428 # param: str,
429 # metric: str,
430 # ml_alg: Optional[str] = None,
431 # ) -> None:
433 ### Utilities ###
434 @staticmethod
435 def _plot_2D(
436 embedding: pd.DataFrame,
437 labels: list,
438 focus_labels: Optional[Union[list, None]] = None,
439 param: Optional[Union[str, None]] = None,
440 layer: str = "latent space",
441 figsize: tuple = (24, 15),
442 center: bool = True,
443 plot_numeric: bool = False,
444 xlim: Optional[Union[tuple, None]] = None,
445 ylim: Optional[Union[tuple, None]] = None,
446 scale: Optional[Union[str, None]] = None,
447 no_leg: bool = False,
448 ) -> matplotlib.figure.Figure:
449 """Plots a 2D scatter plot of the given embedding with labels.
451 Args:
452 embedding: DataFrame containing the 2D embedding coordinates.
453 labels: List of labels corresponding to each point in the embedding.
454 focus_labels: List of labels which should be considered for coloring. All other labels are set to 'other'. Defaults to None where all labels are considered.
455 param: Title for the legend. Defaults to None.
456 layer: Title for the plot. Defaults to "latent space".
457 figsize: Size of the figure. Defaults to (24, 15).
458 center: If True, centers the plot based on label means. Defaults to True.
459 plot_numeric: If True, treats labels as numeric. Defaults to False.
460 xlim: Limits for the x-axis. Defaults to None.
461 ylim: Limits for the y-axis. Defaults to None.
462 scale:: Scale for the axes (e.g., 'log'). Defaults to None.
463 no_leg: If True, no legend is displayed. Defaults to False.
465 Returns:
466 The resulting matplotlib figure.
467 """
469 numeric = False
470 if not isinstance(labels[0], str):
471 if len(np.unique(labels)) > 3:
472 if not plot_numeric:
473 print(
474 "The provided label column is numeric and converted to categories."
475 )
476 labels = [
477 float("nan") if not isinstance(x, float) else x for x in labels
478 ]
479 labels = (
480 pd.qcut(
481 x=pd.Series(labels),
482 q=4,
483 labels=["1stQ", "2ndQ", "3rdQ", "4thQ"],
484 )
485 .astype(str)
486 .to_list()
487 )
488 else:
489 center = False ## Disable centering for numeric params
490 numeric = True
491 else:
492 labels = [str(x) for x in labels]
494 # check if label or embedding is longerm and duplicate the shorter one
495 if len(labels) < embedding.shape[0]:
496 print(
497 "Given labels do not have the same length as given sample size. Labels will be duplicated."
498 )
499 labels = [
500 label
501 for label in labels
502 for _ in range(embedding.shape[0] // len(labels))
503 ]
504 elif len(labels) > embedding.shape[0]:
505 labels = list(set(labels))
507 if len(np.unique(labels)) > 20 and focus_labels is None:
508 warnings.warn(
509 f"The provided label column has {len(np.unique(labels))} unique labels which might make the scatter plot unclear."
510 )
511 # Restrict to top 20 labels
512 focus_labels = pd.Series(labels).value_counts().nlargest(20).index.tolist()
513 print(f"Focusing on top 20 labels instead")
515 if focus_labels is not None:
516 labels = [label if label in focus_labels else "other" for label in labels]
518 # Increase figure size width if legend has has more than 10 labels (two columns)
519 if len(np.unique(labels)) > 10:
520 figsize = (figsize[0] * 1.5, figsize[1])
521 # Increase figure size width if legend labels are very long
522 max_label_length = max([len(str(label)) for label in np.unique(labels)])
523 figsize = (int(figsize[0] + max_label_length * 0.2), figsize[1])
525 fig, ax2 = plt.subplots(1, 1, figsize=figsize)
526 if numeric:
527 ax2 = sns.scatterplot(
528 x=embedding.iloc[:, 0],
529 y=embedding.iloc[:, 1],
530 hue=labels,
531 palette="bwr",
532 s=40,
533 alpha=0.5,
534 ec="black",
535 )
536 else:
537 if len(np.unique(labels)) > 8:
538 cat_pal = sns.color_palette("tab20", n_colors=len(np.unique(labels)))
539 else:
540 cat_pal = sns.color_palette("tab10", n_colors=len(np.unique(labels)))
542 if "other" in np.unique(labels):
543 # set color of "other" to light grey
544 other_color = (0.3, 0.3, 0.3)
545 cat_pal[list(np.unique(labels)).index("other")] = other_color
547 # Adjust alpha depending on number of points
548 if len(labels) > 10000:
549 point_alpha = 0.2
550 point_size = 10
551 elif len(labels) > 5000:
552 point_alpha = 0.4
553 point_size = 20
554 else:
555 point_alpha = 0.7
556 point_size = 40
558 ax2 = sns.scatterplot(
559 x=embedding.iloc[:, 0],
560 y=embedding.iloc[:, 1],
561 hue=labels,
562 hue_order=np.unique(labels),
563 palette=cat_pal,
564 s=point_size,
565 alpha=point_alpha,
566 ec="black",
567 )
568 if center:
569 means = embedding.groupby(by=labels).mean()
571 ax2 = sns.scatterplot(
572 x=means.iloc[:, 0],
573 y=means.iloc[:, 1],
574 hue=np.unique(labels),
575 hue_order=np.unique(labels),
576 palette=cat_pal,
577 s=200,
578 ec="black",
579 alpha=0.7,
580 marker="*",
581 legend=False,
582 ax=ax2,
583 )
585 if xlim is not None:
586 ax2.set_xlim(xlim[0], xlim[1])
588 if ylim is not None:
589 ax2.set_ylim(ylim[0], ylim[1])
591 if scale is not None:
592 plt.yscale(scale)
593 plt.xscale(scale)
594 ax2.set_xlabel("Dim 1")
595 ax2.set_ylabel("Dim 2")
596 legend_cols = 1
597 if len(np.unique(labels)) > 10:
598 legend_cols = 2
600 if no_leg:
601 plt.legend([], [], frameon=False)
602 else:
603 sns.move_legend(
604 ax2,
605 "upper left",
606 bbox_to_anchor=(1, 1),
607 ncol=legend_cols,
608 title=param,
609 frameon=False,
610 )
612 # Add title to the plot
613 ax2.set_title(layer)
614 plt.tight_layout()
616 plt.close()
617 return fig
619 @staticmethod
620 def _plot_latent_clustermap(
621 lat_space: pd.DataFrame,
622 labels: Optional[Union[list, pd.Series, None]] = None,
623 focus_labels: Optional[Union[list, None]] = None,
624 param: Optional[Union[str, None]] = None,
625 ) -> matplotlib.figure.Figure:
626 """Creates a clustermap of the latent space dimension where each row shows the intensity of a latent dimension and columns are clustered.
628 Args:
629 lat_space: DataFrame containing the latent space intensities for samples (rows) and latent dimensions (columns)
630 labels: List of labels for each sample. If None, all samples are considered as one group.
631 focus_labels: List of labels which should be considered for coloring. All other labels are set to 'other'. Defaults to None where all labels are considered.
632 param: Clinical parameter to create groupings and coloring of ridges. Must be a column name (str) of clin_data
633 Returns:
634 fig: Figure object containing the clustermap
635 """
636 if len(np.unique(labels)) > 50 and focus_labels is None:
637 warnings.warn(
638 f"The provided label column has {len(np.unique(labels))} unique labels which might make the clustermap plot too big."
639 )
640 # Restrict to top 50 labels
641 focus_labels = pd.Series(labels).value_counts().nlargest(50).index.tolist()
642 print(f"Focusing on top 50 labels instead")
644 if focus_labels is not None:
645 labels = [label if label in focus_labels else "other" for label in labels]
647 lat_space[param] = labels
649 cluster_figure = sns.clustermap(
650 lat_space.groupby(param).mean(),
651 col_cluster=False,
652 row_cluster=True,
653 figsize=(1 * lat_space.shape[1], 4 + 0.5 * len(set(labels))),
654 dendrogram_ratio=0.1,
655 cmap="icefire",
656 cbar_kws={"orientation": "horizontal"},
657 cbar_pos=(0.2, 0.95, 0.3, 0.02),
658 ).fig
660 plt.close()
661 lat_space.drop(columns=[param], inplace=True)
662 return cluster_figure
664 @staticmethod
665 def _plot_latent_ridge(
666 lat_space: pd.DataFrame,
667 labels: Optional[Union[list, pd.Series, None]] = None,
668 focus_labels: Optional[Union[list, None]] = None,
669 param: Optional[Union[str, None]] = None,
670 ) -> sns.FacetGrid:
671 """Creates a ridge line plot of latent space dimension where each row shows the density of a latent dimension and groups (ridges).
673 Args:
674 lat_space: DataFrame containing the latent space intensities for samples (rows) and latent dimensions (columns)
675 labels: List of labels for each sample. If None, all samples are considered as one group.
676 focus_labels: List of labels which should be considered for coloring. All other labels are set to 'other'. Defaults to None where all labels are considered.
677 param: Clinical parameter to create groupings and coloring of ridges. Must be a column name (str) of clin_data
678 Returns:
679 g: FacetGrid object containing the ridge line plot
680 """
681 sns.set_theme(
682 style="white", rc={"axes.facecolor": (0, 0, 0, 0)}
683 ) ## Necessary to enforce overplotting
685 df = pd.melt(lat_space, var_name="latent dim", value_name="latent intensity")
686 df["sample"] = len(lat_space.columns) * list(lat_space.index)
688 if labels is None:
689 param = "all"
690 labels = ["all"] * len(df)
692 # print(labels[0])
693 if not isinstance(labels[0], str):
694 if len(np.unique(labels)) > 3:
695 # Change all non-float labels to NaN
696 labels = [x if isinstance(x, float) else float("nan") for x in labels]
697 labels = list(
698 pd.qcut(
699 x=pd.Series(labels),
700 q=4,
701 labels=["1stQ", "2ndQ", "3rdQ", "4thQ"],
702 ).astype(str)
703 )
704 else:
705 labels = [str(x) for x in labels]
707 if len(np.unique(labels)) > 20 and focus_labels is None:
708 warnings.warn(
709 f"The provided label column has {len(np.unique(labels))} unique labels which might make the ridgeline plot unclear."
710 )
711 # Restrict to top 20 labels
712 focus_labels = pd.Series(labels).value_counts().nlargest(20).index.tolist()
713 print(f"Focusing on top 20 labels instead")
715 if focus_labels is not None:
716 labels = [label if label in focus_labels else "other" for label in labels]
718 df[param] = len(lat_space.columns) * labels # type: ignore
720 exclude_missing_info = (df[param] == "unknown") | (df[param] == "nan")
722 xmin = (
723 df.loc[~exclude_missing_info, ["latent intensity", "latent dim", param]]
724 .groupby([param, "latent dim"], observed=False)
725 .quantile(0.05)
726 .min()
727 )
728 xmax = (
729 df.loc[~exclude_missing_info, ["latent intensity", "latent dim", param]]
730 .groupby([param, "latent dim"], observed=False)
731 .quantile(0.9)
732 .max()
733 )
735 # if len(np.unique(df[param])) > 8:
736 # cat_pal = sns.husl_palette(len(np.unique(df[param])))
737 # else:
738 # cat_pal = sns.color_palette(n_colors=len(np.unique(df[param])))
740 if len(np.unique(labels)) > 8:
741 cat_pal = sns.color_palette("tab20", n_colors=len(labels))
742 else:
743 cat_pal = sns.color_palette("tab10", n_colors=len(labels))
745 if "other" in np.unique(labels):
746 # set color of "other" to light grey
747 other_color = (0.3, 0.3, 0.3)
748 cat_pal[list(np.unique(labels)).index("other")] = other_color
750 # Length of longest latent dim string for aspect ratio
751 len_longest_latent_dim = max([len(str(x)) for x in lat_space.columns])
753 g = sns.FacetGrid(
754 df[~exclude_missing_info],
755 row="latent dim",
756 hue=param,
757 aspect=12 + len_longest_latent_dim / 4,
758 height=0.8,
759 xlim=(xmin.iloc[0], xmax.iloc[0]),
760 palette=cat_pal,
761 )
763 g.map_dataframe(
764 sns.kdeplot,
765 "latent intensity",
766 bw_adjust=0.5,
767 clip_on=True,
768 fill=True,
769 alpha=0.5,
770 warn_singular=False,
771 ec="k",
772 lw=1,
773 )
775 def label(data, color, label, text="latent dim"):
776 ax = plt.gca()
777 label_text = data[text].unique()[0]
778 ax.text(
779 0.0,
780 0.2,
781 label_text,
782 fontweight="bold",
783 ha="right",
784 va="center",
785 transform=ax.transAxes,
786 )
788 g.map_dataframe(label, text="latent dim")
790 g.set(xlim=(xmin.iloc[0], xmax.iloc[0]))
791 # Set the subplots to overlap
792 g.figure.subplots_adjust(hspace=-0.5)
794 # Remove axes details that don't play well with overlap
795 g.set_titles("")
796 g.set(yticks=[], ylabel="")
797 g.despine(bottom=True, left=True)
799 g.add_legend()
801 plt.close()
802 return g
804 def _plot_evaluation(
805 self,
806 result: Result,
807 ) -> dict:
808 """Plots the evaluation results from the Result object.
810 Args:
811 result: The Result object containing evaluation data.
813 Returns:
814 The generated dictionary containing the evaluation plots.
815 """
816 ## Plot all results
818 ml_plots = dict()
819 plt.ioff()
820 if not hasattr(result.embedding_evaluation, "CLINIC_PARAM"):
821 warnings.warn(
822 "We could not create visualizations for the evaluation plots.\n"
823 "This usually happens if you try to visualize after saving and loading "
824 "the pipeline object with `save_all=False`. This memory-efficient saving mode "
825 "Set save_all=True to avoid this, also this might be fixed soon."
826 )
827 return {}
829 for c in pd.unique(result.embedding_evaluation.CLINIC_PARAM):
830 ml_plots[c] = dict()
831 for m in pd.unique( # ty: ignore
832 result.embedding_evaluation.loc[
833 result.embedding_evaluation.CLINIC_PARAM == c, "metric"
834 ]
835 ):
836 ml_plots[c][m] = dict()
837 for alg in pd.unique( # ty: ignore
838 result.embedding_evaluation.loc[
839 (result.embedding_evaluation.CLINIC_PARAM == c)
840 & (result.embedding_evaluation.metric == m),
841 "ML_ALG",
842 ]
843 ):
844 data = result.embedding_evaluation[
845 (result.embedding_evaluation.metric == m)
846 & (result.embedding_evaluation.CLINIC_PARAM == c)
847 & (result.embedding_evaluation.ML_ALG == alg)
848 ]
850 # Check for missing values
851 if data["value"].isnull().any():
852 warnings.warn(
853 f"Missing values found in evaluation data for parameter '{c}', metric '{m}', and algorithm '{alg}'. These will be ignored in the plot."
854 )
855 data = data.dropna()
857 sns_plot = sns.catplot(
858 data=data,
859 x="score_split",
860 y="value",
861 col="ML_TASK",
862 hue="score_split",
863 kind="bar",
864 )
866 min_y = data.value.min()
867 if min_y > 0:
868 min_y = 0
870 ml_plots[c][m][alg] = sns_plot.set(ylim=(min_y, None))
872 self.plots["ML_Evaluation"] = ml_plots
874 return ml_plots
876 @staticmethod
877 def _total_correlation(latent_space: pd.DataFrame) -> float:
878 """Function to compute the total correlation as described here (Equation2): https://doi.org/10.3390/e21100921
880 Args:
881 latent_space: latent space with dimension sample vs. latent dimensions
882 Returns:
883 tc: total correlation across latent dimensions
884 """
885 lat_cov = np.cov(latent_space.T)
886 tc = 0.5 * (np.sum(np.log(np.diag(lat_cov))) - np.linalg.slogdet(lat_cov)[1])
887 return tc
889 @staticmethod
890 def _coverage_calc(latent_space: pd.DataFrame) -> float:
891 """Function to compute the coverage as described here (Equation3): https://doi.org/10.3390/e21100921
893 Args:
894 latent_space: latent space with dimension sample vs. latent dimensions
895 Returns:
896 cov: coverage across latent dimensions
897 """
898 bins_per_dim = int(
899 np.power(len(latent_space.index), 1 / len(latent_space.columns))
900 )
901 if bins_per_dim < 2:
902 warnings.warn(
903 "Coverage calculation fails since combination of sample size and latent dimension results in less than 2 bins."
904 )
905 cov = np.nan
906 else:
907 latent_bins = latent_space.apply(lambda x: pd.cut(x, bins=bins_per_dim))
908 latent_bins = pd.Series(zip(*[latent_bins[col] for col in latent_bins]))
909 cov = len(latent_bins.unique()) / np.power(
910 bins_per_dim, len(latent_space.columns)
911 )
913 return cov