Coverage for src / autoencodix / base / _base_loss.py: 65%
103 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
1from abc import abstractmethod, ABC
2import itertools
3from autoencodix.configs.default_config import DefaultConfig
4import torch
5from torch import nn
6from autoencodix.utils._annealer import AnnealingScheduler
8from typing import Optional, Any
11class BaseLoss(nn.Module, ABC):
12 """Provides common loss computation functionality for autoencoders.
14 Implements standard loss calculations including reconstruction loss,
15 KL divergence, and Maximum Mean Discrepancy (MMD), while requiring
16 subclasses to implement the specific forward method.
18 Attributes:
19 config: Configuration parameters for the loss function.
20 recon_loss: Module for computing reconstruction loss (MSE or BCE).
21 reduction_fn: Function to apply reduction (mean or sum).
22 compute_kernel: Function to compute kernel for MMD loss.
23 annealing_scheduler: Helper for loss calculation with annealing.
24 """
26 def __init__(self, config: DefaultConfig, annealing_scheduler=None):
27 """Initializes the loss module with the specified configuration.
29 Args:
30 config: Configuration parameters for the loss function.
31 annealing_scheduler: Helper class for loss calculation with annealing.
33 Raises:
34 NotImplementedError: If unsupported loss reduction or reconstruction
35 loss type is specified.
36 """
37 super().__init__()
38 self.annealing_scheduler = annealing_scheduler or AnnealingScheduler()
39 self.config = config
40 self.recon_loss: nn.Module
42 if self.config.loss_reduction == "mean":
43 self.reduction_fn = torch.mean
44 elif self.config.loss_reduction == "sum":
45 self.reduction_fn = torch.sum
46 else:
47 raise NotImplementedError(
48 f"Invalid loss reduction type: {self.config.loss_reduction}. "
49 f"Only 'mean' and 'sum' are supported."
50 )
52 if self.config.reconstruction_loss == "mse":
53 self.recon_loss = nn.MSELoss(reduction=config.loss_reduction)
54 elif self.config.reconstruction_loss == "bce":
55 self.recon_loss = nn.BCEWithLogitsLoss(reduction=config.loss_reduction)
56 else:
57 raise NotImplementedError(
58 f"Invalid reconstruction loss type: {self.config.reconstruction_loss}. "
59 f"Only 'mse' and 'bce' are supported. Please check the value of "
60 f"'config.reconstruction_loss' for typos or unsupported types."
61 )
63 self.compute_kernel = self._mmd_kernel
65 def _mmd_kernel(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
66 """Computes Gaussian kernel for Maximum Mean Discrepancy calculation.
68 Calculates the kernel matrix between two sets of samples, using a
69 Gaussian kernel with normalization by feature dimension.
71 Args:
72 x: First set of input samples.
73 y: Second set of input samples.
75 Returns:
76 Kernel matrix of shape (x.shape[0], y.shape[0]).
77 """
78 x_size = x.size(0)
79 y_size = y.size(0)
80 dim = x.size(1)
82 x = x.unsqueeze(1)
83 y = y.unsqueeze(0)
84 tiled_x = x.expand(x_size, y_size, dim)
85 tiled_y = y.expand(x_size, y_size, dim)
87 kernel_input = (tiled_x - tiled_y).pow(2).mean(2) / float(dim)
88 return torch.exp(-kernel_input)
90 def compute_mmd_loss(
91 self, z: torch.Tensor, true_samples: torch.Tensor
92 ) -> torch.Tensor:
93 """Computes Maximum Mean Discrepancy loss.
95 Args:
96 z: Samples from the encoded distribution.
97 true_samples: Samples from the prior distribution.
99 Returns:
100 The MMD loss value.
102 Raises:
103 NotImplementedError: If unsupported loss reduction type is specified.
104 """
105 true_samples_kernel = self.compute_kernel(x=true_samples, y=true_samples)
106 z_device = z.device
107 true_samples = true_samples.to(z_device)
108 z_kernel = self.compute_kernel(z, z)
109 ztr_kernel = self.compute_kernel(x=true_samples, y=z)
111 if self.config.loss_reduction == "mean":
112 return true_samples_kernel.mean() + z_kernel.mean() - 2 * ztr_kernel.mean()
113 elif self.config.loss_reduction == "sum":
114 return true_samples_kernel.sum() + z_kernel.sum() - 2 * ztr_kernel.sum()
115 else:
116 raise NotImplementedError(
117 f"Invalid loss reduction type: {self.config.loss_reduction}. "
118 f"Only 'mean' and 'sum' are supported."
119 )
121 def compute_kl_loss(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
122 """Computes KL divergence loss between N(mu, logvar) and N(0, 1).
124 Args:
125 mu: Mean tensor.
126 logvar: Log variance tensor.
128 Returns:
129 The KL divergence loss value.
131 Raises:
132 ValueError: If mu and logvar do not have the same shape.
133 """
134 if mu.shape != logvar.shape:
135 raise ValueError(
136 f"Shape mismatch: mu has shape {mu.shape}, but logvar has shape {logvar.shape}."
137 )
138 return -0.5 * self.reduction_fn(1 + logvar - mu.pow(2) - logvar.exp())
140 def compute_variational_loss(
141 self,
142 mu: Optional[torch.Tensor],
143 logvar: Optional[torch.Tensor],
144 z: Optional[torch.Tensor] = None,
145 true_samples: Optional[torch.Tensor] = None,
146 ) -> torch.Tensor:
147 """Computes either KL or MMD loss based on configuration.
149 Args:
150 mu: Mean tensor for variational loss.
151 logvar: Log variance tensor for variational loss.
152 z: Encoded samples for MMD loss.
153 true_samples: Prior samples for MMD loss.
155 Returns:
156 The computed variational loss.
158 Raises:
159 ValueError: If required parameters are missing or if mu and logvar have shape mismatch.
160 NotImplementedError: If unsupported VAE loss type is specified.
161 """
163 if self.config.default_vae_loss == "kl":
164 if mu is None:
165 raise ValueError("mu must be provided for VAE loss")
166 if logvar is None:
167 raise ValueError("logvar must be provided for VAE loss")
168 if mu.shape != logvar.shape:
169 raise ValueError(
170 f"Shape mismatch: mu has shape {mu.shape}, but logvar has shape {logvar.shape}"
171 )
173 return self.compute_kl_loss(mu=mu, logvar=logvar)
175 elif self.config.default_vae_loss == "mmd":
176 if z is None:
177 raise ValueError("z must be provided for MMD loss")
178 if true_samples is None:
179 raise ValueError("true_samples must be provided for MMD loss")
180 return self.compute_mmd_loss(z=z, true_samples=true_samples)
181 else:
182 raise NotImplementedError(
183 f"VAE loss type {self.config.default_vae_loss} is not implemented. "
184 f"Only 'kl' and 'mmd' are supported."
185 )
187 def compute_paired_loss(
188 self,
189 latentspaces: dict[str, torch.Tensor],
190 sample_ids: dict[str, list],
191 ) -> torch.Tensor:
192 """
193 Calculates the paired distance loss across all pairs of modalities in a batch.
195 Args:
196 latentspaces: A dictionary mapping modality names to their latent space tensors.
197 e.g., {'RNA': tensor_rna, 'ATAC': tensor_atac}
198 sample_ids: A dictionary mapping modality names to their list of sample IDs.
200 Returns:
201 A single scalar tensor representing the total paired loss.
202 """
204 loss_helper = []
205 modality_names = list(latentspaces.keys())
207 # 1. Iterate through all unique pairs of modalities
208 for mod_a, mod_b in itertools.combinations(modality_names, 2):
209 ids_a = sample_ids[mod_a]
210 ids_b = sample_ids[mod_b]
212 # 2. Find the intersection of sample IDs
213 common_ids = set(ids_a) & set(ids_b)
215 if not common_ids:
216 print("no common ids")
217 continue
219 # 3. Create a mapping from sample ID to index for efficient lookup
220 id_to_idx_a = {sample_id: i for i, sample_id in enumerate(ids_a)}
221 id_to_idx_b = {sample_id: i for i, sample_id in enumerate(ids_b)}
223 # Get the corresponding indices for the common samples
224 indices_a = [id_to_idx_a[common_id] for common_id in common_ids]
225 indices_b = [id_to_idx_b[common_id] for common_id in common_ids]
227 # 4. Select the latent vectors for the paired samples
228 paired_latents_a = latentspaces[mod_a][indices_a]
229 paired_latents_b = latentspaces[mod_b][indices_b]
231 # 5. Calculate the distance between the aligned latent vectors
232 # L1 distance, averaged over latent dimensions and then over samples
233 distance = torch.abs(paired_latents_a - paired_latents_b).mean(dim=1)
234 pair_loss = self.reduction_fn(distance)
235 loss_helper.append(pair_loss)
236 if not loss_helper:
237 return torch.tensor(0.0)
238 return torch.stack(loss_helper).mean()
240 @staticmethod
241 def _compute_log_gauss_dense(
242 z: torch.Tensor, mu: torch.Tensor, logvar: torch.Tensor
243 ) -> torch.Tensor:
244 """Computes the log probability of a Gaussian distribution.
246 Args:
247 z: Latent variable tensor.
248 mu: Mean tensor.
249 logvar: Log variance tensor.
251 Returns:
252 Log probability of the Gaussian distribution.
253 """
254 return -0.5 * (
255 torch.log(torch.tensor([2 * torch.pi]).to(z.device))
256 + logvar
257 + (z - mu) ** 2 * torch.exp(-logvar)
258 )
260 @staticmethod
261 def _compute_log_import_weight_mat(batch_size: int, n_samples: int) -> torch.Tensor:
262 """Computes the log import weight matrix for disentangled loss.
263 Similar to: https://github.com/rtqichen/beta-tcvae
264 Args:
265 batch_size: Number of samples in the batch.
266 n_samples: Total number of samples in the dataset.
268 Returns:
269 Log import weight matrix of shape (batch_size, n_samples).
270 """
272 N = n_samples
273 M = batch_size - 1
274 strat_weight = (N - M) / (N * M)
275 W = torch.Tensor(batch_size, batch_size).fill_(1 / M)
276 W.view(-1)[:: M + 1] = 1 / N
277 W.view(-1)[1 :: M + 1] = strat_weight
278 W[M - 1, 0] = strat_weight
279 return W.log()
281 @abstractmethod
282 def forward(
283 self,
284 *args,
285 **kwargs,
286 ) -> Any:
287 """Calculates the loss for the autoencoder.
289 This method must be implemented by subclasses to define the specific
290 loss computation logic for the autoencoder. The implementation should
291 compute the total loss as well as any individual loss components
292 (e.g., reconstruction loss, KL divergence, etc.) based on the model's
293 output and the provided targets.
295 Args:
296 *kwargs depending on the loss type and pipeline
299 Returns:
300 - The total loss value as a scalar tensor.
301 - A dictionary of individual loss components, where the keys are
302 descriptive strings (e.g., "reconstruction_loss", "kl_loss") and
303 the values are the corresponding loss tensors.
304 - Implementation in subclasses is flexible, so for new loss classes this can differ.
306 Note:
307 Subclasses must implement this method to define the specific loss
308 computation logic for their use case.
309 """
310 # TODO maybe standardize the return types more i.e. request a scalar and a dict
311 pass