ise.models
Subpackages
- ise.models.ISEFlow
- ise.models.density_estimators
- ise.models.density_estimators.normalizing_flow
NormalizingFlowNormalizingFlow.num_flow_transformsNormalizingFlow.num_input_featuresNormalizingFlow.num_predicted_sleNormalizingFlow.flow_hidden_featuresNormalizingFlow.output_sequence_lengthNormalizingFlow.deviceNormalizingFlow.base_distributionNormalizingFlow.tNormalizingFlow.flowNormalizingFlow.optimizerNormalizingFlow.criterionNormalizingFlow.trainedNormalizingFlow.aleatoric()NormalizingFlow.fit()NormalizingFlow.get_latent()NormalizingFlow.load()NormalizingFlow.sample()NormalizingFlow.save()
- ise.models.density_estimators.normalizing_flow
- ise.models.predictors
- ise.models.pretrained
ise.models.loss
- class ise.models.loss.GridCriterion[source]
Bases:
ModuleCustom loss function enforcing spatial smoothness using total variation regularization.
This function encourages smoothness in spatial predictions by penalizing large variations.
- - total_variation_regularization
Computes the smoothness loss.
- - forward
Computes the final loss.
- forward(true, predicted, smoothness_weight=0.001)[source]
Computes the final loss by combining pixel-wise MSE and TVR.
- Parameters:
true (Tensor) - Ground truth values.
predicted (Tensor) - Model predictions.
smoothness_weight (float, optional) - Weight for TVR. Defaults to 0.001.
- Returns:
Computed loss.
- Return type:
Tensor
- class ise.models.loss.MSEDeviationLoss(threshold=1.0, penalty_multiplier=2.0)[source]
Bases:
ModuleCustom MSE Loss with an additional penalty for large deviations.
This function penalizes predictions that deviate significantly from the target.
- threshold
Deviation threshold for applying penalties.
- Type:
float
- penalty_multiplier
Multiplier controlling penalty severity.
- Type:
float
- - forward
Computes the loss with deviation penalties.
- class ise.models.loss.WeightedGridLoss[source]
Bases:
ModuleCustom loss function that penalizes errors based on the total variation of a grid.
This loss function consists of two components: 1. Pixel-wise Weighted Mean Squared Error (MSE): Higher weight is assigned to extreme values. 2. Total Variation Regularization (TVR): Enforces spatial smoothness by penalizing large differences
between adjacent grid values.
- device
The device on which the model runs (‘cuda’ or ‘cpu’).
- Type:
str
- - total_variation_regularization
Computes TVR penalty for smoothness.
- - weighted_pixelwise_mse
Computes MSE loss with per-pixel weighting.
- - forward
Computes the final weighted loss.
- forward(true, predicted, smoothness_weight=0.001, extreme_value_threshold=1e-06)[source]
Computes the final weighted loss combining pixel-wise MSE and TVR.
- Parameters:
true (Tensor) - Ground truth values.
predicted (Tensor) - Model predictions.
smoothness_weight (float, optional) - Weighting factor for the TVR loss. Defaults to 0.001.
extreme_value_threshold (float, optional) - Threshold to define extreme values. Defaults to 1e-6.
- Returns:
The total computed loss.
- Return type:
Tensor
- total_variation_regularization(grid)[source]
Computes the total variation regularization (TVR) loss for spatial smoothness.
- Parameters:
grid (Tensor) - A 2D tensor representing spatial data.
- Returns:
The total variation loss.
- Return type:
Tensor
- weighted_pixelwise_mse(true, predicted, weights)[source]
Computes the pixel-wise mean squared error (MSE) with custom weights.
- Parameters:
true (Tensor) - Ground truth values.
predicted (Tensor) - Model predictions.
weights (Tensor) - Weighting factor for each pixel.
- Returns:
Weighted mean squared error.
- Return type:
Tensor
- class ise.models.loss.WeightedMSELoss(data_mean, data_std, weight_factor=1.0)[source]
Bases:
ModuleCustom loss function that applies a weighted penalty to extreme values.
This function increases the weight of extreme values based on their deviation from the dataset mean, normalizing by the standard deviation.
- data_mean
Mean value of the dataset.
- Type:
Tensor
- data_std
Standard deviation of the dataset.
- Type:
Tensor
- weight_factor
Factor controlling how much extreme values are penalized.
- Type:
Tensor
- - forward
Computes the weighted mean squared error loss.
- class ise.models.loss.WeightedMSELossWithSignPenalty(data_mean, data_std, weight_factor=1.0, sign_penalty_factor=1.0)[source]
Bases:
ModuleCustom loss function that penalizes errors on extreme values and opposite sign predictions.
This function extends WeightedMSELoss by adding a penalty when the sign of the prediction differs from the target.
- data_mean
Mean of the dataset.
- Type:
Tensor
- data_std
Standard deviation of the dataset.
- Type:
Tensor
- weight_factor
Factor controlling extreme value weighting.
- Type:
Tensor
- sign_penalty_factor
Factor controlling penalty for opposite sign predictions.
- Type:
Tensor
- - forward
Computes the weighted loss with sign penalties.
- class ise.models.loss.WeightedMSEPCALoss(data_mean, data_std, weight_factor=1.0, custom_weights=None)[source]
Bases:
ModuleExtension of WeightedMSELoss that allows for custom per-batch weighting.
This loss function enables additional user-defined weights to further adjust penalties for different predictions.
- data_mean
Mean of the dataset.
- Type:
Tensor
- data_std
Standard deviation of the dataset.
- Type:
Tensor
- weight_factor
Controls the penalty for extreme values.
- Type:
Tensor
- custom_weights
User-defined weight tensor.
- Type:
Tensor, optional
- - forward
Computes the batch-weighted MSE loss.
- class ise.models.loss.WeightedPCALoss(component_weights, reduction='mean')[source]
Bases:
ModuleCustom loss function applying different weights to errors in principal component analysis.
This function allows assigning higher penalties to the first components.
- component_weights
Weighting factors for each principal component.
- Type:
Tensor
- reduction
Specifies reduction mode (‘mean’, ‘sum’, or ‘none’).
- Type:
str
- - forward
Computes the weighted PCA loss.
ise.models.scenario
- class ise.models.scenario.ScenarioPredictor(input_size, hidden_layers=[128, 64], output_size=1, dropout_rate=0.1)[source]
Bases:
Module- evaluate(data_loader)[source]
Evaluates the model on a dataset.
- Parameters:
data_loader (torch.utils.data.DataLoader) - DataLoader for the dataset to evaluate.
- Returns:
- A tuple containing:
avg_loss (float): Average loss over the dataset.
accuracy (float): Accuracy of predictions (0 to 1).
- Return type:
tuple
- fit(train_loader, val_loader=None, epochs=10, lr=0.001, print_every=1, save_checkpoint=True)[source]
Trains the model on the given dataset.
- Parameters:
train_loader (torch.utils.data.DataLoader) - DataLoader for the training dataset.
val_loader (torch.utils.data.DataLoader, optional) - DataLoader for the validation dataset. Defaults to None.
epochs (int, optional) - Number of epochs for training. Defaults to 10.
lr (float, optional) - Learning rate for the optimizer. Defaults to 1e-3.
print_every (int, optional) - Interval for printing training progress. Defaults to 1.
save_checkpoint (bool, optional) - Whether to save model checkpoints based on validation loss. Defaults to True.
- Returns:
None
- forward(x)[source]
Forward pass through the model.
- Parameters:
x (torch.Tensor) - Input tensor with shape (batch_size, input_size).
- Returns:
Output tensor with probabilities in the range [0,1].
- Return type:
torch.Tensor
ise.models.variational_lstm_emulator
This module contains the VariationalLSTMEmulator, which is a class that contains the model architecture for the variational LSTM emulator presented in https://doi.org/10.1029/2023MS003899.
- class ise.models.variational_lstm_emulator.VariationalLSTMEmulator(architecture, mc_dropout=False, dropout_prob=None)[source]
Bases:
ModuleVariational LSTM Emulator model for time series data.
- enable_dropout()[source]
Enable dropout during model evaluation.
This method turns on dropout for each layer that starts with “Dropout”.
- forward(x)[source]
Forward pass of the VariationalLSTMEmulator model.
- Parameters:
x (torch.Tensor) - Input tensor.
- Returns:
Output tensor.
- Return type:
torch.Tensor
- predict(x, approx_dist=None, mc_iterations=None, quantile_range=[0.025, 0.975], confidence='95')[source]
Make predictions using the VariationalLSTMEmulator model.
- Parameters:
x (np.ndarray or torch.Tensor or pd.DataFrame) - Input data.
approx_dist (bool, optional) - Flag indicating whether to approximate the distribution using MC Dropout. Defaults to None.
mc_iterations (int, optional) - Number of MC iterations. Required if approx_dist is True.
quantile_range (list, optional) - Quantile range for prediction intervals. Defaults to [0.025, 0.975].
confidence (str, optional) - Confidence level for prediction intervals. Defaults to “95”.
- Returns:
Tuple containing the predictions, mean predictions, and standard deviations.
- Return type:
tuple