Coverage for src / autoencodix / evaluate / _xmodalix_evaluator.py: 23%
128 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
1from typing import Union, Tuple, Optional, no_type_check
4import pandas as pd
5import torch
6import torch.nn.functional as F
8from matplotlib import pyplot as plt
9from matplotlib.figure import Figure
10import seaborn as sns
12import sklearn
13from sklearn.decomposition import PCA
14from umap import UMAP
15from sklearn.manifold import TSNE
16from sklearn.base import ClassifierMixin, RegressorMixin
18from autoencodix.utils._result import Result
19from autoencodix.data._datasetcontainer import DatasetContainer
20from autoencodix.evaluate._general_evaluator import GeneralEvaluator
22sklearn.set_config(enable_metadata_routing=True)
25class XModalixEvaluator(GeneralEvaluator):
26 def __init__(self):
27 # super().__init__()
28 pass
30 @staticmethod
31 @no_type_check
32 def pure_vae_comparison(
33 xmodalix_result: Result,
34 pure_vae_result: Result,
35 to_key: str,
36 param: Optional[str] = None,
37 ) -> Tuple[Figure, pd.DataFrame]:
38 """Compares the reconstruction performance of a pure VAE model and a cross-modal VAE (xmodalix) model using Mean Squared Error (MSE) on test samples.
40 For each sample in the test set, computes the MSE between the original and reconstructed images for:
41 - Pure VAE reconstructions ("imagix")
42 - xmodalix reference reconstructions ("xmodalix_reference")
43 - xmodalix translated reconstructions ("xmodalix_translated")
44 The results are merged with sample metadata and returned in a long-format DataFrame suitable for plotting. Optionally, boxplots are generated grouped by a specified metadata parameter.
46 Args:
47 xmodalix_result: The result object containing xmodalix model outputs and test datasets.
48 pure_vae_result: The result object containing pure VAE model outputs and test datasets.
49 to_key: The key specifying the target modality in the xmodalix dataset.
50 param: Metadata column name to group boxplots by. If None, plots are grouped by model only.
52 Returns:
53 - The matplotlib/seaborn boxplot figure comparing MSE distributions.
54 - DataFrame: Long-format DataFrame containing MSE values and associated metadata for each sample and model.
55 """
57 if "img" not in to_key:
58 raise NotImplementedError(
59 "Comparison is currently only implemented for the image case."
60 )
62 ## Pure VAE MSE calculation
63 meta_imagix = pure_vae_result.datasets.test.metadata
64 if meta_imagix is None:
65 raise ValueError("metadata cannot be None")
66 sample_ids = list(meta_imagix.index)
68 all_sample_order = sample_ids ## TODO check code, seems unnecessary
69 indices = [
70 all_sample_order.index(sid) for sid in sample_ids if sid in all_sample_order
71 ]
73 mse_records = []
75 for c in range(len(indices)):
76 # print(f"Sample {c+1}/{len(indices)}: {sample_ids[c]}")
78 # Original image
79 orig = torch.Tensor(
80 pure_vae_result.datasets.test.raw_data[indices[c]].img.squeeze()
81 )
83 # Reconstructed image
84 recon = torch.Tensor(
85 pure_vae_result.reconstructions.get(split="test", epoch=-1)[
86 indices[c]
87 ].squeeze()
88 )
90 # Calculate MSE via torch
91 mse_sample = F.mse_loss(orig, recon, reduction="mean")
92 # print(f"Mean Squared Error (MSE) for sample {c+1}: {mse_sample.item()}")
94 # Collect results
95 mse_records.append(
96 {"sample_id": sample_ids[c], "mse_imagix": mse_sample.item()}
97 )
99 df_imagix_mse = pd.DataFrame(mse_records)
100 df_imagix_mse.set_index("sample_id", inplace=True)
101 # Merge with meta_imagix
102 df_imagix_mse = df_imagix_mse.join(meta_imagix, on="sample_id")
104 meta_xmodalix = xmodalix_result.datasets.test.datasets[to_key].metadata
105 sample_ids = list(meta_xmodalix.index)
107 all_sample_order = sample_ids
108 indices = [
109 all_sample_order.index(sid) for sid in sample_ids if sid in all_sample_order
110 ]
112 mse_records = []
114 for c in range(len(indices)):
115 # print(f"Sample {c+1}/{len(indices)}: {sample_ids[c]}")
117 # Original image
118 orig = torch.Tensor(
119 xmodalix_result.datasets.test.datasets[to_key][indices[c]][1].squeeze()
120 )
121 # print(orig.shape)
123 # Reference Reconstructed image
124 reference = torch.Tensor(
125 xmodalix_result.reconstructions.get(epoch=-1, split="test")[
126 f"reference_{to_key}_to_{to_key}"
127 ][indices[c]].squeeze()
128 )
129 # print(reference.shape)
131 # Translated Reconstructed image
132 translation = torch.Tensor(
133 xmodalix_result.reconstructions.get(epoch=-1, split="test")[
134 "translation"
135 ][indices[c]].squeeze()
136 )
137 # print(translation.shape)
139 # Calculate MSE via torch
140 mse_sample_translated = F.mse_loss(orig, translation, reduction="mean")
141 # print(f"Mean Squared Error (MSE) for sample {c+1}: {mse_sample_translated.item()}")
142 mse_sample_reference = F.mse_loss(orig, reference, reduction="mean")
143 # print(f"Mean Squared Error (MSE) for sample {c+1}: {mse_sample_reference.item()}")
145 # Collect results
146 mse_records.append(
147 {
148 "sample_id": sample_ids[c],
149 "mse_xmodalix_translated": mse_sample_translated.item(),
150 "mse_xmodalix_reference": mse_sample_reference.item(),
151 }
152 )
154 df_xmodalix_mse = pd.DataFrame(mse_records)
155 df_xmodalix_mse.set_index("sample_id", inplace=True)
157 # Merge with meta_xmodalix
158 df_xmodalix_mse = df_xmodalix_mse.join(meta_xmodalix, on="sample_id")
160 # Merge via sample_id and keep non overlapping entries
161 df_both_mse = df_imagix_mse.merge(
162 df_xmodalix_mse, on=list(meta_imagix.columns), how="outer"
163 )
165 # Make long format for plotting
166 df_long = df_both_mse.melt(
167 id_vars=[
168 col
169 for col in df_both_mse.columns
170 if col
171 not in [
172 "mse_imagix",
173 "mse_xmodalix_translated",
174 "mse_xmodalix_reference",
175 ]
176 ],
177 value_vars=[
178 "mse_imagix",
179 "mse_xmodalix_translated",
180 "mse_xmodalix_reference",
181 ],
182 var_name="model",
183 value_name="mse_value",
184 )
186 df_long["model"] = df_long["model"].map(
187 {
188 "mse_imagix": "imagix",
189 "mse_xmodalix_translated": "xmodalix_translated",
190 "mse_xmodalix_reference": "xmodalix_reference",
191 }
192 )
194 if param:
195 plt.figure(figsize=(2 * len(df_long[param].unique()), 8))
197 fig = sns.boxplot(data=df_long, x=param, y="mse_value", hue="model")
198 sns.move_legend(
199 fig,
200 "lower center",
201 bbox_to_anchor=(0.5, 1),
202 ncol=3,
203 title=None,
204 frameon=False,
205 )
206 else:
207 plt.figure(figsize=(5, 8))
209 fig = sns.boxplot(data=df_long, x="model", y="mse_value")
210 # Rotate tick labels
211 plt.xticks(rotation=-45)
212 plt.xlabel("")
214 return fig, df_long
216 @staticmethod
217 def _get_clin_data(datasets) -> Union[pd.Series, pd.DataFrame]:
218 """Retrieves the clinical annotation DataFrame (clin_data) from the provided datasets.
220 Handles both standard and XModalix dataset structures.
221 """
222 # XModalix-Case
223 if hasattr(datasets.train, "datasets"):
224 clin_data = pd.DataFrame()
225 splits = [datasets.train, datasets.valid, datasets.test]
227 for s in splits:
228 for k in s.datasets.keys():
229 print(f"Processing dataset: {k}")
230 # Merge metadata by overlapping columns
231 overlap = clin_data.columns.intersection(
232 s.datasets[k].metadata.columns
233 )
234 if overlap.empty:
235 overlap = s.datasets[k].metadata.columns
236 clin_data = pd.concat(
237 [clin_data, s.datasets[k].metadata[overlap]], axis=0
238 )
240 # Remove duplicate rows
241 clin_data = clin_data[~clin_data.index.duplicated(keep="first")]
242 else:
243 # Raise error no annotation given
244 raise ValueError(
245 "No annotation data found. Please provide a valid annotation data type."
246 )
247 return clin_data
249 def _enrich_results(
250 self,
251 results: pd.DataFrame,
252 sklearn_ml: Union[ClassifierMixin, RegressorMixin],
253 ml_type: str,
254 task: str,
255 sub: str,
256 ) -> pd.DataFrame:
257 res_ml_alg = [str(sklearn_ml) for x in range(0, results.shape[0])]
258 res_ml_type = [ml_type for x in range(0, results.shape[0])]
259 res_ml_subtask = [sub for x in range(0, results.shape[0])]
261 results["ML_ALG"] = res_ml_alg
262 results["ML_TYPE"] = res_ml_type
264 modality = task.split("_$_")[1]
265 task_xmodal = task.split("_$_")[0]
267 results["MODALITY"] = [modality for x in range(0, results.shape[0])]
268 results["ML_TASK"] = [task_xmodal for x in range(0, results.shape[0])]
270 results["ML_SUBTASK"] = res_ml_subtask
272 return results
274 @staticmethod
275 @no_type_check
276 def _expand_reference_methods(reference_methods: list, result: Result) -> list:
277 """
278 Expands each reference method by appending a suffix for every key of used data modalities.
279 For each method in `reference_methods`, this function generates new method names by concatenating
280 the method name with each key for the data modalities of the xmodalix.
281 Args:
282 reference_methods (list): A list of reference method names to be expanded.
283 result (Result): An object containing latent space information.
284 Returns:
285 list: A list of expanded reference method names, each suffixed with a key from the latent space.
286 """
287 if not isinstance(result.latentspaces.get(epoch=-1, split="train"), dict):
288 raise NotImplementedError(
289 "This evaluate feature does not support .save(save_all=False) results."
290 )
291 reference_methods = [
292 f"{method}_$_{key}"
293 for method in reference_methods
294 for key in result.latentspaces.get(epoch=-1, split="train").keys()
295 ]
297 return reference_methods
299 ## New for x-modalix
300 @staticmethod
301 def _load_input_for_ml(
302 task: str, dataset: DatasetContainer, result: Result
303 ) -> pd.DataFrame:
304 """Loads and processes input data for various machine learning tasks based on the specified task type.
306 Task Details:
307 - "Latent": Concatenates latent representations from train, validation, and test splits at the final epoch.
308 - "UMAP": Applies UMAP dimensionality reduction to the concatenated dataset splits.
309 - "PCA": Applies PCA dimensionality reduction to the concatenated dataset splits.
310 - "TSNE": Applies t-SNE dimensionality reduction to the concatenated dataset splits.
311 - "RandomFeature": Randomly samples columns (features) from the concatenated dataset splits.
314 Args:
315 task: The type of ML task. Supported values are "Latent", "UMAP", "PCA", "TSNE", and "RandomFeature".
316 dataset: The dataset container object holding train, validation, and test splits.
317 result: The result object containing model configuration and methods to retrieve latent representations.
318 Returns:
319 A DataFrame containing the processed input data suitable for the specified ML task.
320 Raises:
321 ValueError: If the provided task is not supported.
322 """
324 # final_epoch = result.model.config.epochs - 1
325 modality = task.split("_$_")[1]
326 task = task.split("_$_")[0]
328 if task == "Latent":
329 df = pd.concat(
330 [
331 result.get_latent_df(epoch=-1, split="train", modality=modality),
332 result.get_latent_df(epoch=-1, split="valid", modality=modality),
333 result.get_latent_df(epoch=-1, split="test", modality=modality),
334 ]
335 )
336 elif task in ["UMAP", "PCA", "TSNE", "RandomFeature"]:
337 latent_dim = result.get_latent_df(
338 epoch=-1, split="train", modality=modality
339 ).shape[1]
340 if dataset.train is None:
341 raise ValueError("train attribute of dataset cannot be None")
342 if dataset.valid is None:
343 raise ValueError("valid attribute of dataset cannot be None")
344 if dataset.test is None:
345 raise ValueError("test attribute of dataset cannot be None")
347 df_processed = pd.concat(
348 [
349 dataset.train._to_df(modality=modality),
350 dataset.test._to_df(modality=modality),
351 dataset.valid._to_df(modality=modality),
352 ]
353 )
354 if task == "UMAP":
355 reducer = UMAP(n_components=latent_dim)
356 df = pd.DataFrame(
357 reducer.fit_transform(df_processed), index=df_processed.index
358 )
359 elif task == "PCA":
360 reducer = PCA(n_components=latent_dim)
361 df = pd.DataFrame(
362 reducer.fit_transform(df_processed), index=df_processed.index
363 )
364 elif task == "TSNE":
365 reducer = TSNE(n_components=latent_dim)
366 df = pd.DataFrame(
367 reducer.fit_transform(df_processed), index=df_processed.index
368 )
369 elif task == "RandomFeature":
370 df = df_processed.sample(n=latent_dim, axis=1)
371 else:
372 raise ValueError(
373 f"Your ML task {task} is not supported. Please use Latent, UMAP, PCA or RandomFeature."
374 )
376 return df