Coverage for src / autoencodix / trainers / _xmodal_trainer.py: 11%

436 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-21 10:09 +0200

1from collections import defaultdict 

2from packaging.version import Version 

3from torch.profiler import profile, record_function, ProfilerActivity 

4 

5from lightning_fabric.wrappers import _FabricModule 

6import gc 

7import os 

8from typing import Any, Dict, List, Optional, Tuple, Type, Union 

9 

10import numpy as np 

11import torch 

12from torch.utils.data import DataLoader 

13 

14from autoencodix.base._base_autoencoder import BaseAutoencoder 

15from autoencodix.base._base_dataset import BaseDataset, DataSetTypes 

16from autoencodix.base._base_loss import BaseLoss 

17from autoencodix.base._base_trainer import BaseTrainer 

18from autoencodix.data._multimodal_dataset import ( 

19 CoverageEnsuringSampler, 

20 MultiModalDataset, 

21 create_multimodal_collate_fn, 

22) 

23from autoencodix.modeling._classifier import Classifier 

24from autoencodix.modeling._imagevae_architecture import ImageVAEArchitecture 

25from autoencodix.modeling._varix_architecture import VarixArchitecture 

26from autoencodix.utils._losses import VarixLoss 

27from autoencodix.utils._model_output import ModelOutput 

28from autoencodix.utils._result import Result 

29from autoencodix.configs.default_config import DefaultConfig 

30from autoencodix.utils._internals import model_trainer_map 

31from autoencodix.utils._utils import find_translation_keys 

32 

33 

34class XModalTrainer(BaseTrainer): 

35 """Trainer for cross-modal autoencoders, implements multimodal training with adversarial component. 

36 

37 Attributes: 

38 _trainset: The dataset used for training, must be a MultiModalDataset. 

39 _n_train_modalities: Number of modalities in the training dataset. 

40 model_map: Mapping from DataSetTypes to specific autoencoder architectures. 

41 model_trainer_map: Mapping from autoencoder architectures to their corresponding trainer classes. 

42 _n_cpus: Number of CPU cores available for data loading. 

43 latent_dim: Dimensionality of the shared latent space. 

44 n_test: Number of samples in the test set, if provided. 

45 n_train: Number of samples in the training set. 

46 n_valid: Number of samples in the validation set, if provided. 

47 n_features: Total number of features across all modalities. 

48 _cur_epoch: Current epoch number during training. 

49 _is_checkpoint_epoch: Flag indicating if the current epoch is a checkpoint epoch. 

50 _sub_loss_type: Loss function type for individual modality autoencoders. 

51 sub_loss: Instantiated loss function for individual modality autoencoders. 

52 _clf_epoch_loss: Cumulative classifier loss for the current epoch. 

53 _epoch_loss: Cumulative total loss for the current epoch. 

54 _epoch_loss_valid: Cumulative total validation loss for the current epoch. 

55 _modality_dynamics: Dictionary holding model, optimizer, and training state for each modality. 

56 _latent_clf: Classifier model for adversarial training on latent spaces. 

57 _clf_optim: Optimizer for the classifier model. 

58 _clf_loss_fn: Loss function for the classifier. 

59 _trainloader: DataLoader for the training dataset. 

60 _validloader: DataLoader for the validation dataset, if provided. 

61 _model: The instantiated stacked model architecture. 

62 _validset: The dataset used for validation, if provided, must be a MultiModalDataset. 

63 _fabric: Lightning Fabric wrapper for device and precision management. 

64 

65 """ 

66 

67 def __init__( 

68 self, 

69 trainset: MultiModalDataset, 

70 validset: MultiModalDataset, 

71 result: Result, 

72 config: DefaultConfig, 

73 model_type: Type[BaseAutoencoder], 

74 loss_type: Type[BaseLoss], 

75 model_map: Dict[DataSetTypes, Type[BaseAutoencoder]], 

76 sub_loss_type: Type[BaseLoss] = VarixLoss, 

77 ontologies: Optional[Union[Tuple, List]] = None, 

78 **kwargs, 

79 ): 

80 """Initializes the XModalTrainer with datasets and configuration. 

81 

82 Args: 

83 trainset: Training dataset containing multiple modalities 

84 validset: Validation dataset containing multiple modalities 

85 result: Result object to store training outcomes 

86 config: Configuration parameters for training and model architecture 

87 model_type: Type of autoencoder model to use for each modality (not used directly) 

88 loss_type: Type of loss function to use for training the stacked model 

89 sub_loss_type: Type of loss function to use for individual modality autoencoders 

90 model_map: Mapping from DataSetTypes to specific autoencoder architectures 

91 ontologies: Ontology information, for compatibility with Ontix 

92 """ 

93 self._trainset = trainset 

94 self._n_train_modalities = self._trainset.n_modalities # type: ignore 

95 self.model_map = model_map 

96 self.model_trainer_map = model_trainer_map 

97 self._n_cpus = os.cpu_count() 

98 if self._n_cpus is None: 

99 self._n_cpus = 0 

100 

101 super().__init__( 

102 trainset, validset, result, config, model_type, loss_type, ontologies 

103 ) 

104 # init attributes ------------------------------------------------ 

105 self.latent_dim = config.latent_dim 

106 self.n_test: Optional[int] = None 

107 self.n_train = len(trainset.data) if trainset else 0 

108 self.n_valid = len(validset.data) if validset else 0 

109 self._cur_epoch: int = 0 

110 self._is_checkpoint_epoch: Optional[bool] = None 

111 self._sub_loss_type = sub_loss_type 

112 self.sub_loss = sub_loss_type(config=self._config) 

113 self._clf_epoch_loss: float = 0.0 

114 self._epoch_loss: float = 0.0 

115 self._epoch_loss_valid: float = 0.0 

116 self._all_dynamics: Dict[str, Dict[int, Dict]] = { 

117 "test": defaultdict(dict), 

118 "train": defaultdict(dict), 

119 "valid": defaultdict(dict), 

120 } 

121 

122 def _init_model_architecture(self, ontologies: Optional[Tuple] = None): 

123 """Override parent's model init - we don't use a single model, needs to be there because is abstract in parent.""" 

124 pass 

125 

126 def _setup_fabric(self, old_model=None): 

127 """Sets up the models, optimizers, and data loaders with Lightning Fabric.""" 

128 

129 self._fabric.launch() 

130 self._init_adversarial_training() 

131 self._init_modality_training() 

132 self._latent_clf, self._clf_optim = self._fabric.setup( 

133 self._latent_clf, self._clf_optim 

134 ) 

135 

136 self._trainloader = self._fabric.setup_dataloaders(self._trainloader) 

137 if self._validloader is not None: 

138 self._validloader = self._fabric.setup_dataloaders(self._validloader) 

139 

140 for dynamics in self._modality_dynamics.values(): 

141 dynamics["model"], dynamics["optim"] = self._fabric.setup( 

142 dynamics["model"], dynamics["optim"] 

143 ) 

144 dynamics["model"].mark_forward_method("decode") 

145 

146 def _init_loaders(self): 

147 """Initializes DataLoaders with smart sampler selection based on pairing.""" 

148 

149 def _build_loader(dataset: MultiModalDataset, is_train: bool): 

150 batch_size = self._config.batch_size 

151 collate_fn = create_multimodal_collate_fn(multimodal_dataset=dataset) 

152 

153 if dataset.is_fully_paired: 

154 return DataLoader( 

155 dataset, 

156 batch_size=batch_size, 

157 shuffle=is_train, 

158 drop_last=is_train, 

159 collate_fn=collate_fn, 

160 pin_memory=self._config.pin_memory, 

161 ) 

162 else: 

163 print( 

164 f"Dataset has UNPAIRED samples → using CoverageEnsuringSampler " 

165 f"({len(dataset.paired_sample_ids)} paired + {len(dataset.unpaired_sample_ids)} unpaired)" 

166 ) 

167 sampler = CoverageEnsuringSampler( 

168 multimodal_dataset=dataset, 

169 batch_size=batch_size, 

170 ) 

171 return DataLoader( 

172 dataset, 

173 batch_sampler=sampler, # note: batch_sampler, not sampler 

174 collate_fn=collate_fn, 

175 pin_memory=self._config.pin_memory, 

176 ) 

177 

178 # Build train and validation loaders 

179 self._trainloader = _build_loader(self._trainset, is_train=True) 

180 self._validloader = _build_loader(self._validset, is_train=False) 

181 

182 # Optional: expose for logging 

183 self._train_is_fully_paired = len(self._trainset.unpaired_sample_ids) == 0 

184 

185 def _init_modality_training(self): 

186 """Initializes models, optimizers, and training state for each modality.""" 

187 self._modality_dynamics = { 

188 mod_name: None for mod_name in self._trainset.datasets.keys() 

189 } 

190 

191 self._result.sub_results = { 

192 mod_name: None for mod_name in self._trainset.datasets.keys() 

193 } 

194 

195 data_info = self._config.data_config.data_info 

196 for mod_name, ds in self._trainset.datasets.items(): 

197 simple_name = mod_name.split(".")[1] 

198 local_epochs = data_info[simple_name].pretrain_epochs 

199 

200 pretrain_epochs = ( 

201 local_epochs 

202 if local_epochs is not None 

203 else self._config.pretrain_epochs 

204 ) 

205 

206 model_type = self.model_map.get(ds.mytype) 

207 if model_type is None: 

208 raise ValueError( 

209 f"No Mapping exists for {ds.mytype}, you passed this mapping: {self.model_map}" 

210 ) 

211 model = model_type(config=self._config, input_dim=ds.get_input_dim()) 

212 if ( 

213 Version(torch.__version__) >= Version("2.0") 

214 and torch.cuda.is_available() 

215 and self._config.compile_model 

216 ): 

217 model = torch.compile(model) 

218 optimizer = torch.optim.AdamW( 

219 params=model.parameters(), 

220 lr=self._config.learning_rate, 

221 weight_decay=self._config.weight_decay, 

222 ) 

223 self._modality_dynamics[mod_name] = { 

224 "model": model, 

225 "optim": optimizer, 

226 "mp": [], 

227 "losses": [], 

228 "config_name": simple_name, 

229 "pretrain_epochs": pretrain_epochs, 

230 "mytype": ds.mytype, 

231 "pretrain_result": None, 

232 } 

233 

234 def _init_adversarial_training(self): 

235 """Initializes the classifier and its optimizer for adversarial training.""" 

236 self._latent_clf = Classifier( 

237 input_dim=self._config.latent_dim, n_modalities=self._n_train_modalities 

238 ) 

239 self._clf_optim = torch.optim.AdamW( 

240 params=self._latent_clf.parameters(), 

241 lr=self._config.learning_rate, 

242 weight_decay=self._config.weight_decay, 

243 ) 

244 self._clf_loss_fn = torch.nn.CrossEntropyLoss( 

245 reduction=self._config.loss_reduction 

246 ) 

247 

248 def _modalities_forward(self, batch: Dict[str, Dict[str, Any]]): 

249 """Performs forward pass for each modality in the batch and computes losses.""" 

250 for k, v in batch.items(): 

251 model = self._modality_dynamics[k]["model"] 

252 data = v["data"] 

253 

254 mp = model(data) 

255 

256 loss, loss_stats = self.sub_loss( 

257 model_output=mp, targets=data, epoch=self._cur_epoch 

258 ) 

259 

260 self._modality_dynamics[k]["loss_stats"] = loss_stats 

261 self._modality_dynamics[k]["loss"] = loss 

262 self._modality_dynamics[k]["mp"] = mp 

263 

264 def _prep_adver_training(self) -> Tuple[torch.Tensor, torch.Tensor]: 

265 """Prepares concatenated latent spaces and corresponding labels for adversarial training. 

266 

267 Returns: 

268 A tuple containing: 

269 - latents: Concatenated latent space tensor of shape (total_samples, latent_dim) 

270 - labels: Tensor of modality labels of shape (total_samples,) 

271 """ 

272 all_latents: List[torch.Tensor] = [] 

273 all_labels: List[torch.Tensor] = [] 

274 mod2idx = { 

275 mod_name: idx for idx, mod_name in enumerate(self._modality_dynamics.keys()) 

276 } 

277 

278 for mod_name, helper in self._modality_dynamics.items(): 

279 output: ModelOutput = helper["mp"] 

280 latents = output.latentspace 

281 label_id = mod2idx[mod_name] 

282 all_latents.append(latents) 

283 all_labels.append( 

284 torch.full( 

285 (latents.size(0),), 

286 fill_value=label_id, 

287 dtype=torch.long, 

288 device=self._fabric.device, 

289 ) 

290 ) 

291 

292 return torch.cat(all_latents, dim=0), torch.cat(all_labels, dim=0) 

293 

294 def _train_clf(self, latents: torch.Tensor, labels: torch.Tensor) -> None: 

295 """Performs a single optimization step for the classifier. 

296 Args: 

297 latents: Concatenated latent space tensor of shape (total_samples, latent_dim) 

298 labels: Tensor of modality labels of shape (total_samples,) 

299 

300 """ 

301 self._clf_optim.zero_grad() 

302 # We must detach the latents from the computation graph. This serves two purposes: 

303 # 1. It isolates the classifier, ensuring that gradients only update its weights. 

304 # 2. It prevents a "two backward passes" error by leaving the graph to the 

305 # autoencoders untouched, ready for use in the next stage. 

306 detached_latents = latents.detach() 

307 clf_scores = self._latent_clf(detached_latents) 

308 clf_loss = self._clf_loss_fn(clf_scores, labels) 

309 self._fabric.backward(clf_loss) 

310 self._clf_optim.step() 

311 self._clf_epoch_loss += clf_loss.item() 

312 

313 def _validate_one_epoch(self) -> Tuple[List[Dict], Dict[str, float], int]: 

314 """Runs a single, combined validation pass for the entire model, including dedicated metrics for the classifier. 

315 

316 Returns: 

317 A tuple containing: 

318 - epoch_dynamics: List of dictionaries capturing dynamics for each batch 

319 - sub_losses: Dictionary of accumulated sub-losses for the epoch 

320 - n_samples_total: Total number of samples processed in validation 

321 """ 

322 for dynamics in self._modality_dynamics.values(): 

323 dynamics["model"].eval() 

324 self._latent_clf.eval() 

325 

326 self._epoch_loss_valid = 0.0 

327 total_clf_loss = 0.0 

328 n_samples_total: int = 0 

329 epoch_dynamics: List[Dict] = [] 

330 sub_losses: Dict[str, float] = defaultdict(float) 

331 

332 with torch.no_grad(): 

333 for batch in self._validloader: 

334 with self._fabric.autocast(): 

335 self._modalities_forward(batch=batch) 

336 latents, labels = self._prep_adver_training() 

337 n_samples_total += len(labels) 

338 clf_scores = self._latent_clf(latents) 

339 

340 batch_loss, loss_dict = self._loss_fn( 

341 batch=batch, 

342 modality_dynamics=self._modality_dynamics, 

343 clf_scores=clf_scores, 

344 labels=labels, 

345 clf_loss_fn=self._clf_loss_fn, 

346 is_training=False, 

347 ) 

348 self._epoch_loss_valid += batch_loss.item() 

349 for k, v in loss_dict.items(): 

350 value = v.item() if hasattr(v, "item") else v 

351 if "_factor" not in k: 

352 sub_losses[k] += value 

353 else: 

354 sub_losses[k] = value 

355 

356 clf_loss = self._clf_loss_fn(clf_scores, labels) 

357 total_clf_loss += clf_loss.item() 

358 if self._is_checkpoint_epoch: 

359 batch_capture = self._capture_dynamics(batch) 

360 epoch_dynamics.append(batch_capture) 

361 

362 sub_losses["clf_loss"] = total_clf_loss 

363 n_samples_total /= ( 

364 self._n_train_modalities 

365 ) # n_modalities because each sample is counted once per modality and n_modalites is the same for train and valid 

366 for k, v in sub_losses.items(): 

367 if "_factor" not in k: 

368 sub_losses[k] = v / n_samples_total # Average over all samples 

369 self._epoch_loss_valid = ( 

370 self._epoch_loss_valid / n_samples_total 

371 ) # Average over all samples 

372 return epoch_dynamics, sub_losses, len(self._validset) 

373 

374 def _train_one_epoch(self) -> Tuple[List[Dict], Dict[str, float], int]: 

375 """Runs the training loop with corrected adversarial logic. 

376 Returns: 

377 A tuple containing: 

378 - epoch_dynamics: List of dictionaries capturing dynamics for each batch 

379 - sub_losses: Dictionary of accumulated sub-losses for the epoch 

380 - n_samples_total: Total number of samples processed in training 

381 

382 """ 

383 for dynamics in self._modality_dynamics.values(): 

384 dynamics["model"].train() 

385 self._latent_clf.train() 

386 

387 self._clf_epoch_loss = 0 

388 self._epoch_loss = 0 

389 epoch_dynamics: List[Dict] = [] 

390 sub_losses: Dict[str, float] = defaultdict(float) 

391 n_samples_total: int = ( 

392 0 # because of unpaired training we need to sum the samples instead of using len(dataset) 

393 ) 

394 

395 for batch in self._trainloader: 

396 with self._fabric.autocast(): 

397 # --- Stage 1: forward for each data modality --- 

398 self._modalities_forward(batch=batch) 

399 

400 # --- Stage 2: Train the Classifier --- 

401 latents, labels = self._prep_adver_training() 

402 n_samples_total += latents.size(0) 

403 self._train_clf(latents=latents, labels=labels) 

404 

405 # --- Stage 3: Train the Autoencoders --- 

406 for _, dynamics in self._modality_dynamics.items(): 

407 dynamics["optim"].zero_grad() 

408 

409 # We re-calculate scores here and cannot reuse the `clf_scores` from Stage 2. 

410 # The previous scores were from detached latents and would block the gradients 

411 # needed to train the autoencoders adversarially. 

412 clf_scores_for_adv = self._latent_clf(latents) 

413 

414 batch_loss, loss_dict = self._loss_fn( 

415 batch=batch, 

416 modality_dynamics=self._modality_dynamics, 

417 clf_scores=clf_scores_for_adv, 

418 labels=labels, 

419 clf_loss_fn=self._clf_loss_fn, 

420 is_training=True, 

421 ) 

422 self._fabric.backward(batch_loss) 

423 for _, dynamics in self._modality_dynamics.items(): 

424 dynamics["optim"].step() 

425 

426 # --- Logging and Capturing --- 

427 self._epoch_loss += batch_loss.item() 

428 for k, v in loss_dict.items(): 

429 value_to_add = v.item() if hasattr(v, "item") else v 

430 if "_factor" not in k: 

431 sub_losses[k] += value_to_add 

432 else: 

433 sub_losses[k] = value_to_add 

434 

435 if self._is_checkpoint_epoch: 

436 batch_capture = self._capture_dynamics(batch) 

437 epoch_dynamics.append(batch_capture) 

438 n_samples_total /= self._n_train_modalities 

439 sub_losses["clf_loss"] = self._clf_epoch_loss 

440 for k, v in sub_losses.items(): 

441 if "_factor" not in k: 

442 sub_losses[k] = v / n_samples_total # Average over all samples 

443 self._epoch_loss = ( 

444 self._epoch_loss / n_samples_total 

445 ) # Average over all samples 

446 

447 return epoch_dynamics, sub_losses, n_samples_total 

448 

449 def _pretraining(self): 

450 """Pretrain each modality's model if needed.""" 

451 for mod_name, dynamic in self._modality_dynamics.items(): 

452 mytype = dynamic.get("mytype") 

453 pretrain_epochs = dynamic.get("pretrain_epochs") 

454 print(f"Check if we need to pretrain: {mod_name}") 

455 print(f"pretrain epochs : {pretrain_epochs}") 

456 if not pretrain_epochs: 

457 print(f"No pretraining for {mod_name}") 

458 continue 

459 model_type = self.model_map.get(mytype) 

460 pretrainer_type = self.model_trainer_map.get(model_type) 

461 print(f"starting pretraining for: {mod_name} with {pretrainer_type}") 

462 trainset = self._trainset.datasets.get(mod_name) 

463 validset = self._validset.datasets.get(mod_name) 

464 pretrainer = pretrainer_type( 

465 trainset=trainset, 

466 validset=validset, 

467 result=Result(), 

468 config=self._config, 

469 model_type=model_type, 

470 loss_type=self._sub_loss_type, 

471 ontologies=self.ontologies, 

472 ) 

473 pretrain_result = pretrainer.train(epochs_overwrite=pretrain_epochs) 

474 self._result.sub_results[f"pretrain.{mod_name}"] = pretrain_result 

475 self._modality_dynamics[mod_name]["pretrain_result"] = pretrain_result 

476 self._modality_dynamics[mod_name]["model"] = pretrain_result.model 

477 

478 def train(self): 

479 """Orchestrates the full training process for the cross-modal autoencoder.""" 

480 self._pretraining() 

481 for epoch in range(self._config.epochs): 

482 self._cur_epoch = epoch 

483 self._is_checkpoint_epoch = self._should_checkpoint(epoch=epoch) 

484 self._fabric.print(f"--- Epoch {epoch + 1}/{self._config.epochs} ---") 

485 if epoch == 0 and self._config.profiling: 

486 print("Profiling enabled for epoch 0") 

487 self._train_one_epoch_with_profiling() 

488 continue 

489 train_epoch_dynamics, train_sub_losses, n_samples_train = ( 

490 self._train_one_epoch() 

491 ) 

492 

493 self._log_losses( 

494 split="train", 

495 total_loss=self._epoch_loss, 

496 sub_losses=train_sub_losses, 

497 n_samples=n_samples_train, 

498 ) 

499 

500 if self._validset: 

501 valid_epoch_dynamics, valid_sub_losses, n_samples_valid = ( 

502 self._validate_one_epoch() 

503 ) 

504 self._log_losses( 

505 split="valid", 

506 total_loss=self._epoch_loss_valid, 

507 sub_losses=valid_sub_losses, 

508 n_samples=n_samples_valid, 

509 ) 

510 if self._is_checkpoint_epoch: 

511 self._fabric.print(f"Storing checkpoint for epoch {epoch}...") 

512 self._store_checkpoint( 

513 split="train", epoch_dynamics=train_epoch_dynamics 

514 ) 

515 if self._validset: 

516 self._store_checkpoint( 

517 split="valid", 

518 epoch_dynamics=valid_epoch_dynamics, 

519 ) 

520 # save final model 

521 self._store_checkpoint(split="train", epoch_dynamics=train_epoch_dynamics) 

522 self._store_final_models() 

523 return self._result 

524 

525 def decode(self, x: torch.Tensor): 

526 """Decodes input latent representations 

527 Args: 

528 x: Latent representations to decode, shape (n_samples, latent_dim) 

529 """ 

530 raise NotImplementedError("General decode step is not implemented for XModalix") 

531 

532 def predict( 

533 self, 

534 data: BaseDataset, 

535 model: Optional[Dict[str, torch.nn.Module]] = None, 

536 **kwargs, 

537 ) -> Result: 

538 """Performs cross-modal prediction from a specified 'from' modality to a 'to' modality. 

539 

540 The direction is determined by the 'translate_direction' attribute in the config. 

541 Results are stored in the Result object under split='test' and epoch=-1. 

542 

543 Args: 

544 data: A MultiModalDataset containing the input data for the 'from' modality. 

545 

546 Returns: 

547 The Result object populated with prediction results. 

548 """ 

549 split: str = kwargs.pop("split", "test") 

550 if split not in ["train", "valid", "test"]: 

551 raise ValueError( 

552 f"split must be one of 'train', 'valid', or 'test', got {split}" 

553 ) 

554 if not isinstance(data, MultiModalDataset): 

555 raise TypeError( 

556 f"type of data has to be MultiModalDataset, got: {type(data)}" 

557 ) 

558 if model is not None and not hasattr(self, "_modality_dynamics"): 

559 self._modality_dynamics = {mod_name: {} for mod_name in model.keys()} 

560 

561 for mod_name, mod_model in model.items(): 

562 mod_model.eval() 

563 if not isinstance(mod_model, _FabricModule): 

564 mod_model = self._fabric.setup_module(mod_model) 

565 mod_model.to(self._fabric.device) 

566 self._modality_dynamics[mod_name]["model"] = mod_model 

567 # self._setup_fabric(old_model=model) 

568 

569 from_key = kwargs.pop("from_key", None) 

570 to_key = kwargs.pop("to_key", None) 

571 predict_keys = find_translation_keys( 

572 config=self._config, 

573 trained_modalities=list(self._modality_dynamics.keys()), 

574 from_key=from_key, 

575 to_key=to_key, 

576 ) 

577 from_key, to_key = predict_keys["from"], predict_keys["to"] 

578 from_modality = self._modality_dynamics[from_key] 

579 to_modality = self._modality_dynamics[to_key] 

580 from_model = from_modality["model"] 

581 to_model = to_modality["model"] 

582 from_model.eval(), to_model.eval() 

583 # drop_last handled in custom sampler 

584 inference_loader = DataLoader( 

585 data, 

586 batch_sampler=CoverageEnsuringSampler( 

587 multimodal_dataset=data, batch_size=self._config.batch_size 

588 ), 

589 collate_fn=create_multimodal_collate_fn(multimodal_dataset=data), 

590 ) 

591 inference_loader = self._fabric.setup_dataloaders(inference_loader) # type: ignore 

592 epoch_dynamics: List[Dict] = [] 

593 with ( 

594 self._fabric.autocast(), 

595 torch.inference_mode(), 

596 ): 

597 for batch in inference_loader: 

598 # needed for visualize later 

599 self._get_vis_dynamics(batch=batch) 

600 from_z = self._modality_dynamics[from_key]["mp"].latentspace 

601 translated = to_model.decode(x=from_z) 

602 

603 # for reference 

604 to_z = self._modality_dynamics[to_key]["mp"].latentspace 

605 to_to_reference = to_model.decode(x=to_z) 

606 

607 batch_capture = self._capture_dynamics(batch) 

608 translation_key = "translation" 

609 

610 reference_key = f"reference_{to_key}_to_{to_key}" 

611 batch_capture["reconstructions"][ 

612 translation_key 

613 ] = translated.cpu().numpy() 

614 

615 batch_capture["reconstructions"][ 

616 reference_key 

617 ] = to_to_reference.cpu().numpy() 

618 

619 if "sample_ids" in batch[from_key]: 

620 batch_capture["sample_ids"][translation_key] = np.array( 

621 batch[from_key]["sample_ids"] 

622 ) 

623 if "sample_ids" in batch[to_key]: 

624 batch_capture["sample_ids"][reference_key] = np.array( 

625 batch[to_key]["sample_ids"] 

626 ) 

627 

628 epoch_dynamics.append(batch_capture) 

629 

630 self._dynamics_to_result(split=split, epoch_dynamics=epoch_dynamics) 

631 

632 print("Prediction complete.") 

633 return self._result 

634 

635 def _get_vis_dynamics(self, batch: Dict[str, Dict[str, Any]]): 

636 """Runs a forward pass for each modality in the batch and stores the model outputs. 

637 

638 This is used to capture dynamics for visualization or analysis not for actual translation as specified in predict. 

639 Args: 

640 batch: A dictionary where keys are modality names and values are dictionaries containing 'data' tensors. 

641 """ 

642 for mod_name, mod_data in batch.items(): 

643 model = self._modality_dynamics[mod_name]["model"] 

644 # Store model output (containing latents, recons, etc.) 

645 self._modality_dynamics[mod_name]["mp"] = model(mod_data["data"]) 

646 

647 def _capture_dynamics( 

648 self, 

649 batch_data: Union[Dict[str, Dict[str, Any]], Any], 

650 ) -> Union[Dict[str, Dict[str, np.ndarray]], Any]: 

651 """Captures and returns the dynamics (latents, reconstructions, etc.) for each modality in the batch. 

652 

653 Args: 

654 batch_data: A dictionary where keys are modality names and values are dictionaries containing 'data' tensors 

655 and optionally 'sample_ids'. 

656 Returns: 

657 A dictionary with keys 'latentspaces', 'reconstructions', 'mus', 'sigmas', and 'sample_ids', 

658 each mapping to another dictionary where keys are modality names and values are numpy arrays. 

659 """ 

660 captured_data: Dict[str, Dict[str, Any]] = { 

661 "latentspaces": {}, 

662 "reconstructions": {}, 

663 "mus": {}, 

664 "sigmas": {}, 

665 "sample_ids": {}, 

666 } 

667 for mod_name, dynamics in self._modality_dynamics.items(): 

668 sample_ids = batch_data[mod_name].get("sample_ids") 

669 # get index of 

670 if sample_ids is not None: 

671 captured_data["sample_ids"][mod_name] = np.array(sample_ids) 

672 

673 model_output = dynamics["mp"] 

674 captured_data["latentspaces"][mod_name] = model_output.latentspace.detach() 

675 captured_data["reconstructions"][ 

676 mod_name 

677 ] = model_output.reconstruction.detach() 

678 if model_output.latent_mean is not None: 

679 captured_data["mus"][mod_name] = model_output.latent_mean.detach() 

680 if model_output.latent_logvar is not None: 

681 captured_data["sigmas"][mod_name] = model_output.latent_logvar.detach() 

682 

683 return captured_data 

684 

685 def _dynamics_to_result(self, split: str, epoch_dynamics: List[Dict]) -> None: 

686 """Aggregates and stores epoch dynamics into the Result object. 

687 

688 Due to our multi modal Traning with unpaired data, we can see sample_ids more than once per epoch. 

689 For more consistent downstream analysis, we only keep the first occurence of this sample in the epoch 

690 and report the dynamics for this sample 

691 

692 Args: 

693 split: The data split name (e.g., 'train', 'valid', 'test'). 

694 epoch_dynamics: List of dictionaries capturing dynamics for each batch in the epoch. 

695 

696 """ 

697 final_data: Dict[str, Any] = defaultdict(lambda: defaultdict(list)) 

698 for batch_data in epoch_dynamics: 

699 for dynamic_type, mod_data in batch_data.items(): 

700 for mod_name, data in mod_data.items(): 

701 if isinstance(data, torch.Tensor): 

702 data = data.cpu().numpy() 

703 final_data[dynamic_type][mod_name].append(data) 

704 

705 sample_ids: Optional[Dict[str, np.ndarray]] = final_data.get("sample_ids") 

706 if sample_ids is None: 

707 raise ValueError("No Sample Ids in TrainingDynamics") 

708 concat_ids: Dict[str, np.ndarray] = { 

709 k: np.concatenate(v) for k, v in sample_ids.items() 

710 } 

711 unique_idx: Dict[str, np.ndarray] = { 

712 k: np.unique(v, return_index=True)[1] for k, v in concat_ids.items() 

713 } 

714 

715 deduplicated_data: Dict[str, Dict[str, np.ndarray]] = defaultdict(dict) 

716 # all_dynamics_inner: Dict[str, Any] = {} 

717 for dynamic_type, dynamic_data in final_data.items(): 

718 concat_dynamic: Dict[str, np.ndarray] = { 

719 k: np.concatenate(v) for k, v in dynamic_data.items() 

720 } 

721 # all_dynamics_inner[dynamic_type] = concat_dynamic 

722 deduplicated_helper: Dict[str, np.ndarray] = { 

723 k: v[unique_idx[k]] for k, v in concat_dynamic.items() 

724 } 

725 deduplicated_data[dynamic_type] = deduplicated_helper 

726 # Store the deduplicated data 

727 # self._all_dynamics[split][self._cur_epoch] = all_dynamics_inner 

728 self._result.latentspaces.add( 

729 epoch=self._cur_epoch, 

730 split=split, 

731 data=deduplicated_data.get("latentspaces", {}), 

732 ) 

733 

734 self._result.reconstructions.add( 

735 epoch=self._cur_epoch, 

736 split=split, 

737 data=deduplicated_data.get("reconstructions", {}), 

738 ) 

739 

740 if deduplicated_data.get("sample_ids"): 

741 self._result.sample_ids.add( 

742 epoch=self._cur_epoch, 

743 split=split, 

744 data=deduplicated_data["sample_ids"], 

745 ) 

746 

747 if deduplicated_data.get("mus"): 

748 self._result.mus.add( 

749 epoch=self._cur_epoch, 

750 split=split, 

751 data=deduplicated_data["mus"], 

752 ) 

753 

754 if deduplicated_data.get("sigmas"): 

755 self._result.sigmas.add( 

756 epoch=self._cur_epoch, 

757 split=split, 

758 data=deduplicated_data["sigmas"], 

759 ) 

760 

761 def _log_losses( 

762 self, 

763 split: str, 

764 total_loss: float, 

765 sub_losses: Dict[str, float], 

766 n_samples: int, 

767 ) -> None: 

768 """Logs and stores average losses for the epoch. 

769 Args: 

770 split: The data split name (e.g., 'train', 'valid'). 

771 total_loss: The cumulative total loss for the epoch. 

772 sub_losses: Dictionary of accumulated sub-losses for the epoch. 

773 n_samples: Total number of samples processed in the epoch. 

774 

775 """ 

776 

777 n_samples = max(n_samples, 1) 

778 print(f"split: {split}, n_samples: {n_samples}") 

779 # avg_total_loss = total_loss / n_samples ## Now already normalized 

780 self._result.losses.add(epoch=self._cur_epoch, split=split, data=total_loss) 

781 

782 # avg_sub_losses = { 

783 # k: v / n_samples if "_factor" not in k else v for k, v in sub_losses.items() 

784 # } 

785 self._result.sub_losses.add( 

786 epoch=self._cur_epoch, 

787 split=split, 

788 data=sub_losses, 

789 ) 

790 self._fabric.print( 

791 f"Epoch {self._cur_epoch + 1}/{self._config.epochs} - {split.capitalize()} Loss: {total_loss:.4f}" 

792 ) 

793 # Detailed sub-loss logging in one line 

794 sub_loss_str = ", ".join( 

795 [f"{k}: {v:.4f}" for k, v in sub_losses.items() if "_factor" not in k] 

796 ) 

797 self._fabric.print(f"Sub-losses - {sub_loss_str}") 

798 

799 def _store_checkpoint(self, split: str, epoch_dynamics: List[Dict]) -> None: 

800 """Stores model checkpoints and epoch dynamics into the Result object. 

801 

802 Args: 

803 split: The data split name (e.g., 'train', 'valid'). 

804 epoch_dynamics: List of dictionaries capturing dynamics for each batch in the epoch. 

805 

806 """ 

807 

808 state_to_save = { 

809 mod_name: dynamics["model"].state_dict() 

810 for mod_name, dynamics in self._modality_dynamics.items() 

811 } 

812 state_to_save["latent_clf"] = self._latent_clf.state_dict() 

813 self._result.model_checkpoints.add(epoch=self._cur_epoch, data=state_to_save) 

814 self._dynamics_to_result(split=split, epoch_dynamics=epoch_dynamics) 

815 

816 def _store_final_models(self) -> None: 

817 """Stores the final trained models into the Result object.""" 

818 final_models = { 

819 mod_name: dynamics["model"] 

820 for mod_name, dynamics in self._modality_dynamics.items() 

821 } 

822 self._result.model = final_models 

823 

824 def purge(self) -> None: 

825 """ 

826 Cleans up all instantiated resources used during training, including 

827 all modality-specific models/optimizers and the adversarial classifier. 

828 """ 

829 

830 if hasattr(self, "_modality_dynamics"): 

831 for mod_name, dynamics in self._modality_dynamics.items(): 

832 if "model" in dynamics and dynamics["model"] is not None: 

833 # Remove the model and its optimizer from the dynamics dict 

834 del dynamics["model"] 

835 if "optim" in dynamics and dynamics["optim"] is not None: 

836 del dynamics["optim"] 

837 # Delete the container itself after cleaning contents 

838 del self._modality_dynamics 

839 

840 if hasattr(self, "_latent_clf"): 

841 del self._latent_clf 

842 if hasattr(self, "_clf_optim"): 

843 del self._clf_optim 

844 

845 if hasattr(self, "_trainloader"): 

846 del self._trainloader 

847 if hasattr(self, "_validloader") and self._validloader is not None: 

848 del self._validloader 

849 if hasattr(self, "_trainset"): 

850 del self._trainset 

851 if hasattr(self, "_validset"): 

852 del self._validset 

853 if hasattr(self, "_loss_fn"): 

854 del self._loss_fn 

855 if hasattr(self, "sub_loss"): # Directly accessible sub_loss instance 

856 del self.sub_loss 

857 if hasattr(self, "_clf_loss_fn"): 

858 del self._clf_loss_fn 

859 

860 if hasattr(self, "_all_dynamics"): 

861 del self._all_dynamics 

862 

863 if torch.cuda.is_available(): 

864 torch.cuda.empty_cache() 

865 

866 gc.collect() 

867 

868 def _train_one_epoch_with_profiling( 

869 self, 

870 ) -> Tuple[List[Dict], Dict[str, float], int]: 

871 """Training loop with torch.profiler integration.""" 

872 for dynamics in self._modality_dynamics.values(): 

873 dynamics["model"].train() 

874 self._latent_clf.train() 

875 

876 self._clf_epoch_loss = 0 

877 self._epoch_loss = 0 

878 epoch_dynamics: List[Dict] = [] 

879 sub_losses: Dict[str, float] = defaultdict(float) 

880 n_samples_total: int = 0 

881 

882 # Configure profiler 

883 with profile( 

884 activities=[ 

885 ProfilerActivity.CPU, 

886 ProfilerActivity.CUDA, 

887 ], 

888 record_shapes=True, # Record tensor shapes 

889 profile_memory=True, # Track memory usage 

890 with_stack=True, # Include stack traces 

891 # Schedule: skip first 5 batches (warmup), profile next 5, repeat 

892 schedule=torch.profiler.schedule( 

893 wait=1, # Skip first batch (warmup) 

894 warmup=1, # Warmup for 1 batch 

895 active=3, # Profile 3 batches 

896 repeat=1, # Do this once 

897 ), 

898 on_trace_ready=torch.profiler.tensorboard_trace_handler( 

899 "./profiler_logs", worker_name=self._config.profile_logs 

900 ), 

901 ) as prof: 

902 train_iter = iter(self._trainloader) 

903 for batch_idx, batch in enumerate(self._trainloader): 

904 with record_function("dataloader"): 

905 batch = next(train_iter) 

906 with record_function("total_batch"): 

907 with self._fabric.autocast(): 

908 # --- Stage 1: Forward for each modality --- 

909 with record_function("modalities_forward"): 

910 self._modalities_forward(batch=batch) 

911 

912 # --- Stage 2: Prepare adversarial training --- 

913 with record_function("prep_adver_training"): 

914 latents, labels = self._prep_adver_training() 

915 n_samples_total += latents.size(0) 

916 

917 # --- Stage 3: Train Classifier --- 

918 with record_function("train_classifier"): 

919 self._train_clf(latents=latents, labels=labels) 

920 

921 # --- Stage 4: Train Autoencoders --- 

922 with record_function("train_autoencoders"): 

923 for _, dynamics in self._modality_dynamics.items(): 

924 dynamics["optim"].zero_grad() 

925 

926 clf_scores_for_adv = self._latent_clf(latents) 

927 

928 batch_loss, loss_dict = self._loss_fn( 

929 batch=batch, 

930 modality_dynamics=self._modality_dynamics, 

931 clf_scores=clf_scores_for_adv, 

932 labels=labels, 

933 clf_loss_fn=self._clf_loss_fn, 

934 is_training=True, 

935 ) 

936 

937 with record_function("backward_pass"): 

938 self._fabric.backward(batch_loss) 

939 

940 with record_function("optimizer_step"): 

941 for _, dynamics in self._modality_dynamics.items(): 

942 dynamics["optim"].step() 

943 

944 # --- Logging and Capturing --- 

945 self._epoch_loss += batch_loss.item() 

946 for k, v in loss_dict.items(): 

947 value_to_add = v.item() if hasattr(v, "item") else v 

948 if "_factor" not in k: 

949 sub_losses[k] += value_to_add 

950 else: 

951 sub_losses[k] = value_to_add 

952 

953 if self._is_checkpoint_epoch: 

954 with record_function("capture_dynamics"): 

955 batch_capture = self._capture_dynamics(batch) 

956 epoch_dynamics.append(batch_capture) 

957 

958 # Step the profiler 

959 prof.step() 

960 

961 # Stop profiling after a few batches to avoid huge logs 

962 if batch_idx >= 5: 

963 break 

964 # Print summary to console 

965 print("\n" + "=" * 80) 

966 print("PROFILER SUMMARY - Top Operations by CPU Time") 

967 print("=" * 80) 

968 print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=20)) 

969 

970 print("\n" + "=" * 80) 

971 print("PROFILER SUMMARY - Top Operations by CUDA Time") 

972 print("=" * 80) 

973 print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=20)) 

974 

975 # --- Write profiler summary to CSV --- 

976 

977 import csv 

978 from pathlib import Path 

979 

980 csv_path = Path(f"./profiler_logs/{self._config.profile_logs}.csv") 

981 csv_path.parent.mkdir(parents=True, exist_ok=True) 

982 

983 with csv_path.open(mode="w", newline="") as f: 

984 writer = csv.writer(f) 

985 writer.writerow( 

986 [ 

987 "name", 

988 "cpu_timecuda_timecpu_memory_usage", 

989 "cuda_memory_usage", 

990 "input_shapes", 

991 "calls", 

992 ] 

993 ) 

994 

995 for evt in prof.key_averages(): 

996 writer.writerow( 

997 [ 

998 evt.key, 

999 evt.cpu_time, 

1000 evt.cuda_time, 

1001 evt.cpu_memory_usage, 

1002 evt.cuda_memory_usage, 

1003 str(evt.input_shapes), 

1004 evt.count, 

1005 ] 

1006 )