Coverage for src / autoencodix / trainers / _stackix_orchestrator.py: 15%

172 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-21 10:09 +0200

1import os 

2import pickle 

3from typing import Dict, List, Optional, Tuple, Type, Any 

4 

5import lightning_fabric 

6import numpy as np 

7import pandas as pd 

8import torch 

9 

10from autoencodix.base._base_autoencoder import BaseAutoencoder 

11from autoencodix.base._base_dataset import BaseDataset 

12from autoencodix.base._base_loss import BaseLoss 

13from autoencodix.data._numeric_dataset import NumericDataset 

14from autoencodix.data._multimodal_dataset import MultiModalDataset 

15from autoencodix.trainers._general_trainer import GeneralTrainer 

16from autoencodix.utils._result import Result 

17from autoencodix.configs.default_config import DefaultConfig 

18 

19 

20class StackixOrchestrator: 

21 """StackixOrchestrator coordinates the training of multi-modality VAE stacking. 

22 

23 This orchestrator manages both parallel and sequential training of modality-specific 

24 autoencoders, followed by creating a concatenated latent space for the final 

25 stacked model training. It leverages Lightning Fabric's distribution strategies 

26 for efficient training. 

27 

28 Attributes: 

29 _workdir: Directory for saving intermediate models and results 

30 modality_models: Dictionary of trained models for each modality 

31 modality_results: Dictionary of training results for each modality 

32 _modality_latent_dims: Dictionary of latent dimensions for each modality 

33 concatenated_latent_spaces: Dictionary of concatenated latent spaces by split 

34 _dataset_type: Class to use for creating datasets 

35 _fabric: Lightning Fabric instance for distributed operations 

36 _trainer_class: Class to use for training (dependency injection) 

37 trainset: Training dataset containing multiple modalities 

38 validset: Validation dataset containing multiple modalities 

39 testset: Test dataset containing multiple modalities 

40 loss_type: Type of loss function to use for training 

41 model_type: Type of autoencoder model to use for each modality 

42 config: Configuration parameters for training and model architecture 

43 stacked_model: The final stacked autoencoder model (initialized later) 

44 stacked_trainer: Trainer for the stacked model (initialized later) 

45 concat_idx: Dictionary tracking the start and end indices of each modality in the concatenated latent space 

46 dropped_indices_map: Dictionary tracking dropped sample indices for each modality during alignment 

47 reconstruction_shapes: Dictionary storing original shapes of latent spaces for reconstruction 

48 common_sample_ids: Common sample IDs across all modalities for alignment 

49 

50 """ 

51 

52 def __init__( 

53 self, 

54 trainset: Optional[MultiModalDataset], 

55 validset: Optional[MultiModalDataset], 

56 config: DefaultConfig, 

57 model_type: Type[BaseAutoencoder], 

58 loss_type: Type[BaseLoss], 

59 testset: Optional[MultiModalDataset] = None, 

60 trainer_type: Type[GeneralTrainer] = GeneralTrainer, 

61 dataset_type: Type[BaseDataset] = NumericDataset, 

62 workdir: str = "./stackix_work", 

63 ): 

64 """ 

65 Initialize the StackixOrchestrator with datasets and configuration. 

66 

67 Args: 

68 trainset: Training dataset containing multiple modalities 

69 validset: Validation dataset containing multiple modalities 

70 config: Configuration parameters for training and model architecture 

71 model_type: Type of autoencoder model to use for each modality 

72 loss_type: Type of loss function to use for training 

73 testset: Dataset with test split 

74 trainer_type: Type to use for training (default is GeneralTrainer) 

75 dataset_type: Type to use for creating datasets (default is NumericDataset) 

76 workdir: Directory to save intermediate models and results (default is "./stackix_work") 

77 """ 

78 self._trainset = trainset 

79 self._validset = validset 

80 self._testset = testset 

81 self._config = config 

82 self._model_type = model_type 

83 self._loss_type = loss_type 

84 self._workdir = workdir 

85 self._trainer_class = trainer_type 

86 self._dataset_type = dataset_type 

87 

88 # Initialize fabric for distributed operations 

89 strategy = "auto" 

90 if hasattr(config, "gpu_strategy"): 

91 strategy = config.gpu_strategy 

92 

93 self._fabric = lightning_fabric.Fabric( 

94 accelerator=( 

95 str(config.device) if hasattr(config, "device") else None 

96 ), # ty: ignore[invalid-argument-type] 

97 devices=config.n_gpus if hasattr(config, "n_gpus") else 1, 

98 precision=( 

99 config.float_precision if hasattr(config, "float_precision") else "32" 

100 ), 

101 strategy=strategy, 

102 ) 

103 

104 # Initialize storage for models and results 

105 self.modality_trainers: Dict[str, BaseAutoencoder] = {} 

106 self.modality_results: Dict[str, Result] = {} 

107 self._modality_latent_dims: Dict[str, int] = {} 

108 self.concatenated_latent_spaces: Dict[str, torch.Tensor] = {} 

109 

110 # Stacked model and trainer will be created during the process 

111 self.stacked_model: Optional[BaseAutoencoder] = None 

112 self.stacked_trainer: Optional[GeneralTrainer] = None 

113 self.concat_idx: Dict[str, Optional[Tuple[int, int]]] = {} 

114 self.dropped_indices_map: Dict[str, np.ndarray] = {} 

115 self.reconstruction_shapes: Dict[str, torch.Size] = {} 

116 self.common_sample_ids: Optional[pd.Index] = None 

117 

118 # Create the directory if it doesn't exist 

119 os.makedirs(name=self._workdir, exist_ok=True) 

120 

121 def set_testset(self, testset: MultiModalDataset) -> None: 

122 """Set the test dataset for the orchestrator. 

123 

124 Args: 

125 testset: Test dataset containing multiple modalities 

126 """ 

127 self._testset = testset 

128 

129 def _train_modality( 

130 self, 

131 modality: str, 

132 modality_dataset: BaseDataset, 

133 valid_modality_dataset: Optional[BaseDataset] = None, 

134 ) -> None: 

135 """Trains a single modality and returns the trained model and result. 

136 

137 This function is designed to be executed within a Lightning Fabric context. 

138 It trains an individual modality model and saves the results to disk. 

139 

140 Args: 

141 modality: Modality name/identifier 

142 modality_dataset: Training dataset for this modality 

143 valid_modality_dataset: Validation dataset for this modality (default is None) 

144 

145 Returns: 

146 Trained model and result object 

147 """ 

148 print(f"Training modality: {modality}") 

149 result = Result() 

150 trainer = self._trainer_class( 

151 trainset=modality_dataset, 

152 validset=valid_modality_dataset, 

153 result=result, 

154 config=self._config, 

155 model_type=self._model_type, 

156 loss_type=self._loss_type, 

157 ) 

158 

159 result = trainer.train() 

160 self.modality_results[modality] = result 

161 self.modality_trainers[modality] = trainer 

162 

163 def _train_distributed(self, keys: List[str]) -> None: 

164 """Trains modality models in a distributed fashion using Lightning Fabric. 

165 

166 Uses Lightning Fabric's built-in capabilities for distributing work across devices. 

167 Each process trains a subset of modalities, then loads results from other processes. 

168 

169 Args: 

170 keys: List of modality keys to train 

171 """ 

172 # For parallel training with multiple devices 

173 strategy = ( 

174 self._config.gpu_strategy 

175 if hasattr(self._config, "gpu_strategy") 

176 else "auto" 

177 ) 

178 n_gpus = self._config.n_gpus if hasattr(self._config, "n_gpus") else 1 

179 device: str = self._config.device 

180 

181 if strategy == "auto" and device in ["cuda", "gpu"]: 

182 strategy = "ddp" if n_gpus > 1 else "dp" 

183 

184 # Create a fabric instance with appropriate parallelism strategy 

185 fabric = lightning_fabric.Fabric( 

186 accelerator=device, 

187 devices=min(n_gpus, len(keys)), # No more devices than modalities 

188 precision=( 

189 self._config.float_precision 

190 if hasattr(self._config, "float_precision") 

191 else "32" 

192 ), 

193 strategy=strategy, 

194 ) 

195 

196 # Launch distributed training 

197 fabric.launch() 

198 

199 # Device ID within the process group 

200 local_rank = fabric.local_rank if hasattr(fabric, "local_rank") else 0 

201 world_size = fabric.world_size if hasattr(fabric, "world_size") else 1 

202 

203 # Distribute modalities across ranks 

204 my_modalities = [k for i, k in enumerate(keys) if i % world_size == local_rank] 

205 

206 # Train assigned modalities 

207 for modality in my_modalities: 

208 train_dataset = self._trainset.datasets[modality] 

209 valid_dataset = ( 

210 self._validset.datasets.get(modality) 

211 if self._validset and hasattr(self._validset, "datasets") 

212 else None 

213 ) 

214 

215 self._train_modality( 

216 modality=modality, 

217 modality_dataset=train_dataset, 

218 valid_modality_dataset=valid_dataset, 

219 ) 

220 self._modality_latent_dims[modality] = train_dataset.get_input_dim() 

221 

222 # Synchronize across processes to ensure all modalities are trained 

223 if world_size > 1: 

224 fabric.barrier() 

225 

226 # Load models and results from other processes 

227 for modality in keys: 

228 if modality in self.modality_trainers.keys(): 

229 continue 

230 # Load model state dict 

231 model_path = os.path.join(self._workdir, f"{modality}_model.ckpt") 

232 

233 # Get input dimension for this modality 

234 input_dim = self._trainset.datasets[modality].get_input_dim() 

235 

236 # Initialize a new model 

237 model = self._model_type(config=self._config, input_dim=input_dim) 

238 model.load_state_dict( 

239 state_dict=torch.load(f=model_path, map_location="cpu") 

240 ) 

241 

242 # Load result 

243 with open( 

244 file=os.path.join(self._workdir, f"{modality}_result.pkl"), 

245 mode="rb", 

246 ) as f: 

247 result = pickle.load(file=f) 

248 

249 self.modality_trainers[modality] = model 

250 self.modality_results[modality] = result 

251 self._modality_latent_dims[modality] = input_dim 

252 

253 def _train_sequential(self, keys: List[str]) -> None: 

254 """Trains modality models sequentially on a single device. 

255 

256 Used when distributed training is not available or not necessary. 

257 Processes each modality one after another. 

258 

259 Args: 

260 keys: List of modality keys to train 

261 """ 

262 for modality in keys: 

263 print(f"Training modality: {modality}") 

264 train_dataset = self._trainset.datasets[modality] 

265 valid_dataset = ( 

266 self._validset.datasets.get(modality) if self._validset else None 

267 ) 

268 

269 self._train_modality( 

270 modality=modality, 

271 modality_dataset=train_dataset, 

272 valid_modality_dataset=valid_dataset, 

273 ) 

274 

275 self._modality_latent_dims[modality] = train_dataset.get_input_dim() 

276 

277 def _extract_latent_spaces( 

278 self, result_dict: Dict[str, Any], split: str = "train" 

279 ) -> torch.Tensor: 

280 """Extracts, aligns, and concatenates latent spaces, populating instance attributes for later reconstruction. 

281 

282 Args: 

283 result_dict: Dictionary of Result objects from modality training 

284 split: Data split to extract latent spaces from ("train", "valid", "test"), default is "train" 

285 """ 

286 # --- Step 1: Extract Latent Spaces and Sample IDs --- 

287 all_latents = {} 

288 all_sample_ids = {} 

289 for name, result in result_dict.items(): 

290 try: 

291 latent_space = result.latentspaces.get(split=split, epoch=-1) 

292 sample_ids = result.sample_ids.get(split=split, epoch=-1) 

293 

294 if latent_space.shape[0] != len(sample_ids): 

295 raise ValueError( 

296 f"Mismatch between latent space rows ({latent_space.shape[0]}) and sample IDs ({len(sample_ids)}) for modality '{name}'." 

297 ) 

298 

299 all_latents[name] = latent_space 

300 all_sample_ids[name] = sample_ids 

301 self.reconstruction_shapes[name] = latent_space.shape 

302 except Exception as e: 

303 raise ValueError( 

304 f"Failed to extract data for modality {name} on split {split}: {e}" 

305 ) 

306 

307 if not all_latents: 

308 raise ValueError("No latent spaces were extracted.") 

309 

310 # --- Step 2: Find Common Sample IDs --- 

311 initial_ids = set(next(iter(all_sample_ids.values()))) 

312 common_ids_set = initial_ids.intersection( 

313 *[set(ids) for ids in all_sample_ids.values()] 

314 ) 

315 

316 if not common_ids_set: 

317 raise ValueError( 

318 "No common samples found across all modalities for concatenation." 

319 ) 

320 # common_ids = [cid for cid in list(common_ids_set) if cid] 

321 self.common_sample_ids = pd.Index(sorted(list(common_ids_set))) 

322 

323 # self.common_sample_ids = pd.Index(sorted(list(common_ids_set))) 

324 print( 

325 f"Found {len(self.common_sample_ids)} common samples for the stacked autoencoder." 

326 ) 

327 

328 # --- Step 3 & 4: Align Latent Spaces and Track Dropped Indices --- 

329 aligned_latents_list = [] 

330 start_idx = 0 

331 for name, latent_space in all_latents.items(): 

332 original_ids = pd.Index(all_sample_ids[name]) 

333 latent_df = pd.DataFrame(latent_space, index=original_ids) 

334 

335 aligned_df = latent_df.loc[self.common_sample_ids] 

336 aligned_latents_list.append(torch.from_numpy(aligned_df.values)) 

337 

338 end_idx = start_idx + aligned_df.shape[1] 

339 self.concat_idx[name] = (start_idx, end_idx) 

340 start_idx = end_idx 

341 

342 dropped_mask = ~original_ids.isin(self.common_sample_ids) 

343 self.dropped_indices_map[name] = np.where(dropped_mask)[0] 

344 

345 # --- Step 5: Concatenate --- 

346 stackix_input = torch.cat(aligned_latents_list, dim=1) 

347 return stackix_input 

348 

349 def train_modalities(self) -> Tuple[Dict[str, BaseAutoencoder], Dict[str, Result]]: 

350 """Trains all modality-specific models. 

351 

352 This is the first phase of Stackix training where each modality is 

353 trained independently before their latent spaces are combined. 

354 

355 Returns: 

356 Dictionary of trained models for each modality and Dictionary of training results 

357 for each modality. 

358 

359 Raises: 

360 ValueError: If trainset is not a MultiModalDataset or has no modalities 

361 """ 

362 # if not isinstance(self._trainset, MulDataset): 

363 # raise ValueError("Trainset must be a MultiModalDataset for Stackix training") 

364 

365 keys = self._trainset.datasets.keys() 

366 if not keys: 

367 raise ValueError("No modalities found in trainset") 

368 

369 n_modalities = len(keys) 

370 

371 # Determine if we should use distributed training 

372 n_gpus = self._config.n_gpus if hasattr(self._config, "n_gpus") else 0 

373 use_distributed = bool(n_gpus and n_gpus > 1 and n_modalities > 1) 

374 

375 if use_distributed: 

376 self._train_distributed(keys=keys) 

377 else: 

378 self._train_sequential(keys=keys) 

379 

380 return self.modality_trainers, self.modality_results 

381 

382 def prepare_latent_datasets(self, split: str) -> NumericDataset: 

383 """Prepares datasets with concatenated latent spaces for stacked model training. 

384 

385 This is the second phase of Stackix training where latent spaces from 

386 all modalities are extracted and concatenated. 

387 

388 Returns: 

389 Training and validation datasets with concatenated latent spaces 

390 

391 Raises: 

392 ValueError: If no modality models have been trained or no latent spaces could be extracted 

393 """ 

394 if not self.modality_trainers: 

395 raise ValueError( 

396 "No modality models have been trained. Call train_modalities() first." 

397 ) 

398 

399 # Extract and concatenate latent spaces 

400 if split == "test": 

401 if self._testset is None: 

402 raise ValueError( 

403 "No test dataset available. Please provide a test dataset for prediction." 

404 ) 

405 latent = self._extract_latent_spaces( 

406 result_dict=self.predict_modalities(data=self._testset), split=split 

407 ) 

408 ds = NumericDataset( 

409 data=latent, 

410 config=self._config, 

411 sample_ids=self.common_sample_ids, 

412 feature_ids=[ 

413 f"{k}_latent_{i}" 

414 for k, (start, end) in self.concat_idx.items() 

415 for i in range(start, end) 

416 ], 

417 ) 

418 return ds 

419 

420 latent = self._extract_latent_spaces( 

421 result_dict=self.modality_results, split=split 

422 ) 

423 # Create datasets for the concatenated latent spaces 

424 feature_ids = [ 

425 f"{k}_latent_{i}" 

426 for k, (start, end) in self.concat_idx.items() 

427 for i in range(start, end) 

428 ] 

429 ds = NumericDataset( 

430 data=latent, 

431 config=self._config, 

432 sample_ids=self.common_sample_ids, 

433 feature_ids=feature_ids, 

434 ) 

435 

436 return ds 

437 

438 def predict_modalities(self, data: MultiModalDataset) -> Dict[str, torch.Tensor]: 

439 """Predicts using the trained models for each modality. 

440 

441 Args: 

442 data: Input data for prediction, uses test data if not provided 

443 

444 Returns: 

445 Dictionary of reconstructed tensors by modality 

446 

447 Raises: 

448 ValueError: If model has not been trained yet or no data is available 

449 """ 

450 if not self.modality_trainers: 

451 raise ValueError( 

452 "No modality models have been trained. Call train_modalities() first." 

453 ) 

454 

455 predictions = {} 

456 for modality, trainer in self.modality_trainers.items(): 

457 predictions[modality] = trainer.predict( 

458 data=data.datasets[modality], 

459 model=trainer._model, 

460 ) 

461 return predictions 

462 

463 def reconstruct_from_stack( 

464 self, reconstructed_stack: torch.Tensor 

465 ) -> Dict[str, torch.Tensor]: 

466 """Reconstructs the full data for each modality from the stacked latent reconstruction. 

467 

468 Args: 

469 reconstructed_stack: Tensor with the reconstructed concatenated latent space 

470 Returns: 

471 Dictionary of reconstructed tensors by modality 

472 """ 

473 modality_reconstructions: Dict[str, torch.Tensor] = {} 

474 

475 for trainer in self.modality_trainers.values(): 

476 trainer._model.eval() 

477 

478 for name, (start_idx, end_idx) in self.concat_idx.items(): 

479 # ================================================================= 

480 # STEP 1: DE-CONCATENATE THE ALIGNED PART 

481 # This is the small, dense reconstruction matching the common samples. 

482 # ================================================================= 

483 aligned_latent_recon = reconstructed_stack[:, start_idx:end_idx] 

484 

485 # ================================================================= 

486 # STEP 2: RE-ASSEMBLE THE FULL-SIZE LATENT SPACE 

487 # This is the logic from the old `reconstruct_full_latent` function, 

488 # now integrated directly here. 

489 # ================================================================= 

490 original_shape = self.reconstruction_shapes[name] 

491 dropped_indices = self.dropped_indices_map[name] 

492 n_original_samples = original_shape[0] 

493 

494 # Determine the original integer positions of the samples that were KEPT. 

495 kept_indices = np.delete(np.arange(n_original_samples), dropped_indices) 

496 

497 # Create a full-size placeholder tensor matching the original latent space. 

498 # We use zeros, as they are a neutral input for most decoders. 

499 full_latent_recon = torch.zeros( 

500 size=original_shape, 

501 dtype=aligned_latent_recon.dtype, 

502 device=self._fabric.device, 

503 ) 

504 

505 # Place the aligned reconstruction data into the correct rows of the full tensor. 

506 full_latent_recon[kept_indices, :] = self._fabric.to_device( 

507 aligned_latent_recon 

508 ) 

509 

510 # ================================================================= 

511 # STEP 3: DECODE THE FULL-SIZE LATENT TENSOR 

512 # The decoder now receives a tensor with the correct number of samples, 

513 # including the rows of zeros for the dropped samples. 

514 # ================================================================= 

515 with torch.no_grad(): 

516 model = self.modality_trainers[name].get_model() 

517 model = self._fabric.to_device(model) 

518 

519 final_recon = model.decode(full_latent_recon) 

520 modality_reconstructions[name] = final_recon.cpu() 

521 

522 return modality_reconstructions