ise.models package
Subpackages
Submodules
ise.models.loss module
This module contains custom loss functions for training neural network emulators.
- Classes:
WeightedGridLoss: Custom loss function that penalizes errors based on the total variation of a grid.
WeightedMSELoss: Custom loss function that penalizes errors on extreme values more.
WeightedMSEPCALoss: Custom loss function that penalizes errors on extreme values more and allows for custom weighting of each prediction in a batched manner.
WeightedMSELossWithSignPenalty: Custom loss function that penalizes errors on extreme values more and adds a penalty for opposite sign predictions.
- class ise.models.loss.GridCriterion[source]
Bases:
Module- forward(true, predicted, smoothness_weight=0.001)[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class ise.models.loss.MSEDeviationLoss(threshold=1.0, penalty_multiplier=2.0)[source]
Bases:
Module
- class ise.models.loss.WeightedGridLoss[source]
Bases:
Module- forward(true, predicted, smoothness_weight=0.001, extreme_value_threshold=1e-06)[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class ise.models.loss.WeightedMSELoss(data_mean, data_std, weight_factor=1.0)[source]
Bases:
Module
- class ise.models.loss.WeightedMSELossWithSignPenalty(data_mean, data_std, weight_factor=1.0, sign_penalty_factor=1.0)[source]
Bases:
Module
- class ise.models.loss.WeightedMSEPCALoss(data_mean, data_std, weight_factor=1.0, custom_weights=None)[source]
Bases:
Module
ise.models.scenario module
- class ise.models.scenario.ScenarioPredictor(input_size, hidden_layers=[128, 64], output_size=1, dropout_rate=0.1)[source]
Bases:
Module- evaluate(data_loader)[source]
Evaluates the model on a dataset.
- Parameters:
data_loader (DataLoader) - DataLoader for the dataset to evaluate.
- Returns:
A tuple containing the average loss and accuracy on the dataset.
- Return type:
tuple
- fit(train_loader, val_loader=None, epochs=10, lr=0.001, print_every=1, save_checkpoint=True)[source]
- forward(x)[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
ise.models.variational_lstm_emulator module
This module contains the VariationalLSTMEmulator, which is a class that contains the model architecture for the variational LSTM emulator presented in https://doi.org/10.1029/2023MS003899.
- class ise.models.variational_lstm_emulator.VariationalLSTMEmulator(architecture, mc_dropout=False, dropout_prob=None)[source]
Bases:
ModuleVariational LSTM Emulator model for time series data.
- enable_dropout()[source]
Enable dropout during model evaluation.
This method turns on dropout for each layer that starts with “Dropout”.
- forward(x)[source]
Forward pass of the VariationalLSTMEmulator model.
- Parameters:
x (torch.Tensor) - Input tensor.
- Returns:
Output tensor.
- Return type:
torch.Tensor
- predict(x, approx_dist=None, mc_iterations=None, quantile_range=[0.025, 0.975], confidence='95')[source]
Make predictions using the VariationalLSTMEmulator model.
- Parameters:
x (np.ndarray or torch.Tensor or pd.DataFrame) - Input data.
approx_dist (bool, optional) - Flag indicating whether to approximate the distribution using MC Dropout. Defaults to None.
mc_iterations (int, optional) - Number of MC iterations. Required if approx_dist is True.
quantile_range (list, optional) - Quantile range for prediction intervals. Defaults to [0.025, 0.975].
confidence (str, optional) - Confidence level for prediction intervals. Defaults to “95”.
- Returns:
Tuple containing the predictions, mean predictions, and standard deviations.
- Return type:
tuple