Source code for ise.models.loss

"""This module contains custom loss functions for training neural network emulators.

Classes:
    - WeightedGridLoss: Custom loss function that penalizes errors based on the total variation of a grid.
    - WeightedMSELoss: Custom loss function that penalizes errors on extreme values more.
    - WeightedMSEPCALoss: Custom loss function that penalizes errors on extreme values more and allows for custom weighting of each prediction in a batched manner.
    - WeightedMSELossWithSignPenalty: Custom loss function that penalizes errors on extreme values more and adds a penalty for opposite sign predictions.
"""


import torch


[docs] class WeightedGridLoss(torch.nn.Module): def __init__(self): super(WeightedGridLoss, self).__init__() self.to(self.device)
[docs] def total_variation_regularization(self, grid): # Calculate the sum of horizontal and vertical differences horizontal_diff = torch.abs(torch.diff(grid, axis=2)) vertical_diff = torch.abs(torch.diff(grid, axis=1)) total_variation = torch.sum(horizontal_diff, axis=(1, 2)) + torch.sum( vertical_diff, axis=(1, 2) ) return torch.mean(total_variation)
[docs] def weighted_pixelwise_mse(self, true, predicted, weights): # Compute the squared error squared_error = (true - predicted) ** 2 # Apply weights weighted_error = weights * squared_error # Return the mean of the weighted error return torch.mean(weighted_error)
[docs] def forward(self, true, predicted, smoothness_weight=0.001, extreme_value_threshold=1e-6): true = torch.tensor(true, dtype=torch.float32, device=self.device) predicted = torch.tensor(predicted, dtype=torch.float32, device=self.device) # Determine weights based on extreme values if extreme_value_threshold is not None: # Identify extreme values in the true data extreme_mask = torch.abs(true) > extreme_value_threshold # Assign higher weight to extreme values, 1 to others weights = torch.where(extreme_mask, 10.0 * torch.ones_like(true), torch.ones_like(true)) else: # If no threshold is provided, use uniform weights weights = torch.ones_like(true) pixelwise_mse = self.weighted_pixelwise_mse(true, predicted, weights) tvr = self.total_variation_regularization(predicted) return pixelwise_mse + smoothness_weight * tvr
[docs] class WeightedMSELoss(torch.nn.Module): def __init__(self, data_mean, data_std, weight_factor=1.0): """ Custom loss function that penalizes errors on extreme values more. Args: data_mean (float): Mean of the target variable in the training set. data_std (float): Standard deviation of the target variable in the training set. weight_factor (float): Factor to adjust the weighting. Higher values will increase the penalty on extremes. Default is 1.0. """ super(WeightedMSELoss, self).__init__() self.device = "cuda" if torch.cuda.is_available() else "cpu" self.data_mean = torch.tensor(data_mean, dtype=torch.float32, device=self.device) self.data_std = torch.tensor(data_std, dtype=torch.float32, device=self.device) self.weight_factor = torch.tensor(weight_factor, dtype=torch.float32, device=self.device) self.to(self.device)
[docs] def forward(self, input, target): """ Calculate the Weighted MSE Loss. Args: input (tensor): Predicted values. target (tensor): Actual values. Returns: Tensor: Computed loss. """ # Ensure data_mean, data_std, and weight_factor are on the same device as input input = input.to(self.device) target = target.to(self.device) # Calculate the deviation of each target value from the mean deviation = torch.abs(target - self.data_mean) # Scale deviations by the standard deviation to normalize them # normalized_deviation = torch.tensor(deviation / self.data_std, dtype=torch.float32, device=self.device) normalized_deviation = deviation / self.data_std # Compute weights: increase penalty for extreme values weights = 1 + (normalized_deviation * self.weight_factor) # Compute the squared error squared_error = torch.nn.functional.mse_loss(input, target, reduction="none") # Apply the weights and take the mean to get the final loss weighted_squared_error = weights * squared_error loss = torch.mean(weighted_squared_error) return loss
[docs] class WeightedMSEPCALoss(torch.nn.Module): def __init__(self, data_mean, data_std, weight_factor=1.0, custom_weights=None): """ Custom loss function that penalizes errors on extreme values more and allows for custom weighting of each prediction in a batched manner. Args: data_mean (float): Mean of the target variable in the training set. data_std (float): Standard deviation of the target variable in the training set. weight_factor (float): Factor to adjust the weighting. Higher values will increase the penalty on extremes. Default is 1.0. custom_weights (torch.Tensor, optional): A tensor of weights corresponding to each y-value in the batch. Default is None. """ super(WeightedMSEPCALoss, self).__init__() self.device = "cuda" if torch.cuda.is_available() else "cpu" self.to(self.device) self.data_mean = torch.tensor(data_mean, dtype=torch.float32, device=self.device) self.data_std = torch.tensor(data_std, dtype=torch.float32, device=self.device) self.weight_factor = torch.tensor(weight_factor, dtype=torch.float32, device=self.device) self.custom_weights = ( torch.tensor(custom_weights, dtype=torch.float32, device=self.device) if custom_weights is not None else None )
[docs] def forward(self, input, target): """ Calculate the Weighted MSE Loss for batched inputs and outputs. Args: input (tensor): Predicted values with shape (batch_size, num_targets). target (tensor): Actual values with shape (batch_size, num_targets). Returns: Tensor: Computed loss. """ input = input.to(self.device) target = target.to(self.device) # Ensure input and target are of the same shape if input.shape != target.shape: raise ValueError("Input and target must have the same shape.") # Calculate the deviation of each target value from the mean deviation = torch.abs(target - self.data_mean) # Scale deviations by the standard deviation to normalize them normalized_deviation = deviation / self.data_std # Compute weights: increase penalty for extreme values weights = 1 + (normalized_deviation * self.weight_factor) # If custom weights are provided, multiply them by the calculated weights if self.custom_weights is not None: # Expand custom weights to match batch size if necessary if self.custom_weights.dim() == 1: self.custom_weights = self.custom_weights.unsqueeze(0) # Make it a 2D tensor if self.custom_weights.shape != weights.shape: raise ValueError("Custom weights shape must match input/target shape.") weights *= self.custom_weights # Compute the squared error for each element in the batch without reducing squared_error = (input - target) ** 2 # Apply the weights to the squared error weighted_squared_error = weights * squared_error # Take the mean across all dimensions to get the final loss loss = torch.mean(weighted_squared_error) return loss
[docs] class WeightedMSELossWithSignPenalty(torch.nn.Module): def __init__(self, data_mean, data_std, weight_factor=1.0, sign_penalty_factor=1.0): """ Custom loss function that penalizes errors on extreme values more and adds a penalty for opposite sign predictions. Args: data_mean (float): Mean of the target variable in the training set. data_std (float): Standard deviation of the target variable in the training set. weight_factor (float): Factor to adjust the weighting for extremes. Higher values will increase the penalty on extremes. Default is 1.0. sign_penalty_factor (float): Factor to adjust the penalty for opposite sign predictions. Higher values increase the penalty. Default is 1.0. """ super(WeightedMSELossWithSignPenalty, self).__init__() self.device = "cuda" if torch.cuda.is_available() else "cpu" self.data_mean = torch.tensor(data_mean, dtype=torch.float32, device=self.device) self.data_std = torch.tensor(data_std, dtype=torch.float32, device=self.device) self.weight_factor = torch.tensor(weight_factor, dtype=torch.float32, device=self.device) self.sign_penalty_factor = torch.tensor( sign_penalty_factor, dtype=torch.float32, device=self.device ) self.to(self.device)
[docs] def forward(self, input, target): """ Calculate the Weighted MSE Loss with an additional penalty for opposite sign predictions. Args: input (tensor): Predicted values. target (tensor): Actual values. Returns: Tensor: Computed loss. """ # Calculate the deviation of each target value from the mean deviation = torch.abs(target - self.data_mean) # Scale deviations by the standard deviation to normalize them normalized_deviation = deviation / self.data_std # Compute weights: increase penalty for extreme values weights = 1 + (normalized_deviation * self.weight_factor) # Compute the squared error squared_error = torch.nn.functional.mse_loss(input, target, reduction="none") # Calculate sign penalty sign_penalty = torch.where( torch.sign(input) != torch.sign(target), torch.abs(input - target) * self.sign_penalty_factor, torch.zeros_like(input), ) # Apply the weights and sign penalty, then take the mean to get the final loss weighted_squared_error = weights * (squared_error + sign_penalty) loss = torch.mean(weighted_squared_error) return loss
[docs] class GridCriterion(torch.nn.Module): def __init__( self, ): super(GridCriterion, self).__init__()
[docs] def total_variation_regularization( self, grid, ): # Calculate the sum of horizontal and vertical differences horizontal_diff = torch.abs(torch.diff(grid, axis=2)) vertical_diff = torch.abs(torch.diff(grid, axis=1)) total_variation = torch.sum(horizontal_diff, axis=(1, 2)) + torch.sum( vertical_diff, axis=(1, 2) ) return torch.mean(total_variation)
# def spatial_loss(self, true, predicted, smoothness_weight=0.001):
[docs] def forward(self, true, predicted, smoothness_weight=0.001): pixelwise_mse = torch.mean( torch.abs(true - predicted) ** 2, ) # loss for each image in the batch (batch_size,) tvr = self.total_variation_regularization( predicted, ) return pixelwise_mse + smoothness_weight * tvr
# def forward(self, true, predicted, x, y, flow, predictor_weight=0.5, nf_weight=0.5,): # if predictor_weight + nf_weight != 1: # raise ValueError("The sum of predictor_weight and nf_weight must be 1") # predictor_loss = self.spatial_loss(true, predicted, smoothness_weight=0.2) # nf_loss = -flow.log_prob(inputs=y, context=x) # return predictor_weight*predictor_loss + nf_weight*nf_loss
[docs] class WeightedPCALoss(torch.nn.Module): def __init__(self, component_weights, reduction="mean"): """ Custom loss function that applies different weights to the error of each principal component prediction. Args: component_weights (list or torch.Tensor): Weights for each principal component's error, where the first component has the highest weight. reduction (str): Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'. """ super(WeightedPCALoss, self).__init__() self.device = "cuda" if torch.cuda.is_available() else "cpu" self.component_weights = torch.tensor( component_weights, dtype=torch.float32, device=self.device ) if len(self.component_weights.size()) == 1: self.component_weights = self.component_weights.unsqueeze(0) # Make it a row vector self.reduction = reduction self.to(self.device)
[docs] def forward(self, input, target): """ Calculate the weighted loss for principal component predictions. Args: input (tensor): Predicted principal components. target (tensor): Actual principal components. Returns: Tensor: Computed Weighted PCA Loss. """ input = input.to(self.device) target = target.to(self.device) # Ensure input and target are of the same shape if input.shape != target.shape: raise ValueError("Input and target must have the same shape") # Calculate the squared error squared_error = (input - target) ** 2 # Apply weights to the squared error weighted_error = squared_error * self.component_weights.to(input.device) # Apply reduction if self.reduction == "mean": return torch.mean(weighted_error) elif self.reduction == "sum": return torch.sum(weighted_error) else: return weighted_error
[docs] class MSEDeviationLoss(torch.nn.Module): def __init__(self, threshold=1.0, penalty_multiplier=2.0): """ Custom MSE Loss with an additional penalty for large deviations. Parameters: - threshold: The error magnitude beyond which the penalty is applied. - penalty_multiplier: Multiplier for the penalty term for errors exceeding the threshold. """ super(MSEDeviationLoss, self).__init__() self.threshold = threshold self.penalty_multiplier = penalty_multiplier
[docs] def forward(self, predictions, targets): """ Compute the custom loss. Parameters: - predictions: The predicted values. - targets: The ground truth values. Returns: - loss: The computed custom loss. """ mse_loss = torch.mean((predictions - targets) ** 2) large_deviation_penalty = torch.mean( torch.where( torch.abs(predictions - targets) > self.threshold, self.penalty_multiplier * (predictions - targets) ** 2, torch.tensor(0.0, device=predictions.device), ) ) return mse_loss + large_deviation_penalty