Coverage for src / autoencodix / visualize / visualize.py: 10%
428 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
1import os
2from dataclasses import field
3from typing import Any, Dict, Optional, Union, Literal, no_type_check
4import warnings
6import matplotlib.figure
7import numpy as np
8import pandas as pd
9import seaborn as sns # type: ignore
10import torch
11from matplotlib import pyplot as plt
12from umap import UMAP # type: ignore
14from autoencodix.base._base_visualizer import BaseVisualizer
15from autoencodix.utils._result import Result
16from autoencodix.utils._utils import nested_dict, nested_to_tuple, show_figure
17from autoencodix.configs.default_config import DefaultConfig
20class Visualizer(BaseVisualizer):
21 plots: Dict[str, Any] = field(
22 default_factory=nested_dict
23 ) ## Nested dictionary of plots as figure handles
25 def __init__(self):
26 self.plots = nested_dict()
28 def __setitem__(self, key, elem):
29 self.plots[key] = elem
31 def visualize(self, result: Result, config: DefaultConfig) -> Result:
32 ## Make Model Weights plot
33 self.plots["ModelWeights"] = self.plot_model_weights(model=result.model)
35 ## Make long format of losses
36 loss_df_melt = self.make_loss_format(result=result, config=config)
38 ## Make plot loss absolute
39 self.plots["loss_absolute"] = self.make_loss_plot(
40 df_plot=loss_df_melt, plot_type="absolute"
41 )
42 ## Make plot loss relative
43 self.plots["loss_relative"] = self.make_loss_plot(
44 df_plot=loss_df_melt, plot_type="relative"
45 )
47 return result
49 ## Plotting methods ##
51 def save_plots(
52 self, path: str, which: Union[str, list] = "all", format: str = "png"
53 ) -> None:
54 """Save specified plots to the given path in the specified format.
56 Args:
57 path: The directory path where the plots will be saved.
58 which: A list of plot names to save or a string specifying which plots to save.
59 If 'all', all plots in the plots dictionary will be saved.
60 If a single plot name is provided as a string, only that plot will be saved.
61 format: The file format in which to save the plots (e.g., 'png', 'jpg').
63 Raises:
64 ValueError: If the 'which' parameter is not a list or a string.
65 """
66 if not isinstance(which, list):
67 ## Case when which is a string
68 if which == "all":
69 ## Case when all plots are to be saved
70 if len(self.plots) == 0:
71 print("No plots found in the plots dictionary")
72 print("You need to run visualize() method first")
73 else:
74 for item in nested_to_tuple(self.plots):
75 fig = item[-1] ## Figure is in last element of the tuple
76 filename = "_".join(str(x) for x in item[0:-1])
77 fullpath = os.path.join(path, filename)
78 fig.savefig(f"{fullpath}.{format}")
79 else:
80 ## Case when a single plot is provided as string
81 if which not in self.plots.keys():
82 print(f"Plot {which} not found in the plots dictionary")
83 print(f"All available plots are: {list(self.plots.keys())}")
84 else:
85 for item in nested_to_tuple(
86 self.plots[which]
87 ): # Plot all epochs and splits of type which
88 fig = item[-1] ## Figure is in last element of the tuple
89 filename = (
90 which # ty: ignore
91 + "_"
92 + "_".join(str(x) for x in item[0:-1])
93 )
94 fullpath = os.path.join(path, filename)
95 fig.savefig(f"{fullpath}.{format}")
96 else:
97 ## Case when which is a list of plot specified as strings
98 for key in which:
99 if key not in self.plots.keys():
100 print(f"Plot {key} not found in the plots dictionary")
101 print(f"All available plots are: {list(self.plots.keys())}")
102 continue
103 else:
104 for item in nested_to_tuple(
105 self.plots[key]
106 ): # Plot all epochs and splits of type key
107 fig = item[-1] ## Figure is in last element of the tuple
108 filename = key + "_" + "_".join(str(x) for x in item[0:-1])
109 fullpath = os.path.join(path, filename)
110 fig.savefig(f"{fullpath}.{format}")
112 def show_loss(
113 self, plot_type: Literal["absolute", "relative"] = "absolute"
114 ) -> None:
115 """Display the loss plot.
117 Args:
118 plot_type: The type of loss plot to display. Defaults to "absolute".
119 """
120 if plot_type == "absolute":
121 if "loss_absolute" not in self.plots.keys():
122 print("Absolute loss plot not found in the plots dictionary")
123 print("You need to run visualize() method first")
124 else:
125 fig = self.plots["loss_absolute"]
126 show_figure(fig)
127 plt.show()
128 if plot_type == "relative":
129 if "loss_relative" not in self.plots.keys():
130 print("Relative loss plot not found in the plots dictionary")
131 print("You need to run visualize() method first")
132 else:
133 fig = self.plots["loss_relative"]
134 show_figure(fig)
135 plt.show()
137 if plot_type not in ["absolute", "relative"]:
138 print(
139 "Type of loss plot not recognized. Please use 'absolute' or 'relative'"
140 )
142 @no_type_check
143 def show_latent_space(
144 self,
145 result: Result,
146 plot_type: str = "2D-scatter",
147 labels: Optional[Union[list, pd.Series, None]] = None,
148 param: Optional[Union[list, str]] = None,
149 epoch: Optional[Union[int, None]] = None,
150 split: str = "all",
151 **kwargs,
152 ) -> None:
153 """Visualizes the latent space of the given result using different types of plots.
155 Args:
156 result: The result object containing latent spaces and losses.
157 plot_type The type of plot to generate. Options are "2D-scatter", "Ridgeline", and "Coverage-Correlation". Default is "2D-scatter".
158 labels: List of labels for the data points in the latent space. Default is None.
159 param : List of parameters provided and stored as metadata. Strings must match column names. If not a list, string "all" is expected for convenient way to make plots for all parameters available. Default is None where no colored labels are plotted.
160 epoch: The epoch number to visualize. If None, the last epoch is inferred from the losses. Default is None.
161 split: The data split to visualize. Options are "train", "valid", "test", and "all". Default is "all".
163 """
164 plt.ioff()
165 if plot_type == "Coverage-Correlation":
166 if "Coverage-Correlation" in self.plots:
167 fig = self.plots["Coverage-Correlation"]
168 show_figure(fig)
169 plt.show()
170 else:
171 results = []
172 for epoch in range(
173 result.model.config.checkpoint_interval,
174 result.model.config.epochs + 1,
175 result.model.config.checkpoint_interval,
176 ):
177 for split in ["train", "valid"]:
178 latent_df = result.get_latent_df(epoch=epoch - 1, split=split)
179 tc = self._total_correlation(latent_df)
180 cov = self._coverage_calc(latent_df)
181 results.append(
182 {
183 "epoch": epoch,
184 "split": split,
185 "total_correlation": tc,
186 "coverage": cov,
187 }
188 )
190 df_metrics = pd.DataFrame(results)
192 fig, axes = plt.subplots(1, 2, figsize=(12, 5))
194 # Total Correlation plot
195 _ = sns.lineplot(
196 data=df_metrics,
197 x="epoch",
198 y="total_correlation",
199 hue="split",
200 ax=axes[0],
201 )
202 axes[0].set_title("Total Correlation")
203 axes[0].set_xlabel("Epoch")
204 axes[0].set_ylabel("Total Correlation")
206 # Coverage plot
207 _ = sns.lineplot(
208 data=df_metrics, x="epoch", y="coverage", hue="split", ax=axes[1]
209 )
210 axes[1].set_title("Coverage")
211 axes[1].set_xlabel("Epoch")
212 axes[1].set_ylabel("Coverage")
214 plt.tight_layout()
215 self.plots["Coverage-Correlation"] = fig
216 show_figure(fig)
217 plt.show()
219 else:
220 # Set Defaults
221 if epoch is None:
222 epoch = result.model.config.epochs - 1
224 ## Getting clin_data
225 if not hasattr(result.datasets, "train"):
226 raise ValueError("no train split in datasets")
228 if not hasattr(result.datasets, "valid"):
229 raise ValueError("no valid split in datasets")
230 if result.datasets.train is None:
231 raise ValueError("train is None")
232 if result.datasets.valid is None:
233 raise ValueError("train is None")
234 if result.datasets.test is None:
235 raise ValueError("train is None")
237 if not hasattr(result.datasets.train, "metadata"):
238 raise ValueError("train dataset has no metadata")
239 if not hasattr(result.datasets.valid, "metadata"):
240 raise ValueError("valid dataset has no metadata")
242 # Check if metadata is a dictionary and contains 'paired'
243 if isinstance(result.datasets.train.metadata, dict):
244 if "paired" in result.datasets.train.metadata:
245 clin_data = result.datasets.train.metadata["paired"]
246 if hasattr(result.datasets, "test"):
247 clin_data = pd.concat(
248 [clin_data, result.datasets.test.metadata["paired"]],
249 axis=0,
250 )
251 if hasattr(result.datasets, "valid"):
252 clin_data = pd.concat(
253 [clin_data, result.datasets.valid.metadata["paired"]],
254 axis=0,
255 )
256 else:
257 # Raise error no annotation given
258 raise ValueError(
259 "Please provide paired annotation data with key 'paired' in metadata dictionary."
260 )
261 elif isinstance(result.datasets.train.metadata, pd.DataFrame):
262 clin_data = result.datasets.train.metadata
263 if hasattr(result.datasets, "test"):
264 clin_data = pd.concat(
265 [clin_data, result.datasets.test.metadata],
266 axis=0,
267 )
268 if hasattr(result.datasets, "valid"):
269 clin_data = pd.concat(
270 [clin_data, result.datasets.valid.metadata],
271 axis=0,
272 )
273 else:
274 # Raise error no annotation given
275 raise ValueError(
276 "Metadata is not a dictionary or DataFrame. Please provide a valid annotation data type."
277 )
278 else:
279 # Raise error no annotation given
280 raise ValueError(
281 "No annotation data found. Please provide a valid annotation data type."
282 )
284 if split == "all":
285 df_latent = pd.concat(
286 [
287 result.get_latent_df(epoch=epoch, split="train"),
288 result.get_latent_df(epoch=epoch, split="valid"),
289 result.get_latent_df(epoch=-1, split="test"),
290 ]
291 )
292 else:
293 if split == "test":
294 df_latent = result.get_latent_df(epoch=-1, split=split)
295 else:
296 df_latent = result.get_latent_df(epoch=epoch, split=split)
298 if labels is None and param is None:
299 labels = ["all"] * df_latent.shape[0]
301 if labels is None and isinstance(param, str):
302 if param == "all":
303 param = list(clin_data.columns)
304 else:
305 raise ValueError(
306 "Please provide parameter to plot as a list not as string. If you want to plot all parameters, set param to 'all' and labels to None."
307 )
309 if labels is not None and param is not None:
310 raise ValueError(
311 "Please provide either labels or param, not both. If you want to plot all parameters, set param to 'all' and labels to None."
312 )
314 if labels is not None and param is None:
315 if isinstance(labels, pd.Series):
316 param = [labels.name]
317 # Order by index of df_latent first, fill missing with "unknown"
318 labels = labels.reindex(
319 df_latent.index, fill_value="unknown"
320 ).tolist()
321 else:
322 param = ["user_label"] # Default label if none provided
324 for p in param:
325 if p in clin_data.columns:
326 labels = clin_data.loc[df_latent.index, p].tolist()
328 if plot_type == "2D-scatter":
329 ## Make 2D Embedding with UMAP
330 if df_latent.shape[1] > 2:
331 reducer = UMAP(n_components=2)
332 embedding = pd.DataFrame(reducer.fit_transform(df_latent))
333 else:
334 embedding = df_latent
336 self.plots["2D-scatter"][epoch][split][p] = self.plot_2D(
337 embedding=embedding,
338 labels=labels,
339 param=p,
340 layer=f"2D latent space (epoch {epoch + 1})", # we start counting epochs at 0, so add 1 for display
341 figsize=(12, 8),
342 center=True,
343 )
345 fig = self.plots["2D-scatter"][epoch][split][p]
346 show_figure(fig)
347 plt.show()
349 if plot_type == "Ridgeline":
350 ## Make ridgeline plot
352 self.plots["Ridgeline"][epoch][split][p] = self.plot_latent_ridge(
353 lat_space=df_latent, labels=labels, param=p
354 )
356 fig = self.plots["Ridgeline"][epoch][split][p].figure
357 show_figure(fig)
358 plt.show()
360 def show_weights(self) -> None:
361 """Display the model weights plot if it exists in the plots dictionary."""
363 if "ModelWeights" not in self.plots.keys():
364 print("Model weights not found in the plots dictionary")
365 print("You need to run visualize() method first")
366 else:
367 fig = self.plots["ModelWeights"]
368 show_figure(fig)
369 plt.show()
371 # def plot_model_weights(model: torch.nn.Module) -> matplotlib.figure.Figure:
372 # """
373 # Visualization of model weights in encoder and decoder layers as heatmap for each layer as subplot.
374 # ARGS:
375 # model (torch.nn.Module): PyTorch model instance.
376 # filepath (str): Path specifying save name and location.
377 # RETURNS:
378 # fig (matplotlib.figure): Figure handle (of last plot)
379 # """
380 # all_weights = []
381 # names = []
382 # if hasattr(model, "ontologies"):
383 # if model.ontologies is not None:
384 # # If model is Ontix
385 # # Get node names from ontologies
386 # node_names = list()
387 # for ontology in model.ontologies:
388 # node_names.append(ontology.keys())
390 # node_names.append(model.feature_order) # Add feature order as last layer
392 # for name, param in model.named_parameters():
393 # if "weight" in name and len(param.shape) == 2:
394 # if "var" not in name: ## For VAE plot only mu weights
395 # all_weights.append(param.detach().cpu().numpy())
396 # names.append(name[:-7])
398 # layers = int(len(all_weights) / 2)
399 # fig, axes = plt.subplots(2, layers, sharex=False, figsize=(20, 10))
401 # for layer in range(layers):
402 # ## Encoder Layer
403 # if layers > 1:
404 # sns.heatmap(
405 # all_weights[layer],
406 # cmap=sns.color_palette("Spectral", as_cmap=True),
407 # ax=axes[0, layer],
408 # ).set(title=names[layer])
409 # ## Decoder Layer
410 # sns.heatmap(
411 # all_weights[layers + layer],
412 # cmap=sns.color_palette("Spectral", as_cmap=True),
413 # ax=axes[1, layer],
414 # ).set(title=names[layers + layer])
415 # axes[1, layer].set_xlabel("In Node", size=12)
416 # if model.ontologies is not None:
417 # axes[1, layer].set_xticks(
418 # ticks=range(len(node_names[layer])),
419 # labels=node_names[layer],
420 # rotation=90,
421 # fontsize=8,
422 # )
423 # axes[1, layer].set_yticks(
424 # ticks=range(len(node_names[layer + 1])),
425 # labels=node_names[layer + 1],
426 # rotation=0,
427 # fontsize=8,
428 # )
429 # else:
430 # sns.heatmap(
431 # all_weights[layer],
432 # cmap=sns.color_palette("Spectral", as_cmap=True),
433 # ax=axes[layer],
434 # ).set(title=names[layer])
435 # ## Decoder Layer
436 # sns.heatmap(
437 # all_weights[layer + 2],
438 # cmap=sns.color_palette("Spectral", as_cmap=True),
439 # ax=axes[layer + 1],
440 # ).set(title=names[layer + 2])
441 # axes[1].set_xlabel("In Node", size=12)
443 # if layers > 1:
444 # axes[1, 0].set_ylabel("Out Node", size=12)
445 # axes[0, 0].set_ylabel("Out Node", size=12)
446 # else:
447 # axes[1].set_ylabel("Out Node", size=12)
448 # axes[0].set_ylabel("Out Node", size=12)
450 # ## Add title
451 # fig.suptitle("Model Weights", size=20)
452 # plt.close()
453 # return fig
455 ## NEW VERSION
456 # @staticmethod
457 # def plot_model_weights(model: torch.nn.Module) -> matplotlib.figure.Figure:
458 # """
459 # Visualization of model weights in encoder and decoder layers as heatmap for each layer as subplot.
460 # ARGS:
461 # model (torch.nn.Module): PyTorch model instance.
462 # filepath (str): Path specifying save name and location.
463 # RETURNS:
464 # fig (matplotlib.figure): Figure handle (of last plot)
465 # """
466 # all_weights = []
467 # names = []
468 # if hasattr(model, "ontologies"):
469 # if model.ontologies is not None:
470 # # If model is Ontix
471 # # Get node names from ontologies
472 # node_names = list()
473 # for ontology in model.ontologies:
474 # node_names.append(ontology.keys())
476 # node_names.append(model.feature_order) # Add feature order as last layer
478 # for name, param in model.named_parameters():
479 # if "weight" in name and len(param.shape) == 2:
480 # if "var" not in name: ## For VAE plot only mu weights
481 # all_weights.append(param.detach().cpu().numpy())
482 # names.append(name[:-7])
484 # layers = int(len(all_weights) / 2)
485 # fig, axes = plt.subplots(2, layers, sharex=False, figsize=(20, 10))
487 # for layer in range(layers):
488 # ## Encoder Layer
489 # if layers > 1:
490 # sns.heatmap(
491 # all_weights[layer],
492 # cmap=sns.color_palette("Spectral", as_cmap=True),
493 # ax=axes[0, layer],
494 # ).set(title=names[layer])
495 # ## Decoder Layer
496 # sns.heatmap(
497 # all_weights[layers + layer],
498 # cmap=sns.color_palette("Spectral", as_cmap=True),
499 # ax=axes[1, layer],
500 # ).set(title=names[layers + layer])
501 # axes[1, layer].set_xlabel("In Node", size=12)
502 # if model.ontologies is not None:
503 # axes[1, layer].set_xticks(
504 # ticks=range(len(node_names[layer])),
505 # labels=node_names[layer],
506 # rotation=90,
507 # fontsize=8,
508 # )
509 # axes[1, layer].set_yticks(
510 # ticks=range(len(node_names[layer + 1])),
511 # labels=node_names[layer + 1],
512 # rotation=0,
513 # fontsize=8,
514 # )
515 # else:
516 # sns.heatmap(
517 # all_weights[layer],
518 # cmap=sns.color_palette("Spectral", as_cmap=True),
519 # ax=axes[layer],
520 # ).set(title=names[layer])
521 # ## Decoder Layer
522 # sns.heatmap(
523 # all_weights[layer + 2],
524 # cmap=sns.color_palette("Spectral", as_cmap=True),
525 # ax=axes[layer + 1],
526 # ).set(title=names[layer + 2])
527 # axes[1].set_xlabel("In Node", size=12)
529 # if layers > 1:
530 # axes[1, 0].set_ylabel("Out Node", size=12)
531 # axes[0, 0].set_ylabel("Out Node", size=12)
532 # else:
533 # axes[1].set_ylabel("Out Node", size=12)
534 # axes[0].set_ylabel("Out Node", size=12)
536 # ## Add title
537 # fig.suptitle("Model Weights", size=20)
538 # plt.close()
539 # return fig
541 ## NEW VERSION
542 def plot_model_weights(model: torch.nn.Module) -> matplotlib.figure.Figure:
543 """Visualization of model weights in encoder and decoder layers as heatmap for each layer as subplot.
545 Handles non-symmetrical autoencoder architectures.
546 Plots _mu layer for encoder as well.
547 Uses node_names for decoder layers if model has ontologies.
549 Args:
550 model: PyTorch model instance.
551 Returns:
552 fig: Figure handle (of last plot)
553 """
554 all_weights = []
555 names = []
556 node_names = []
557 if hasattr(model, "ontologies"):
558 if model.ontologies is not None:
559 node_names = []
560 for ontology in model.ontologies:
561 node_names.append(list(ontology.keys()))
562 node_names.append(model.feature_order)
564 # Collect encoder and decoder weights separately
565 encoder_weights = []
566 encoder_names = []
567 decoder_weights = []
568 decoder_names = []
569 for name, param in model.named_parameters():
570 # print(name)
571 if "weight" in name and len(param.shape) == 2:
572 if "encoder" in name and "var" not in name and "_mu" not in name:
573 encoder_weights.append(param.detach().cpu().numpy())
574 encoder_names.append(name[:-7])
575 elif "_mu" in name:
576 encoder_weights.append(param.detach().cpu().numpy())
577 encoder_names.append(name[:-7])
578 elif "decoder" in name and "var" not in name:
579 decoder_weights.append(param.detach().cpu().numpy())
580 decoder_names.append(name[:-7])
581 elif (
582 "encoder" not in name
583 and "decoder" not in name
584 and "var" not in name
585 ):
586 # fallback for models without explicit encoder/decoder in name
587 all_weights.append(param.detach().cpu().numpy())
588 names.append(name[:-7])
590 if encoder_weights or decoder_weights:
591 n_enc = len(encoder_weights)
592 n_dec = len(decoder_weights)
593 n_cols = max(n_enc, n_dec)
594 fig, axes = plt.subplots(2, n_cols, sharex=False, figsize=(15 * n_cols, 15))
595 if n_cols == 1:
596 axes = axes.reshape(2, 1)
597 # Plot encoder weights
598 for i in range(n_enc):
599 ax = axes[0, i]
600 sns.heatmap(
601 encoder_weights[i],
602 cmap=sns.color_palette("Spectral", as_cmap=True),
603 center=0,
604 ax=ax,
605 ).set(title=encoder_names[i])
606 ax.set_ylabel("Out Node", size=12)
607 # Hide unused encoder subplots
608 for i in range(n_enc, n_cols):
609 axes[0, i].axis("off")
610 # Plot decoder weights
611 for i in range(n_dec):
612 ax = axes[1, i]
613 heatmap_kwargs = {}
615 sns.heatmap(
616 decoder_weights[i],
617 cmap=sns.color_palette("Spectral", as_cmap=True),
618 center=0,
619 ax=ax,
620 **heatmap_kwargs,
621 ).set(title=decoder_names[i])
622 if model.ontologies is not None:
623 axes[1, i].set_xticks(
624 ticks=range(len(node_names[i])),
625 labels=node_names[i],
626 rotation=90,
627 fontsize=8,
628 )
629 axes[1, i].set_yticks(
630 ticks=range(len(node_names[i + 1])),
631 labels=node_names[i + 1],
632 rotation=0,
633 fontsize=8,
634 )
635 ax.set_xlabel("In Node", size=12)
636 ax.set_ylabel("Out Node", size=12)
637 # Hide unused decoder subplots
638 for i in range(n_dec, n_cols):
639 axes[1, i].axis("off")
640 else:
641 # fallback: plot all weights in order, split in half for encoder/decoder
642 n_layers = len(all_weights) // 2
643 fig, axes = plt.subplots(
644 2, n_layers, sharex=False, figsize=(5 * n_layers, 10)
645 )
646 for layer in range(n_layers):
647 sns.heatmap(
648 all_weights[layer],
649 cmap=sns.color_palette("Spectral", as_cmap=True),
650 center=0,
651 ax=axes[0, layer],
652 ).set(title=names[layer])
653 sns.heatmap(
654 all_weights[n_layers + layer],
655 cmap=sns.color_palette("Spectral", as_cmap=True),
656 center=0,
657 ax=axes[1, layer],
658 ).set(title=names[n_layers + layer])
659 axes[1, layer].set_xlabel("In Node", size=12)
660 axes[0, layer].set_ylabel("Out Node", size=12)
661 axes[1, layer].set_ylabel("Out Node", size=12)
663 fig.suptitle("Model Weights", size=20)
664 plt.close()
665 return fig
667 @staticmethod
668 def plot_2D(
669 embedding: pd.DataFrame,
670 labels: list,
671 param: Optional[Union[str, None]] = None,
672 layer: str = "latent space",
673 figsize: tuple = (24, 15),
674 center: bool = True,
675 plot_numeric: bool = False,
676 xlim: Optional[Union[tuple, None]] = None,
677 ylim: Optional[Union[tuple, None]] = None,
678 scale: Optional[Union[str, None]] = None,
679 no_leg: bool = False,
680 ) -> matplotlib.figure.Figure:
681 """Plots a 2D scatter plot of the given embedding with labels.
683 Args:
684 embedding: DataFrame containing the 2D embedding coordinates.
685 labels: List of labels corresponding to each point in the embedding.
686 param: Title for the legend. Defaults to None.
687 layer: Title for the plot. Defaults to "latent space".
688 figsize: Size of the figure. Defaults to (24, 15).
689 center: If True, centers the plot based on label means. Defaults to True.
690 plot_numeric Defaults to False.
691 xlim: Defaults to None.
692 ylim: Defaults to None.
693 scale: Defaults to None.
694 no_leg: Defaults to False.
696 Returns:
697 The resulting matplotlib figure.
698 """
700 numeric = False
701 if not isinstance(labels[0], str):
702 if len(np.unique(labels)) > 3:
703 if not plot_numeric:
704 print(
705 "The provided label column is numeric and converted to categories."
706 )
707 # Change non-float labels to NaN
708 labels = [
709 x if isinstance(x, float) else float("nan") for x in labels
710 ]
711 labels = (
712 pd.qcut(
713 x=pd.Series(labels),
714 q=4,
715 labels=["1stQ", "2ndQ", "3rdQ", "4thQ"],
716 )
717 .astype(str)
718 .to_list()
719 )
720 else:
721 center = False ## Disable centering for numeric params
722 numeric = True
723 else:
724 labels = [str(x) for x in labels]
726 fig, ax1 = plt.subplots(figsize=figsize)
728 # check if label or embedding is longerm and duplicate the shorter one
729 if len(labels) < embedding.shape[0]:
730 print(
731 "Given labels do not have the same length as given sample size. Labels will be duplicated."
732 )
733 labels = [
734 label
735 for label in labels
736 for _ in range(embedding.shape[0] // len(labels))
737 ]
738 elif len(labels) > embedding.shape[0]:
739 labels = list(set(labels))
741 if numeric:
742 ax2 = sns.scatterplot(
743 x=embedding.iloc[:, 0],
744 y=embedding.iloc[:, 1],
745 hue=labels,
746 palette="bwr",
747 s=40,
748 alpha=0.5,
749 ec="black",
750 )
751 else:
752 ax2 = sns.scatterplot(
753 x=embedding.iloc[:, 0],
754 y=embedding.iloc[:, 1],
755 hue=labels,
756 hue_order=np.unique(labels),
757 s=40,
758 alpha=0.5,
759 ec="black",
760 )
761 if center:
762 means = embedding.groupby(by=labels).mean()
764 ax2 = sns.scatterplot(
765 x=means.iloc[:, 0],
766 y=means.iloc[:, 1],
767 hue=np.unique(labels),
768 hue_order=np.unique(labels),
769 s=200,
770 ec="black",
771 alpha=0.9,
772 marker="*",
773 legend=False,
774 ax=ax2,
775 )
777 if xlim is not None:
778 ax2.set_xlim(xlim[0], xlim[1])
780 if ylim is not None:
781 ax2.set_ylim(ylim[0], ylim[1])
783 if scale is not None:
784 plt.yscale(scale)
785 plt.xscale(scale)
786 ax2.set_xlabel("Dim 1")
787 ax2.set_ylabel("Dim 2")
788 legend_cols = 1
789 if len(np.unique(labels)) > 10:
790 legend_cols = 2
792 if no_leg:
793 plt.legend([], [], frameon=False)
794 else:
795 sns.move_legend(
796 ax2,
797 "upper left",
798 bbox_to_anchor=(1, 1),
799 ncol=legend_cols,
800 title=param,
801 frameon=False,
802 )
804 # Add title to the plot
805 ax2.set_title(layer)
807 plt.close()
808 return fig
810 @staticmethod
811 def plot_latent_ridge(
812 lat_space: pd.DataFrame,
813 labels: Optional[Union[list, pd.Series, None]] = None,
814 param: Optional[Union[str, None]] = None,
815 ) -> sns.FacetGrid:
816 """Creates a ridge line plot of latent space dimension where each row shows the density of a latent dimension and groups (ridges).
817 Args:
818 lat_space: If None, all samples are considered as one group.
819 param: Must be a column name (str) of clin_data
820 Returns:
821 g: FacetGrid object containing the ridge line plot
822 """
823 sns.set_theme(
824 style="white", rc={"axes.facecolor": (0, 0, 0, 0)}
825 ) ## Necessary to enforce overplotting
827 df = pd.melt(lat_space, var_name="latent dim", value_name="latent intensity")
828 df["sample"] = len(lat_space.columns) * list(lat_space.index)
830 if labels is None:
831 param = "all"
832 labels = ["all"] * len(df)
834 # print(labels[0])
835 if not isinstance(labels[0], str):
836 if len(np.unique(labels)) > 3:
837 # Change non-float labels to NaN
838 labels = [x if isinstance(x, float) else float("nan") for x in labels]
839 labels = pd.qcut(
840 x=pd.Series(labels),
841 q=4,
842 labels=["1stQ", "2ndQ", "3rdQ", "4thQ"],
843 ).astype(str)
844 else:
845 labels = [str(x) for x in labels]
847 df[param] = len(lat_space.columns) * labels # type: ignore
849 exclude_missing_info = (df[param] == "unknown") | (df[param] == "nan")
851 xmin = (
852 df.loc[~exclude_missing_info, ["latent intensity", "latent dim", param]]
853 .groupby([param, "latent dim"], observed=False)
854 .quantile(0.05)
855 .min()
856 )
857 xmax = (
858 df.loc[~exclude_missing_info, ["latent intensity", "latent dim", param]]
859 .groupby([param, "latent dim"], observed=False)
860 .quantile(0.9)
861 .max()
862 )
864 if len(np.unique(df[param])) > 8:
865 cat_pal = sns.husl_palette(len(np.unique(df[param])))
866 else:
867 cat_pal = sns.color_palette(n_colors=len(np.unique(df[param])))
869 g = sns.FacetGrid(
870 df[~exclude_missing_info],
871 row="latent dim",
872 hue=param,
873 aspect=12,
874 height=0.8,
875 xlim=(xmin.iloc[0], xmax.iloc[0]),
876 palette=cat_pal,
877 )
879 g.map_dataframe(
880 sns.kdeplot,
881 "latent intensity",
882 bw_adjust=0.5,
883 clip_on=True,
884 fill=True,
885 alpha=0.5,
886 warn_singular=False,
887 ec="k",
888 lw=1,
889 )
891 def label(data, color, label, text="latent dim"):
892 ax = plt.gca()
893 label_text = data[text].unique()[0]
894 ax.text(
895 0.0,
896 0.2,
897 label_text,
898 fontweight="bold",
899 ha="right",
900 va="center",
901 transform=ax.transAxes,
902 )
904 g.map_dataframe(label, text="latent dim")
906 g.set(xlim=(xmin.iloc[0], xmax.iloc[0]))
907 # Set the subplots to overlap
908 g.figure.subplots_adjust(hspace=-0.5)
910 # Remove axes details that don't play well with overlap
911 g.set_titles("")
912 g.set(yticks=[], ylabel="")
913 g.despine(bottom=True, left=True)
915 g.add_legend()
917 plt.close()
918 return g
920 @staticmethod
921 def make_loss_plot(
922 df_plot: pd.DataFrame, plot_type: str
923 ) -> matplotlib.figure.Figure:
924 """Generates a plot for visualizing loss values from a DataFrame.
926 Args:
927 df_plot: DataFrame containing the loss values to be plotted. It should have the columns:
928 - "Loss Term": The type of loss term (e.g., "total_loss", "reconstruction_loss").
929 - "Epoch": The epoch number.
930 - "Loss Value": The value of the loss.
931 - "Split": The data split (e.g., "train", "validation").
933 plot_type: The type of plot to generate. It can be either "absolute" or "relative".
934 - "absolute": Generates a line plot for each unique loss term.
935 - "relative": Generates a density plot for each data split, excluding the "total_loss" term.
937 Returns:
938 The generated matplotlib figure containing the loss plots.
939 """
940 fig_width_abs = 5 * len(df_plot["Loss Term"].unique())
941 fig_width_rel = 5 * len(df_plot["Split"].unique())
942 if plot_type == "absolute":
943 fig, axes = plt.subplots(
944 1,
945 len(df_plot["Loss Term"].unique()),
946 figsize=(fig_width_abs, 5),
947 sharey=False,
948 )
949 ax = 0
950 for term in df_plot["Loss Term"].unique():
951 axes[ax] = sns.lineplot(
952 data=df_plot[(df_plot["Loss Term"] == term)],
953 x="Epoch",
954 y="Loss Value",
955 hue="Split",
956 ax=axes[ax],
957 ).set_title(term)
958 ax += 1
960 plt.close()
962 if plot_type == "relative":
963 # Check if loss values are positive
964 if (df_plot["Loss Value"] < 0).any():
965 # Warning
966 warnings.warn(
967 "Loss values contain negative values. Check your loss function if correct. Loss will be clipped to zero for plotting."
968 )
969 df_plot["Loss Value"] = df_plot["Loss Value"].clip(lower=0)
971 # Exclude loss terms where all Loss Value are zero or NaN over all epochs
972 valid_terms = [
973 term
974 for term in df_plot["Loss Term"].unique()
975 if (
976 (df_plot[df_plot["Loss Term"] == term]["Loss Value"].notna().any())
977 and (df_plot[df_plot["Loss Term"] == term]["Loss Value"] != 0).any()
978 )
979 ]
980 exclude = (
981 (df_plot["Loss Term"] != "total_loss")
982 & ~(df_plot["Loss Term"].str.contains("_factor"))
983 & (df_plot["Loss Term"].isin(valid_terms))
984 )
986 fig, axes = plt.subplots(1, 2, figsize=(fig_width_rel, 5), sharey=True)
988 ax = 0
990 for split in df_plot["Split"].unique():
991 axes[ax] = sns.kdeplot(
992 data=df_plot[exclude & (df_plot["Split"] == split)],
993 x="Epoch",
994 hue="Loss Term",
995 multiple="fill",
996 weights="Loss Value",
997 clip=[0, df_plot["Epoch"].max()],
998 ax=axes[ax],
999 ).set_title(split)
1000 ax += 1
1002 plt.close()
1004 return fig
1006 @staticmethod
1007 def make_loss_format(result: Result, config: DefaultConfig) -> pd.DataFrame:
1008 loss_df_melt = pd.DataFrame()
1009 for term in result.sub_losses.keys():
1010 # Get the loss values and ensure it's a dictionary
1011 loss_values = result.sub_losses.get(key=term).get()
1013 # Add explicit type checking/conversion
1014 if not isinstance(loss_values, dict):
1015 # If it's not a dict, try to convert it or handle appropriately
1016 if hasattr(loss_values, "to_dict"):
1017 loss_values = loss_values.to_dict() # type: ignore
1018 else:
1019 # For non-convertible types, you might need a custom solution
1020 # For numpy arrays, you could do something like:
1021 if hasattr(loss_values, "shape"):
1022 # For numpy arrays, create a dict with indices as keys
1023 loss_values = {i: val for i, val in enumerate(loss_values)}
1025 # Now create the DataFrame
1026 loss_df = pd.DataFrame.from_dict(loss_values, orient="index") # type: ignore
1028 # Rest of your code remains the same
1029 if term == "var_loss":
1030 loss_df = loss_df * config.beta
1031 loss_df["Epoch"] = loss_df.index + 1
1032 loss_df["Loss Term"] = term
1034 loss_df_melt = pd.concat(
1035 [
1036 loss_df_melt,
1037 loss_df.melt(
1038 id_vars=["Epoch", "Loss Term"],
1039 var_name="Split",
1040 value_name="Loss Value",
1041 ),
1042 ],
1043 axis=0,
1044 ).reset_index(drop=True)
1046 # Similar handling for the total losses
1047 loss_values = result.losses.get()
1048 if not isinstance(loss_values, dict):
1049 if hasattr(loss_values, "to_dict"):
1050 loss_values = loss_values.to_dict() # ty: ignore
1051 else:
1052 if hasattr(loss_values, "shape"):
1053 loss_values = {i: val for i, val in enumerate(loss_values)}
1055 loss_df = pd.DataFrame.from_dict(loss_values, orient="index") # type: ignore
1056 loss_df["Epoch"] = loss_df.index + 1
1057 loss_df["Loss Term"] = "total_loss"
1059 loss_df_melt = pd.concat(
1060 [
1061 loss_df_melt,
1062 loss_df.melt(
1063 id_vars=["Epoch", "Loss Term"],
1064 var_name="Split",
1065 value_name="Loss Value",
1066 ),
1067 ],
1068 axis=0,
1069 ).reset_index(drop=True)
1071 loss_df_melt["Loss Value"] = loss_df_melt["Loss Value"].astype(float)
1072 return loss_df_melt
1074 @no_type_check
1075 def plot_evaluation(
1076 self,
1077 result: Result,
1078 ) -> dict:
1079 """Plots the evaluation results from the Result object.
1081 Args:
1082 result: The Result object containing evaluation data.
1084 Returns:
1085 The generated dictionary containing the evaluation plots.
1086 """
1087 ## Plot all results
1089 ml_plots = dict()
1090 plt.ioff()
1092 for c in pd.unique(result.embedding_evaluation.CLINIC_PARAM):
1093 ml_plots[c] = dict()
1094 for m in pd.unique(
1095 result.embedding_evaluation.loc[
1096 result.embedding_evaluation.CLINIC_PARAM == c, "metric"
1097 ]
1098 ):
1099 ml_plots[c][m] = dict()
1100 for alg in pd.unique(
1101 result.embedding_evaluation.loc[
1102 (result.embedding_evaluation.CLINIC_PARAM == c)
1103 & (result.embedding_evaluation.metric == m),
1104 "ML_ALG",
1105 ]
1106 ):
1107 data = result.embedding_evaluation[
1108 (result.embedding_evaluation.metric == m)
1109 & (result.embedding_evaluation.CLINIC_PARAM == c)
1110 & (result.embedding_evaluation.ML_ALG == alg)
1111 ]
1113 sns_plot = sns.catplot(
1114 data=data,
1115 x="score_split",
1116 y="value",
1117 col="ML_TASK",
1118 hue="score_split",
1119 kind="bar",
1120 )
1122 min_y = data.value.min()
1123 if min_y > 0:
1124 min_y = 0
1126 ml_plots[c][m][alg] = sns_plot.set(ylim=(min_y, None))
1128 self.plots["ML_Evaluation"] = ml_plots
1130 return ml_plots
1132 def show_evaluation(
1133 self,
1134 param: str,
1135 metric: str,
1136 ml_alg: Optional[str] = None,
1137 ) -> None:
1138 """Displays the evaluation plot for a specific clinical parameter, metric, and optionally ML algorithm.
1140 Args:
1141 param: The clinical parameter to visualize.
1142 metric: The metric to visualize.
1143 ml_alg: If None, plots all available algorithms.
1144 """
1145 plt.ioff()
1146 if "ML_Evaluation" not in self.plots.keys():
1147 print("ML Evaluation plots not found in the plots dictionary")
1148 print("You need to run evaluate() method first")
1149 return None
1150 if param not in self.plots["ML_Evaluation"].keys():
1151 print(f"Parameter {param} not found in the ML Evaluation plots")
1152 print(f"Available parameters: {list(self.plots['ML_Evaluation'].keys())}")
1153 return None
1154 if metric not in self.plots["ML_Evaluation"][param].keys():
1155 print(f"Metric {metric} not found in the ML Evaluation plots for {param}")
1156 print(
1157 f"Available metrics: {list(self.plots['ML_Evaluation'][param].keys())}"
1158 )
1159 return None
1161 algs = list(self.plots["ML_Evaluation"][param][metric].keys())
1162 if ml_alg is not None:
1163 if ml_alg not in algs:
1164 print(f"ML algorithm {ml_alg} not found for {param} and {metric}")
1165 print(f"Available ML algorithms: {algs}")
1166 return None
1167 fig = self.plots["ML_Evaluation"][param][metric][ml_alg].figure
1168 show_figure(fig)
1169 plt.show()
1170 else:
1171 for alg in algs:
1172 print(f"Showing plot for ML algorithm: {alg}")
1173 fig = self.plots["ML_Evaluation"][param][metric][alg].figure
1174 show_figure(fig)
1175 plt.show()
1177 @staticmethod
1178 def _total_correlation(latent_space: pd.DataFrame) -> float:
1179 """Function to compute the total correlation as described here (Equation2): https://doi.org/10.3390/e21100921
1181 Args:
1182 latent_space - (pd.DataFrame): latent space with dimension sample vs. latent dimensions
1183 Returns:
1184 tc - (float): total correlation across latent dimensions
1185 """
1186 lat_cov = np.cov(latent_space.T)
1187 tc = 0.5 * (np.sum(np.log(np.diag(lat_cov))) - np.linalg.slogdet(lat_cov)[1])
1188 return tc
1190 @staticmethod
1191 def _coverage_calc(latent_space: pd.DataFrame) -> float:
1192 """Function to compute the coverage as described here (Equation3): https://doi.org/10.3390/e21100921
1194 Args:
1195 latent_space: latent dimensions
1196 Returns:
1197 cov: coverage across latent dimensions
1198 """
1199 bins_per_dim = int(
1200 np.power(len(latent_space.index), 1 / len(latent_space.columns))
1201 )
1202 if bins_per_dim < 2:
1203 warnings.warn(
1204 "Coverage calculation fails since combination of sample size and latent dimension results in less than 2 bins."
1205 )
1206 cov = np.nan
1207 else:
1208 latent_bins = latent_space.apply(lambda x: pd.cut(x, bins=bins_per_dim))
1209 latent_bins = pd.Series(zip(*[latent_bins[col] for col in latent_bins]))
1210 cov = len(latent_bins.unique()) / np.power(
1211 bins_per_dim, len(latent_space.columns)
1212 )
1214 return cov