Coverage for src / autoencodix / base / _base_loss.py: 65%

103 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-21 10:09 +0200

1from abc import abstractmethod, ABC 

2import itertools 

3from autoencodix.configs.default_config import DefaultConfig 

4import torch 

5from torch import nn 

6from autoencodix.utils._annealer import AnnealingScheduler 

7 

8from typing import Optional, Any 

9 

10 

11class BaseLoss(nn.Module, ABC): 

12 """Provides common loss computation functionality for autoencoders. 

13 

14 Implements standard loss calculations including reconstruction loss, 

15 KL divergence, and Maximum Mean Discrepancy (MMD), while requiring 

16 subclasses to implement the specific forward method. 

17 

18 Attributes: 

19 config: Configuration parameters for the loss function. 

20 recon_loss: Module for computing reconstruction loss (MSE or BCE). 

21 reduction_fn: Function to apply reduction (mean or sum). 

22 compute_kernel: Function to compute kernel for MMD loss. 

23 annealing_scheduler: Helper for loss calculation with annealing. 

24 """ 

25 

26 def __init__(self, config: DefaultConfig, annealing_scheduler=None): 

27 """Initializes the loss module with the specified configuration. 

28 

29 Args: 

30 config: Configuration parameters for the loss function. 

31 annealing_scheduler: Helper class for loss calculation with annealing. 

32 

33 Raises: 

34 NotImplementedError: If unsupported loss reduction or reconstruction 

35 loss type is specified. 

36 """ 

37 super().__init__() 

38 self.annealing_scheduler = annealing_scheduler or AnnealingScheduler() 

39 self.config = config 

40 self.recon_loss: nn.Module 

41 

42 if self.config.loss_reduction == "mean": 

43 self.reduction_fn = torch.mean 

44 elif self.config.loss_reduction == "sum": 

45 self.reduction_fn = torch.sum 

46 else: 

47 raise NotImplementedError( 

48 f"Invalid loss reduction type: {self.config.loss_reduction}. " 

49 f"Only 'mean' and 'sum' are supported." 

50 ) 

51 

52 if self.config.reconstruction_loss == "mse": 

53 self.recon_loss = nn.MSELoss(reduction=config.loss_reduction) 

54 elif self.config.reconstruction_loss == "bce": 

55 self.recon_loss = nn.BCEWithLogitsLoss(reduction=config.loss_reduction) 

56 else: 

57 raise NotImplementedError( 

58 f"Invalid reconstruction loss type: {self.config.reconstruction_loss}. " 

59 f"Only 'mse' and 'bce' are supported. Please check the value of " 

60 f"'config.reconstruction_loss' for typos or unsupported types." 

61 ) 

62 

63 self.compute_kernel = self._mmd_kernel 

64 

65 def _mmd_kernel(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 

66 """Computes Gaussian kernel for Maximum Mean Discrepancy calculation. 

67 

68 Calculates the kernel matrix between two sets of samples, using a 

69 Gaussian kernel with normalization by feature dimension. 

70 

71 Args: 

72 x: First set of input samples. 

73 y: Second set of input samples. 

74 

75 Returns: 

76 Kernel matrix of shape (x.shape[0], y.shape[0]). 

77 """ 

78 x_size = x.size(0) 

79 y_size = y.size(0) 

80 dim = x.size(1) 

81 

82 x = x.unsqueeze(1) 

83 y = y.unsqueeze(0) 

84 tiled_x = x.expand(x_size, y_size, dim) 

85 tiled_y = y.expand(x_size, y_size, dim) 

86 

87 kernel_input = (tiled_x - tiled_y).pow(2).mean(2) / float(dim) 

88 return torch.exp(-kernel_input) 

89 

90 def compute_mmd_loss( 

91 self, z: torch.Tensor, true_samples: torch.Tensor 

92 ) -> torch.Tensor: 

93 """Computes Maximum Mean Discrepancy loss. 

94 

95 Args: 

96 z: Samples from the encoded distribution. 

97 true_samples: Samples from the prior distribution. 

98 

99 Returns: 

100 The MMD loss value. 

101 

102 Raises: 

103 NotImplementedError: If unsupported loss reduction type is specified. 

104 """ 

105 true_samples_kernel = self.compute_kernel(x=true_samples, y=true_samples) 

106 z_device = z.device 

107 true_samples = true_samples.to(z_device) 

108 z_kernel = self.compute_kernel(z, z) 

109 ztr_kernel = self.compute_kernel(x=true_samples, y=z) 

110 

111 if self.config.loss_reduction == "mean": 

112 return true_samples_kernel.mean() + z_kernel.mean() - 2 * ztr_kernel.mean() 

113 elif self.config.loss_reduction == "sum": 

114 return true_samples_kernel.sum() + z_kernel.sum() - 2 * ztr_kernel.sum() 

115 else: 

116 raise NotImplementedError( 

117 f"Invalid loss reduction type: {self.config.loss_reduction}. " 

118 f"Only 'mean' and 'sum' are supported." 

119 ) 

120 

121 def compute_kl_loss(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor: 

122 """Computes KL divergence loss between N(mu, logvar) and N(0, 1). 

123 

124 Args: 

125 mu: Mean tensor. 

126 logvar: Log variance tensor. 

127 

128 Returns: 

129 The KL divergence loss value. 

130 

131 Raises: 

132 ValueError: If mu and logvar do not have the same shape. 

133 """ 

134 if mu.shape != logvar.shape: 

135 raise ValueError( 

136 f"Shape mismatch: mu has shape {mu.shape}, but logvar has shape {logvar.shape}." 

137 ) 

138 return -0.5 * self.reduction_fn(1 + logvar - mu.pow(2) - logvar.exp()) 

139 

140 def compute_variational_loss( 

141 self, 

142 mu: Optional[torch.Tensor], 

143 logvar: Optional[torch.Tensor], 

144 z: Optional[torch.Tensor] = None, 

145 true_samples: Optional[torch.Tensor] = None, 

146 ) -> torch.Tensor: 

147 """Computes either KL or MMD loss based on configuration. 

148 

149 Args: 

150 mu: Mean tensor for variational loss. 

151 logvar: Log variance tensor for variational loss. 

152 z: Encoded samples for MMD loss. 

153 true_samples: Prior samples for MMD loss. 

154 

155 Returns: 

156 The computed variational loss. 

157 

158 Raises: 

159 ValueError: If required parameters are missing or if mu and logvar have shape mismatch. 

160 NotImplementedError: If unsupported VAE loss type is specified. 

161 """ 

162 

163 if self.config.default_vae_loss == "kl": 

164 if mu is None: 

165 raise ValueError("mu must be provided for VAE loss") 

166 if logvar is None: 

167 raise ValueError("logvar must be provided for VAE loss") 

168 if mu.shape != logvar.shape: 

169 raise ValueError( 

170 f"Shape mismatch: mu has shape {mu.shape}, but logvar has shape {logvar.shape}" 

171 ) 

172 

173 return self.compute_kl_loss(mu=mu, logvar=logvar) 

174 

175 elif self.config.default_vae_loss == "mmd": 

176 if z is None: 

177 raise ValueError("z must be provided for MMD loss") 

178 if true_samples is None: 

179 raise ValueError("true_samples must be provided for MMD loss") 

180 return self.compute_mmd_loss(z=z, true_samples=true_samples) 

181 else: 

182 raise NotImplementedError( 

183 f"VAE loss type {self.config.default_vae_loss} is not implemented. " 

184 f"Only 'kl' and 'mmd' are supported." 

185 ) 

186 

187 def compute_paired_loss( 

188 self, 

189 latentspaces: dict[str, torch.Tensor], 

190 sample_ids: dict[str, list], 

191 ) -> torch.Tensor: 

192 """ 

193 Calculates the paired distance loss across all pairs of modalities in a batch. 

194 

195 Args: 

196 latentspaces: A dictionary mapping modality names to their latent space tensors. 

197 e.g., {'RNA': tensor_rna, 'ATAC': tensor_atac} 

198 sample_ids: A dictionary mapping modality names to their list of sample IDs. 

199 

200 Returns: 

201 A single scalar tensor representing the total paired loss. 

202 """ 

203 

204 loss_helper = [] 

205 modality_names = list(latentspaces.keys()) 

206 

207 # 1. Iterate through all unique pairs of modalities 

208 for mod_a, mod_b in itertools.combinations(modality_names, 2): 

209 ids_a = sample_ids[mod_a] 

210 ids_b = sample_ids[mod_b] 

211 

212 # 2. Find the intersection of sample IDs 

213 common_ids = set(ids_a) & set(ids_b) 

214 

215 if not common_ids: 

216 print("no common ids") 

217 continue 

218 

219 # 3. Create a mapping from sample ID to index for efficient lookup 

220 id_to_idx_a = {sample_id: i for i, sample_id in enumerate(ids_a)} 

221 id_to_idx_b = {sample_id: i for i, sample_id in enumerate(ids_b)} 

222 

223 # Get the corresponding indices for the common samples 

224 indices_a = [id_to_idx_a[common_id] for common_id in common_ids] 

225 indices_b = [id_to_idx_b[common_id] for common_id in common_ids] 

226 

227 # 4. Select the latent vectors for the paired samples 

228 paired_latents_a = latentspaces[mod_a][indices_a] 

229 paired_latents_b = latentspaces[mod_b][indices_b] 

230 

231 # 5. Calculate the distance between the aligned latent vectors 

232 # L1 distance, averaged over latent dimensions and then over samples 

233 distance = torch.abs(paired_latents_a - paired_latents_b).mean(dim=1) 

234 pair_loss = self.reduction_fn(distance) 

235 loss_helper.append(pair_loss) 

236 if not loss_helper: 

237 return torch.tensor(0.0) 

238 return torch.stack(loss_helper).mean() 

239 

240 @staticmethod 

241 def _compute_log_gauss_dense( 

242 z: torch.Tensor, mu: torch.Tensor, logvar: torch.Tensor 

243 ) -> torch.Tensor: 

244 """Computes the log probability of a Gaussian distribution. 

245 

246 Args: 

247 z: Latent variable tensor. 

248 mu: Mean tensor. 

249 logvar: Log variance tensor. 

250 

251 Returns: 

252 Log probability of the Gaussian distribution. 

253 """ 

254 return -0.5 * ( 

255 torch.log(torch.tensor([2 * torch.pi]).to(z.device)) 

256 + logvar 

257 + (z - mu) ** 2 * torch.exp(-logvar) 

258 ) 

259 

260 @staticmethod 

261 def _compute_log_import_weight_mat(batch_size: int, n_samples: int) -> torch.Tensor: 

262 """Computes the log import weight matrix for disentangled loss. 

263 Similar to: https://github.com/rtqichen/beta-tcvae 

264 Args: 

265 batch_size: Number of samples in the batch. 

266 n_samples: Total number of samples in the dataset. 

267 

268 Returns: 

269 Log import weight matrix of shape (batch_size, n_samples). 

270 """ 

271 

272 N = n_samples 

273 M = batch_size - 1 

274 strat_weight = (N - M) / (N * M) 

275 W = torch.Tensor(batch_size, batch_size).fill_(1 / M) 

276 W.view(-1)[:: M + 1] = 1 / N 

277 W.view(-1)[1 :: M + 1] = strat_weight 

278 W[M - 1, 0] = strat_weight 

279 return W.log() 

280 

281 @abstractmethod 

282 def forward( 

283 self, 

284 *args, 

285 **kwargs, 

286 ) -> Any: 

287 """Calculates the loss for the autoencoder. 

288 

289 This method must be implemented by subclasses to define the specific 

290 loss computation logic for the autoencoder. The implementation should 

291 compute the total loss as well as any individual loss components 

292 (e.g., reconstruction loss, KL divergence, etc.) based on the model's 

293 output and the provided targets. 

294 

295 Args: 

296 *kwargs depending on the loss type and pipeline 

297 

298 

299 Returns: 

300 - The total loss value as a scalar tensor. 

301 - A dictionary of individual loss components, where the keys are 

302 descriptive strings (e.g., "reconstruction_loss", "kl_loss") and 

303 the values are the corresponding loss tensors. 

304 - Implementation in subclasses is flexible, so for new loss classes this can differ. 

305 

306 Note: 

307 Subclasses must implement this method to define the specific loss 

308 computation logic for their use case. 

309 """ 

310 # TODO maybe standardize the return types more i.e. request a scalar and a dict 

311 pass