Coverage for src / autoencodix / visualize / _imagix_visualizer.py: 33%
46 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
1from dataclasses import field
3import matplotlib.pyplot as plt
6from typing import Any, Dict, no_type_check
7from autoencodix.visualize._general_visualizer import GeneralVisualizer
8from autoencodix.utils._result import Result
9from autoencodix.utils._utils import nested_dict, get_dataset
10from autoencodix.configs.default_config import DefaultConfig
13class ImagixVisualizer(GeneralVisualizer):
14 plots: Dict[str, Any] = field(
15 default_factory=nested_dict
16 ) ## Nested dictionary of plots as figure handles
18 def __init__(self):
19 self.plots = nested_dict()
21 def __setitem__(self, key, elem):
22 self.plots[key] = elem
24 def visualize(self, result: Result, config: DefaultConfig) -> Result:
25 ## Make Model Weights plot
26 ## TODO needs to be adjusted for Imagix ##
27 ## Plot Model weights for each sub-VAE ##
28 # self.plots["ModelWeights"] = self._plot_model_weights(model=result.model)
30 ## Make long format of losses
31 loss_df_melt = self._make_loss_format(result=result, config=config)
33 ## Make plot loss absolute
34 self.plots["loss_absolute"] = self._make_loss_plot(
35 df_plot=loss_df_melt, plot_type="absolute"
36 )
37 ## Make plot loss relative
38 self.plots["loss_relative"] = self._make_loss_plot(
39 df_plot=loss_df_melt, plot_type="relative"
40 )
42 return result
44 ## Latent space visualization via GeneralVisualizer
45 # def show_latent_space(
46 # self,
47 # result: Result,
48 # plot_type: str = "2D-scatter",
49 # labels: Optional[Union[list, pd.Series, None]] = None,
50 # param: Optional[Union[list, str]] = None,
51 # import pandas as pd
52 # epoch: Optional[Union[int, None]] = None,
53 # split: str = "all",
54 # ) -> None:
56 def show_weights(self) -> None:
57 ## TODO
58 raise NotImplementedError(
59 "Weight visualization for X-Modalix is not implemented."
60 )
62 @no_type_check
63 def show_image_recon_grid(self, result: Result, n_samples: int = 3) -> None:
64 ## TODO add similar labels/param logic from other visualizations
65 dataset = result.datasets
67 ## Overwrite original datasets with new_datasets if available after predict with other data
68 if dataset is None:
69 dataset = DatasetContainer()
71 if bool(result.new_datasets.test):
72 dataset.test = result.new_datasets.test
74 if dataset.test is None:
75 raise ValueError("test of dataset is None")
76 meta = dataset.test.metadata
77 sample_ids = meta.sample(n=n_samples, random_state=42).index
79 all_sample_order = dataset.test.sample_ids
80 indices = [
81 all_sample_order.index(sid)
82 for sid in sample_ids
83 if sid in all_sample_order # ty: ignore
84 ]
86 fig, axes = plt.subplots(
87 ncols=n_samples, # Number of classes
88 nrows=2, # Original, Reconstructed
89 figsize=(n_samples * 4, 2 * 4),
90 )
92 for r in range(2):
93 for c in range(n_samples):
94 if r == 0:
95 ## Original image
96 axes[r, c].imshow(dataset.test.raw_data[indices[c]].img.squeeze())
97 axes[r, c].set_title(f"Original: {sample_ids[c]}")
98 axes[r, c].axis("off")
99 if r == 1:
100 ## Reconstructed image
101 axes[r, c].imshow(
102 result.reconstructions.get(split="test", epoch=-1)[
103 indices[c]
104 ].squeeze()
105 )
106 axes[r, c].set_title(f"Reconstructed: {sample_ids[c]}")
107 axes[r, c].axis("off")
109 self.plots["Image_recon_grid"] = fig
110 # show_figure(fig)
111 plt.show()