Coverage for src / autoencodix / stackix.py: 67%

45 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-21 10:09 +0200

1from typing import Dict, Optional, Type, Union, List 

2import torch 

3import numpy as np 

4 

5import anndata as ad # type: ignore 

6from autoencodix.base._base_dataset import BaseDataset 

7from autoencodix.utils._utils import config_method 

8from autoencodix.base._base_loss import BaseLoss 

9from autoencodix.base._base_pipeline import BasePipeline 

10from autoencodix.base._base_trainer import BaseTrainer 

11from autoencodix.base._base_visualizer import BaseVisualizer 

12from autoencodix.base._base_preprocessor import BasePreprocessor 

13from autoencodix.base._base_autoencoder import BaseAutoencoder 

14from autoencodix.base._base_evaluator import BaseEvaluator 

15from autoencodix.data._datasetcontainer import DatasetContainer 

16from autoencodix.data._datasplitter import DataSplitter 

17from autoencodix.data.datapackage import DataPackage 

18from autoencodix.data._numeric_dataset import NumericDataset 

19from autoencodix.evaluate._general_evaluator import GeneralEvaluator 

20from autoencodix.modeling._varix_architecture import VarixArchitecture 

21from autoencodix.utils._result import Result 

22from autoencodix.configs.default_config import DefaultConfig 

23 

24from autoencodix.configs.stackix_config import StackixConfig 

25from autoencodix.utils._losses import VarixLoss 

26from autoencodix.data._stackix_preprocessor import StackixPreprocessor 

27from autoencodix.data._stackix_dataset import StackixDataset 

28from autoencodix.trainers._stackix_trainer import StackixTrainer 

29from autoencodix.visualize._general_visualizer import GeneralVisualizer 

30 

31 

32class Stackix(BasePipeline): 

33 """Stackix pipeline for training multiple VAEs on different modalities and stacking their latent spaces. 

34 

35 This pipeline uses: 

36 1. StackixPreprocessor to prepare data for multi-modality training 

37 2. StackixTrainer to train individual VAEs, extract latent spaces, and train the final stacked model 

38 

39 Like other pipelines, it follows the standard BasePipeline interface and workflow. 

40 

41 Additional Attributes: 

42 _default_config: Is set to StackixConfig here. 

43 

44 """ 

45 

46 def __init__( 

47 self, 

48 data: Optional[Union[DataPackage, DatasetContainer]] = None, 

49 trainer_type: Type[BaseTrainer] = StackixTrainer, 

50 dataset_type: Type[BaseDataset] = StackixDataset, 

51 model_type: Type[BaseAutoencoder] = VarixArchitecture, 

52 loss_type: Type[BaseLoss] = VarixLoss, 

53 preprocessor_type: Type[BasePreprocessor] = StackixPreprocessor, 

54 visualizer: Type[BaseVisualizer] = GeneralVisualizer, 

55 evaluator: Optional[Type[BaseEvaluator]] = GeneralEvaluator, 

56 result: Optional[Result] = None, 

57 datasplitter_type: Type[DataSplitter] = DataSplitter, 

58 custom_splits: Optional[Dict[str, np.ndarray]] = None, 

59 config: Optional[DefaultConfig] = None, 

60 ontologies: Optional[Union[List, Dict]] = None, 

61 ) -> None: 

62 """Initialize the Stackix pipeline. 

63 

64 See parent class for full list of Args. 

65 """ 

66 self._default_config = StackixConfig() 

67 super().__init__( 

68 data=data, 

69 dataset_type=dataset_type 

70 or NumericDataset, # Fallback, but not directly used 

71 trainer_type=trainer_type, 

72 model_type=model_type, 

73 loss_type=loss_type, 

74 preprocessor_type=preprocessor_type, 

75 visualizer=visualizer, 

76 evaluator=evaluator, 

77 result=result, 

78 datasplitter_type=datasplitter_type, 

79 config=config, 

80 custom_split=custom_splits, 

81 ontologies=ontologies, 

82 ) 

83 if not isinstance(self.config, StackixConfig): 

84 raise TypeError( 

85 f"For Stackix Pipeline, we only allow StackixConfig as type for config, got {type(self.config)}" 

86 ) 

87 

88 def _process_latent_results( 

89 self, predictor_results: Result, predict_data: DatasetContainer 

90 ): 

91 """Processes the latent spaces from the StackixTrainer prediction results. 

92 

93 Creates a correctly annotated AnnData object. 

94 This method overrides the BasePipeline implementation to specifically handle 

95 the aligned latent space from the unpaired/stacked workflow. 

96 

97 

98 Args: 

99 predictor_results: Result object after predict step 

100 predict_data: not used here, only to keep interface structure 

101 

102 """ 

103 latent = predictor_results.latentspaces.get(epoch=-1, split="test") 

104 sample_ids = predictor_results.sample_ids.get(epoch=-1, split="test") 

105 if latent is None: 

106 import warnings 

107 

108 warnings.warn( 

109 "No latent space found in predictor results. Cannot create AnnData object." 

110 ) 

111 return 

112 

113 self.result.adata_latent = ad.AnnData(X=latent) 

114 self.result.adata_latent.obs_names = sample_ids 

115 self.result.adata_latent.var_names = [ 

116 f"Latent_{i}" for i in range(latent.shape[1]) # ty: ignore 

117 ] 

118 

119 # 4. Update the main result object with the rest of the prediction results. 

120 self.result.update(predictor_results) 

121 

122 print("Successfully created annotated latent space object (adata_latent).")