Coverage for src / autoencodix / evaluate / _xmodalix_evaluator.py: 23%

128 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-21 10:09 +0200

1from typing import Union, Tuple, Optional, no_type_check 

2 

3 

4import pandas as pd 

5import torch 

6import torch.nn.functional as F 

7 

8from matplotlib import pyplot as plt 

9from matplotlib.figure import Figure 

10import seaborn as sns 

11 

12import sklearn 

13from sklearn.decomposition import PCA 

14from umap import UMAP 

15from sklearn.manifold import TSNE 

16from sklearn.base import ClassifierMixin, RegressorMixin 

17 

18from autoencodix.utils._result import Result 

19from autoencodix.data._datasetcontainer import DatasetContainer 

20from autoencodix.evaluate._general_evaluator import GeneralEvaluator 

21 

22sklearn.set_config(enable_metadata_routing=True) 

23 

24 

25class XModalixEvaluator(GeneralEvaluator): 

26 def __init__(self): 

27 # super().__init__() 

28 pass 

29 

30 @staticmethod 

31 @no_type_check 

32 def pure_vae_comparison( 

33 xmodalix_result: Result, 

34 pure_vae_result: Result, 

35 to_key: str, 

36 param: Optional[str] = None, 

37 ) -> Tuple[Figure, pd.DataFrame]: 

38 """Compares the reconstruction performance of a pure VAE model and a cross-modal VAE (xmodalix) model using Mean Squared Error (MSE) on test samples. 

39 

40 For each sample in the test set, computes the MSE between the original and reconstructed images for: 

41 - Pure VAE reconstructions ("imagix") 

42 - xmodalix reference reconstructions ("xmodalix_reference") 

43 - xmodalix translated reconstructions ("xmodalix_translated") 

44 The results are merged with sample metadata and returned in a long-format DataFrame suitable for plotting. Optionally, boxplots are generated grouped by a specified metadata parameter. 

45 

46 Args: 

47 xmodalix_result: The result object containing xmodalix model outputs and test datasets. 

48 pure_vae_result: The result object containing pure VAE model outputs and test datasets. 

49 to_key: The key specifying the target modality in the xmodalix dataset. 

50 param: Metadata column name to group boxplots by. If None, plots are grouped by model only. 

51 

52 Returns: 

53 - The matplotlib/seaborn boxplot figure comparing MSE distributions. 

54 - DataFrame: Long-format DataFrame containing MSE values and associated metadata for each sample and model. 

55 """ 

56 

57 if "img" not in to_key: 

58 raise NotImplementedError( 

59 "Comparison is currently only implemented for the image case." 

60 ) 

61 

62 ## Pure VAE MSE calculation 

63 meta_imagix = pure_vae_result.datasets.test.metadata 

64 if meta_imagix is None: 

65 raise ValueError("metadata cannot be None") 

66 sample_ids = list(meta_imagix.index) 

67 

68 all_sample_order = sample_ids ## TODO check code, seems unnecessary 

69 indices = [ 

70 all_sample_order.index(sid) for sid in sample_ids if sid in all_sample_order 

71 ] 

72 

73 mse_records = [] 

74 

75 for c in range(len(indices)): 

76 # print(f"Sample {c+1}/{len(indices)}: {sample_ids[c]}") 

77 

78 # Original image 

79 orig = torch.Tensor( 

80 pure_vae_result.datasets.test.raw_data[indices[c]].img.squeeze() 

81 ) 

82 

83 # Reconstructed image 

84 recon = torch.Tensor( 

85 pure_vae_result.reconstructions.get(split="test", epoch=-1)[ 

86 indices[c] 

87 ].squeeze() 

88 ) 

89 

90 # Calculate MSE via torch 

91 mse_sample = F.mse_loss(orig, recon, reduction="mean") 

92 # print(f"Mean Squared Error (MSE) for sample {c+1}: {mse_sample.item()}") 

93 

94 # Collect results 

95 mse_records.append( 

96 {"sample_id": sample_ids[c], "mse_imagix": mse_sample.item()} 

97 ) 

98 

99 df_imagix_mse = pd.DataFrame(mse_records) 

100 df_imagix_mse.set_index("sample_id", inplace=True) 

101 # Merge with meta_imagix 

102 df_imagix_mse = df_imagix_mse.join(meta_imagix, on="sample_id") 

103 

104 meta_xmodalix = xmodalix_result.datasets.test.datasets[to_key].metadata 

105 sample_ids = list(meta_xmodalix.index) 

106 

107 all_sample_order = sample_ids 

108 indices = [ 

109 all_sample_order.index(sid) for sid in sample_ids if sid in all_sample_order 

110 ] 

111 

112 mse_records = [] 

113 

114 for c in range(len(indices)): 

115 # print(f"Sample {c+1}/{len(indices)}: {sample_ids[c]}") 

116 

117 # Original image 

118 orig = torch.Tensor( 

119 xmodalix_result.datasets.test.datasets[to_key][indices[c]][1].squeeze() 

120 ) 

121 # print(orig.shape) 

122 

123 # Reference Reconstructed image 

124 reference = torch.Tensor( 

125 xmodalix_result.reconstructions.get(epoch=-1, split="test")[ 

126 f"reference_{to_key}_to_{to_key}" 

127 ][indices[c]].squeeze() 

128 ) 

129 # print(reference.shape) 

130 

131 # Translated Reconstructed image 

132 translation = torch.Tensor( 

133 xmodalix_result.reconstructions.get(epoch=-1, split="test")[ 

134 "translation" 

135 ][indices[c]].squeeze() 

136 ) 

137 # print(translation.shape) 

138 

139 # Calculate MSE via torch 

140 mse_sample_translated = F.mse_loss(orig, translation, reduction="mean") 

141 # print(f"Mean Squared Error (MSE) for sample {c+1}: {mse_sample_translated.item()}") 

142 mse_sample_reference = F.mse_loss(orig, reference, reduction="mean") 

143 # print(f"Mean Squared Error (MSE) for sample {c+1}: {mse_sample_reference.item()}") 

144 

145 # Collect results 

146 mse_records.append( 

147 { 

148 "sample_id": sample_ids[c], 

149 "mse_xmodalix_translated": mse_sample_translated.item(), 

150 "mse_xmodalix_reference": mse_sample_reference.item(), 

151 } 

152 ) 

153 

154 df_xmodalix_mse = pd.DataFrame(mse_records) 

155 df_xmodalix_mse.set_index("sample_id", inplace=True) 

156 

157 # Merge with meta_xmodalix 

158 df_xmodalix_mse = df_xmodalix_mse.join(meta_xmodalix, on="sample_id") 

159 

160 # Merge via sample_id and keep non overlapping entries 

161 df_both_mse = df_imagix_mse.merge( 

162 df_xmodalix_mse, on=list(meta_imagix.columns), how="outer" 

163 ) 

164 

165 # Make long format for plotting 

166 df_long = df_both_mse.melt( 

167 id_vars=[ 

168 col 

169 for col in df_both_mse.columns 

170 if col 

171 not in [ 

172 "mse_imagix", 

173 "mse_xmodalix_translated", 

174 "mse_xmodalix_reference", 

175 ] 

176 ], 

177 value_vars=[ 

178 "mse_imagix", 

179 "mse_xmodalix_translated", 

180 "mse_xmodalix_reference", 

181 ], 

182 var_name="model", 

183 value_name="mse_value", 

184 ) 

185 

186 df_long["model"] = df_long["model"].map( 

187 { 

188 "mse_imagix": "imagix", 

189 "mse_xmodalix_translated": "xmodalix_translated", 

190 "mse_xmodalix_reference": "xmodalix_reference", 

191 } 

192 ) 

193 

194 if param: 

195 plt.figure(figsize=(2 * len(df_long[param].unique()), 8)) 

196 

197 fig = sns.boxplot(data=df_long, x=param, y="mse_value", hue="model") 

198 sns.move_legend( 

199 fig, 

200 "lower center", 

201 bbox_to_anchor=(0.5, 1), 

202 ncol=3, 

203 title=None, 

204 frameon=False, 

205 ) 

206 else: 

207 plt.figure(figsize=(5, 8)) 

208 

209 fig = sns.boxplot(data=df_long, x="model", y="mse_value") 

210 # Rotate tick labels 

211 plt.xticks(rotation=-45) 

212 plt.xlabel("") 

213 

214 return fig, df_long 

215 

216 @staticmethod 

217 def _get_clin_data(datasets) -> Union[pd.Series, pd.DataFrame]: 

218 """Retrieves the clinical annotation DataFrame (clin_data) from the provided datasets. 

219 

220 Handles both standard and XModalix dataset structures. 

221 """ 

222 # XModalix-Case 

223 if hasattr(datasets.train, "datasets"): 

224 clin_data = pd.DataFrame() 

225 splits = [datasets.train, datasets.valid, datasets.test] 

226 

227 for s in splits: 

228 for k in s.datasets.keys(): 

229 print(f"Processing dataset: {k}") 

230 # Merge metadata by overlapping columns 

231 overlap = clin_data.columns.intersection( 

232 s.datasets[k].metadata.columns 

233 ) 

234 if overlap.empty: 

235 overlap = s.datasets[k].metadata.columns 

236 clin_data = pd.concat( 

237 [clin_data, s.datasets[k].metadata[overlap]], axis=0 

238 ) 

239 

240 # Remove duplicate rows 

241 clin_data = clin_data[~clin_data.index.duplicated(keep="first")] 

242 else: 

243 # Raise error no annotation given 

244 raise ValueError( 

245 "No annotation data found. Please provide a valid annotation data type." 

246 ) 

247 return clin_data 

248 

249 def _enrich_results( 

250 self, 

251 results: pd.DataFrame, 

252 sklearn_ml: Union[ClassifierMixin, RegressorMixin], 

253 ml_type: str, 

254 task: str, 

255 sub: str, 

256 ) -> pd.DataFrame: 

257 res_ml_alg = [str(sklearn_ml) for x in range(0, results.shape[0])] 

258 res_ml_type = [ml_type for x in range(0, results.shape[0])] 

259 res_ml_subtask = [sub for x in range(0, results.shape[0])] 

260 

261 results["ML_ALG"] = res_ml_alg 

262 results["ML_TYPE"] = res_ml_type 

263 

264 modality = task.split("_$_")[1] 

265 task_xmodal = task.split("_$_")[0] 

266 

267 results["MODALITY"] = [modality for x in range(0, results.shape[0])] 

268 results["ML_TASK"] = [task_xmodal for x in range(0, results.shape[0])] 

269 

270 results["ML_SUBTASK"] = res_ml_subtask 

271 

272 return results 

273 

274 @staticmethod 

275 @no_type_check 

276 def _expand_reference_methods(reference_methods: list, result: Result) -> list: 

277 """ 

278 Expands each reference method by appending a suffix for every key of used data modalities. 

279 For each method in `reference_methods`, this function generates new method names by concatenating 

280 the method name with each key for the data modalities of the xmodalix. 

281 Args: 

282 reference_methods (list): A list of reference method names to be expanded. 

283 result (Result): An object containing latent space information. 

284 Returns: 

285 list: A list of expanded reference method names, each suffixed with a key from the latent space. 

286 """ 

287 if not isinstance(result.latentspaces.get(epoch=-1, split="train"), dict): 

288 raise NotImplementedError( 

289 "This evaluate feature does not support .save(save_all=False) results." 

290 ) 

291 reference_methods = [ 

292 f"{method}_$_{key}" 

293 for method in reference_methods 

294 for key in result.latentspaces.get(epoch=-1, split="train").keys() 

295 ] 

296 

297 return reference_methods 

298 

299 ## New for x-modalix 

300 @staticmethod 

301 def _load_input_for_ml( 

302 task: str, dataset: DatasetContainer, result: Result 

303 ) -> pd.DataFrame: 

304 """Loads and processes input data for various machine learning tasks based on the specified task type. 

305 

306 Task Details: 

307 - "Latent": Concatenates latent representations from train, validation, and test splits at the final epoch. 

308 - "UMAP": Applies UMAP dimensionality reduction to the concatenated dataset splits. 

309 - "PCA": Applies PCA dimensionality reduction to the concatenated dataset splits. 

310 - "TSNE": Applies t-SNE dimensionality reduction to the concatenated dataset splits. 

311 - "RandomFeature": Randomly samples columns (features) from the concatenated dataset splits. 

312 

313 

314 Args: 

315 task: The type of ML task. Supported values are "Latent", "UMAP", "PCA", "TSNE", and "RandomFeature". 

316 dataset: The dataset container object holding train, validation, and test splits. 

317 result: The result object containing model configuration and methods to retrieve latent representations. 

318 Returns: 

319 A DataFrame containing the processed input data suitable for the specified ML task. 

320 Raises: 

321 ValueError: If the provided task is not supported. 

322 """ 

323 

324 # final_epoch = result.model.config.epochs - 1 

325 modality = task.split("_$_")[1] 

326 task = task.split("_$_")[0] 

327 

328 if task == "Latent": 

329 df = pd.concat( 

330 [ 

331 result.get_latent_df(epoch=-1, split="train", modality=modality), 

332 result.get_latent_df(epoch=-1, split="valid", modality=modality), 

333 result.get_latent_df(epoch=-1, split="test", modality=modality), 

334 ] 

335 ) 

336 elif task in ["UMAP", "PCA", "TSNE", "RandomFeature"]: 

337 latent_dim = result.get_latent_df( 

338 epoch=-1, split="train", modality=modality 

339 ).shape[1] 

340 if dataset.train is None: 

341 raise ValueError("train attribute of dataset cannot be None") 

342 if dataset.valid is None: 

343 raise ValueError("valid attribute of dataset cannot be None") 

344 if dataset.test is None: 

345 raise ValueError("test attribute of dataset cannot be None") 

346 

347 df_processed = pd.concat( 

348 [ 

349 dataset.train._to_df(modality=modality), 

350 dataset.test._to_df(modality=modality), 

351 dataset.valid._to_df(modality=modality), 

352 ] 

353 ) 

354 if task == "UMAP": 

355 reducer = UMAP(n_components=latent_dim) 

356 df = pd.DataFrame( 

357 reducer.fit_transform(df_processed), index=df_processed.index 

358 ) 

359 elif task == "PCA": 

360 reducer = PCA(n_components=latent_dim) 

361 df = pd.DataFrame( 

362 reducer.fit_transform(df_processed), index=df_processed.index 

363 ) 

364 elif task == "TSNE": 

365 reducer = TSNE(n_components=latent_dim) 

366 df = pd.DataFrame( 

367 reducer.fit_transform(df_processed), index=df_processed.index 

368 ) 

369 elif task == "RandomFeature": 

370 df = df_processed.sample(n=latent_dim, axis=1) 

371 else: 

372 raise ValueError( 

373 f"Your ML task {task} is not supported. Please use Latent, UMAP, PCA or RandomFeature." 

374 ) 

375 

376 return df