Coverage for src / autoencodix / trainers / _stackix_orchestrator.py: 15%
172 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
1import os
2import pickle
3from typing import Dict, List, Optional, Tuple, Type, Any
5import lightning_fabric
6import numpy as np
7import pandas as pd
8import torch
10from autoencodix.base._base_autoencoder import BaseAutoencoder
11from autoencodix.base._base_dataset import BaseDataset
12from autoencodix.base._base_loss import BaseLoss
13from autoencodix.data._numeric_dataset import NumericDataset
14from autoencodix.data._multimodal_dataset import MultiModalDataset
15from autoencodix.trainers._general_trainer import GeneralTrainer
16from autoencodix.utils._result import Result
17from autoencodix.configs.default_config import DefaultConfig
20class StackixOrchestrator:
21 """StackixOrchestrator coordinates the training of multi-modality VAE stacking.
23 This orchestrator manages both parallel and sequential training of modality-specific
24 autoencoders, followed by creating a concatenated latent space for the final
25 stacked model training. It leverages Lightning Fabric's distribution strategies
26 for efficient training.
28 Attributes:
29 _workdir: Directory for saving intermediate models and results
30 modality_models: Dictionary of trained models for each modality
31 modality_results: Dictionary of training results for each modality
32 _modality_latent_dims: Dictionary of latent dimensions for each modality
33 concatenated_latent_spaces: Dictionary of concatenated latent spaces by split
34 _dataset_type: Class to use for creating datasets
35 _fabric: Lightning Fabric instance for distributed operations
36 _trainer_class: Class to use for training (dependency injection)
37 trainset: Training dataset containing multiple modalities
38 validset: Validation dataset containing multiple modalities
39 testset: Test dataset containing multiple modalities
40 loss_type: Type of loss function to use for training
41 model_type: Type of autoencoder model to use for each modality
42 config: Configuration parameters for training and model architecture
43 stacked_model: The final stacked autoencoder model (initialized later)
44 stacked_trainer: Trainer for the stacked model (initialized later)
45 concat_idx: Dictionary tracking the start and end indices of each modality in the concatenated latent space
46 dropped_indices_map: Dictionary tracking dropped sample indices for each modality during alignment
47 reconstruction_shapes: Dictionary storing original shapes of latent spaces for reconstruction
48 common_sample_ids: Common sample IDs across all modalities for alignment
50 """
52 def __init__(
53 self,
54 trainset: Optional[MultiModalDataset],
55 validset: Optional[MultiModalDataset],
56 config: DefaultConfig,
57 model_type: Type[BaseAutoencoder],
58 loss_type: Type[BaseLoss],
59 testset: Optional[MultiModalDataset] = None,
60 trainer_type: Type[GeneralTrainer] = GeneralTrainer,
61 dataset_type: Type[BaseDataset] = NumericDataset,
62 workdir: str = "./stackix_work",
63 ):
64 """
65 Initialize the StackixOrchestrator with datasets and configuration.
67 Args:
68 trainset: Training dataset containing multiple modalities
69 validset: Validation dataset containing multiple modalities
70 config: Configuration parameters for training and model architecture
71 model_type: Type of autoencoder model to use for each modality
72 loss_type: Type of loss function to use for training
73 testset: Dataset with test split
74 trainer_type: Type to use for training (default is GeneralTrainer)
75 dataset_type: Type to use for creating datasets (default is NumericDataset)
76 workdir: Directory to save intermediate models and results (default is "./stackix_work")
77 """
78 self._trainset = trainset
79 self._validset = validset
80 self._testset = testset
81 self._config = config
82 self._model_type = model_type
83 self._loss_type = loss_type
84 self._workdir = workdir
85 self._trainer_class = trainer_type
86 self._dataset_type = dataset_type
88 # Initialize fabric for distributed operations
89 strategy = "auto"
90 if hasattr(config, "gpu_strategy"):
91 strategy = config.gpu_strategy
93 self._fabric = lightning_fabric.Fabric(
94 accelerator=(
95 str(config.device) if hasattr(config, "device") else None
96 ), # ty: ignore[invalid-argument-type]
97 devices=config.n_gpus if hasattr(config, "n_gpus") else 1,
98 precision=(
99 config.float_precision if hasattr(config, "float_precision") else "32"
100 ),
101 strategy=strategy,
102 )
104 # Initialize storage for models and results
105 self.modality_trainers: Dict[str, BaseAutoencoder] = {}
106 self.modality_results: Dict[str, Result] = {}
107 self._modality_latent_dims: Dict[str, int] = {}
108 self.concatenated_latent_spaces: Dict[str, torch.Tensor] = {}
110 # Stacked model and trainer will be created during the process
111 self.stacked_model: Optional[BaseAutoencoder] = None
112 self.stacked_trainer: Optional[GeneralTrainer] = None
113 self.concat_idx: Dict[str, Optional[Tuple[int, int]]] = {}
114 self.dropped_indices_map: Dict[str, np.ndarray] = {}
115 self.reconstruction_shapes: Dict[str, torch.Size] = {}
116 self.common_sample_ids: Optional[pd.Index] = None
118 # Create the directory if it doesn't exist
119 os.makedirs(name=self._workdir, exist_ok=True)
121 def set_testset(self, testset: MultiModalDataset) -> None:
122 """Set the test dataset for the orchestrator.
124 Args:
125 testset: Test dataset containing multiple modalities
126 """
127 self._testset = testset
129 def _train_modality(
130 self,
131 modality: str,
132 modality_dataset: BaseDataset,
133 valid_modality_dataset: Optional[BaseDataset] = None,
134 ) -> None:
135 """Trains a single modality and returns the trained model and result.
137 This function is designed to be executed within a Lightning Fabric context.
138 It trains an individual modality model and saves the results to disk.
140 Args:
141 modality: Modality name/identifier
142 modality_dataset: Training dataset for this modality
143 valid_modality_dataset: Validation dataset for this modality (default is None)
145 Returns:
146 Trained model and result object
147 """
148 print(f"Training modality: {modality}")
149 result = Result()
150 trainer = self._trainer_class(
151 trainset=modality_dataset,
152 validset=valid_modality_dataset,
153 result=result,
154 config=self._config,
155 model_type=self._model_type,
156 loss_type=self._loss_type,
157 )
159 result = trainer.train()
160 self.modality_results[modality] = result
161 self.modality_trainers[modality] = trainer
163 def _train_distributed(self, keys: List[str]) -> None:
164 """Trains modality models in a distributed fashion using Lightning Fabric.
166 Uses Lightning Fabric's built-in capabilities for distributing work across devices.
167 Each process trains a subset of modalities, then loads results from other processes.
169 Args:
170 keys: List of modality keys to train
171 """
172 # For parallel training with multiple devices
173 strategy = (
174 self._config.gpu_strategy
175 if hasattr(self._config, "gpu_strategy")
176 else "auto"
177 )
178 n_gpus = self._config.n_gpus if hasattr(self._config, "n_gpus") else 1
179 device: str = self._config.device
181 if strategy == "auto" and device in ["cuda", "gpu"]:
182 strategy = "ddp" if n_gpus > 1 else "dp"
184 # Create a fabric instance with appropriate parallelism strategy
185 fabric = lightning_fabric.Fabric(
186 accelerator=device,
187 devices=min(n_gpus, len(keys)), # No more devices than modalities
188 precision=(
189 self._config.float_precision
190 if hasattr(self._config, "float_precision")
191 else "32"
192 ),
193 strategy=strategy,
194 )
196 # Launch distributed training
197 fabric.launch()
199 # Device ID within the process group
200 local_rank = fabric.local_rank if hasattr(fabric, "local_rank") else 0
201 world_size = fabric.world_size if hasattr(fabric, "world_size") else 1
203 # Distribute modalities across ranks
204 my_modalities = [k for i, k in enumerate(keys) if i % world_size == local_rank]
206 # Train assigned modalities
207 for modality in my_modalities:
208 train_dataset = self._trainset.datasets[modality]
209 valid_dataset = (
210 self._validset.datasets.get(modality)
211 if self._validset and hasattr(self._validset, "datasets")
212 else None
213 )
215 self._train_modality(
216 modality=modality,
217 modality_dataset=train_dataset,
218 valid_modality_dataset=valid_dataset,
219 )
220 self._modality_latent_dims[modality] = train_dataset.get_input_dim()
222 # Synchronize across processes to ensure all modalities are trained
223 if world_size > 1:
224 fabric.barrier()
226 # Load models and results from other processes
227 for modality in keys:
228 if modality in self.modality_trainers.keys():
229 continue
230 # Load model state dict
231 model_path = os.path.join(self._workdir, f"{modality}_model.ckpt")
233 # Get input dimension for this modality
234 input_dim = self._trainset.datasets[modality].get_input_dim()
236 # Initialize a new model
237 model = self._model_type(config=self._config, input_dim=input_dim)
238 model.load_state_dict(
239 state_dict=torch.load(f=model_path, map_location="cpu")
240 )
242 # Load result
243 with open(
244 file=os.path.join(self._workdir, f"{modality}_result.pkl"),
245 mode="rb",
246 ) as f:
247 result = pickle.load(file=f)
249 self.modality_trainers[modality] = model
250 self.modality_results[modality] = result
251 self._modality_latent_dims[modality] = input_dim
253 def _train_sequential(self, keys: List[str]) -> None:
254 """Trains modality models sequentially on a single device.
256 Used when distributed training is not available or not necessary.
257 Processes each modality one after another.
259 Args:
260 keys: List of modality keys to train
261 """
262 for modality in keys:
263 print(f"Training modality: {modality}")
264 train_dataset = self._trainset.datasets[modality]
265 valid_dataset = (
266 self._validset.datasets.get(modality) if self._validset else None
267 )
269 self._train_modality(
270 modality=modality,
271 modality_dataset=train_dataset,
272 valid_modality_dataset=valid_dataset,
273 )
275 self._modality_latent_dims[modality] = train_dataset.get_input_dim()
277 def _extract_latent_spaces(
278 self, result_dict: Dict[str, Any], split: str = "train"
279 ) -> torch.Tensor:
280 """Extracts, aligns, and concatenates latent spaces, populating instance attributes for later reconstruction.
282 Args:
283 result_dict: Dictionary of Result objects from modality training
284 split: Data split to extract latent spaces from ("train", "valid", "test"), default is "train"
285 """
286 # --- Step 1: Extract Latent Spaces and Sample IDs ---
287 all_latents = {}
288 all_sample_ids = {}
289 for name, result in result_dict.items():
290 try:
291 latent_space = result.latentspaces.get(split=split, epoch=-1)
292 sample_ids = result.sample_ids.get(split=split, epoch=-1)
294 if latent_space.shape[0] != len(sample_ids):
295 raise ValueError(
296 f"Mismatch between latent space rows ({latent_space.shape[0]}) and sample IDs ({len(sample_ids)}) for modality '{name}'."
297 )
299 all_latents[name] = latent_space
300 all_sample_ids[name] = sample_ids
301 self.reconstruction_shapes[name] = latent_space.shape
302 except Exception as e:
303 raise ValueError(
304 f"Failed to extract data for modality {name} on split {split}: {e}"
305 )
307 if not all_latents:
308 raise ValueError("No latent spaces were extracted.")
310 # --- Step 2: Find Common Sample IDs ---
311 initial_ids = set(next(iter(all_sample_ids.values())))
312 common_ids_set = initial_ids.intersection(
313 *[set(ids) for ids in all_sample_ids.values()]
314 )
316 if not common_ids_set:
317 raise ValueError(
318 "No common samples found across all modalities for concatenation."
319 )
320 # common_ids = [cid for cid in list(common_ids_set) if cid]
321 self.common_sample_ids = pd.Index(sorted(list(common_ids_set)))
323 # self.common_sample_ids = pd.Index(sorted(list(common_ids_set)))
324 print(
325 f"Found {len(self.common_sample_ids)} common samples for the stacked autoencoder."
326 )
328 # --- Step 3 & 4: Align Latent Spaces and Track Dropped Indices ---
329 aligned_latents_list = []
330 start_idx = 0
331 for name, latent_space in all_latents.items():
332 original_ids = pd.Index(all_sample_ids[name])
333 latent_df = pd.DataFrame(latent_space, index=original_ids)
335 aligned_df = latent_df.loc[self.common_sample_ids]
336 aligned_latents_list.append(torch.from_numpy(aligned_df.values))
338 end_idx = start_idx + aligned_df.shape[1]
339 self.concat_idx[name] = (start_idx, end_idx)
340 start_idx = end_idx
342 dropped_mask = ~original_ids.isin(self.common_sample_ids)
343 self.dropped_indices_map[name] = np.where(dropped_mask)[0]
345 # --- Step 5: Concatenate ---
346 stackix_input = torch.cat(aligned_latents_list, dim=1)
347 return stackix_input
349 def train_modalities(self) -> Tuple[Dict[str, BaseAutoencoder], Dict[str, Result]]:
350 """Trains all modality-specific models.
352 This is the first phase of Stackix training where each modality is
353 trained independently before their latent spaces are combined.
355 Returns:
356 Dictionary of trained models for each modality and Dictionary of training results
357 for each modality.
359 Raises:
360 ValueError: If trainset is not a MultiModalDataset or has no modalities
361 """
362 # if not isinstance(self._trainset, MulDataset):
363 # raise ValueError("Trainset must be a MultiModalDataset for Stackix training")
365 keys = self._trainset.datasets.keys()
366 if not keys:
367 raise ValueError("No modalities found in trainset")
369 n_modalities = len(keys)
371 # Determine if we should use distributed training
372 n_gpus = self._config.n_gpus if hasattr(self._config, "n_gpus") else 0
373 use_distributed = bool(n_gpus and n_gpus > 1 and n_modalities > 1)
375 if use_distributed:
376 self._train_distributed(keys=keys)
377 else:
378 self._train_sequential(keys=keys)
380 return self.modality_trainers, self.modality_results
382 def prepare_latent_datasets(self, split: str) -> NumericDataset:
383 """Prepares datasets with concatenated latent spaces for stacked model training.
385 This is the second phase of Stackix training where latent spaces from
386 all modalities are extracted and concatenated.
388 Returns:
389 Training and validation datasets with concatenated latent spaces
391 Raises:
392 ValueError: If no modality models have been trained or no latent spaces could be extracted
393 """
394 if not self.modality_trainers:
395 raise ValueError(
396 "No modality models have been trained. Call train_modalities() first."
397 )
399 # Extract and concatenate latent spaces
400 if split == "test":
401 if self._testset is None:
402 raise ValueError(
403 "No test dataset available. Please provide a test dataset for prediction."
404 )
405 latent = self._extract_latent_spaces(
406 result_dict=self.predict_modalities(data=self._testset), split=split
407 )
408 ds = NumericDataset(
409 data=latent,
410 config=self._config,
411 sample_ids=self.common_sample_ids,
412 feature_ids=[
413 f"{k}_latent_{i}"
414 for k, (start, end) in self.concat_idx.items()
415 for i in range(start, end)
416 ],
417 )
418 return ds
420 latent = self._extract_latent_spaces(
421 result_dict=self.modality_results, split=split
422 )
423 # Create datasets for the concatenated latent spaces
424 feature_ids = [
425 f"{k}_latent_{i}"
426 for k, (start, end) in self.concat_idx.items()
427 for i in range(start, end)
428 ]
429 ds = NumericDataset(
430 data=latent,
431 config=self._config,
432 sample_ids=self.common_sample_ids,
433 feature_ids=feature_ids,
434 )
436 return ds
438 def predict_modalities(self, data: MultiModalDataset) -> Dict[str, torch.Tensor]:
439 """Predicts using the trained models for each modality.
441 Args:
442 data: Input data for prediction, uses test data if not provided
444 Returns:
445 Dictionary of reconstructed tensors by modality
447 Raises:
448 ValueError: If model has not been trained yet or no data is available
449 """
450 if not self.modality_trainers:
451 raise ValueError(
452 "No modality models have been trained. Call train_modalities() first."
453 )
455 predictions = {}
456 for modality, trainer in self.modality_trainers.items():
457 predictions[modality] = trainer.predict(
458 data=data.datasets[modality],
459 model=trainer._model,
460 )
461 return predictions
463 def reconstruct_from_stack(
464 self, reconstructed_stack: torch.Tensor
465 ) -> Dict[str, torch.Tensor]:
466 """Reconstructs the full data for each modality from the stacked latent reconstruction.
468 Args:
469 reconstructed_stack: Tensor with the reconstructed concatenated latent space
470 Returns:
471 Dictionary of reconstructed tensors by modality
472 """
473 modality_reconstructions: Dict[str, torch.Tensor] = {}
475 for trainer in self.modality_trainers.values():
476 trainer._model.eval()
478 for name, (start_idx, end_idx) in self.concat_idx.items():
479 # =================================================================
480 # STEP 1: DE-CONCATENATE THE ALIGNED PART
481 # This is the small, dense reconstruction matching the common samples.
482 # =================================================================
483 aligned_latent_recon = reconstructed_stack[:, start_idx:end_idx]
485 # =================================================================
486 # STEP 2: RE-ASSEMBLE THE FULL-SIZE LATENT SPACE
487 # This is the logic from the old `reconstruct_full_latent` function,
488 # now integrated directly here.
489 # =================================================================
490 original_shape = self.reconstruction_shapes[name]
491 dropped_indices = self.dropped_indices_map[name]
492 n_original_samples = original_shape[0]
494 # Determine the original integer positions of the samples that were KEPT.
495 kept_indices = np.delete(np.arange(n_original_samples), dropped_indices)
497 # Create a full-size placeholder tensor matching the original latent space.
498 # We use zeros, as they are a neutral input for most decoders.
499 full_latent_recon = torch.zeros(
500 size=original_shape,
501 dtype=aligned_latent_recon.dtype,
502 device=self._fabric.device,
503 )
505 # Place the aligned reconstruction data into the correct rows of the full tensor.
506 full_latent_recon[kept_indices, :] = self._fabric.to_device(
507 aligned_latent_recon
508 )
510 # =================================================================
511 # STEP 3: DECODE THE FULL-SIZE LATENT TENSOR
512 # The decoder now receives a tensor with the correct number of samples,
513 # including the rows of zeros for the dropped samples.
514 # =================================================================
515 with torch.no_grad():
516 model = self.modality_trainers[name].get_model()
517 model = self._fabric.to_device(model)
519 final_recon = model.decode(full_latent_recon)
520 modality_reconstructions[name] = final_recon.cpu()
522 return modality_reconstructions