Coverage for src / autoencodix / trainers / _stackix_trainer.py: 31%

61 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-21 10:09 +0200

1from typing import Dict, Optional, Tuple, Type 

2 

3import torch 

4from lightning_fabric import Fabric 

5 

6from autoencodix.base._base_autoencoder import BaseAutoencoder 

7from autoencodix.base._base_dataset import BaseDataset 

8from autoencodix.base._base_loss import BaseLoss 

9from autoencodix.base._base_trainer import BaseTrainer 

10from autoencodix.data._stackix_dataset import StackixDataset 

11from autoencodix.trainers._general_trainer import GeneralTrainer 

12from autoencodix.trainers._stackix_orchestrator import StackixOrchestrator 

13from autoencodix.utils._result import Result 

14from autoencodix.configs.default_config import DefaultConfig 

15 

16 

17class StackixTrainer(GeneralTrainer): 

18 """StackixTrainer is a wrapper for StackixOrchestrator that conforms to the BaseTrainer interface. 

19 

20 This trainer maintains compatibility with the BasePipeline interface while 

21 leveraging the more modular and well-designed StackixOrchestrator 

22 classes for the actual implementation. 

23 

24 Attributes: 

25 _workdir: Directory for saving intermediate models and results 

26 _result: Result object to store training outcomes 

27 _config: Configuration parameters for training and model architecture 

28 _model_type: Type of autoencoder model to use for each modality 

29 _loss_type: Type of loss function to use for training 

30 _trainset: Training dataset containing multiple modalities 

31 _validset: Validation dataset containing multiple modalities 

32 _orchestrator_type: Type to use for orchestrating modality training (default is StackixOrchestrator) 

33 _trainer_type: Type to use for training each modality model (default is GeneralTrainer) 

34 _modality_trainers: Dictionary of trained models for each modality 

35 _modality_results: Dictionary of training results for each modality 

36 _trainer: Trainer for the stacked model 

37 _fabric: Lightning Fabric wrapper for device and precision management 

38 _train_latent_ds: Training dataset with concatenated latent spaces 

39 _valid_latent_ds: Validation dataset with concatenated latent spaces 

40 concat_idx: Indices used for concatenating latent spaces 

41 _model: The instantiated stacked model architecture 

42 _optimizer: The optimizer used for training 

43 _orchestrator: The orchestrator that manages modality model training and latent space preparation 

44 _workdir: Directory for saving intermediate models and results 

45 """ 

46 

47 def __init__( 

48 self, 

49 trainset: Optional[StackixDataset], 

50 validset: Optional[StackixDataset], 

51 result: Result, 

52 config: DefaultConfig, 

53 model_type: Type[BaseAutoencoder], 

54 loss_type: Type[BaseLoss], 

55 orchestrator_type: Type[StackixOrchestrator] = StackixOrchestrator, 

56 trainer_type: Type[BaseTrainer] = GeneralTrainer, 

57 workdir: str = "./stackix_work", 

58 ontologies: Optional[Tuple] = None, 

59 **kwargs, 

60 ) -> None: 

61 """Initialize the StackixTrainer with datasets and configuration. 

62 

63 Args: 

64 trainset: Training dataset containing multiple modalities 

65 validset: Validation dataset containing multiple modalities 

66 result: Result object to store training outcomes 

67 config: Configuration parameters for training and model architecture 

68 model_type: Type of autoencoder model to use for each modality 

69 loss_type: Type of loss function to use for training 

70 orchestrator_type: Type to use for orchestrating modality training (default is StackixOrchestrator) 

71 trainer_type: Type to use for training each modality model (default is GeneralTrainer) 

72 workdir: Directory to save intermediate models and results (default is "./stackix_work") 

73 onotologies: Ontology information, if provided for Ontix compatibility 

74 """ 

75 self._workdir = workdir 

76 self._result = result 

77 self._config = config 

78 self._model_type = model_type 

79 self._loss_type = loss_type 

80 self._trainset = trainset 

81 self._validset = validset 

82 self._orchestrator_type = orchestrator_type 

83 self._trainer_type = trainer_type 

84 

85 self._orchestrator = self._orchestrator_type( 

86 trainset=trainset, 

87 validset=validset, 

88 config=config, 

89 model_type=model_type, 

90 loss_type=loss_type, 

91 workdir=workdir, 

92 ) 

93 self._modality_trainers: Optional[Dict[str, BaseAutoencoder]] = None 

94 self._modality_results: Optional[Dict[str, Result]] = None 

95 

96 self._fabric = Fabric( 

97 accelerator=self._config.device, 

98 devices=self._config.n_gpus, 

99 precision=self._config.float_precision, 

100 strategy=self._config.gpu_strategy, 

101 ) 

102 

103 super().__init__( 

104 trainset=trainset, 

105 validset=validset, 

106 result=result, 

107 config=config, 

108 model_type=model_type, 

109 loss_type=loss_type, 

110 ) 

111 

112 def get_model(self) -> torch.nn.Module: 

113 """Getter for the the trained model. 

114 

115 Returns: 

116 The trained model 

117 """ 

118 return self._model 

119 

120 def train(self) -> Result: 

121 """Train the stacked model on the concatenated latent space. 

122 

123 Uses the standard BaseTrainer training process but with the stacked model. 

124 

125 Returns: 

126 Training results including losses, latent spaces, and other metrics 

127 """ 

128 print("Training each modality model...") 

129 self._train_modalities() 

130 print("finished training each modality model") 

131 self._trainer = self._trainer_type( 

132 trainset=self._train_latent_ds, 

133 validset=self._valid_latent_ds, 

134 result=self._result, 

135 config=self._config, 

136 model_type=self._model_type, 

137 loss_type=self._loss_type, 

138 ) 

139 

140 self._model = self._trainer._model 

141 self._result = self._trainer.train() 

142 

143 return self._result 

144 

145 def _train_modalities(self) -> None: 

146 """Trains a Autoencoder for each modality in the dataset. 

147 This method orchestrates the training of individual modality models 

148 

149 This method orchestrates the complete training process: 

150 1. Train individual modality models 

151 2. Extract and concatenate latent spaces 

152 3. Populates the self._modality_models and self._modality_results attributes 

153 

154 """ 

155 # Step 1: Train individual modality models 

156 self._modality_trainers, self._modality_results = ( 

157 self._orchestrator.train_modalities() 

158 ) 

159 

160 self._result.sub_results = self._modality_results 

161 # Step 2: Prepare concatenated latent space datasets 

162 self._train_latent_ds = self._orchestrator.prepare_latent_datasets( 

163 split="train" 

164 ) 

165 self._valid_latent_ds = self._orchestrator.prepare_latent_datasets( 

166 split="valid" 

167 ) 

168 self.concat_idx = self._orchestrator.concat_idx 

169 

170 def _reconstruct(self, split: str) -> None: 

171 """Orchestrates the reconstruction by delegating the task to the StackixOrchestrator. 

172 

173 Args: 

174 split: The data split to reconstruct ('train', 'valid', 'test'). 

175 """ 

176 stacked_recon = self._result.reconstructions.get(epoch=-1, split=split) 

177 if stacked_recon is None: 

178 print(f"Warning: No reconstruction found for split '{split}'. Skipping.") 

179 return 

180 

181 # The orchestrator handles the de-concatenation, re-assembly, and decoding. 

182 modality_reconstructions = self._orchestrator.reconstruct_from_stack( 

183 reconstructed_stack=torch.from_numpy(stacked_recon).to(self._fabric.device) 

184 ) 

185 

186 # The result is the final, full-sized data reconstructions. 

187 self._result.sub_reconstructions = modality_reconstructions 

188 

189 def predict( 

190 self, data: BaseDataset, model: Optional[torch.nn.Module] = None, **kwargs 

191 ) -> Result: 

192 """Make predictions on the given dataset. 

193 

194 Args: 

195 data: The dataset to make predictions on. 

196 model: The model to use for predictions. If None, uses the trained model. 

197 **kwargs: Additional keyword arguments. 

198 Returns: 

199 Result: The prediction results including reconstructions and latent spaces. 

200 """ 

201 self.n_test = len(data) if data is not None else 0 

202 self._orchestrator.set_testset(testset=data) 

203 test_ds = self._orchestrator.prepare_latent_datasets(split="test") 

204 self.testset = test_ds 

205 print(test_ds) 

206 pred_result = super().predict(data=test_ds, model=model) 

207 self._result.update(other=pred_result) 

208 self._reconstruct(split="test") 

209 return self._result