Coverage for src / autoencodix / visualize / _xmodal_visualizer.py: 11%
292 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
1from dataclasses import field
2import pandas as pd
3import numpy as np
4import seaborn as sns
5import matplotlib.pyplot as plt
6from umap import UMAP
7import warnings
8import torch
9from sklearn.decomposition import PCA
10from sklearn.manifold import TSNE
12from typing import Any, Dict, Optional, Union, List, no_type_check
13from autoencodix.base._base_visualizer import BaseVisualizer
14from autoencodix.utils._result import Result
15from autoencodix.utils._utils import nested_dict, show_figure
16from autoencodix.configs.default_config import DefaultConfig
17from autoencodix.data._datasetcontainer import DatasetContainer
20class XModalVisualizer(BaseVisualizer):
21 plots: Dict[str, Any] = field(
22 default_factory=nested_dict
23 ) ## Nested dictionary of plots as figure handles
25 def __init__(self):
26 self.plots = nested_dict()
28 def __setitem__(self, key, elem):
29 self.plots[key] = elem
31 def visualize(self, result: Result, config: DefaultConfig) -> Result:
32 ## Make Model Weights plot
33 ## TODO needs to be adjusted for X-Modalix ##
34 ## Plot Model weights for each sub-VAE ##
35 # self.plots["ModelWeights"] = self._plot_model_weights(model=result.model)
37 ## Make long format of losses
38 loss_df_melt = self._make_loss_format(result=result, config=config)
40 ## X-Modalix specific ##
41 # Filter loss terms which are specific for each modality VAE
42 # Plot only combined loss terms as in old autoencodix framework
43 if not hasattr(result.datasets, "train"):
44 raise ValueError("result.datasets has no attribute train")
45 if result.datasets.train is None:
46 raise ValueError("Train attribute of datasets is None")
47 loss_df_melt = loss_df_melt[
48 ~loss_df_melt["Loss Term"].str.startswith(
49 tuple(result.datasets.train.datasets.keys())
50 )
51 ]
52 if not result.losses._data:
53 import warnings
55 warnings.warn(
56 "No loss data: This usually happens if you try to visualize after saving and loading the pipeline object with `save_all=False`. This memory-efficient saving mode does not retain past training loss data."
57 )
58 return result
59 ## Make plot loss absolute
60 self.plots["loss_absolute"] = self._make_loss_plot(
61 df_plot=loss_df_melt, plot_type="absolute"
62 )
63 ## Make plot loss relative
64 self.plots["loss_relative"] = self._make_loss_plot(
65 df_plot=loss_df_melt, plot_type="relative"
66 )
68 return result
70 def show_latent_space(
71 self,
72 result: Result,
73 plot_type: str = "2D-scatter",
74 labels: Optional[Union[list, pd.Series, None]] = None,
75 param: Optional[Union[list, str]] = None,
76 epoch: Optional[Union[int, None]] = None,
77 split: str = "all",
78 ) -> None:
79 plt.ioff()
80 if plot_type == "Coverage-Correlation":
81 print("TODO: Implement Coverage-Correlation plot for X-Modalix")
82 # if "Coverage-Correlation" in self.plots:
83 # fig = self.plots["Coverage-Correlation"]
84 # show_figure(fig)
85 # plt.show()
86 # else:
87 # results = []
88 # for epoch in range(result.model.config.checkpoint_interval, result.model.config.epochs + 1, result.model.config.checkpoint_interval):
89 # for split in ["train", "valid"]:
90 # latent_df = result.get_latent_df(epoch=epoch-1, split=split)
91 # tc = self._total_correlation(latent_df)
92 # cov = self._coverage_calc(latent_df)
93 # results.append({"epoch": epoch, "split": split, "total_correlation": tc, "coverage": cov})
95 # df_metrics = pd.DataFrame(results)
97 # fig, axes = plt.subplots(1, 2, figsize=(12, 5))
99 # # Total Correlation plot
100 # ax1 = sns.lineplot(data=df_metrics, x="epoch", y="total_correlation", hue="split", ax=axes[0])
101 # axes[0].set_title("Total Correlation")
102 # axes[0].set_xlabel("Epoch")
103 # axes[0].set_ylabel("Total Correlation")
105 # # Coverage plot
106 # ax2 = sns.lineplot(data=df_metrics, x="epoch", y="coverage", hue="split", ax=axes[1])
107 # axes[1].set_title("Coverage")
108 # axes[1].set_xlabel("Epoch")
109 # axes[1].set_ylabel("Coverage")
111 # plt.tight_layout()
112 # self.plots["Coverage-Correlation"] = fig
113 # show_figure(fig)
114 # plt.show()
115 else:
116 # Set Defaults
117 if epoch is None:
118 epoch = -1
120 ## Collect all metadata and latent spaces from datasets
121 clin_data = []
122 latent_data = []
124 if split == "all":
125 split_list = ["train", "test", "valid"]
126 else:
127 split_list = [split]
128 for s in split_list:
129 split_ds = getattr(result.datasets, s, None)
130 if split_ds is not None:
131 for key, ds in split_ds.datasets.items():
132 if s == "test":
133 df_latent = result.get_latent_df(
134 epoch=-1, split=s, modality=key
135 )
136 else:
137 df_latent = result.get_latent_df(
138 epoch=epoch, split=s, modality=key
139 )
140 df_latent["modality"] = key
141 df_latent["sample_ids"] = (
142 df_latent.index
143 ) # Each sample can occur multiple times in latent space
144 latent_data.append(df_latent)
145 if hasattr(ds, "metadata") and ds.metadata is not None:
146 df = ds.metadata.copy()
147 df["sample_ids"] = df.index.astype(str)
148 df["split"] = s
149 df["modality"] = key
150 clin_data.append(df)
152 if latent_data and clin_data:
153 latent_data = pd.concat(latent_data, axis=0, ignore_index=True)
154 clin_data = pd.concat(clin_data, axis=0, ignore_index=True)
155 if "sample_ids" in clin_data.columns:
156 clin_data = clin_data.drop_duplicates(
157 subset="sample_ids"
158 ).set_index("sample_ids")
159 else:
160 latent_data = pd.DataFrame()
161 clin_data = pd.DataFrame()
163 ## Label options
164 if param is None:
165 modality = list(result.model.keys())[
166 0
167 ] # Take the first since configs are same for all sub-VAEs
168 model = result.model.get(modality, None)
169 if model is None:
170 raise ValueError(
171 f"Model for modality {modality} not found in result.model"
172 )
173 param = model.config.data_config.annotation_columns
175 if labels is None and param is None:
176 labels = ["all"] * latent_data["sample_ids"].unique().shape[0]
178 if labels is None and isinstance(param, str):
179 if param == "all":
180 param = list(clin_data.columns)
181 else:
182 raise ValueError(
183 "Please provide parameter to plot as a list not as string. If you want to plot all parameters, set param to 'all' and labels to None."
184 )
186 if labels is not None and param is not None:
187 raise ValueError(
188 "Please provide either labels or param, not both. If you want to plot all parameters, set param to 'all' and labels to None."
189 )
191 if labels is not None and param is None:
192 if isinstance(labels, pd.Series):
193 param = [labels.name]
194 # Order by index of latent_data first, fill missing with "unknown"
195 labels = labels.reindex( # ty: ignore
196 latent_data["sample_ids"], # ty: ignore
197 fill_value="unknown", # ty: ignore
198 ).tolist()
199 else:
200 param = ["user_label"] # Default label if none provided
201 if not isinstance(param, list):
202 raise ValueError(f"param: should be converted to list, got: {param}")
203 for p in param:
204 if p in clin_data.columns:
205 labels: List = clin_data.loc[
206 latent_data["sample_ids"], p
207 ].tolist() # ty: ignore
208 else:
209 if clin_data.shape[0] == len(labels): # ty: ignore
210 clin_data[p] = labels
211 else:
212 clin_data[p] = ["all"] * clin_data.shape[0]
214 if plot_type == "2D-scatter":
215 ## Make 2D Embedding with UMAP
216 if (
217 latent_data.drop(
218 columns=["sample_ids", "modality"]
219 ).shape[ # ty: ignore
220 1
221 ] # ty: ignore
222 > 2
223 ):
224 reducer = UMAP(n_components=2)
225 embedding = pd.DataFrame(
226 reducer.fit_transform(
227 latent_data.drop(
228 columns=["sample_ids", "modality"]
229 ) # ty: ignore
230 )
231 )
232 embedding.columns = ["DIM1", "DIM2"]
233 embedding["sample_ids"] = latent_data["sample_ids"]
234 embedding["modality"] = latent_data["modality"]
235 else:
236 embedding = latent_data
238 # Merge with clinical data via sample_ids
239 clin_data["sample_ids"] = clin_data.index.astype(str)
240 clin_data.index = clin_data.index.astype(str) # Add this line
241 embedding["sample_ids"] = embedding["sample_ids"].astype(str)
243 embedding = embedding.merge(
244 clin_data.drop(columns=["modality"]), # ty: ignore
245 left_on="sample_ids",
246 right_index=True,
247 how="left",
248 )
250 self.plots["2D-scatter"][epoch][split][p] = (
251 self._plot_translate_latent(
252 embedding=embedding,
253 color_param=p,
254 style_param="modality",
255 )
256 )
258 fig = self.plots["2D-scatter"][epoch][split][p].figure
259 # show_figure(fig)
260 plt.show()
262 if plot_type == "Ridgeline":
263 ## Make ridgeline plot
264 if len(labels) != latent_data.shape[0]: # ty: ignore
265 if labels[0] == "all": # ty: ignore
266 labels = ["all"] * latent_data.shape[0] # ty: ignore
267 else:
268 raise ValueError(
269 "Labels must match the number of samples in the latent space."
270 )
272 self.plots["Ridgeline"][epoch][split][p] = (
273 self._plot_latent_ridge_multi(
274 lat_space=latent_data.drop(
275 columns=["sample_ids"]
276 ), # ty: ignore
277 labels=labels,
278 modality="modality",
279 param=p,
280 )
281 )
283 fig = self.plots["Ridgeline"][epoch][split][p].figure
284 show_figure(fig)
285 plt.show()
287 def show_weights(self) -> None:
288 ## TODO
289 raise NotImplementedError(
290 "Weight visualization for X-Modalix is not implemented."
291 )
293 @no_type_check
294 def show_image_translation( # ty: ignore
295 self,
296 result: Result,
297 from_key: str,
298 to_key: str,
299 n_sample_per_class: int = 3,
300 param: Optional[str] = None,
301 ) -> None: # ty: ignore
302 """Visualizes image translation results for a given dataset.
304 Split by displaying a grid of original, translated, and reference images,grouped by class values.
305 Args:
306 result:The result object containing datasets and reconstructions.
307 from_key: The source modality key (not directly used in visualization, but relevant for context).
308 to_key: The target modality key. Must correspond to an image dataset (must contain "IMG").
309 split: The dataset split to visualize ("test", "train", or "valid"). Default is "test".
310 n_sample_per_class: Number of samples to display per class value. Default is 3.
311 param: The metadata column name used to group samples by class.
312 Raises
313 ValueError: If `to_key` does not correspond to an image dataset.
314 """
316 if "img" not in to_key:
317 raise ValueError(
318 f"You provided as 'to_key' {to_key} a non-image dataset. "
319 "Image translation grid visualization is only possible for translation to IMG data type."
320 )
321 else:
322 split = "test" # Currently only test split is supported
323 ## Get n samples per class
324 if split == "test":
325 meta = result.datasets.test.datasets[to_key].metadata
326 paired_sample_ids = result.datasets.test.paired_sample_ids
328 # Restrict meta to only paired sample ids
329 meta = meta.loc[paired_sample_ids]
331 if param is None:
332 param = "user-label"
333 meta[param] = (
334 "all" # Default to all samples if no parameter is provided
335 )
337 # Get possible class values
338 class_values = meta[param].unique()
339 if len(class_values) > 10:
340 # Make warning
341 warnings.warn(
342 f"Found {len(class_values)} class values for parameter '{param}'. Only first 10 will be used to limit figure size"
343 )
344 class_values = class_values[:10]
346 # Build dictionary of sample_ids per class value (max n_sample_per_class per class)
347 sample_per_class = {
348 val: meta[meta[param] == val]
349 .sample(
350 n=min(n_sample_per_class, (meta[param] == val).sum()),
351 random_state=42,
352 )
353 .index.tolist()
354 for val in class_values
355 }
357 print(f"Sample per class: {sample_per_class}")
359 # Lookup of sample indices per modality
360 sample_ids_per_key = dict()
362 for key in result.sample_ids.get(epoch=-1, split="test").keys():
363 sample_ids_per_key[key] = result.sample_ids.get(epoch=-1, split="test")[
364 key
365 ]
366 # Original
367 sample_ids_per_key["original"] = result.datasets.test.datasets[
368 to_key
369 ].sample_ids
371 ## Generate Image Grid
372 # Number of test (or train or valid) samples from all values in sample_per_class dictionary
373 n_test_samples = sum(len(indices) for indices in sample_per_class.values())
375 # #
376 col_labels = []
377 for class_value in sample_per_class:
378 col_labels.extend(
379 [
380 class_value + " " + split + "-sample:" + s
381 for s in sample_per_class[class_value]
382 ]
383 )
385 row_labels = ["Original", "Translated", "Reference"]
387 fig, axes = plt.subplots(
388 ncols=n_test_samples, # Number of classes
389 nrows=3, # Original, translated, reference
390 figsize=(n_test_samples * 2, 3 * 2),
391 )
393 for i, ax in enumerate(axes.flat):
394 row = int(i / n_test_samples)
395 # test_sample = sample_idx_list[i % n_test_samples]
396 # print(f"Row: {row}, Column: {i % n_test_samples}")
397 # print(f"Current sample: {col_labels[i % n_test_samples]}")
399 if row == 0:
400 if split == "test":
401 idx_original = list(sample_ids_per_key["original"]).index(
402 col_labels[i % n_test_samples].split("sample:")[1]
403 )
404 img_temp = result.datasets.test.datasets[to_key][idx_original][
405 1
406 ].squeeze() # Stored as Tuple (index, tensor, sample_id)
408 # Original image
409 ax.imshow(np.asarray(img_temp))
410 ax.axis("off")
411 # Sample label
412 ax.text(
413 0.5,
414 1.1,
415 col_labels[i],
416 va="bottom",
417 ha="center",
418 # rotation='vertical',
419 rotation=45,
420 transform=ax.transAxes,
421 )
422 # Row label
423 if i % n_test_samples == 0:
424 ax.text(
425 -0.1,
426 0.5,
427 row_labels[0],
428 va="center",
429 ha="right",
430 transform=ax.transAxes,
431 )
433 if row == 1:
434 # Translated image
435 idx_translated = list(sample_ids_per_key["translation"]).index(
436 col_labels[i % n_test_samples].split("sample:")[1]
437 )
438 ax.imshow(
439 result.reconstructions.get(epoch=-1, split=split)[
440 "translation"
441 ][idx_translated].squeeze()
442 )
443 ax.axis("off")
444 # Row label
445 if i % n_test_samples == 0:
446 ax.text(
447 -0.1,
448 0.5,
449 row_labels[1],
450 va="center",
451 ha="right",
452 transform=ax.transAxes,
453 )
455 if row == 2:
456 # Reference image reconstruction
457 idx_reference = list(
458 sample_ids_per_key[f"reference_{to_key}_to_{to_key}"]
459 ).index(col_labels[i % n_test_samples].split("sample:")[1])
460 ax.imshow(
461 result.reconstructions.get(epoch=-1, split=split)[
462 f"reference_{to_key}_to_{to_key}"
463 ][idx_reference].squeeze()
464 )
465 ax.axis("off")
466 # Row label
467 if i % n_test_samples == 0:
468 ax.text(
469 -0.1,
470 0.5,
471 row_labels[2],
472 va="center",
473 ha="right",
474 transform=ax.transAxes,
475 )
477 self.plots["Image-translation"][to_key][split][param] = fig
478 # show_figure(fig)
479 plt.show()
481 @no_type_check
482 def show_2D_translation(
483 self,
484 result: Result,
485 translated_modality: str,
486 split: str = "test",
487 param: Optional[str] = None,
488 reducer: str = "UMAP",
489 ) -> None:
490 ## TODO add similar labels/param logic from other visualizations
491 dataset = result.datasets
493 ## Overwrite original datasets with new_datasets if available after predict with other data
494 if dataset is None:
495 dataset = DatasetContainer()
497 if bool(result.new_datasets.test):
498 dataset.test = result.new_datasets.test
500 if split not in ["train", "valid", "test", "all"]:
501 raise ValueError(f"Unknown split: {split}")
503 if dataset.test is None:
504 raise ValueError("test of dataset is None")
506 if split == "test":
507 df_processed = dataset.test._to_df(modality=translated_modality)
508 else:
509 raise NotImplementedError(
510 "2D translation visualization is currently only implemented for the 'test' split since reconstruction is only performed on test-split."
511 )
513 # Get translated reconstruction
514 tensor_list = result.reconstructions.get(epoch=-1, split=split)[ # ty: ignore
515 "translation"
516 ] # ty: ignore
517 print(f"len of tensor-list: {len(tensor_list)}")
518 tensor_ids = result.sample_ids.get(epoch=-1, split=split)["translation"]
519 print(f"len of tensor_ids: {len(tensor_ids)}")
521 # Flatten each tensor and collect as rows (for image case)
522 rows = [
523 t.flatten().cpu().numpy() if isinstance(t, torch.Tensor) else t.flatten()
524 for t in tensor_list
525 ]
527 # Create DataFrame
528 df_translate_flat = pd.DataFrame(
529 rows,
530 columns=["Feature_" + str(i) for i in range(len(rows[0]))],
531 index=tensor_ids,
532 )
534 if reducer == "UMAP":
535 reducer_model = UMAP(n_components=2)
536 elif reducer == "PCA":
537 reducer_model = PCA(n_components=2)
538 elif reducer == "TSNE":
539 reducer_model = TSNE(n_components=2)
541 # making sure of index alignemnt
542 common_ids = df_processed.index.intersection(df_translate_flat.index)
543 df_processed = df_processed.loc[common_ids]
544 df_translate_flat = df_translate_flat.loc[common_ids]
545 df_translate_flat = df_translate_flat.reindex(df_processed.index)
546 df_translate_flat.index = pd.Index([i for i in range(len(common_ids))])
547 X = np.vstack([df_processed.values, df_translate_flat.values])
548 df_red_comb = pd.DataFrame(reducer_model.fit_transform(X))
550 # df_comb = pd.concat(
551 # [df_processed, df_translate_flat], axis=0, ignore_index=True
552 # )
554 df_red_comb["origin"] = ["input"] * df_processed.shape[0] + [
555 "translated"
556 ] * df_translate_flat.shape[0]
558 # df_red_comb = pd.DataFrame(
559 # reducer_model.fit_transform(
560 # pd.concat([df_processed, df_translate_flat], axis=0)
561 # )
562 # )
564 labels = (
565 list(
566 result.datasets.test.datasets[translated_modality].metadata[param]
567 ) # ty: ignore
568 * 2
569 )
570 df_red_comb[param] = (
571 labels + labels[0 : df_red_comb.shape[0] - len(labels)]
572 ) ## TODO fix for not matching lengths
574 g = sns.FacetGrid(
575 df_red_comb,
576 col="origin",
577 hue=param,
578 sharex=True,
579 sharey=True,
580 height=8,
581 aspect=1,
582 )
583 g.map_dataframe(sns.scatterplot, x=0, y=1, alpha=0.7)
584 g.add_legend()
585 g.set_axis_labels(reducer + " DIM 1", reducer + " DIM 2")
586 g.set_titles(col_template="{col_name}")
588 self.plots["2D-translation"][translated_modality][split][param] = g
589 plt.show()
591 ## Utilities specific for X-Modalix
592 @staticmethod
593 def _plot_translate_latent(
594 embedding,
595 color_param,
596 style_param=None,
597 ):
598 """Creates a 2D visualization of the 2D embedding of the latent space.
599 Args:
600 embedding: embedding on which is visualized. Assumes prior 2D dimension reduction.
601 color_params: Clinical parameter to color scatter plot
602 style_param: Parameter e.g. "Translate" to facet scatter plot
603 Returns:
604 fig: Figure handle
606 """
607 labels = list(embedding[color_param])
608 # logger = getlogger(cfg)
609 numeric = False
610 if not isinstance(labels[0], str):
611 if len(np.unique(labels)) > 3:
612 # TODO Decide if numeric to category should be optional in new Package
613 # print(
614 # f"The provided label column is numeric and converted to categories."
615 # )
616 # labels = pd.qcut(
617 # labels, q=4, labels=["1stQ", "2ndQ", "3rdQ", "4thQ"]
618 # ).astype(str)
619 # else:
620 numeric = True
621 else:
622 labels = [str(x) for x in labels]
624 # check if label or embedding is longerm and duplicate the shorter one
625 if len(labels) < embedding.shape[0]:
626 print(
627 "Given labels do not have the same length as given sample size. Labels will be duplicated."
628 )
629 labels = [
630 label
631 for label in labels
632 for _ in range(embedding.shape[0] // len(labels))
633 ]
634 elif len(labels) > embedding.shape[0]:
635 labels = list(set(labels))
637 if style_param is not None:
638 embedding[color_param] = labels
639 if numeric:
640 palette = "bwr"
641 else:
642 palette = None
643 plot = sns.relplot(
644 data=embedding,
645 x="DIM1",
646 y="DIM2",
647 hue=color_param,
648 palette=palette,
649 col=style_param,
650 style=style_param,
651 markers=True,
652 alpha=0.4,
653 ec="black",
654 height=10,
655 aspect=1,
656 s=150,
657 )
659 return plot
661 @staticmethod
662 def _plot_latent_ridge_multi(
663 lat_space: pd.DataFrame,
664 modality: Optional[str] = None,
665 labels: Optional[Union[list, pd.Series, None]] = None,
666 param: Optional[Union[str, None]] = None,
667 ) -> sns.FacetGrid:
668 """Creates a ridge line plot of latent space dimension where each row shows the density of a latent dimension and groups (ridges).
669 Args:
670 lat_space: DataFrame containing the latent space intensities for samples (rows) and latent dimensions (columns)
671 labels: List of labels for each sample. If None, all samples are considered as one group.
672 param: Clinical parameter to create groupings and coloring of ridges. Must be a column name (str) of clin_data
673 Returns:
674 g (sns.FacetGrid): FacetGrid object containing the ridge line plot
675 """
676 sns.set_theme(
677 style="white", rc={"axes.facecolor": (0, 0, 0, 0)}
678 ) ## Necessary to enforce overplotting
680 df = pd.melt(
681 lat_space,
682 id_vars=modality, # ty: ignore
683 var_name="latent dim",
684 value_name="latent intensity",
685 )
686 # print(df)
687 df["sample"] = len(lat_space.drop(columns=modality).columns) * list(
688 lat_space.index
689 )
691 if labels is None:
692 param = "all"
693 labels = ["all"] * len(df)
695 # print(labels[0])
696 if not isinstance(labels[0], str):
697 if len(np.unique(labels)) > 3:
698 # Change all non-float labels to NaN
699 labels = [x if isinstance(x, float) else float("nan") for x in labels]
700 labels = pd.qcut(
701 x=pd.Series(labels),
702 q=4,
703 labels=["1stQ", "2ndQ", "3rdQ", "4thQ"],
704 ).astype(str)
705 else:
706 labels = [str(x) for x in labels]
708 df[param] = len(lat_space.drop(columns=modality).columns) * labels # type: ignore
710 exclude_missing_info = (df[param] == "unknown") | (df[param] == "nan")
712 xmin = (
713 df.loc[~exclude_missing_info, ["latent intensity", "latent dim", param]]
714 .groupby([param, "latent dim"], observed=False)
715 .quantile(0.05)
716 .min()
717 )
718 xmax = (
719 df.loc[~exclude_missing_info, ["latent intensity", "latent dim", param]]
720 .groupby([param, "latent dim"], observed=False)
721 .quantile(0.9)
722 .max()
723 )
725 if len(np.unique(df[param])) > 8:
726 cat_pal = sns.husl_palette(len(np.unique(df[param])))
727 else:
728 cat_pal = sns.color_palette(n_colors=len(np.unique(df[param])))
730 g = sns.FacetGrid(
731 df[~exclude_missing_info],
732 row="latent dim",
733 col=modality,
734 hue=param,
735 aspect=12,
736 height=0.8,
737 xlim=(xmin.iloc[0], xmax.iloc[0]),
738 palette=cat_pal,
739 )
741 g.map_dataframe(
742 sns.kdeplot,
743 "latent intensity",
744 bw_adjust=0.5,
745 clip_on=True,
746 fill=True,
747 alpha=0.5,
748 warn_singular=False,
749 ec="k",
750 lw=1,
751 )
753 def label(data, color, label, text="latent dim"):
754 ax = plt.gca()
755 label_text = data[text].unique()[0]
756 ax.text(
757 0.0,
758 0.2,
759 label_text,
760 fontweight="bold",
761 ha="right",
762 va="center",
763 transform=ax.transAxes,
764 )
766 g.map_dataframe(label, text="latent dim")
768 g.set(xlim=(xmin.iloc[0], xmax.iloc[0]))
769 # Set the subplots to overlap
770 g.figure.subplots_adjust(hspace=-0.5)
772 # Remove axes details that don't play well with overlap
773 g.set_titles("")
774 g.set(yticks=[], ylabel="")
775 g.despine(bottom=True, left=True)
777 for i, m in enumerate(df[modality].unique()):
778 g.fig.get_axes()[i].set_title(m)
780 g.add_legend()
782 plt.close()
783 return g
785 def _plot_evaluation(
786 self,
787 result: Result,
788 ) -> dict:
789 """Plots the evaluation results from the Result object.
791 Args:
792 result: The Result object containing evaluation data.
794 Returns:
795 The generated dictionary containing the evaluation plots.
796 """
797 ## Plot all results
799 ml_plots = dict()
800 plt.ioff()
802 for c in pd.unique(result.embedding_evaluation.CLINIC_PARAM):
803 ml_plots[c] = dict()
804 for m in pd.unique(
805 result.embedding_evaluation.loc[
806 result.embedding_evaluation.CLINIC_PARAM == c, "metric"
807 ]
808 ): # ty: ignore
809 ml_plots[c][m] = dict()
810 for alg in pd.unique(
811 result.embedding_evaluation.loc[
812 (result.embedding_evaluation.CLINIC_PARAM == c)
813 & (result.embedding_evaluation.metric == m),
814 "ML_ALG",
815 ]
816 ): # ty: ignore
817 data = result.embedding_evaluation[
818 (result.embedding_evaluation.metric == m)
819 & (result.embedding_evaluation.CLINIC_PARAM == c)
820 & (result.embedding_evaluation.ML_ALG == alg)
821 ]
823 sns_plot = sns.catplot(
824 data=data,
825 x="score_split",
826 y="value",
827 col="ML_TASK",
828 row="MODALITY",
829 hue="score_split",
830 kind="bar",
831 )
833 min_y = data.value.min()
834 if min_y > 0:
835 min_y = 0
837 ml_plots[c][m][alg] = sns_plot.set(ylim=(min_y, None))
839 self.plots["ML_Evaluation"] = ml_plots
841 return ml_plots