Coverage for src / autoencodix / xmodalix.py: 37%
95 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
1from typing import Dict, Optional, Type, Union, Literal
2import numpy as np
4import anndata as ad # type: ignore
6from autoencodix.base._base_dataset import BaseDataset, DataSetTypes
7from autoencodix.base._base_loss import BaseLoss
8from autoencodix.base._base_pipeline import BasePipeline
9from autoencodix.base._base_trainer import BaseTrainer
10from autoencodix.base._base_visualizer import BaseVisualizer
11from autoencodix.base._base_preprocessor import BasePreprocessor
12from autoencodix.base._base_autoencoder import BaseAutoencoder
13from autoencodix.data._datasetcontainer import DatasetContainer
14from autoencodix.data._multimodal_dataset import MultiModalDataset
16from autoencodix.modeling._imagevae_architecture import ImageVAEArchitecture
17from autoencodix.data._datasplitter import DataSplitter
18from autoencodix.data.datapackage import DataPackage
19from autoencodix.modeling._varix_architecture import VarixArchitecture
20from autoencodix.data._xmodal_preprocessor import XModalPreprocessor
21from autoencodix.evaluate._xmodalix_evaluator import XModalixEvaluator
22from autoencodix.trainers._xmodal_trainer import XModalTrainer
23from autoencodix.utils._result import Result
24from autoencodix.configs.default_config import DefaultConfig
25from autoencodix.configs.xmodalix_config import XModalixConfig
26from autoencodix.utils._losses import XModalLoss
27from autoencodix.utils._utils import find_translation_keys
28from autoencodix.visualize._xmodal_visualizer import XModalVisualizer
30from mudata import MuData
31import torch # type: ignore
34class XModalix(BasePipeline):
35 """XModalix specific version of the BasePipeline class.
37 Inherits preprocess, fit, evaluate, and visualize methods from BasePipeline.
38 Overrides predict
40 This class extends BasePipeline. See the parent class for a full list
41 of attributes and methods.
43 Additional Attributes:
44 _default_config: Is set to XModalixConfig here.
46 """
48 def __init__(
49 self,
50 data: Optional[Union[DataPackage, DatasetContainer]] = None,
51 trainer_type: Type[BaseTrainer] = XModalTrainer,
52 dataset_type: Type[BaseDataset] = MultiModalDataset,
53 model_type: Type[BaseAutoencoder] = VarixArchitecture,
54 loss_type: Type[BaseLoss] = XModalLoss,
55 preprocessor_type: Type[BasePreprocessor] = XModalPreprocessor,
56 visualizer: Optional[Type[BaseVisualizer]] = XModalVisualizer,
57 evaluator: Optional[Type[XModalixEvaluator]] = XModalixEvaluator,
58 result: Optional[Result] = None,
59 datasplitter_type: Type[DataSplitter] = DataSplitter,
60 custom_splits: Optional[Dict[str, np.ndarray]] = None,
61 config: Optional[DefaultConfig] = None,
62 model_map: Dict[DataSetTypes, Type[BaseAutoencoder]] = {
63 DataSetTypes.NUM: VarixArchitecture,
64 DataSetTypes.IMG: ImageVAEArchitecture,
65 },
66 ) -> None:
67 """See base class for full list of Args."""
68 self._default_config = XModalixConfig()
69 super().__init__(
70 data=data,
71 dataset_type=dataset_type,
72 trainer_type=trainer_type,
73 model_type=model_type,
74 loss_type=loss_type,
75 preprocessor_type=preprocessor_type,
76 visualizer=visualizer,
77 evaluator=evaluator,
78 result=result,
79 datasplitter_type=datasplitter_type,
80 config=config,
81 custom_split=custom_splits,
82 model_map=model_map,
83 )
84 if not isinstance(self.config, XModalixConfig):
85 raise TypeError(
86 f"For XModalix Pipeline, we only allow XModalixConfig as type for config, got {type(self.config)}"
87 )
89 def show_result(self):
90 """Displays key visualizations of model results.
92 This method generates the following visualizations:
93 1. Loss Curves: Displays the absolute loss curves to provide insights into
94 the model's training and validation performance over epochs.
95 2. Latent Space Ridgeline Plot: Visualizes the distribution of the latent
96 space representations across different dimensions, offering a high-level
97 overview of the learned embeddings.
98 3. Latent Space 2D Scatter Plot: Projects the latent space into two dimensions
99 for a detailed view of the clustering or separation of data points.
101 These visualizations help in understanding the model's performance and
102 the structure of the latent space representations.
103 """
104 print("Creating plots ...")
106 self.visualizer.show_loss(plot_type="absolute")
108 self.visualizer.show_latent_space(result=self.result, plot_type="Ridgeline")
110 self.visualizer.show_latent_space(result=self.result, plot_type="2D-scatter")
112 dm_keys = find_translation_keys(
113 config=self.config,
114 trained_modalities=self._trainer._modality_dynamics.keys(),
115 )
116 if "IMG" in dm_keys["to"]:
117 self.visualizer.show_image_translation(
118 result=self.result,
119 from_key=dm_keys["from"],
120 to_key=dm_keys["to"],
121 split="test",
122 )
124 def _process_latent_results(
125 self, predictor_results: Result, predict_data: DatasetContainer
126 ):
127 """Processes latent space results into a single AnnData object for the source modality.
129 This method identifies the source ('from') modality used for translation,
130 extracts its latent space and sample IDs, and creates a single, informative
131 AnnData object. The original feature names from the source dataset are
132 also stored for reference.
134 The final AnnData object is stored in `self.result.adata_latent`.
136 Args:
137 predictor_results: The Result object returned by the `predict` method.
138 predict_data: The MultiModalDataset used for prediction, to access metadata.
140 Raises:
141 TypeError: if predicitonsresults arr no Dicts.
142 ValueError: if no translate direction can be found or no latentspace is stored for the specified modality.
143 """
144 print("Processing latent space results into a single AnnData object...")
145 all_latents = predictor_results.latentspaces.get(epoch=-1, split="test")
146 all_sample_ids = predictor_results.sample_ids.get(epoch=-1, split="test")
147 all_recons = predictor_results.reconstructions.get(epoch=-1, split="test")
149 if not all(
150 [isinstance(res, dict) for res in [all_latents, all_sample_ids, all_recons]]
151 ):
152 raise TypeError("Expected prediction results to be dictionaries.")
154 try:
155 predict_keys = find_translation_keys(
156 config=self.config,
157 trained_modalities=all_latents.keys(), # type: ignore
158 )
159 from_key = predict_keys["from"]
160 except Exception as e:
161 # Provide a helpful error if the key can't be determined.
162 raise ValueError(
163 "Could not determine the 'from_key' for translation. "
164 "Ensure the translation direction is specified in the config."
165 ) from e
167 print(f"Identified source modality for latent space: '{from_key}'")
169 latent = all_latents.get(from_key) # type: ignore
170 sample_ids = all_sample_ids.get(from_key) # type: ignore
172 if latent is None:
173 raise ValueError(
174 f"Latent space for source modality '{from_key}' not found."
175 )
177 self.result.adata_latent = ad.AnnData(latent)
178 self.result.adata_latent.var_names = [
179 f"latent_{i}" for i in range(latent.shape[1])
180 ]
182 if sample_ids is not None:
183 self.result.adata_latent.obs_names = sample_ids
185 source_dataset = predict_data.test.datasets.get(from_key) # type: ignore
186 if (
187 source_dataset
188 and hasattr(source_dataset, "feature_ids")
189 and source_dataset.feature_ids is not None
190 ):
191 self.result.adata_latent.uns["feature_ids"] = source_dataset.feature_ids
192 print(
193 f" - Added {len(source_dataset.feature_ids)} source feature IDs to .uns"
194 )
196 self.result.update(predictor_results)
197 print("Finished processing latent results.")
199 def _validate_prediction_data(self, predict_data: BaseDataset):
200 """Validate that prediction data has required test split.
202 Args:
203 predict_data: Dataset for prediciton
204 Raises:
205 ValueError: if prediciton data is empty.
207 """
208 if predict_data is None:
209 raise ValueError(
210 f"The data for prediction need to be a DatasetContainer with a test attribute containing a `BaseDataset` child"
211 f"attribute, got: {predict_data}"
212 )
214 def _predict(
215 self,
216 predict_data: BaseDataset,
217 from_key: Optional[str] = None,
218 to_key: Optional[str] = None,
219 split: Literal["train", "valid", "test"] = "test",
220 ):
221 """Utility warpper of predict method of the trainer instance"""
222 self._validate_prediction_data(predict_data=predict_data)
223 return self._trainer.predict(
224 data=predict_data,
225 model=self.result.model,
226 from_key=from_key,
227 to_key=to_key,
228 split=split,
229 ) # type: ignore
231 def predict(
232 self,
233 data: Optional[
234 Union[
235 DataPackage,
236 DatasetContainer,
237 ad.AnnData,
238 MuData, # ty: ignore[invalid-type-form]
239 ]
240 ] = None,
241 config: Optional[Union[None, DefaultConfig]] = None,
242 from_key: Optional[str] = None,
243 to_key: Optional[str] = None,
244 predict_all: bool = False,
245 **kwargs,
246 ):
247 """Generates predictions using the trained model.
249 Uses the trained model to make predictions on test data or new data
250 provided by the user. Processes the results and stores them in the
251 result container.
253 Args:
254 data: Optional new data for predictions.
255 config: Optional custom configuration for prediction.
256 from_key: string indicator of 'from' translation direction.
257 to_key: string indicator of 'to' translation direction.
258 **kwargs: Additional configuration parameters as keyword arguments.
260 Raises:
261 NotImplementedError: If required components aren't initialized.
262 ValueError: If no test data is available or data format is invalid.
263 """
264 self._validate_prediction_requirements()
266 original_input = data
267 predict_data = self._prepare_prediction_data(data=data)
268 if predict_data.test is None:
269 raise ValueError("No test data available for predictions.")
270 predictor_results = self._predict(
271 predict_data=predict_data.test,
272 from_key=from_key,
273 to_key=to_key,
274 split="test",
275 )
277 self._process_latent_results(
278 predictor_results=predictor_results, predict_data=predict_data
279 )
280 self._postprocess_reconstruction(
281 predictor_results=predictor_results,
282 original_input=original_input,
283 predict_data=predict_data,
284 )
286 if predict_all:
287 if self._datasets is None:
288 raise ValueError(
289 "No training/validation data available for predictions."
290 )
291 train_pred_results = self._predict(
292 predict_data=self._datasets.train,
293 from_key=from_key,
294 to_key=to_key,
295 split="train",
296 )
297 self.result.update(other=train_pred_results)
298 if self._datasets.valid is None:
299 raise ValueError("No validation data available for predictions.")
300 valid_pred_results = self._predict(
301 predict_data=self._datasets.valid,
302 from_key=from_key,
303 to_key=to_key,
304 split="valid",
305 )
306 self.result.update(other=valid_pred_results)
308 return self.result
310 def sample_latent_space(
311 self,
312 n_samples: int,
313 split: str = "test",
314 epoch: int = -1,
315 ) -> torch.Tensor:
316 raise NotImplementedError(
317 "Sampling latent space is not implemented for XModalix."
318 )