Coverage for src / autoencodix / visualize / _imagix_visualizer.py: 33%

46 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-21 10:09 +0200

1from dataclasses import field 

2 

3import matplotlib.pyplot as plt 

4 

5 

6from typing import Any, Dict, no_type_check 

7from autoencodix.visualize._general_visualizer import GeneralVisualizer 

8from autoencodix.utils._result import Result 

9from autoencodix.utils._utils import nested_dict, get_dataset 

10from autoencodix.configs.default_config import DefaultConfig 

11 

12 

13class ImagixVisualizer(GeneralVisualizer): 

14 plots: Dict[str, Any] = field( 

15 default_factory=nested_dict 

16 ) ## Nested dictionary of plots as figure handles 

17 

18 def __init__(self): 

19 self.plots = nested_dict() 

20 

21 def __setitem__(self, key, elem): 

22 self.plots[key] = elem 

23 

24 def visualize(self, result: Result, config: DefaultConfig) -> Result: 

25 ## Make Model Weights plot 

26 ## TODO needs to be adjusted for Imagix ## 

27 ## Plot Model weights for each sub-VAE ## 

28 # self.plots["ModelWeights"] = self._plot_model_weights(model=result.model) 

29 

30 ## Make long format of losses 

31 loss_df_melt = self._make_loss_format(result=result, config=config) 

32 

33 ## Make plot loss absolute 

34 self.plots["loss_absolute"] = self._make_loss_plot( 

35 df_plot=loss_df_melt, plot_type="absolute" 

36 ) 

37 ## Make plot loss relative 

38 self.plots["loss_relative"] = self._make_loss_plot( 

39 df_plot=loss_df_melt, plot_type="relative" 

40 ) 

41 

42 return result 

43 

44 ## Latent space visualization via GeneralVisualizer 

45 # def show_latent_space( 

46 # self, 

47 # result: Result, 

48 # plot_type: str = "2D-scatter", 

49 # labels: Optional[Union[list, pd.Series, None]] = None, 

50 # param: Optional[Union[list, str]] = None, 

51 # import pandas as pd 

52 # epoch: Optional[Union[int, None]] = None, 

53 # split: str = "all", 

54 # ) -> None: 

55 

56 def show_weights(self) -> None: 

57 ## TODO 

58 raise NotImplementedError( 

59 "Weight visualization for X-Modalix is not implemented." 

60 ) 

61 

62 @no_type_check 

63 def show_image_recon_grid(self, result: Result, n_samples: int = 3) -> None: 

64 ## TODO add similar labels/param logic from other visualizations 

65 dataset = result.datasets 

66 

67 ## Overwrite original datasets with new_datasets if available after predict with other data 

68 if dataset is None: 

69 dataset = DatasetContainer() 

70 

71 if bool(result.new_datasets.test): 

72 dataset.test = result.new_datasets.test 

73 

74 if dataset.test is None: 

75 raise ValueError("test of dataset is None") 

76 meta = dataset.test.metadata 

77 sample_ids = meta.sample(n=n_samples, random_state=42).index 

78 

79 all_sample_order = dataset.test.sample_ids 

80 indices = [ 

81 all_sample_order.index(sid) 

82 for sid in sample_ids 

83 if sid in all_sample_order # ty: ignore 

84 ] 

85 

86 fig, axes = plt.subplots( 

87 ncols=n_samples, # Number of classes 

88 nrows=2, # Original, Reconstructed 

89 figsize=(n_samples * 4, 2 * 4), 

90 ) 

91 

92 for r in range(2): 

93 for c in range(n_samples): 

94 if r == 0: 

95 ## Original image 

96 axes[r, c].imshow(dataset.test.raw_data[indices[c]].img.squeeze()) 

97 axes[r, c].set_title(f"Original: {sample_ids[c]}") 

98 axes[r, c].axis("off") 

99 if r == 1: 

100 ## Reconstructed image 

101 axes[r, c].imshow( 

102 result.reconstructions.get(split="test", epoch=-1)[ 

103 indices[c] 

104 ].squeeze() 

105 ) 

106 axes[r, c].set_title(f"Reconstructed: {sample_ids[c]}") 

107 axes[r, c].axis("off") 

108 

109 self.plots["Image_recon_grid"] = fig 

110 # show_figure(fig) 

111 plt.show()