Coverage for src / autoencodix / losses / xmodal_loss.py: 15%
117 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
1import torch
2import numpy as np
3import pandas as pd
4from typing import Dict, Optional, Any, List
5from collections import defaultdict
6from autoencodix.base._base_loss import BaseLoss
7from autoencodix.utils._annealer import AnnealingScheduler
8from autoencodix.configs.default_config import DefaultConfig
9from autoencodix.utils._utils import flip_labels
12class XModalLoss(BaseLoss):
13 """Implements Loss for XModalix.
14 The loss of the XModalix consists of 4 parts:
15 - Combined (mean) Reconstruction loss over all sub modalities.
16 - Combined Distribution loss over all sub modalities (KL or MMD)
17 - Class loss: When we have metadata information about a sample e.g.
18 cancer type, we calculate the mean of all samples in the latent
19 space of each class and then we calc the distance between the mean
20 individual sample.
21 - Paired loss: When samples of different modalities are paired, then
22 their latens space representation should be proximal.
23 - Advers loss: Forces the latentspaaces of different data modalities to
24 be similar.
25 The reconstruction and distribution loss are calculated, combined, and weighted
26 (with hyperparam beta) for each sub-modality first. In this class we grab this
27 loss from the modality_dynamics and combine them for all sub-modalities with the mean.
28 Attributes:
29 class_means_train:
30 class_means_valid:
31 sample_to_class_map:
32 """
34 def __init__(self, config: DefaultConfig, annealing_scheduler=None):
35 super().__init__(config)
36 self.class_means_train: Dict[str, torch.Tensor] = {}
37 self.class_means_valid: Dict[str, torch.Tensor] = {}
38 self.sample_to_class_map: Dict[str, Any] = {}
39 self.class_mean_momentum = 0.75
41 def forward(
42 self,
43 batch: Dict[str, Dict[str, Any]],
44 modality_dynamics: Dict[str, Dict[str, Any]],
45 clf_scores: torch.Tensor,
46 labels: torch.Tensor,
47 clf_loss_fn: torch.nn.Module,
48 is_training: bool = True,
49 epoch: Optional[int] = None,
50 **kwargs,
51 ):
52 """Forward pass of XModal loss.
53 Args:
54 batch: data from custom dataset
55 modality_dynamics: trainingdynamics such as losses, reconstructions for each data modality.
56 clf_scores: output of latent classifier for advers loss.
57 labels: indicator to which datamodality which latentspace belongs (for adver. loss).
58 clf_loss_fn: loss for clf, passed in xmodal_trainer.py, defaults Crossentropy.
59 is_training: indicator if we're in training, false for valid and test loop. Used for class loss calc.
60 **kwargs: addtional keyword arguments.
61 """
62 adver_loss = self._calc_adversial_loss(
63 labels=labels, clf_loss_fn=clf_loss_fn, clf_scores=clf_scores
64 )
65 aggregated_sub_losses = self._combine_sub_losses(
66 modality_dynamics=modality_dynamics
67 )
68 sub_losses = self._store_sub_losses(modality_dynamics=modality_dynamics)
69 paired_loss = self._calc_paired_loss(
70 batch=batch, modality_dynamics=modality_dynamics
71 )
72 class_loss = self._calc_class_loss(
73 batch=batch,
74 modality_dynamics=modality_dynamics,
75 is_training=is_training,
76 )
77 total_loss = (
78 self.config.gamma * adver_loss
79 + aggregated_sub_losses # beta already applied when calcing sub_loss in train loop (works because of distributive law.)
80 + self.config.delta_pair * paired_loss
81 + self.config.delta_class * class_loss
82 )
83 loss_dict = {
84 "adver_loss": self.config.gamma * adver_loss,
85 "aggregated_sub_losses": aggregated_sub_losses,
86 "paired_loss": self.config.delta_pair * paired_loss,
87 "class_loss": self.config.delta_class * class_loss,
88 }
89 loss_dict.update(sub_losses)
90 return total_loss, loss_dict
92 def _calc_class_loss(
93 self,
94 batch: Dict[str, Dict[str, Any]],
95 modality_dynamics: Dict[str, Dict[str, Any]],
96 is_training: bool,
97 ) -> torch.Tensor:
98 """
99 Optimized, vectorized version of the original _calc_class_loss.
101 Notes:
102 - Handles arbitrary (hashable) class labels (e.g., strings or ints).
103 - Groups latents by label across modalities without looping over samples.
104 - Initializes new class means from the batch means (detached).
105 - Updates existing class means using in-place EMA under torch.no_grad().
106 - Preserves original return semantics (avg across modalities that have metadata).
107 """
109 if not self.config.class_param:
110 # choose a device if possible
111 first_mp = next(iter(modality_dynamics.values()), None)
112 device = (
113 first_mp["mp"].latentspace.device
114 if first_mp is not None
115 else torch.device("cpu")
116 )
117 return torch.tensor(0.0, device=device)
119 any_mp = next(iter(modality_dynamics.values()))
120 device = any_mp["mp"].latentspace.device
122 class_means_dict = (
123 self.class_means_train if is_training else self.class_means_valid
124 )
126 # 1) Build global lists: all_latents (tensor), all_labels (python list), and modality slices
127 latents_list = []
128 labels_list = [] # python objects (strings/ints)
129 modality_slices = {} # mod_name -> (start_idx, end_idx)
130 start = 0
132 for mod_name, mod_data in batch.items():
133 sample_ids: Optional[List[str]] = mod_data.get("sample_ids")
134 mod_labels: Optional[List[str]] = mod_data.get("class_labels")
135 if not sample_ids or not mod_labels:
136 import warnings
138 warnings.warn(f"No metadata for modality {mod_name}")
139 continue
140 self.sample_to_class_map.update(
141 {
142 sample_id: mod_label
143 for sample_id, mod_label in zip(sample_ids, mod_labels)
144 }
145 )
146 latents = modality_dynamics[mod_name]["mp"].latentspace # (N_mod, D)
147 n_mod = latents.shape[0]
149 if len(mod_labels) != n_mod:
150 raise ValueError(
151 f"Mismatch between number of latents ({n_mod}) and labels ({len(mod_labels)}) for modality {mod_name}"
152 )
154 latents_list.append(latents)
155 labels_list.extend(mod_labels)
156 modality_slices[mod_name] = (start, start + n_mod)
157 start += n_mod
159 if len(latents_list) == 0:
160 return torch.tensor(0.0, device=device)
161 all_latents = torch.cat(latents_list, dim=0) # shape (N_total, D)
163 # 2) Determine unique labels (preserve order) - small (<=20) so Python-level unique is fine
164 unique_labels = list(
165 dict.fromkeys(labels_list)
166 ) # preserves first-occurrence order
168 # 3) Compute batch means per label (vectorized per label w/o per-sample loops)
169 batch_means: Dict[Any, torch.Tensor] = {}
170 for lbl in unique_labels:
171 # find indices for this label (python-level)
172 # create a tensor of indices for selection to avoid many small GPU->CPU ops
173 indices = [i for i, l in enumerate(labels_list) if l == lbl]
174 if len(indices) == 0:
175 continue
176 idx_tensor = torch.tensor(indices, device=device, dtype=torch.long)
177 # gather latents and compute mean
178 lbl_latents = all_latents.index_select(0, idx_tensor)
179 batch_mean = lbl_latents.mean(dim=0)
180 batch_means[lbl] = batch_mean # kept on device
182 # 4) Initialize new classes if needed (mean of current batch samples)
183 for lbl, bmean in batch_means.items():
184 if lbl not in class_means_dict:
185 # store detached clone to avoid accidental linkage to computation graph
186 class_means_dict[lbl] = bmean.detach().clone()
188 # 5) Compute loss per modality using current class means
189 total_class_loss = torch.tensor(0.0, device=device)
190 num_modalities_with_metadata = 0
192 for mod_name, (start_idx, end_idx) in modality_slices.items():
193 num_modalities_with_metadata += 1
194 latents = all_latents[start_idx:end_idx] # view into concatenated tensor
195 mod_labels = labels_list[start_idx:end_idx]
197 # Build target_means_tensor by stacking class_means_dict entries (ensures device alignment)
198 # This is vectorized: (N_mod, D)
199 target_means = torch.stack(
200 [class_means_dict[lbl].to(device) for lbl in mod_labels], dim=0
201 )
203 distances = torch.linalg.norm(latents - target_means, dim=1)
205 # Apply the configured reduction to distances
206 total_class_loss = total_class_loss + self.reduction_fn(distances)
208 # Average across modalities (same semantic as original)
209 avg_class_loss = (
210 total_class_loss / num_modalities_with_metadata
211 if num_modalities_with_metadata > 0
212 else total_class_loss
213 )
215 # 6) If training, update class means using EMA with batch statistics (in-place, no_grad)
216 # if is_training:
217 with torch.no_grad():
218 for lbl, bmean in batch_means.items():
219 if lbl in class_means_dict:
220 # ensure same device, then update in-place
221 cm = class_means_dict[lbl]
222 if cm.device != bmean.device:
223 cm = cm.to(bmean.device)
224 class_means_dict[lbl] = cm
225 cm.mul_(self.class_mean_momentum).add_(
226 bmean, alpha=(1 - self.class_mean_momentum)
227 )
228 else:
229 # this branch shouldn't run due to initialization above, but keep for safety
230 class_means_dict[lbl] = bmean.detach().clone()
232 return avg_class_loss
234 def _calc_paired_loss(
235 self,
236 batch: Dict[str, Dict[str, Any]],
237 modality_dynamics: Dict[str, Dict[str, Any]],
238 ) -> torch.Tensor:
239 """Compute the paired loss across modalities in the current batch.
240 This method prepares latent spaces and sample IDs for each modality and
241 computes a paired loss if at least two modalities are present. If fewer
242 than two modalities exist, it returns a zero tensor that still requires
243 gradients to preserve the computation graph.
244 Args:
245 batch: A dictionary mapping modality names to modality data.
246 Each entry contains sample identifiers under the key
247 `"sample_ids"`.
248 modality_dynamics: A dictionary mapping modality names to their
249 dynamics, where each entry contains a `"mp"` object with a
250 `latentspace` tensor.
251 Returns:
252 A scalar tensor representing the paired loss. Returns a zero tensor
253 requiring gradients if fewer than two modalities are available.
254 """
255 latentspaces = {
256 mod_name: dynamics["mp"].latentspace
257 for mod_name, dynamics in modality_dynamics.items()
258 }
259 sample_ids = {
260 mod_name: mod_data["sample_ids"] for mod_name, mod_data in batch.items()
261 }
262 if len(latentspaces) < 2:
263 # Return a zero tensor that requires gradients to avoid issues in the graph
264 # Assuming at least one tensor exists to get the device
265 any_latent_tensor = next(iter(latentspaces.values()))
266 return torch.tensor(
267 0.0, device=any_latent_tensor.device, requires_grad=True
268 )
269 return self.compute_paired_loss(
270 latentspaces=latentspaces, sample_ids=sample_ids
271 )
273 def _combine_sub_losses(
274 self, modality_dynamics: Dict[str, Dict[str, Any]]
275 ) -> torch.Tensor:
276 """Combines the sub losses total loss for all modalities."""
277 losses = [helper["loss"] for helper in modality_dynamics.values()]
278 if self.config.loss_reduction == "mean":
279 return torch.stack(losses).mean()
280 return torch.stack(losses).sum()
282 def _store_sub_losses(
283 self, modality_dynamics: Dict[str, Dict[str, Any]]
284 ) -> Dict[str, float]:
285 sub_losses = {}
286 for k, v in modality_dynamics.items():
287 for k2, v2 in v.items():
288 if isinstance(v2, dict):
289 for k3, v3 in v2.items():
290 new_key = f"{k}.{k3}"
291 sub_losses[new_key] = v3.item()
292 if k2 == "loss":
293 sub_losses[f"{k}.loss"] = v2.item()
294 return sub_losses
296 def _calc_adversial_loss(
297 self,
298 labels: torch.Tensor,
299 clf_scores: torch.Tensor,
300 clf_loss_fn: torch.nn.Module,
301 ):
302 flipped_labels = flip_labels(labels=labels)
303 adversarial_loss = clf_loss_fn(clf_scores, flipped_labels)
304 return adversarial_loss