Coverage for src / autoencodix / xmodalix.py: 37%

95 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-21 10:09 +0200

1from typing import Dict, Optional, Type, Union, Literal 

2import numpy as np 

3 

4import anndata as ad # type: ignore 

5 

6from autoencodix.base._base_dataset import BaseDataset, DataSetTypes 

7from autoencodix.base._base_loss import BaseLoss 

8from autoencodix.base._base_pipeline import BasePipeline 

9from autoencodix.base._base_trainer import BaseTrainer 

10from autoencodix.base._base_visualizer import BaseVisualizer 

11from autoencodix.base._base_preprocessor import BasePreprocessor 

12from autoencodix.base._base_autoencoder import BaseAutoencoder 

13from autoencodix.data._datasetcontainer import DatasetContainer 

14from autoencodix.data._multimodal_dataset import MultiModalDataset 

15 

16from autoencodix.modeling._imagevae_architecture import ImageVAEArchitecture 

17from autoencodix.data._datasplitter import DataSplitter 

18from autoencodix.data.datapackage import DataPackage 

19from autoencodix.modeling._varix_architecture import VarixArchitecture 

20from autoencodix.data._xmodal_preprocessor import XModalPreprocessor 

21from autoencodix.evaluate._xmodalix_evaluator import XModalixEvaluator 

22from autoencodix.trainers._xmodal_trainer import XModalTrainer 

23from autoencodix.utils._result import Result 

24from autoencodix.configs.default_config import DefaultConfig 

25from autoencodix.configs.xmodalix_config import XModalixConfig 

26from autoencodix.utils._losses import XModalLoss 

27from autoencodix.utils._utils import find_translation_keys 

28from autoencodix.visualize._xmodal_visualizer import XModalVisualizer 

29 

30from mudata import MuData 

31import torch # type: ignore 

32 

33 

34class XModalix(BasePipeline): 

35 """XModalix specific version of the BasePipeline class. 

36 

37 Inherits preprocess, fit, evaluate, and visualize methods from BasePipeline. 

38 Overrides predict 

39 

40 This class extends BasePipeline. See the parent class for a full list 

41 of attributes and methods. 

42 

43 Additional Attributes: 

44 _default_config: Is set to XModalixConfig here. 

45 

46 """ 

47 

48 def __init__( 

49 self, 

50 data: Optional[Union[DataPackage, DatasetContainer]] = None, 

51 trainer_type: Type[BaseTrainer] = XModalTrainer, 

52 dataset_type: Type[BaseDataset] = MultiModalDataset, 

53 model_type: Type[BaseAutoencoder] = VarixArchitecture, 

54 loss_type: Type[BaseLoss] = XModalLoss, 

55 preprocessor_type: Type[BasePreprocessor] = XModalPreprocessor, 

56 visualizer: Optional[Type[BaseVisualizer]] = XModalVisualizer, 

57 evaluator: Optional[Type[XModalixEvaluator]] = XModalixEvaluator, 

58 result: Optional[Result] = None, 

59 datasplitter_type: Type[DataSplitter] = DataSplitter, 

60 custom_splits: Optional[Dict[str, np.ndarray]] = None, 

61 config: Optional[DefaultConfig] = None, 

62 model_map: Dict[DataSetTypes, Type[BaseAutoencoder]] = { 

63 DataSetTypes.NUM: VarixArchitecture, 

64 DataSetTypes.IMG: ImageVAEArchitecture, 

65 }, 

66 ) -> None: 

67 """See base class for full list of Args.""" 

68 self._default_config = XModalixConfig() 

69 super().__init__( 

70 data=data, 

71 dataset_type=dataset_type, 

72 trainer_type=trainer_type, 

73 model_type=model_type, 

74 loss_type=loss_type, 

75 preprocessor_type=preprocessor_type, 

76 visualizer=visualizer, 

77 evaluator=evaluator, 

78 result=result, 

79 datasplitter_type=datasplitter_type, 

80 config=config, 

81 custom_split=custom_splits, 

82 model_map=model_map, 

83 ) 

84 if not isinstance(self.config, XModalixConfig): 

85 raise TypeError( 

86 f"For XModalix Pipeline, we only allow XModalixConfig as type for config, got {type(self.config)}" 

87 ) 

88 

89 def show_result(self): 

90 """Displays key visualizations of model results. 

91 

92 This method generates the following visualizations: 

93 1. Loss Curves: Displays the absolute loss curves to provide insights into 

94 the model's training and validation performance over epochs. 

95 2. Latent Space Ridgeline Plot: Visualizes the distribution of the latent 

96 space representations across different dimensions, offering a high-level 

97 overview of the learned embeddings. 

98 3. Latent Space 2D Scatter Plot: Projects the latent space into two dimensions 

99 for a detailed view of the clustering or separation of data points. 

100 

101 These visualizations help in understanding the model's performance and 

102 the structure of the latent space representations. 

103 """ 

104 print("Creating plots ...") 

105 

106 self.visualizer.show_loss(plot_type="absolute") 

107 

108 self.visualizer.show_latent_space(result=self.result, plot_type="Ridgeline") 

109 

110 self.visualizer.show_latent_space(result=self.result, plot_type="2D-scatter") 

111 

112 dm_keys = find_translation_keys( 

113 config=self.config, 

114 trained_modalities=self._trainer._modality_dynamics.keys(), 

115 ) 

116 if "IMG" in dm_keys["to"]: 

117 self.visualizer.show_image_translation( 

118 result=self.result, 

119 from_key=dm_keys["from"], 

120 to_key=dm_keys["to"], 

121 split="test", 

122 ) 

123 

124 def _process_latent_results( 

125 self, predictor_results: Result, predict_data: DatasetContainer 

126 ): 

127 """Processes latent space results into a single AnnData object for the source modality. 

128 

129 This method identifies the source ('from') modality used for translation, 

130 extracts its latent space and sample IDs, and creates a single, informative 

131 AnnData object. The original feature names from the source dataset are 

132 also stored for reference. 

133 

134 The final AnnData object is stored in `self.result.adata_latent`. 

135 

136 Args: 

137 predictor_results: The Result object returned by the `predict` method. 

138 predict_data: The MultiModalDataset used for prediction, to access metadata. 

139 

140 Raises: 

141 TypeError: if predicitonsresults arr no Dicts. 

142 ValueError: if no translate direction can be found or no latentspace is stored for the specified modality. 

143 """ 

144 print("Processing latent space results into a single AnnData object...") 

145 all_latents = predictor_results.latentspaces.get(epoch=-1, split="test") 

146 all_sample_ids = predictor_results.sample_ids.get(epoch=-1, split="test") 

147 all_recons = predictor_results.reconstructions.get(epoch=-1, split="test") 

148 

149 if not all( 

150 [isinstance(res, dict) for res in [all_latents, all_sample_ids, all_recons]] 

151 ): 

152 raise TypeError("Expected prediction results to be dictionaries.") 

153 

154 try: 

155 predict_keys = find_translation_keys( 

156 config=self.config, 

157 trained_modalities=all_latents.keys(), # type: ignore 

158 ) 

159 from_key = predict_keys["from"] 

160 except Exception as e: 

161 # Provide a helpful error if the key can't be determined. 

162 raise ValueError( 

163 "Could not determine the 'from_key' for translation. " 

164 "Ensure the translation direction is specified in the config." 

165 ) from e 

166 

167 print(f"Identified source modality for latent space: '{from_key}'") 

168 

169 latent = all_latents.get(from_key) # type: ignore 

170 sample_ids = all_sample_ids.get(from_key) # type: ignore 

171 

172 if latent is None: 

173 raise ValueError( 

174 f"Latent space for source modality '{from_key}' not found." 

175 ) 

176 

177 self.result.adata_latent = ad.AnnData(latent) 

178 self.result.adata_latent.var_names = [ 

179 f"latent_{i}" for i in range(latent.shape[1]) 

180 ] 

181 

182 if sample_ids is not None: 

183 self.result.adata_latent.obs_names = sample_ids 

184 

185 source_dataset = predict_data.test.datasets.get(from_key) # type: ignore 

186 if ( 

187 source_dataset 

188 and hasattr(source_dataset, "feature_ids") 

189 and source_dataset.feature_ids is not None 

190 ): 

191 self.result.adata_latent.uns["feature_ids"] = source_dataset.feature_ids 

192 print( 

193 f" - Added {len(source_dataset.feature_ids)} source feature IDs to .uns" 

194 ) 

195 

196 self.result.update(predictor_results) 

197 print("Finished processing latent results.") 

198 

199 def _validate_prediction_data(self, predict_data: BaseDataset): 

200 """Validate that prediction data has required test split. 

201 

202 Args: 

203 predict_data: Dataset for prediciton 

204 Raises: 

205 ValueError: if prediciton data is empty. 

206 

207 """ 

208 if predict_data is None: 

209 raise ValueError( 

210 f"The data for prediction need to be a DatasetContainer with a test attribute containing a `BaseDataset` child" 

211 f"attribute, got: {predict_data}" 

212 ) 

213 

214 def _predict( 

215 self, 

216 predict_data: BaseDataset, 

217 from_key: Optional[str] = None, 

218 to_key: Optional[str] = None, 

219 split: Literal["train", "valid", "test"] = "test", 

220 ): 

221 """Utility warpper of predict method of the trainer instance""" 

222 self._validate_prediction_data(predict_data=predict_data) 

223 return self._trainer.predict( 

224 data=predict_data, 

225 model=self.result.model, 

226 from_key=from_key, 

227 to_key=to_key, 

228 split=split, 

229 ) # type: ignore 

230 

231 def predict( 

232 self, 

233 data: Optional[ 

234 Union[ 

235 DataPackage, 

236 DatasetContainer, 

237 ad.AnnData, 

238 MuData, # ty: ignore[invalid-type-form] 

239 ] 

240 ] = None, 

241 config: Optional[Union[None, DefaultConfig]] = None, 

242 from_key: Optional[str] = None, 

243 to_key: Optional[str] = None, 

244 predict_all: bool = False, 

245 **kwargs, 

246 ): 

247 """Generates predictions using the trained model. 

248 

249 Uses the trained model to make predictions on test data or new data 

250 provided by the user. Processes the results and stores them in the 

251 result container. 

252 

253 Args: 

254 data: Optional new data for predictions. 

255 config: Optional custom configuration for prediction. 

256 from_key: string indicator of 'from' translation direction. 

257 to_key: string indicator of 'to' translation direction. 

258 **kwargs: Additional configuration parameters as keyword arguments. 

259 

260 Raises: 

261 NotImplementedError: If required components aren't initialized. 

262 ValueError: If no test data is available or data format is invalid. 

263 """ 

264 self._validate_prediction_requirements() 

265 

266 original_input = data 

267 predict_data = self._prepare_prediction_data(data=data) 

268 if predict_data.test is None: 

269 raise ValueError("No test data available for predictions.") 

270 predictor_results = self._predict( 

271 predict_data=predict_data.test, 

272 from_key=from_key, 

273 to_key=to_key, 

274 split="test", 

275 ) 

276 

277 self._process_latent_results( 

278 predictor_results=predictor_results, predict_data=predict_data 

279 ) 

280 self._postprocess_reconstruction( 

281 predictor_results=predictor_results, 

282 original_input=original_input, 

283 predict_data=predict_data, 

284 ) 

285 

286 if predict_all: 

287 if self._datasets is None: 

288 raise ValueError( 

289 "No training/validation data available for predictions." 

290 ) 

291 train_pred_results = self._predict( 

292 predict_data=self._datasets.train, 

293 from_key=from_key, 

294 to_key=to_key, 

295 split="train", 

296 ) 

297 self.result.update(other=train_pred_results) 

298 if self._datasets.valid is None: 

299 raise ValueError("No validation data available for predictions.") 

300 valid_pred_results = self._predict( 

301 predict_data=self._datasets.valid, 

302 from_key=from_key, 

303 to_key=to_key, 

304 split="valid", 

305 ) 

306 self.result.update(other=valid_pred_results) 

307 

308 return self.result 

309 

310 def sample_latent_space( 

311 self, 

312 n_samples: int, 

313 split: str = "test", 

314 epoch: int = -1, 

315 ) -> torch.Tensor: 

316 raise NotImplementedError( 

317 "Sampling latent space is not implemented for XModalix." 

318 )