ise.models package

Subpackages

Submodules

ise.models.loss module

This module contains custom loss functions for training neural network emulators.

Classes:
  • WeightedGridLoss: Custom loss function that penalizes errors based on the total variation of a grid.

  • WeightedMSELoss: Custom loss function that penalizes errors on extreme values more.

  • WeightedMSEPCALoss: Custom loss function that penalizes errors on extreme values more and allows for custom weighting of each prediction in a batched manner.

  • WeightedMSELossWithSignPenalty: Custom loss function that penalizes errors on extreme values more and adds a penalty for opposite sign predictions.

class ise.models.loss.GridCriterion[source]

Bases: Module

forward(true, predicted, smoothness_weight=0.001)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

total_variation_regularization(grid)[source]
class ise.models.loss.MSEDeviationLoss(threshold=1.0, penalty_multiplier=2.0)[source]

Bases: Module

forward(predictions, targets)[source]

Compute the custom loss.

Parameters: - predictions: The predicted values. - targets: The ground truth values.

Returns: - loss: The computed custom loss.

class ise.models.loss.WeightedGridLoss[source]

Bases: Module

forward(true, predicted, smoothness_weight=0.001, extreme_value_threshold=1e-06)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

total_variation_regularization(grid)[source]
weighted_pixelwise_mse(true, predicted, weights)[source]
class ise.models.loss.WeightedMSELoss(data_mean, data_std, weight_factor=1.0)[source]

Bases: Module

forward(input, target)[source]

Calculate the Weighted MSE Loss.

Parameters:
  • input (tensor) - Predicted values.

  • target (tensor) - Actual values.

Returns:

Computed loss.

Return type:

Tensor

class ise.models.loss.WeightedMSELossWithSignPenalty(data_mean, data_std, weight_factor=1.0, sign_penalty_factor=1.0)[source]

Bases: Module

forward(input, target)[source]

Calculate the Weighted MSE Loss with an additional penalty for opposite sign predictions.

Parameters:
  • input (tensor) - Predicted values.

  • target (tensor) - Actual values.

Returns:

Computed loss.

Return type:

Tensor

class ise.models.loss.WeightedMSEPCALoss(data_mean, data_std, weight_factor=1.0, custom_weights=None)[source]

Bases: Module

forward(input, target)[source]

Calculate the Weighted MSE Loss for batched inputs and outputs.

Parameters:
  • input (tensor) - Predicted values with shape (batch_size, num_targets).

  • target (tensor) - Actual values with shape (batch_size, num_targets).

Returns:

Computed loss.

Return type:

Tensor

class ise.models.loss.WeightedPCALoss(component_weights, reduction='mean')[source]

Bases: Module

forward(input, target)[source]

Calculate the weighted loss for principal component predictions.

Parameters:
  • input (tensor) - Predicted principal components.

  • target (tensor) - Actual principal components.

Returns:

Computed Weighted PCA Loss.

Return type:

Tensor

ise.models.scenario module

class ise.models.scenario.ScenarioPredictor(input_size, hidden_layers=[128, 64], output_size=1, dropout_rate=0.1)[source]

Bases: Module

evaluate(data_loader)[source]

Evaluates the model on a dataset.

Parameters:

data_loader (DataLoader) - DataLoader for the dataset to evaluate.

Returns:

A tuple containing the average loss and accuracy on the dataset.

Return type:

tuple

fit(train_loader, val_loader=None, epochs=10, lr=0.001, print_every=1, save_checkpoint=True)[source]
forward(x)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

load(path)[source]

Loads the model state from a file.

Parameters:

path (str) - Path to the file containing the model state.

predict(x)[source]

Predicts the output for a given input.

Parameters:

x (torch.Tensor) - Input tensor.

Returns:

Predicted output tensor.

Return type:

torch.Tensor

ise.models.variational_lstm_emulator module

This module contains the VariationalLSTMEmulator, which is a class that contains the model architecture for the variational LSTM emulator presented in https://doi.org/10.1029/2023MS003899.

class ise.models.variational_lstm_emulator.VariationalLSTMEmulator(architecture, mc_dropout=False, dropout_prob=None)[source]

Bases: Module

Variational LSTM Emulator model for time series data.

enable_dropout()[source]

Enable dropout during model evaluation.

This method turns on dropout for each layer that starts with “Dropout”.

forward(x)[source]

Forward pass of the VariationalLSTMEmulator model.

Parameters:

x (torch.Tensor) - Input tensor.

Returns:

Output tensor.

Return type:

torch.Tensor

predict(x, approx_dist=None, mc_iterations=None, quantile_range=[0.025, 0.975], confidence='95')[source]

Make predictions using the VariationalLSTMEmulator model.

Parameters:
  • x (np.ndarray or torch.Tensor or pd.DataFrame) - Input data.

  • approx_dist (bool, optional) - Flag indicating whether to approximate the distribution using MC Dropout. Defaults to None.

  • mc_iterations (int, optional) - Number of MC iterations. Required if approx_dist is True.

  • quantile_range (list, optional) - Quantile range for prediction intervals. Defaults to [0.025, 0.975].

  • confidence (str, optional) - Confidence level for prediction intervals. Defaults to “95”.

Returns:

Tuple containing the predictions, mean predictions, and standard deviations.

Return type:

tuple

Module contents