Coverage for src / autoencodix / trainers / _xmodal_trainer.py: 11%
436 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
1from collections import defaultdict
2from packaging.version import Version
3from torch.profiler import profile, record_function, ProfilerActivity
5from lightning_fabric.wrappers import _FabricModule
6import gc
7import os
8from typing import Any, Dict, List, Optional, Tuple, Type, Union
10import numpy as np
11import torch
12from torch.utils.data import DataLoader
14from autoencodix.base._base_autoencoder import BaseAutoencoder
15from autoencodix.base._base_dataset import BaseDataset, DataSetTypes
16from autoencodix.base._base_loss import BaseLoss
17from autoencodix.base._base_trainer import BaseTrainer
18from autoencodix.data._multimodal_dataset import (
19 CoverageEnsuringSampler,
20 MultiModalDataset,
21 create_multimodal_collate_fn,
22)
23from autoencodix.modeling._classifier import Classifier
24from autoencodix.modeling._imagevae_architecture import ImageVAEArchitecture
25from autoencodix.modeling._varix_architecture import VarixArchitecture
26from autoencodix.utils._losses import VarixLoss
27from autoencodix.utils._model_output import ModelOutput
28from autoencodix.utils._result import Result
29from autoencodix.configs.default_config import DefaultConfig
30from autoencodix.utils._internals import model_trainer_map
31from autoencodix.utils._utils import find_translation_keys
34class XModalTrainer(BaseTrainer):
35 """Trainer for cross-modal autoencoders, implements multimodal training with adversarial component.
37 Attributes:
38 _trainset: The dataset used for training, must be a MultiModalDataset.
39 _n_train_modalities: Number of modalities in the training dataset.
40 model_map: Mapping from DataSetTypes to specific autoencoder architectures.
41 model_trainer_map: Mapping from autoencoder architectures to their corresponding trainer classes.
42 _n_cpus: Number of CPU cores available for data loading.
43 latent_dim: Dimensionality of the shared latent space.
44 n_test: Number of samples in the test set, if provided.
45 n_train: Number of samples in the training set.
46 n_valid: Number of samples in the validation set, if provided.
47 n_features: Total number of features across all modalities.
48 _cur_epoch: Current epoch number during training.
49 _is_checkpoint_epoch: Flag indicating if the current epoch is a checkpoint epoch.
50 _sub_loss_type: Loss function type for individual modality autoencoders.
51 sub_loss: Instantiated loss function for individual modality autoencoders.
52 _clf_epoch_loss: Cumulative classifier loss for the current epoch.
53 _epoch_loss: Cumulative total loss for the current epoch.
54 _epoch_loss_valid: Cumulative total validation loss for the current epoch.
55 _modality_dynamics: Dictionary holding model, optimizer, and training state for each modality.
56 _latent_clf: Classifier model for adversarial training on latent spaces.
57 _clf_optim: Optimizer for the classifier model.
58 _clf_loss_fn: Loss function for the classifier.
59 _trainloader: DataLoader for the training dataset.
60 _validloader: DataLoader for the validation dataset, if provided.
61 _model: The instantiated stacked model architecture.
62 _validset: The dataset used for validation, if provided, must be a MultiModalDataset.
63 _fabric: Lightning Fabric wrapper for device and precision management.
65 """
67 def __init__(
68 self,
69 trainset: MultiModalDataset,
70 validset: MultiModalDataset,
71 result: Result,
72 config: DefaultConfig,
73 model_type: Type[BaseAutoencoder],
74 loss_type: Type[BaseLoss],
75 model_map: Dict[DataSetTypes, Type[BaseAutoencoder]],
76 sub_loss_type: Type[BaseLoss] = VarixLoss,
77 ontologies: Optional[Union[Tuple, List]] = None,
78 **kwargs,
79 ):
80 """Initializes the XModalTrainer with datasets and configuration.
82 Args:
83 trainset: Training dataset containing multiple modalities
84 validset: Validation dataset containing multiple modalities
85 result: Result object to store training outcomes
86 config: Configuration parameters for training and model architecture
87 model_type: Type of autoencoder model to use for each modality (not used directly)
88 loss_type: Type of loss function to use for training the stacked model
89 sub_loss_type: Type of loss function to use for individual modality autoencoders
90 model_map: Mapping from DataSetTypes to specific autoencoder architectures
91 ontologies: Ontology information, for compatibility with Ontix
92 """
93 self._trainset = trainset
94 self._n_train_modalities = self._trainset.n_modalities # type: ignore
95 self.model_map = model_map
96 self.model_trainer_map = model_trainer_map
97 self._n_cpus = os.cpu_count()
98 if self._n_cpus is None:
99 self._n_cpus = 0
101 super().__init__(
102 trainset, validset, result, config, model_type, loss_type, ontologies
103 )
104 # init attributes ------------------------------------------------
105 self.latent_dim = config.latent_dim
106 self.n_test: Optional[int] = None
107 self.n_train = len(trainset.data) if trainset else 0
108 self.n_valid = len(validset.data) if validset else 0
109 self._cur_epoch: int = 0
110 self._is_checkpoint_epoch: Optional[bool] = None
111 self._sub_loss_type = sub_loss_type
112 self.sub_loss = sub_loss_type(config=self._config)
113 self._clf_epoch_loss: float = 0.0
114 self._epoch_loss: float = 0.0
115 self._epoch_loss_valid: float = 0.0
116 self._all_dynamics: Dict[str, Dict[int, Dict]] = {
117 "test": defaultdict(dict),
118 "train": defaultdict(dict),
119 "valid": defaultdict(dict),
120 }
122 def _init_model_architecture(self, ontologies: Optional[Tuple] = None):
123 """Override parent's model init - we don't use a single model, needs to be there because is abstract in parent."""
124 pass
126 def _setup_fabric(self, old_model=None):
127 """Sets up the models, optimizers, and data loaders with Lightning Fabric."""
129 self._fabric.launch()
130 self._init_adversarial_training()
131 self._init_modality_training()
132 self._latent_clf, self._clf_optim = self._fabric.setup(
133 self._latent_clf, self._clf_optim
134 )
136 self._trainloader = self._fabric.setup_dataloaders(self._trainloader)
137 if self._validloader is not None:
138 self._validloader = self._fabric.setup_dataloaders(self._validloader)
140 for dynamics in self._modality_dynamics.values():
141 dynamics["model"], dynamics["optim"] = self._fabric.setup(
142 dynamics["model"], dynamics["optim"]
143 )
144 dynamics["model"].mark_forward_method("decode")
146 def _init_loaders(self):
147 """Initializes DataLoaders with smart sampler selection based on pairing."""
149 def _build_loader(dataset: MultiModalDataset, is_train: bool):
150 batch_size = self._config.batch_size
151 collate_fn = create_multimodal_collate_fn(multimodal_dataset=dataset)
153 if dataset.is_fully_paired:
154 return DataLoader(
155 dataset,
156 batch_size=batch_size,
157 shuffle=is_train,
158 drop_last=is_train,
159 collate_fn=collate_fn,
160 pin_memory=self._config.pin_memory,
161 )
162 else:
163 print(
164 f"Dataset has UNPAIRED samples → using CoverageEnsuringSampler "
165 f"({len(dataset.paired_sample_ids)} paired + {len(dataset.unpaired_sample_ids)} unpaired)"
166 )
167 sampler = CoverageEnsuringSampler(
168 multimodal_dataset=dataset,
169 batch_size=batch_size,
170 )
171 return DataLoader(
172 dataset,
173 batch_sampler=sampler, # note: batch_sampler, not sampler
174 collate_fn=collate_fn,
175 pin_memory=self._config.pin_memory,
176 )
178 # Build train and validation loaders
179 self._trainloader = _build_loader(self._trainset, is_train=True)
180 self._validloader = _build_loader(self._validset, is_train=False)
182 # Optional: expose for logging
183 self._train_is_fully_paired = len(self._trainset.unpaired_sample_ids) == 0
185 def _init_modality_training(self):
186 """Initializes models, optimizers, and training state for each modality."""
187 self._modality_dynamics = {
188 mod_name: None for mod_name in self._trainset.datasets.keys()
189 }
191 self._result.sub_results = {
192 mod_name: None for mod_name in self._trainset.datasets.keys()
193 }
195 data_info = self._config.data_config.data_info
196 for mod_name, ds in self._trainset.datasets.items():
197 simple_name = mod_name.split(".")[1]
198 local_epochs = data_info[simple_name].pretrain_epochs
200 pretrain_epochs = (
201 local_epochs
202 if local_epochs is not None
203 else self._config.pretrain_epochs
204 )
206 model_type = self.model_map.get(ds.mytype)
207 if model_type is None:
208 raise ValueError(
209 f"No Mapping exists for {ds.mytype}, you passed this mapping: {self.model_map}"
210 )
211 model = model_type(config=self._config, input_dim=ds.get_input_dim())
212 if (
213 Version(torch.__version__) >= Version("2.0")
214 and torch.cuda.is_available()
215 and self._config.compile_model
216 ):
217 model = torch.compile(model)
218 optimizer = torch.optim.AdamW(
219 params=model.parameters(),
220 lr=self._config.learning_rate,
221 weight_decay=self._config.weight_decay,
222 )
223 self._modality_dynamics[mod_name] = {
224 "model": model,
225 "optim": optimizer,
226 "mp": [],
227 "losses": [],
228 "config_name": simple_name,
229 "pretrain_epochs": pretrain_epochs,
230 "mytype": ds.mytype,
231 "pretrain_result": None,
232 }
234 def _init_adversarial_training(self):
235 """Initializes the classifier and its optimizer for adversarial training."""
236 self._latent_clf = Classifier(
237 input_dim=self._config.latent_dim, n_modalities=self._n_train_modalities
238 )
239 self._clf_optim = torch.optim.AdamW(
240 params=self._latent_clf.parameters(),
241 lr=self._config.learning_rate,
242 weight_decay=self._config.weight_decay,
243 )
244 self._clf_loss_fn = torch.nn.CrossEntropyLoss(
245 reduction=self._config.loss_reduction
246 )
248 def _modalities_forward(self, batch: Dict[str, Dict[str, Any]]):
249 """Performs forward pass for each modality in the batch and computes losses."""
250 for k, v in batch.items():
251 model = self._modality_dynamics[k]["model"]
252 data = v["data"]
254 mp = model(data)
256 loss, loss_stats = self.sub_loss(
257 model_output=mp, targets=data, epoch=self._cur_epoch
258 )
260 self._modality_dynamics[k]["loss_stats"] = loss_stats
261 self._modality_dynamics[k]["loss"] = loss
262 self._modality_dynamics[k]["mp"] = mp
264 def _prep_adver_training(self) -> Tuple[torch.Tensor, torch.Tensor]:
265 """Prepares concatenated latent spaces and corresponding labels for adversarial training.
267 Returns:
268 A tuple containing:
269 - latents: Concatenated latent space tensor of shape (total_samples, latent_dim)
270 - labels: Tensor of modality labels of shape (total_samples,)
271 """
272 all_latents: List[torch.Tensor] = []
273 all_labels: List[torch.Tensor] = []
274 mod2idx = {
275 mod_name: idx for idx, mod_name in enumerate(self._modality_dynamics.keys())
276 }
278 for mod_name, helper in self._modality_dynamics.items():
279 output: ModelOutput = helper["mp"]
280 latents = output.latentspace
281 label_id = mod2idx[mod_name]
282 all_latents.append(latents)
283 all_labels.append(
284 torch.full(
285 (latents.size(0),),
286 fill_value=label_id,
287 dtype=torch.long,
288 device=self._fabric.device,
289 )
290 )
292 return torch.cat(all_latents, dim=0), torch.cat(all_labels, dim=0)
294 def _train_clf(self, latents: torch.Tensor, labels: torch.Tensor) -> None:
295 """Performs a single optimization step for the classifier.
296 Args:
297 latents: Concatenated latent space tensor of shape (total_samples, latent_dim)
298 labels: Tensor of modality labels of shape (total_samples,)
300 """
301 self._clf_optim.zero_grad()
302 # We must detach the latents from the computation graph. This serves two purposes:
303 # 1. It isolates the classifier, ensuring that gradients only update its weights.
304 # 2. It prevents a "two backward passes" error by leaving the graph to the
305 # autoencoders untouched, ready for use in the next stage.
306 detached_latents = latents.detach()
307 clf_scores = self._latent_clf(detached_latents)
308 clf_loss = self._clf_loss_fn(clf_scores, labels)
309 self._fabric.backward(clf_loss)
310 self._clf_optim.step()
311 self._clf_epoch_loss += clf_loss.item()
313 def _validate_one_epoch(self) -> Tuple[List[Dict], Dict[str, float], int]:
314 """Runs a single, combined validation pass for the entire model, including dedicated metrics for the classifier.
316 Returns:
317 A tuple containing:
318 - epoch_dynamics: List of dictionaries capturing dynamics for each batch
319 - sub_losses: Dictionary of accumulated sub-losses for the epoch
320 - n_samples_total: Total number of samples processed in validation
321 """
322 for dynamics in self._modality_dynamics.values():
323 dynamics["model"].eval()
324 self._latent_clf.eval()
326 self._epoch_loss_valid = 0.0
327 total_clf_loss = 0.0
328 n_samples_total: int = 0
329 epoch_dynamics: List[Dict] = []
330 sub_losses: Dict[str, float] = defaultdict(float)
332 with torch.no_grad():
333 for batch in self._validloader:
334 with self._fabric.autocast():
335 self._modalities_forward(batch=batch)
336 latents, labels = self._prep_adver_training()
337 n_samples_total += len(labels)
338 clf_scores = self._latent_clf(latents)
340 batch_loss, loss_dict = self._loss_fn(
341 batch=batch,
342 modality_dynamics=self._modality_dynamics,
343 clf_scores=clf_scores,
344 labels=labels,
345 clf_loss_fn=self._clf_loss_fn,
346 is_training=False,
347 )
348 self._epoch_loss_valid += batch_loss.item()
349 for k, v in loss_dict.items():
350 value = v.item() if hasattr(v, "item") else v
351 if "_factor" not in k:
352 sub_losses[k] += value
353 else:
354 sub_losses[k] = value
356 clf_loss = self._clf_loss_fn(clf_scores, labels)
357 total_clf_loss += clf_loss.item()
358 if self._is_checkpoint_epoch:
359 batch_capture = self._capture_dynamics(batch)
360 epoch_dynamics.append(batch_capture)
362 sub_losses["clf_loss"] = total_clf_loss
363 n_samples_total /= (
364 self._n_train_modalities
365 ) # n_modalities because each sample is counted once per modality and n_modalites is the same for train and valid
366 for k, v in sub_losses.items():
367 if "_factor" not in k:
368 sub_losses[k] = v / n_samples_total # Average over all samples
369 self._epoch_loss_valid = (
370 self._epoch_loss_valid / n_samples_total
371 ) # Average over all samples
372 return epoch_dynamics, sub_losses, len(self._validset)
374 def _train_one_epoch(self) -> Tuple[List[Dict], Dict[str, float], int]:
375 """Runs the training loop with corrected adversarial logic.
376 Returns:
377 A tuple containing:
378 - epoch_dynamics: List of dictionaries capturing dynamics for each batch
379 - sub_losses: Dictionary of accumulated sub-losses for the epoch
380 - n_samples_total: Total number of samples processed in training
382 """
383 for dynamics in self._modality_dynamics.values():
384 dynamics["model"].train()
385 self._latent_clf.train()
387 self._clf_epoch_loss = 0
388 self._epoch_loss = 0
389 epoch_dynamics: List[Dict] = []
390 sub_losses: Dict[str, float] = defaultdict(float)
391 n_samples_total: int = (
392 0 # because of unpaired training we need to sum the samples instead of using len(dataset)
393 )
395 for batch in self._trainloader:
396 with self._fabric.autocast():
397 # --- Stage 1: forward for each data modality ---
398 self._modalities_forward(batch=batch)
400 # --- Stage 2: Train the Classifier ---
401 latents, labels = self._prep_adver_training()
402 n_samples_total += latents.size(0)
403 self._train_clf(latents=latents, labels=labels)
405 # --- Stage 3: Train the Autoencoders ---
406 for _, dynamics in self._modality_dynamics.items():
407 dynamics["optim"].zero_grad()
409 # We re-calculate scores here and cannot reuse the `clf_scores` from Stage 2.
410 # The previous scores were from detached latents and would block the gradients
411 # needed to train the autoencoders adversarially.
412 clf_scores_for_adv = self._latent_clf(latents)
414 batch_loss, loss_dict = self._loss_fn(
415 batch=batch,
416 modality_dynamics=self._modality_dynamics,
417 clf_scores=clf_scores_for_adv,
418 labels=labels,
419 clf_loss_fn=self._clf_loss_fn,
420 is_training=True,
421 )
422 self._fabric.backward(batch_loss)
423 for _, dynamics in self._modality_dynamics.items():
424 dynamics["optim"].step()
426 # --- Logging and Capturing ---
427 self._epoch_loss += batch_loss.item()
428 for k, v in loss_dict.items():
429 value_to_add = v.item() if hasattr(v, "item") else v
430 if "_factor" not in k:
431 sub_losses[k] += value_to_add
432 else:
433 sub_losses[k] = value_to_add
435 if self._is_checkpoint_epoch:
436 batch_capture = self._capture_dynamics(batch)
437 epoch_dynamics.append(batch_capture)
438 n_samples_total /= self._n_train_modalities
439 sub_losses["clf_loss"] = self._clf_epoch_loss
440 for k, v in sub_losses.items():
441 if "_factor" not in k:
442 sub_losses[k] = v / n_samples_total # Average over all samples
443 self._epoch_loss = (
444 self._epoch_loss / n_samples_total
445 ) # Average over all samples
447 return epoch_dynamics, sub_losses, n_samples_total
449 def _pretraining(self):
450 """Pretrain each modality's model if needed."""
451 for mod_name, dynamic in self._modality_dynamics.items():
452 mytype = dynamic.get("mytype")
453 pretrain_epochs = dynamic.get("pretrain_epochs")
454 print(f"Check if we need to pretrain: {mod_name}")
455 print(f"pretrain epochs : {pretrain_epochs}")
456 if not pretrain_epochs:
457 print(f"No pretraining for {mod_name}")
458 continue
459 model_type = self.model_map.get(mytype)
460 pretrainer_type = self.model_trainer_map.get(model_type)
461 print(f"starting pretraining for: {mod_name} with {pretrainer_type}")
462 trainset = self._trainset.datasets.get(mod_name)
463 validset = self._validset.datasets.get(mod_name)
464 pretrainer = pretrainer_type(
465 trainset=trainset,
466 validset=validset,
467 result=Result(),
468 config=self._config,
469 model_type=model_type,
470 loss_type=self._sub_loss_type,
471 ontologies=self.ontologies,
472 )
473 pretrain_result = pretrainer.train(epochs_overwrite=pretrain_epochs)
474 self._result.sub_results[f"pretrain.{mod_name}"] = pretrain_result
475 self._modality_dynamics[mod_name]["pretrain_result"] = pretrain_result
476 self._modality_dynamics[mod_name]["model"] = pretrain_result.model
478 def train(self):
479 """Orchestrates the full training process for the cross-modal autoencoder."""
480 self._pretraining()
481 for epoch in range(self._config.epochs):
482 self._cur_epoch = epoch
483 self._is_checkpoint_epoch = self._should_checkpoint(epoch=epoch)
484 self._fabric.print(f"--- Epoch {epoch + 1}/{self._config.epochs} ---")
485 if epoch == 0 and self._config.profiling:
486 print("Profiling enabled for epoch 0")
487 self._train_one_epoch_with_profiling()
488 continue
489 train_epoch_dynamics, train_sub_losses, n_samples_train = (
490 self._train_one_epoch()
491 )
493 self._log_losses(
494 split="train",
495 total_loss=self._epoch_loss,
496 sub_losses=train_sub_losses,
497 n_samples=n_samples_train,
498 )
500 if self._validset:
501 valid_epoch_dynamics, valid_sub_losses, n_samples_valid = (
502 self._validate_one_epoch()
503 )
504 self._log_losses(
505 split="valid",
506 total_loss=self._epoch_loss_valid,
507 sub_losses=valid_sub_losses,
508 n_samples=n_samples_valid,
509 )
510 if self._is_checkpoint_epoch:
511 self._fabric.print(f"Storing checkpoint for epoch {epoch}...")
512 self._store_checkpoint(
513 split="train", epoch_dynamics=train_epoch_dynamics
514 )
515 if self._validset:
516 self._store_checkpoint(
517 split="valid",
518 epoch_dynamics=valid_epoch_dynamics,
519 )
520 # save final model
521 self._store_checkpoint(split="train", epoch_dynamics=train_epoch_dynamics)
522 self._store_final_models()
523 return self._result
525 def decode(self, x: torch.Tensor):
526 """Decodes input latent representations
527 Args:
528 x: Latent representations to decode, shape (n_samples, latent_dim)
529 """
530 raise NotImplementedError("General decode step is not implemented for XModalix")
532 def predict(
533 self,
534 data: BaseDataset,
535 model: Optional[Dict[str, torch.nn.Module]] = None,
536 **kwargs,
537 ) -> Result:
538 """Performs cross-modal prediction from a specified 'from' modality to a 'to' modality.
540 The direction is determined by the 'translate_direction' attribute in the config.
541 Results are stored in the Result object under split='test' and epoch=-1.
543 Args:
544 data: A MultiModalDataset containing the input data for the 'from' modality.
546 Returns:
547 The Result object populated with prediction results.
548 """
549 split: str = kwargs.pop("split", "test")
550 if split not in ["train", "valid", "test"]:
551 raise ValueError(
552 f"split must be one of 'train', 'valid', or 'test', got {split}"
553 )
554 if not isinstance(data, MultiModalDataset):
555 raise TypeError(
556 f"type of data has to be MultiModalDataset, got: {type(data)}"
557 )
558 if model is not None and not hasattr(self, "_modality_dynamics"):
559 self._modality_dynamics = {mod_name: {} for mod_name in model.keys()}
561 for mod_name, mod_model in model.items():
562 mod_model.eval()
563 if not isinstance(mod_model, _FabricModule):
564 mod_model = self._fabric.setup_module(mod_model)
565 mod_model.to(self._fabric.device)
566 self._modality_dynamics[mod_name]["model"] = mod_model
567 # self._setup_fabric(old_model=model)
569 from_key = kwargs.pop("from_key", None)
570 to_key = kwargs.pop("to_key", None)
571 predict_keys = find_translation_keys(
572 config=self._config,
573 trained_modalities=list(self._modality_dynamics.keys()),
574 from_key=from_key,
575 to_key=to_key,
576 )
577 from_key, to_key = predict_keys["from"], predict_keys["to"]
578 from_modality = self._modality_dynamics[from_key]
579 to_modality = self._modality_dynamics[to_key]
580 from_model = from_modality["model"]
581 to_model = to_modality["model"]
582 from_model.eval(), to_model.eval()
583 # drop_last handled in custom sampler
584 inference_loader = DataLoader(
585 data,
586 batch_sampler=CoverageEnsuringSampler(
587 multimodal_dataset=data, batch_size=self._config.batch_size
588 ),
589 collate_fn=create_multimodal_collate_fn(multimodal_dataset=data),
590 )
591 inference_loader = self._fabric.setup_dataloaders(inference_loader) # type: ignore
592 epoch_dynamics: List[Dict] = []
593 with (
594 self._fabric.autocast(),
595 torch.inference_mode(),
596 ):
597 for batch in inference_loader:
598 # needed for visualize later
599 self._get_vis_dynamics(batch=batch)
600 from_z = self._modality_dynamics[from_key]["mp"].latentspace
601 translated = to_model.decode(x=from_z)
603 # for reference
604 to_z = self._modality_dynamics[to_key]["mp"].latentspace
605 to_to_reference = to_model.decode(x=to_z)
607 batch_capture = self._capture_dynamics(batch)
608 translation_key = "translation"
610 reference_key = f"reference_{to_key}_to_{to_key}"
611 batch_capture["reconstructions"][
612 translation_key
613 ] = translated.cpu().numpy()
615 batch_capture["reconstructions"][
616 reference_key
617 ] = to_to_reference.cpu().numpy()
619 if "sample_ids" in batch[from_key]:
620 batch_capture["sample_ids"][translation_key] = np.array(
621 batch[from_key]["sample_ids"]
622 )
623 if "sample_ids" in batch[to_key]:
624 batch_capture["sample_ids"][reference_key] = np.array(
625 batch[to_key]["sample_ids"]
626 )
628 epoch_dynamics.append(batch_capture)
630 self._dynamics_to_result(split=split, epoch_dynamics=epoch_dynamics)
632 print("Prediction complete.")
633 return self._result
635 def _get_vis_dynamics(self, batch: Dict[str, Dict[str, Any]]):
636 """Runs a forward pass for each modality in the batch and stores the model outputs.
638 This is used to capture dynamics for visualization or analysis not for actual translation as specified in predict.
639 Args:
640 batch: A dictionary where keys are modality names and values are dictionaries containing 'data' tensors.
641 """
642 for mod_name, mod_data in batch.items():
643 model = self._modality_dynamics[mod_name]["model"]
644 # Store model output (containing latents, recons, etc.)
645 self._modality_dynamics[mod_name]["mp"] = model(mod_data["data"])
647 def _capture_dynamics(
648 self,
649 batch_data: Union[Dict[str, Dict[str, Any]], Any],
650 ) -> Union[Dict[str, Dict[str, np.ndarray]], Any]:
651 """Captures and returns the dynamics (latents, reconstructions, etc.) for each modality in the batch.
653 Args:
654 batch_data: A dictionary where keys are modality names and values are dictionaries containing 'data' tensors
655 and optionally 'sample_ids'.
656 Returns:
657 A dictionary with keys 'latentspaces', 'reconstructions', 'mus', 'sigmas', and 'sample_ids',
658 each mapping to another dictionary where keys are modality names and values are numpy arrays.
659 """
660 captured_data: Dict[str, Dict[str, Any]] = {
661 "latentspaces": {},
662 "reconstructions": {},
663 "mus": {},
664 "sigmas": {},
665 "sample_ids": {},
666 }
667 for mod_name, dynamics in self._modality_dynamics.items():
668 sample_ids = batch_data[mod_name].get("sample_ids")
669 # get index of
670 if sample_ids is not None:
671 captured_data["sample_ids"][mod_name] = np.array(sample_ids)
673 model_output = dynamics["mp"]
674 captured_data["latentspaces"][mod_name] = model_output.latentspace.detach()
675 captured_data["reconstructions"][
676 mod_name
677 ] = model_output.reconstruction.detach()
678 if model_output.latent_mean is not None:
679 captured_data["mus"][mod_name] = model_output.latent_mean.detach()
680 if model_output.latent_logvar is not None:
681 captured_data["sigmas"][mod_name] = model_output.latent_logvar.detach()
683 return captured_data
685 def _dynamics_to_result(self, split: str, epoch_dynamics: List[Dict]) -> None:
686 """Aggregates and stores epoch dynamics into the Result object.
688 Due to our multi modal Traning with unpaired data, we can see sample_ids more than once per epoch.
689 For more consistent downstream analysis, we only keep the first occurence of this sample in the epoch
690 and report the dynamics for this sample
692 Args:
693 split: The data split name (e.g., 'train', 'valid', 'test').
694 epoch_dynamics: List of dictionaries capturing dynamics for each batch in the epoch.
696 """
697 final_data: Dict[str, Any] = defaultdict(lambda: defaultdict(list))
698 for batch_data in epoch_dynamics:
699 for dynamic_type, mod_data in batch_data.items():
700 for mod_name, data in mod_data.items():
701 if isinstance(data, torch.Tensor):
702 data = data.cpu().numpy()
703 final_data[dynamic_type][mod_name].append(data)
705 sample_ids: Optional[Dict[str, np.ndarray]] = final_data.get("sample_ids")
706 if sample_ids is None:
707 raise ValueError("No Sample Ids in TrainingDynamics")
708 concat_ids: Dict[str, np.ndarray] = {
709 k: np.concatenate(v) for k, v in sample_ids.items()
710 }
711 unique_idx: Dict[str, np.ndarray] = {
712 k: np.unique(v, return_index=True)[1] for k, v in concat_ids.items()
713 }
715 deduplicated_data: Dict[str, Dict[str, np.ndarray]] = defaultdict(dict)
716 # all_dynamics_inner: Dict[str, Any] = {}
717 for dynamic_type, dynamic_data in final_data.items():
718 concat_dynamic: Dict[str, np.ndarray] = {
719 k: np.concatenate(v) for k, v in dynamic_data.items()
720 }
721 # all_dynamics_inner[dynamic_type] = concat_dynamic
722 deduplicated_helper: Dict[str, np.ndarray] = {
723 k: v[unique_idx[k]] for k, v in concat_dynamic.items()
724 }
725 deduplicated_data[dynamic_type] = deduplicated_helper
726 # Store the deduplicated data
727 # self._all_dynamics[split][self._cur_epoch] = all_dynamics_inner
728 self._result.latentspaces.add(
729 epoch=self._cur_epoch,
730 split=split,
731 data=deduplicated_data.get("latentspaces", {}),
732 )
734 self._result.reconstructions.add(
735 epoch=self._cur_epoch,
736 split=split,
737 data=deduplicated_data.get("reconstructions", {}),
738 )
740 if deduplicated_data.get("sample_ids"):
741 self._result.sample_ids.add(
742 epoch=self._cur_epoch,
743 split=split,
744 data=deduplicated_data["sample_ids"],
745 )
747 if deduplicated_data.get("mus"):
748 self._result.mus.add(
749 epoch=self._cur_epoch,
750 split=split,
751 data=deduplicated_data["mus"],
752 )
754 if deduplicated_data.get("sigmas"):
755 self._result.sigmas.add(
756 epoch=self._cur_epoch,
757 split=split,
758 data=deduplicated_data["sigmas"],
759 )
761 def _log_losses(
762 self,
763 split: str,
764 total_loss: float,
765 sub_losses: Dict[str, float],
766 n_samples: int,
767 ) -> None:
768 """Logs and stores average losses for the epoch.
769 Args:
770 split: The data split name (e.g., 'train', 'valid').
771 total_loss: The cumulative total loss for the epoch.
772 sub_losses: Dictionary of accumulated sub-losses for the epoch.
773 n_samples: Total number of samples processed in the epoch.
775 """
777 n_samples = max(n_samples, 1)
778 print(f"split: {split}, n_samples: {n_samples}")
779 # avg_total_loss = total_loss / n_samples ## Now already normalized
780 self._result.losses.add(epoch=self._cur_epoch, split=split, data=total_loss)
782 # avg_sub_losses = {
783 # k: v / n_samples if "_factor" not in k else v for k, v in sub_losses.items()
784 # }
785 self._result.sub_losses.add(
786 epoch=self._cur_epoch,
787 split=split,
788 data=sub_losses,
789 )
790 self._fabric.print(
791 f"Epoch {self._cur_epoch + 1}/{self._config.epochs} - {split.capitalize()} Loss: {total_loss:.4f}"
792 )
793 # Detailed sub-loss logging in one line
794 sub_loss_str = ", ".join(
795 [f"{k}: {v:.4f}" for k, v in sub_losses.items() if "_factor" not in k]
796 )
797 self._fabric.print(f"Sub-losses - {sub_loss_str}")
799 def _store_checkpoint(self, split: str, epoch_dynamics: List[Dict]) -> None:
800 """Stores model checkpoints and epoch dynamics into the Result object.
802 Args:
803 split: The data split name (e.g., 'train', 'valid').
804 epoch_dynamics: List of dictionaries capturing dynamics for each batch in the epoch.
806 """
808 state_to_save = {
809 mod_name: dynamics["model"].state_dict()
810 for mod_name, dynamics in self._modality_dynamics.items()
811 }
812 state_to_save["latent_clf"] = self._latent_clf.state_dict()
813 self._result.model_checkpoints.add(epoch=self._cur_epoch, data=state_to_save)
814 self._dynamics_to_result(split=split, epoch_dynamics=epoch_dynamics)
816 def _store_final_models(self) -> None:
817 """Stores the final trained models into the Result object."""
818 final_models = {
819 mod_name: dynamics["model"]
820 for mod_name, dynamics in self._modality_dynamics.items()
821 }
822 self._result.model = final_models
824 def purge(self) -> None:
825 """
826 Cleans up all instantiated resources used during training, including
827 all modality-specific models/optimizers and the adversarial classifier.
828 """
830 if hasattr(self, "_modality_dynamics"):
831 for mod_name, dynamics in self._modality_dynamics.items():
832 if "model" in dynamics and dynamics["model"] is not None:
833 # Remove the model and its optimizer from the dynamics dict
834 del dynamics["model"]
835 if "optim" in dynamics and dynamics["optim"] is not None:
836 del dynamics["optim"]
837 # Delete the container itself after cleaning contents
838 del self._modality_dynamics
840 if hasattr(self, "_latent_clf"):
841 del self._latent_clf
842 if hasattr(self, "_clf_optim"):
843 del self._clf_optim
845 if hasattr(self, "_trainloader"):
846 del self._trainloader
847 if hasattr(self, "_validloader") and self._validloader is not None:
848 del self._validloader
849 if hasattr(self, "_trainset"):
850 del self._trainset
851 if hasattr(self, "_validset"):
852 del self._validset
853 if hasattr(self, "_loss_fn"):
854 del self._loss_fn
855 if hasattr(self, "sub_loss"): # Directly accessible sub_loss instance
856 del self.sub_loss
857 if hasattr(self, "_clf_loss_fn"):
858 del self._clf_loss_fn
860 if hasattr(self, "_all_dynamics"):
861 del self._all_dynamics
863 if torch.cuda.is_available():
864 torch.cuda.empty_cache()
866 gc.collect()
868 def _train_one_epoch_with_profiling(
869 self,
870 ) -> Tuple[List[Dict], Dict[str, float], int]:
871 """Training loop with torch.profiler integration."""
872 for dynamics in self._modality_dynamics.values():
873 dynamics["model"].train()
874 self._latent_clf.train()
876 self._clf_epoch_loss = 0
877 self._epoch_loss = 0
878 epoch_dynamics: List[Dict] = []
879 sub_losses: Dict[str, float] = defaultdict(float)
880 n_samples_total: int = 0
882 # Configure profiler
883 with profile(
884 activities=[
885 ProfilerActivity.CPU,
886 ProfilerActivity.CUDA,
887 ],
888 record_shapes=True, # Record tensor shapes
889 profile_memory=True, # Track memory usage
890 with_stack=True, # Include stack traces
891 # Schedule: skip first 5 batches (warmup), profile next 5, repeat
892 schedule=torch.profiler.schedule(
893 wait=1, # Skip first batch (warmup)
894 warmup=1, # Warmup for 1 batch
895 active=3, # Profile 3 batches
896 repeat=1, # Do this once
897 ),
898 on_trace_ready=torch.profiler.tensorboard_trace_handler(
899 "./profiler_logs", worker_name=self._config.profile_logs
900 ),
901 ) as prof:
902 train_iter = iter(self._trainloader)
903 for batch_idx, batch in enumerate(self._trainloader):
904 with record_function("dataloader"):
905 batch = next(train_iter)
906 with record_function("total_batch"):
907 with self._fabric.autocast():
908 # --- Stage 1: Forward for each modality ---
909 with record_function("modalities_forward"):
910 self._modalities_forward(batch=batch)
912 # --- Stage 2: Prepare adversarial training ---
913 with record_function("prep_adver_training"):
914 latents, labels = self._prep_adver_training()
915 n_samples_total += latents.size(0)
917 # --- Stage 3: Train Classifier ---
918 with record_function("train_classifier"):
919 self._train_clf(latents=latents, labels=labels)
921 # --- Stage 4: Train Autoencoders ---
922 with record_function("train_autoencoders"):
923 for _, dynamics in self._modality_dynamics.items():
924 dynamics["optim"].zero_grad()
926 clf_scores_for_adv = self._latent_clf(latents)
928 batch_loss, loss_dict = self._loss_fn(
929 batch=batch,
930 modality_dynamics=self._modality_dynamics,
931 clf_scores=clf_scores_for_adv,
932 labels=labels,
933 clf_loss_fn=self._clf_loss_fn,
934 is_training=True,
935 )
937 with record_function("backward_pass"):
938 self._fabric.backward(batch_loss)
940 with record_function("optimizer_step"):
941 for _, dynamics in self._modality_dynamics.items():
942 dynamics["optim"].step()
944 # --- Logging and Capturing ---
945 self._epoch_loss += batch_loss.item()
946 for k, v in loss_dict.items():
947 value_to_add = v.item() if hasattr(v, "item") else v
948 if "_factor" not in k:
949 sub_losses[k] += value_to_add
950 else:
951 sub_losses[k] = value_to_add
953 if self._is_checkpoint_epoch:
954 with record_function("capture_dynamics"):
955 batch_capture = self._capture_dynamics(batch)
956 epoch_dynamics.append(batch_capture)
958 # Step the profiler
959 prof.step()
961 # Stop profiling after a few batches to avoid huge logs
962 if batch_idx >= 5:
963 break
964 # Print summary to console
965 print("\n" + "=" * 80)
966 print("PROFILER SUMMARY - Top Operations by CPU Time")
967 print("=" * 80)
968 print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=20))
970 print("\n" + "=" * 80)
971 print("PROFILER SUMMARY - Top Operations by CUDA Time")
972 print("=" * 80)
973 print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=20))
975 # --- Write profiler summary to CSV ---
977 import csv
978 from pathlib import Path
980 csv_path = Path(f"./profiler_logs/{self._config.profile_logs}.csv")
981 csv_path.parent.mkdir(parents=True, exist_ok=True)
983 with csv_path.open(mode="w", newline="") as f:
984 writer = csv.writer(f)
985 writer.writerow(
986 [
987 "name",
988 "cpu_timecuda_timecpu_memory_usage",
989 "cuda_memory_usage",
990 "input_shapes",
991 "calls",
992 ]
993 )
995 for evt in prof.key_averages():
996 writer.writerow(
997 [
998 evt.key,
999 evt.cpu_time,
1000 evt.cuda_time,
1001 evt.cpu_memory_usage,
1002 evt.cuda_memory_usage,
1003 str(evt.input_shapes),
1004 evt.count,
1005 ]
1006 )