Coverage for src / autoencodix / losses / xmodal_loss.py: 15%

117 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-21 10:09 +0200

1import torch 

2import numpy as np 

3import pandas as pd 

4from typing import Dict, Optional, Any, List 

5from collections import defaultdict 

6from autoencodix.base._base_loss import BaseLoss 

7from autoencodix.utils._annealer import AnnealingScheduler 

8from autoencodix.configs.default_config import DefaultConfig 

9from autoencodix.utils._utils import flip_labels 

10 

11 

12class XModalLoss(BaseLoss): 

13 """Implements Loss for XModalix. 

14 The loss of the XModalix consists of 4 parts: 

15 - Combined (mean) Reconstruction loss over all sub modalities. 

16 - Combined Distribution loss over all sub modalities (KL or MMD) 

17 - Class loss: When we have metadata information about a sample e.g. 

18 cancer type, we calculate the mean of all samples in the latent 

19 space of each class and then we calc the distance between the mean 

20 individual sample. 

21 - Paired loss: When samples of different modalities are paired, then 

22 their latens space representation should be proximal. 

23 - Advers loss: Forces the latentspaaces of different data modalities to 

24 be similar. 

25 The reconstruction and distribution loss are calculated, combined, and weighted 

26 (with hyperparam beta) for each sub-modality first. In this class we grab this 

27 loss from the modality_dynamics and combine them for all sub-modalities with the mean. 

28 Attributes: 

29 class_means_train: 

30 class_means_valid: 

31 sample_to_class_map: 

32 """ 

33 

34 def __init__(self, config: DefaultConfig, annealing_scheduler=None): 

35 super().__init__(config) 

36 self.class_means_train: Dict[str, torch.Tensor] = {} 

37 self.class_means_valid: Dict[str, torch.Tensor] = {} 

38 self.sample_to_class_map: Dict[str, Any] = {} 

39 self.class_mean_momentum = 0.75 

40 

41 def forward( 

42 self, 

43 batch: Dict[str, Dict[str, Any]], 

44 modality_dynamics: Dict[str, Dict[str, Any]], 

45 clf_scores: torch.Tensor, 

46 labels: torch.Tensor, 

47 clf_loss_fn: torch.nn.Module, 

48 is_training: bool = True, 

49 epoch: Optional[int] = None, 

50 **kwargs, 

51 ): 

52 """Forward pass of XModal loss. 

53 Args: 

54 batch: data from custom dataset 

55 modality_dynamics: trainingdynamics such as losses, reconstructions for each data modality. 

56 clf_scores: output of latent classifier for advers loss. 

57 labels: indicator to which datamodality which latentspace belongs (for adver. loss). 

58 clf_loss_fn: loss for clf, passed in xmodal_trainer.py, defaults Crossentropy. 

59 is_training: indicator if we're in training, false for valid and test loop. Used for class loss calc. 

60 **kwargs: addtional keyword arguments. 

61 """ 

62 adver_loss = self._calc_adversial_loss( 

63 labels=labels, clf_loss_fn=clf_loss_fn, clf_scores=clf_scores 

64 ) 

65 aggregated_sub_losses = self._combine_sub_losses( 

66 modality_dynamics=modality_dynamics 

67 ) 

68 sub_losses = self._store_sub_losses(modality_dynamics=modality_dynamics) 

69 paired_loss = self._calc_paired_loss( 

70 batch=batch, modality_dynamics=modality_dynamics 

71 ) 

72 class_loss = self._calc_class_loss( 

73 batch=batch, 

74 modality_dynamics=modality_dynamics, 

75 is_training=is_training, 

76 ) 

77 total_loss = ( 

78 self.config.gamma * adver_loss 

79 + aggregated_sub_losses # beta already applied when calcing sub_loss in train loop (works because of distributive law.) 

80 + self.config.delta_pair * paired_loss 

81 + self.config.delta_class * class_loss 

82 ) 

83 loss_dict = { 

84 "adver_loss": self.config.gamma * adver_loss, 

85 "aggregated_sub_losses": aggregated_sub_losses, 

86 "paired_loss": self.config.delta_pair * paired_loss, 

87 "class_loss": self.config.delta_class * class_loss, 

88 } 

89 loss_dict.update(sub_losses) 

90 return total_loss, loss_dict 

91 

92 def _calc_class_loss( 

93 self, 

94 batch: Dict[str, Dict[str, Any]], 

95 modality_dynamics: Dict[str, Dict[str, Any]], 

96 is_training: bool, 

97 ) -> torch.Tensor: 

98 """ 

99 Optimized, vectorized version of the original _calc_class_loss. 

100 

101 Notes: 

102 - Handles arbitrary (hashable) class labels (e.g., strings or ints). 

103 - Groups latents by label across modalities without looping over samples. 

104 - Initializes new class means from the batch means (detached). 

105 - Updates existing class means using in-place EMA under torch.no_grad(). 

106 - Preserves original return semantics (avg across modalities that have metadata). 

107 """ 

108 

109 if not self.config.class_param: 

110 # choose a device if possible 

111 first_mp = next(iter(modality_dynamics.values()), None) 

112 device = ( 

113 first_mp["mp"].latentspace.device 

114 if first_mp is not None 

115 else torch.device("cpu") 

116 ) 

117 return torch.tensor(0.0, device=device) 

118 

119 any_mp = next(iter(modality_dynamics.values())) 

120 device = any_mp["mp"].latentspace.device 

121 

122 class_means_dict = ( 

123 self.class_means_train if is_training else self.class_means_valid 

124 ) 

125 

126 # 1) Build global lists: all_latents (tensor), all_labels (python list), and modality slices 

127 latents_list = [] 

128 labels_list = [] # python objects (strings/ints) 

129 modality_slices = {} # mod_name -> (start_idx, end_idx) 

130 start = 0 

131 

132 for mod_name, mod_data in batch.items(): 

133 sample_ids: Optional[List[str]] = mod_data.get("sample_ids") 

134 mod_labels: Optional[List[str]] = mod_data.get("class_labels") 

135 if not sample_ids or not mod_labels: 

136 import warnings 

137 

138 warnings.warn(f"No metadata for modality {mod_name}") 

139 continue 

140 self.sample_to_class_map.update( 

141 { 

142 sample_id: mod_label 

143 for sample_id, mod_label in zip(sample_ids, mod_labels) 

144 } 

145 ) 

146 latents = modality_dynamics[mod_name]["mp"].latentspace # (N_mod, D) 

147 n_mod = latents.shape[0] 

148 

149 if len(mod_labels) != n_mod: 

150 raise ValueError( 

151 f"Mismatch between number of latents ({n_mod}) and labels ({len(mod_labels)}) for modality {mod_name}" 

152 ) 

153 

154 latents_list.append(latents) 

155 labels_list.extend(mod_labels) 

156 modality_slices[mod_name] = (start, start + n_mod) 

157 start += n_mod 

158 

159 if len(latents_list) == 0: 

160 return torch.tensor(0.0, device=device) 

161 all_latents = torch.cat(latents_list, dim=0) # shape (N_total, D) 

162 

163 # 2) Determine unique labels (preserve order) - small (<=20) so Python-level unique is fine 

164 unique_labels = list( 

165 dict.fromkeys(labels_list) 

166 ) # preserves first-occurrence order 

167 

168 # 3) Compute batch means per label (vectorized per label w/o per-sample loops) 

169 batch_means: Dict[Any, torch.Tensor] = {} 

170 for lbl in unique_labels: 

171 # find indices for this label (python-level) 

172 # create a tensor of indices for selection to avoid many small GPU->CPU ops 

173 indices = [i for i, l in enumerate(labels_list) if l == lbl] 

174 if len(indices) == 0: 

175 continue 

176 idx_tensor = torch.tensor(indices, device=device, dtype=torch.long) 

177 # gather latents and compute mean 

178 lbl_latents = all_latents.index_select(0, idx_tensor) 

179 batch_mean = lbl_latents.mean(dim=0) 

180 batch_means[lbl] = batch_mean # kept on device 

181 

182 # 4) Initialize new classes if needed (mean of current batch samples) 

183 for lbl, bmean in batch_means.items(): 

184 if lbl not in class_means_dict: 

185 # store detached clone to avoid accidental linkage to computation graph 

186 class_means_dict[lbl] = bmean.detach().clone() 

187 

188 # 5) Compute loss per modality using current class means 

189 total_class_loss = torch.tensor(0.0, device=device) 

190 num_modalities_with_metadata = 0 

191 

192 for mod_name, (start_idx, end_idx) in modality_slices.items(): 

193 num_modalities_with_metadata += 1 

194 latents = all_latents[start_idx:end_idx] # view into concatenated tensor 

195 mod_labels = labels_list[start_idx:end_idx] 

196 

197 # Build target_means_tensor by stacking class_means_dict entries (ensures device alignment) 

198 # This is vectorized: (N_mod, D) 

199 target_means = torch.stack( 

200 [class_means_dict[lbl].to(device) for lbl in mod_labels], dim=0 

201 ) 

202 

203 distances = torch.linalg.norm(latents - target_means, dim=1) 

204 

205 # Apply the configured reduction to distances 

206 total_class_loss = total_class_loss + self.reduction_fn(distances) 

207 

208 # Average across modalities (same semantic as original) 

209 avg_class_loss = ( 

210 total_class_loss / num_modalities_with_metadata 

211 if num_modalities_with_metadata > 0 

212 else total_class_loss 

213 ) 

214 

215 # 6) If training, update class means using EMA with batch statistics (in-place, no_grad) 

216 # if is_training: 

217 with torch.no_grad(): 

218 for lbl, bmean in batch_means.items(): 

219 if lbl in class_means_dict: 

220 # ensure same device, then update in-place 

221 cm = class_means_dict[lbl] 

222 if cm.device != bmean.device: 

223 cm = cm.to(bmean.device) 

224 class_means_dict[lbl] = cm 

225 cm.mul_(self.class_mean_momentum).add_( 

226 bmean, alpha=(1 - self.class_mean_momentum) 

227 ) 

228 else: 

229 # this branch shouldn't run due to initialization above, but keep for safety 

230 class_means_dict[lbl] = bmean.detach().clone() 

231 

232 return avg_class_loss 

233 

234 def _calc_paired_loss( 

235 self, 

236 batch: Dict[str, Dict[str, Any]], 

237 modality_dynamics: Dict[str, Dict[str, Any]], 

238 ) -> torch.Tensor: 

239 """Compute the paired loss across modalities in the current batch. 

240 This method prepares latent spaces and sample IDs for each modality and 

241 computes a paired loss if at least two modalities are present. If fewer 

242 than two modalities exist, it returns a zero tensor that still requires 

243 gradients to preserve the computation graph. 

244 Args: 

245 batch: A dictionary mapping modality names to modality data. 

246 Each entry contains sample identifiers under the key 

247 `"sample_ids"`. 

248 modality_dynamics: A dictionary mapping modality names to their 

249 dynamics, where each entry contains a `"mp"` object with a 

250 `latentspace` tensor. 

251 Returns: 

252 A scalar tensor representing the paired loss. Returns a zero tensor 

253 requiring gradients if fewer than two modalities are available. 

254 """ 

255 latentspaces = { 

256 mod_name: dynamics["mp"].latentspace 

257 for mod_name, dynamics in modality_dynamics.items() 

258 } 

259 sample_ids = { 

260 mod_name: mod_data["sample_ids"] for mod_name, mod_data in batch.items() 

261 } 

262 if len(latentspaces) < 2: 

263 # Return a zero tensor that requires gradients to avoid issues in the graph 

264 # Assuming at least one tensor exists to get the device 

265 any_latent_tensor = next(iter(latentspaces.values())) 

266 return torch.tensor( 

267 0.0, device=any_latent_tensor.device, requires_grad=True 

268 ) 

269 return self.compute_paired_loss( 

270 latentspaces=latentspaces, sample_ids=sample_ids 

271 ) 

272 

273 def _combine_sub_losses( 

274 self, modality_dynamics: Dict[str, Dict[str, Any]] 

275 ) -> torch.Tensor: 

276 """Combines the sub losses total loss for all modalities.""" 

277 losses = [helper["loss"] for helper in modality_dynamics.values()] 

278 if self.config.loss_reduction == "mean": 

279 return torch.stack(losses).mean() 

280 return torch.stack(losses).sum() 

281 

282 def _store_sub_losses( 

283 self, modality_dynamics: Dict[str, Dict[str, Any]] 

284 ) -> Dict[str, float]: 

285 sub_losses = {} 

286 for k, v in modality_dynamics.items(): 

287 for k2, v2 in v.items(): 

288 if isinstance(v2, dict): 

289 for k3, v3 in v2.items(): 

290 new_key = f"{k}.{k3}" 

291 sub_losses[new_key] = v3.item() 

292 if k2 == "loss": 

293 sub_losses[f"{k}.loss"] = v2.item() 

294 return sub_losses 

295 

296 def _calc_adversial_loss( 

297 self, 

298 labels: torch.Tensor, 

299 clf_scores: torch.Tensor, 

300 clf_loss_fn: torch.nn.Module, 

301 ): 

302 flipped_labels = flip_labels(labels=labels) 

303 adversarial_loss = clf_loss_fn(clf_scores, flipped_labels) 

304 return adversarial_loss