Coverage for src / autoencodix / trainers / _stackix_trainer.py: 31%
61 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
1from typing import Dict, Optional, Tuple, Type
3import torch
4from lightning_fabric import Fabric
6from autoencodix.base._base_autoencoder import BaseAutoencoder
7from autoencodix.base._base_dataset import BaseDataset
8from autoencodix.base._base_loss import BaseLoss
9from autoencodix.base._base_trainer import BaseTrainer
10from autoencodix.data._stackix_dataset import StackixDataset
11from autoencodix.trainers._general_trainer import GeneralTrainer
12from autoencodix.trainers._stackix_orchestrator import StackixOrchestrator
13from autoencodix.utils._result import Result
14from autoencodix.configs.default_config import DefaultConfig
17class StackixTrainer(GeneralTrainer):
18 """StackixTrainer is a wrapper for StackixOrchestrator that conforms to the BaseTrainer interface.
20 This trainer maintains compatibility with the BasePipeline interface while
21 leveraging the more modular and well-designed StackixOrchestrator
22 classes for the actual implementation.
24 Attributes:
25 _workdir: Directory for saving intermediate models and results
26 _result: Result object to store training outcomes
27 _config: Configuration parameters for training and model architecture
28 _model_type: Type of autoencoder model to use for each modality
29 _loss_type: Type of loss function to use for training
30 _trainset: Training dataset containing multiple modalities
31 _validset: Validation dataset containing multiple modalities
32 _orchestrator_type: Type to use for orchestrating modality training (default is StackixOrchestrator)
33 _trainer_type: Type to use for training each modality model (default is GeneralTrainer)
34 _modality_trainers: Dictionary of trained models for each modality
35 _modality_results: Dictionary of training results for each modality
36 _trainer: Trainer for the stacked model
37 _fabric: Lightning Fabric wrapper for device and precision management
38 _train_latent_ds: Training dataset with concatenated latent spaces
39 _valid_latent_ds: Validation dataset with concatenated latent spaces
40 concat_idx: Indices used for concatenating latent spaces
41 _model: The instantiated stacked model architecture
42 _optimizer: The optimizer used for training
43 _orchestrator: The orchestrator that manages modality model training and latent space preparation
44 _workdir: Directory for saving intermediate models and results
45 """
47 def __init__(
48 self,
49 trainset: Optional[StackixDataset],
50 validset: Optional[StackixDataset],
51 result: Result,
52 config: DefaultConfig,
53 model_type: Type[BaseAutoencoder],
54 loss_type: Type[BaseLoss],
55 orchestrator_type: Type[StackixOrchestrator] = StackixOrchestrator,
56 trainer_type: Type[BaseTrainer] = GeneralTrainer,
57 workdir: str = "./stackix_work",
58 ontologies: Optional[Tuple] = None,
59 **kwargs,
60 ) -> None:
61 """Initialize the StackixTrainer with datasets and configuration.
63 Args:
64 trainset: Training dataset containing multiple modalities
65 validset: Validation dataset containing multiple modalities
66 result: Result object to store training outcomes
67 config: Configuration parameters for training and model architecture
68 model_type: Type of autoencoder model to use for each modality
69 loss_type: Type of loss function to use for training
70 orchestrator_type: Type to use for orchestrating modality training (default is StackixOrchestrator)
71 trainer_type: Type to use for training each modality model (default is GeneralTrainer)
72 workdir: Directory to save intermediate models and results (default is "./stackix_work")
73 onotologies: Ontology information, if provided for Ontix compatibility
74 """
75 self._workdir = workdir
76 self._result = result
77 self._config = config
78 self._model_type = model_type
79 self._loss_type = loss_type
80 self._trainset = trainset
81 self._validset = validset
82 self._orchestrator_type = orchestrator_type
83 self._trainer_type = trainer_type
85 self._orchestrator = self._orchestrator_type(
86 trainset=trainset,
87 validset=validset,
88 config=config,
89 model_type=model_type,
90 loss_type=loss_type,
91 workdir=workdir,
92 )
93 self._modality_trainers: Optional[Dict[str, BaseAutoencoder]] = None
94 self._modality_results: Optional[Dict[str, Result]] = None
96 self._fabric = Fabric(
97 accelerator=self._config.device,
98 devices=self._config.n_gpus,
99 precision=self._config.float_precision,
100 strategy=self._config.gpu_strategy,
101 )
103 super().__init__(
104 trainset=trainset,
105 validset=validset,
106 result=result,
107 config=config,
108 model_type=model_type,
109 loss_type=loss_type,
110 )
112 def get_model(self) -> torch.nn.Module:
113 """Getter for the the trained model.
115 Returns:
116 The trained model
117 """
118 return self._model
120 def train(self) -> Result:
121 """Train the stacked model on the concatenated latent space.
123 Uses the standard BaseTrainer training process but with the stacked model.
125 Returns:
126 Training results including losses, latent spaces, and other metrics
127 """
128 print("Training each modality model...")
129 self._train_modalities()
130 print("finished training each modality model")
131 self._trainer = self._trainer_type(
132 trainset=self._train_latent_ds,
133 validset=self._valid_latent_ds,
134 result=self._result,
135 config=self._config,
136 model_type=self._model_type,
137 loss_type=self._loss_type,
138 )
140 self._model = self._trainer._model
141 self._result = self._trainer.train()
143 return self._result
145 def _train_modalities(self) -> None:
146 """Trains a Autoencoder for each modality in the dataset.
147 This method orchestrates the training of individual modality models
149 This method orchestrates the complete training process:
150 1. Train individual modality models
151 2. Extract and concatenate latent spaces
152 3. Populates the self._modality_models and self._modality_results attributes
154 """
155 # Step 1: Train individual modality models
156 self._modality_trainers, self._modality_results = (
157 self._orchestrator.train_modalities()
158 )
160 self._result.sub_results = self._modality_results
161 # Step 2: Prepare concatenated latent space datasets
162 self._train_latent_ds = self._orchestrator.prepare_latent_datasets(
163 split="train"
164 )
165 self._valid_latent_ds = self._orchestrator.prepare_latent_datasets(
166 split="valid"
167 )
168 self.concat_idx = self._orchestrator.concat_idx
170 def _reconstruct(self, split: str) -> None:
171 """Orchestrates the reconstruction by delegating the task to the StackixOrchestrator.
173 Args:
174 split: The data split to reconstruct ('train', 'valid', 'test').
175 """
176 stacked_recon = self._result.reconstructions.get(epoch=-1, split=split)
177 if stacked_recon is None:
178 print(f"Warning: No reconstruction found for split '{split}'. Skipping.")
179 return
181 # The orchestrator handles the de-concatenation, re-assembly, and decoding.
182 modality_reconstructions = self._orchestrator.reconstruct_from_stack(
183 reconstructed_stack=torch.from_numpy(stacked_recon).to(self._fabric.device)
184 )
186 # The result is the final, full-sized data reconstructions.
187 self._result.sub_reconstructions = modality_reconstructions
189 def predict(
190 self, data: BaseDataset, model: Optional[torch.nn.Module] = None, **kwargs
191 ) -> Result:
192 """Make predictions on the given dataset.
194 Args:
195 data: The dataset to make predictions on.
196 model: The model to use for predictions. If None, uses the trained model.
197 **kwargs: Additional keyword arguments.
198 Returns:
199 Result: The prediction results including reconstructions and latent spaces.
200 """
201 self.n_test = len(data) if data is not None else 0
202 self._orchestrator.set_testset(testset=data)
203 test_ds = self._orchestrator.prepare_latent_datasets(split="test")
204 self.testset = test_ds
205 print(test_ds)
206 pred_result = super().predict(data=test_ds, model=model)
207 self._result.update(other=pred_result)
208 self._reconstruct(split="test")
209 return self._result