Source code for ise.models.loss

import torch


[docs] class WeightedGridLoss(torch.nn.Module): """ Custom loss function that penalizes errors based on the total variation of a grid. This loss function consists of two components: 1. **Pixel-wise Weighted Mean Squared Error (MSE):** Higher weight is assigned to extreme values. 2. **Total Variation Regularization (TVR):** Enforces spatial smoothness by penalizing large differences between adjacent grid values. Attributes: device (str): The device on which the model runs ('cuda' or 'cpu'). Methods: - total_variation_regularization: Computes TVR penalty for smoothness. - weighted_pixelwise_mse: Computes MSE loss with per-pixel weighting. - forward: Computes the final weighted loss. """ def __init__(self): super(WeightedGridLoss, self).__init__() self.to(self.device)
[docs] def total_variation_regularization(self, grid): """ Computes the total variation regularization (TVR) loss for spatial smoothness. Args: grid (Tensor): A 2D tensor representing spatial data. Returns: Tensor: The total variation loss. """ # Calculate the sum of horizontal and vertical differences horizontal_diff = torch.abs(torch.diff(grid, axis=2)) vertical_diff = torch.abs(torch.diff(grid, axis=1)) total_variation = torch.sum(horizontal_diff, axis=(1, 2)) + torch.sum( vertical_diff, axis=(1, 2) ) return torch.mean(total_variation)
[docs] def weighted_pixelwise_mse(self, true, predicted, weights): """ Computes the pixel-wise mean squared error (MSE) with custom weights. Args: true (Tensor): Ground truth values. predicted (Tensor): Model predictions. weights (Tensor): Weighting factor for each pixel. Returns: Tensor: Weighted mean squared error. """ # Compute the squared error squared_error = (true - predicted) ** 2 # Apply weights weighted_error = weights * squared_error # Return the mean of the weighted error return torch.mean(weighted_error)
[docs] def forward(self, true, predicted, smoothness_weight=0.001, extreme_value_threshold=1e-6): """ Computes the final weighted loss combining pixel-wise MSE and TVR. Args: true (Tensor): Ground truth values. predicted (Tensor): Model predictions. smoothness_weight (float, optional): Weighting factor for the TVR loss. Defaults to 0.001. extreme_value_threshold (float, optional): Threshold to define extreme values. Defaults to 1e-6. Returns: Tensor: The total computed loss. """ true = torch.tensor(true, dtype=torch.float32, device=self.device) predicted = torch.tensor(predicted, dtype=torch.float32, device=self.device) # Determine weights based on extreme values if extreme_value_threshold is not None: # Identify extreme values in the true data extreme_mask = torch.abs(true) > extreme_value_threshold # Assign higher weight to extreme values, 1 to others weights = torch.where(extreme_mask, 10.0 * torch.ones_like(true), torch.ones_like(true)) else: # If no threshold is provided, use uniform weights weights = torch.ones_like(true) pixelwise_mse = self.weighted_pixelwise_mse(true, predicted, weights) tvr = self.total_variation_regularization(predicted) return pixelwise_mse + smoothness_weight * tvr
[docs] class WeightedMSELoss(torch.nn.Module): """ Custom loss function that applies a weighted penalty to extreme values. This function increases the weight of extreme values based on their deviation from the dataset mean, normalizing by the standard deviation. Attributes: data_mean (Tensor): Mean value of the dataset. data_std (Tensor): Standard deviation of the dataset. weight_factor (Tensor): Factor controlling how much extreme values are penalized. Methods: - forward: Computes the weighted mean squared error loss. """ def __init__(self, data_mean, data_std, weight_factor=1.0): super(WeightedMSELoss, self).__init__() self.device = "cuda" if torch.cuda.is_available() else "cpu" self.data_mean = torch.tensor(data_mean, dtype=torch.float32, device=self.device) self.data_std = torch.tensor(data_std, dtype=torch.float32, device=self.device) self.weight_factor = torch.tensor(weight_factor, dtype=torch.float32, device=self.device) self.to(self.device)
[docs] def forward(self, input, target): """ Computes the Weighted Mean Squared Error (MSE) Loss. Args: input (Tensor): Predicted values. target (Tensor): Ground truth values. Returns: Tensor: Computed loss. """ # Ensure data_mean, data_std, and weight_factor are on the same device as input input = input.to(self.device) target = target.to(self.device) # Calculate the deviation of each target value from the mean deviation = torch.abs(target - self.data_mean) # Scale deviations by the standard deviation to normalize them # normalized_deviation = torch.tensor(deviation / self.data_std, dtype=torch.float32, device=self.device) normalized_deviation = deviation / self.data_std # Compute weights: increase penalty for extreme values weights = 1 + (normalized_deviation * self.weight_factor) # Compute the squared error squared_error = torch.nn.functional.mse_loss(input, target, reduction="none") # Apply the weights and take the mean to get the final loss weighted_squared_error = weights * squared_error loss = torch.mean(weighted_squared_error) return loss
[docs] class WeightedMSEPCALoss(torch.nn.Module): """ Extension of WeightedMSELoss that allows for custom per-batch weighting. This loss function enables additional user-defined weights to further adjust penalties for different predictions. Attributes: data_mean (Tensor): Mean of the dataset. data_std (Tensor): Standard deviation of the dataset. weight_factor (Tensor): Controls the penalty for extreme values. custom_weights (Tensor, optional): User-defined weight tensor. Methods: - forward: Computes the batch-weighted MSE loss. """ def __init__(self, data_mean, data_std, weight_factor=1.0, custom_weights=None): super(WeightedMSEPCALoss, self).__init__() self.device = "cuda" if torch.cuda.is_available() else "cpu" self.to(self.device) self.data_mean = torch.tensor(data_mean, dtype=torch.float32, device=self.device) self.data_std = torch.tensor(data_std, dtype=torch.float32, device=self.device) self.weight_factor = torch.tensor(weight_factor, dtype=torch.float32, device=self.device) self.custom_weights = ( torch.tensor(custom_weights, dtype=torch.float32, device=self.device) if custom_weights is not None else None )
[docs] def forward(self, input, target): """ Computes the batch-weighted mean squared error loss. Args: input (Tensor): Predicted values. target (Tensor): Ground truth values. Returns: Tensor: Computed loss. """ input = input.to(self.device) target = target.to(self.device) # Ensure input and target are of the same shape if input.shape != target.shape: raise ValueError("Input and target must have the same shape.") # Calculate the deviation of each target value from the mean deviation = torch.abs(target - self.data_mean) # Scale deviations by the standard deviation to normalize them normalized_deviation = deviation / self.data_std # Compute weights: increase penalty for extreme values weights = 1 + (normalized_deviation * self.weight_factor) # If custom weights are provided, multiply them by the calculated weights if self.custom_weights is not None: # Expand custom weights to match batch size if necessary if self.custom_weights.dim() == 1: self.custom_weights = self.custom_weights.unsqueeze(0) # Make it a 2D tensor if self.custom_weights.shape != weights.shape: raise ValueError("Custom weights shape must match input/target shape.") weights *= self.custom_weights # Compute the squared error for each element in the batch without reducing squared_error = (input - target) ** 2 # Apply the weights to the squared error weighted_squared_error = weights * squared_error # Take the mean across all dimensions to get the final loss loss = torch.mean(weighted_squared_error) return loss
[docs] class WeightedMSELossWithSignPenalty(torch.nn.Module): """ Custom loss function that penalizes errors on extreme values and opposite sign predictions. This function extends WeightedMSELoss by adding a penalty when the sign of the prediction differs from the target. Attributes: data_mean (Tensor): Mean of the dataset. data_std (Tensor): Standard deviation of the dataset. weight_factor (Tensor): Factor controlling extreme value weighting. sign_penalty_factor (Tensor): Factor controlling penalty for opposite sign predictions. Methods: - forward: Computes the weighted loss with sign penalties. """ def __init__(self, data_mean, data_std, weight_factor=1.0, sign_penalty_factor=1.0): super(WeightedMSELossWithSignPenalty, self).__init__() self.device = "cuda" if torch.cuda.is_available() else "cpu" self.data_mean = torch.tensor(data_mean, dtype=torch.float32, device=self.device) self.data_std = torch.tensor(data_std, dtype=torch.float32, device=self.device) self.weight_factor = torch.tensor(weight_factor, dtype=torch.float32, device=self.device) self.sign_penalty_factor = torch.tensor( sign_penalty_factor, dtype=torch.float32, device=self.device ) self.to(self.device)
[docs] def forward(self, input, target): """ Computes the Weighted MSE Loss with an additional sign penalty. Args: input (Tensor): Predicted values. target (Tensor): Ground truth values. Returns: Tensor: Computed loss. """ # Calculate the deviation of each target value from the mean deviation = torch.abs(target - self.data_mean) # Scale deviations by the standard deviation to normalize them normalized_deviation = deviation / self.data_std # Compute weights: increase penalty for extreme values weights = 1 + (normalized_deviation * self.weight_factor) # Compute the squared error squared_error = torch.nn.functional.mse_loss(input, target, reduction="none") # Calculate sign penalty sign_penalty = torch.where( torch.sign(input) != torch.sign(target), torch.abs(input - target) * self.sign_penalty_factor, torch.zeros_like(input), ) # Apply the weights and sign penalty, then take the mean to get the final loss weighted_squared_error = weights * (squared_error + sign_penalty) loss = torch.mean(weighted_squared_error) return loss
[docs] class GridCriterion(torch.nn.Module): """ Custom loss function enforcing spatial smoothness using total variation regularization. This function encourages smoothness in spatial predictions by penalizing large variations. Methods: - total_variation_regularization: Computes the smoothness loss. - forward: Computes the final loss. """ def __init__( self, ): super(GridCriterion, self).__init__()
[docs] def total_variation_regularization(self, grid): """ Computes total variation regularization (TVR) loss. Args: grid (Tensor): A 2D tensor representing spatial data. Returns: Tensor: TVR loss. """ # Calculate the sum of horizontal and vertical differences horizontal_diff = torch.abs(torch.diff(grid, axis=2)) vertical_diff = torch.abs(torch.diff(grid, axis=1)) total_variation = torch.sum(horizontal_diff, axis=(1, 2)) + torch.sum( vertical_diff, axis=(1, 2) ) return torch.mean(total_variation)
# def spatial_loss(self, true, predicted, smoothness_weight=0.001):
[docs] def forward(self, true, predicted, smoothness_weight=0.001): """ Computes the final loss by combining pixel-wise MSE and TVR. Args: true (Tensor): Ground truth values. predicted (Tensor): Model predictions. smoothness_weight (float, optional): Weight for TVR. Defaults to 0.001. Returns: Tensor: Computed loss. """ pixelwise_mse = torch.mean( torch.abs(true - predicted) ** 2, ) # loss for each image in the batch (batch_size,) tvr = self.total_variation_regularization( predicted, ) return pixelwise_mse + smoothness_weight * tvr
# def forward(self, true, predicted, x, y, flow, predictor_weight=0.5, nf_weight=0.5,): # if predictor_weight + nf_weight != 1: # raise ValueError("The sum of predictor_weight and nf_weight must be 1") # predictor_loss = self.spatial_loss(true, predicted, smoothness_weight=0.2) # nf_loss = -flow.log_prob(inputs=y, context=x) # return predictor_weight*predictor_loss + nf_weight*nf_loss
[docs] class WeightedPCALoss(torch.nn.Module): """ Custom loss function applying different weights to errors in principal component analysis. This function allows assigning higher penalties to the first components. Attributes: component_weights (Tensor): Weighting factors for each principal component. reduction (str): Specifies reduction mode ('mean', 'sum', or 'none'). Methods: - forward: Computes the weighted PCA loss. """ def __init__(self, component_weights, reduction="mean"): super(WeightedPCALoss, self).__init__() self.device = "cuda" if torch.cuda.is_available() else "cpu" self.component_weights = torch.tensor( component_weights, dtype=torch.float32, device=self.device ) if len(self.component_weights.size()) == 1: self.component_weights = self.component_weights.unsqueeze(0) # Make it a row vector self.reduction = reduction self.to(self.device)
[docs] def forward(self, input, target): """ Computes the weighted PCA loss. Args: input (Tensor): Predicted principal components. target (Tensor): Actual principal components. Returns: Tensor: Computed loss. """ input = input.to(self.device) target = target.to(self.device) # Ensure input and target are of the same shape if input.shape != target.shape: raise ValueError("Input and target must have the same shape") # Calculate the squared error squared_error = (input - target) ** 2 # Apply weights to the squared error weighted_error = squared_error * self.component_weights.to(input.device) # Apply reduction if self.reduction == "mean": return torch.mean(weighted_error) elif self.reduction == "sum": return torch.sum(weighted_error) else: return weighted_error
[docs] class MSEDeviationLoss(torch.nn.Module): """ Custom MSE Loss with an additional penalty for large deviations. This function penalizes predictions that deviate significantly from the target. Attributes: threshold (float): Deviation threshold for applying penalties. penalty_multiplier (float): Multiplier controlling penalty severity. Methods: - forward: Computes the loss with deviation penalties. """ def __init__(self, threshold=1.0, penalty_multiplier=2.0): super(MSEDeviationLoss, self).__init__() self.threshold = threshold self.penalty_multiplier = penalty_multiplier
[docs] def forward(self, predictions, targets): """ Computes the MSE loss with an additional deviation penalty. Args: predictions (Tensor): Predicted values. targets (Tensor): Ground truth values. Returns: Tensor: Computed loss. """ mse_loss = torch.mean((predictions - targets) ** 2) large_deviation_penalty = torch.mean( torch.where( torch.abs(predictions - targets) > self.threshold, self.penalty_multiplier * (predictions - targets) ** 2, torch.tensor(0.0, device=predictions.device), ) ) return mse_loss + large_deviation_penalty