Source code for ise.models.ISEFlow.ISEFlow

import json
import os
import pickle
import warnings
import shutil

import numpy as np
import pandas as pd
import torch
from torch import nn, optim
from nflows import distributions, flows, transforms

from ise.data.dataclasses import EmulatorDataset
from ise.utils.functions import to_tensor
from ise.utils.training import EarlyStoppingCheckpointer, CheckpointSaver
from ise.data import feature_engineer as fe
from ise.models.predictors.deep_ensemble import DeepEnsemble
from ise.models.density_estimators.normalizing_flow import NormalizingFlow
from .de import ISEFlow_AIS_DE, ISEFlow_GrIS_DE
from .nf import ISEFlow_AIS_NF, ISEFlow_GrIS_NF
from ise.models.pretrained import ISEFlow_AIS_v1_0_0_path, ISEFlow_GrIS_v1_0_0_path

[docs] class ISEFlow(torch.nn.Module): """ The ISEFlow (Flow-based Ice Sheet Emulator) that combines a deep ensemble and a normalizing flow model. """ def __init__(self, deep_ensemble, normalizing_flow): super(ISEFlow, self).__init__() self.device = "cuda" if torch.cuda.is_available() else "cpu" self.to(self.device) if not isinstance(deep_ensemble, DeepEnsemble): raise ValueError("deep_ensemble must be a DeepEnsemble instance") if not isinstance(normalizing_flow, NormalizingFlow): raise ValueError("normalizing_flow must be a NormalizingFlow instance") self.deep_ensemble = deep_ensemble.to(self.device) self.normalizing_flow = normalizing_flow.to(self.device) self.trained = self.deep_ensemble.trained and self.normalizing_flow.trained self.scaler_path = None
[docs] def fit(self, X, y, nf_epochs, de_epochs, batch_size=64, X_val=None, y_val=None, save_checkpoints=True, checkpoint_path='checkpoint_ensemble', early_stopping=True, sequence_length=5, patience=10, verbose=True): """ Fits the hybrid emulator to the training data. """ if early_stopping is None: early_stopping = X_val is not None and y_val is not None torch.manual_seed(np.random.randint(0, 100000)) X, y = to_tensor(X).to(self.device), to_tensor(y).to(self.device) if self.trained: warnings.warn("This model has already been trained. Training again.") # Train Normalizing Flow if not self.normalizing_flow.trained: print(f"\nTraining Normalizing Flow ({'Maximum ' if early_stopping else ''}{nf_epochs} epochs):") self.normalizing_flow.fit(X, y, nf_epochs, batch_size, save_checkpoints, f"{checkpoint_path}_nf.pth", early_stopping, patience, verbose) # Latent representation z = self.normalizing_flow.get_latent(X).detach() X_latent = torch.cat((X, z), axis=1) X_val_latent = None if X_val is not None and y_val is not None: X_val, y_val = to_tensor(X_val).to(self.device), to_tensor(y_val).to(self.device) z_val = self.normalizing_flow.get_latent(X_val).detach() X_val_latent = torch.cat((X_val, z_val), axis=1) # Train Deep Ensemble if not self.deep_ensemble.trained: print(f"\nTraining Deep Ensemble ({'Maximum ' if early_stopping else ''}{de_epochs} epochs):") self.deep_ensemble.fit(X_latent, y, X_val_latent, y_val, save_checkpoints, f"{checkpoint_path}_de", early_stopping, de_epochs, batch_size, sequence_length, patience, verbose,) self.trained = True
[docs] def forward(self, x, smooth_projection=False): """ Performs a forward pass through the hybrid emulator. """ self.eval() x = to_tensor(x).to(self.device) if not self.trained: warnings.warn("This model has not been trained. Predictions may not be accurate.") z = self.normalizing_flow.get_latent(x).detach() X_latent = torch.cat((x, z), axis=1) prediction, epistemic = self.deep_ensemble(X_latent) aleatoric = self.normalizing_flow.aleatoric(x, 100) prediction = prediction.detach().cpu().numpy() epistemic = epistemic.detach().cpu().numpy() uncertainties = dict( total=aleatoric + epistemic, epistemic=epistemic, aleatoric=aleatoric, ) return prediction, uncertainties
[docs] def predict(self, x, output_scaler=True, smooth_projection=False): self.eval() if output_scaler is True: output_scaler = os.path.join(self.model_dir, "scaler_y.pkl") with open(output_scaler, "rb") as f: output_scaler = pickle.load(f) elif output_scaler is False and self.scaler_path is None: warnings.warn("No scaler path provided, uncertainties are not in units of SLE.") return self.forward(x, smooth_projection=smooth_projection) elif isinstance(output_scaler, str): self.scaler_path = output_scaler with open(self.scaler_path, "rb") as f: output_scaler = pickle.load(f) predictions, uncertainties = self.forward(x, smooth_projection=smooth_projection) unscaled_predictions = output_scaler.inverse_transform(predictions.reshape(-1, 1)) bound_epistemic = predictions + uncertainties["epistemic"] bound_aleatoric = predictions + uncertainties["aleatoric"] unscaled_bound_epistemic = output_scaler.inverse_transform(bound_epistemic.reshape(-1, 1)) unscaled_bound_aleatoric = output_scaler.inverse_transform(bound_aleatoric.reshape(-1, 1)) epistemic = unscaled_bound_epistemic - unscaled_predictions aleatoric = unscaled_bound_aleatoric - unscaled_predictions uncertainties = dict( total=epistemic + aleatoric, epistemic=epistemic, aleatoric=aleatoric, ) return unscaled_predictions, uncertainties
[docs] def save(self, save_dir, input_features=None, output_scaler_path=None): """ Saves the trained model to the specified directory. """ if not self.trained: raise ValueError("This model has not been trained yet. Train the model before saving.") if save_dir.endswith(".pth"): raise ValueError("save_dir must be a directory, not a file") os.makedirs(save_dir, exist_ok=True) self.deep_ensemble.save(os.path.join(save_dir, "deep_ensemble.pth")) self.normalizing_flow.save(os.path.join(save_dir, "normalizing_flow.pth")) if input_features is not None: if not isinstance(input_features, list): raise ValueError("input_features must be a list of feature names") with open(os.path.join(save_dir, "input_features.json"), "w") as f: json.dump(input_features, f, indent=4) if output_scaler_path is not None and output_scaler_path.endswith(".pkl"): self.scaler_path = output_scaler_path if self.scaler_path is not None: shutil.copy(self.scaler_path, os.path.join(save_dir, "scaler_y.pkl"))
[docs] @staticmethod def load(model_dir=None, deep_ensemble_path=None, normalizing_flow_path=None,): """ Loads a trained model from the specified paths. """ if model_dir: deep_ensemble_path = os.path.join(model_dir, "deep_ensemble.pth") normalizing_flow_path = os.path.join(model_dir, "normalizing_flow.pth") deep_ensemble = DeepEnsemble.load(deep_ensemble_path) normalizing_flow = NormalizingFlow.load(normalizing_flow_path) model = ISEFlow(deep_ensemble, normalizing_flow) model.trained = True model.model_dir = model_dir return model
[docs] class ISEFlow_AIS(ISEFlow): def __init__(self,): self.ice_sheet = "AIS" deep_ensemble = ISEFlow_AIS_DE() normalizing_flow = ISEFlow_AIS_NF() super(ISEFlow_AIS, self).__init__(deep_ensemble, normalizing_flow)
[docs] @staticmethod def load(version="v1.0.0", model_dir=None, deep_ensemble_path=None, normalizing_flow_path=None): if model_dir is None: if version == "v1.0.0": model_dir = ISEFlow_AIS_v1_0_0_path else: raise NotImplementedError("Only version v1.0.0 is supported") # Load components using the parent class logic deep_ensemble = DeepEnsemble.load(os.path.join(model_dir, "deep_ensemble.pth")) normalizing_flow = NormalizingFlow.load(os.path.join(model_dir, "normalizing_flow.pth")) # Return an instance of ISEFlow_AIS instead of ISEFlow model = ISEFlow_AIS() model.deep_ensemble = deep_ensemble model.normalizing_flow = normalizing_flow model.trained = True model.model_dir = model_dir return model
[docs] def process( self, year: np.array, pr_anomaly: np.array, evspsbl_anomaly: np.array, mrro_anomaly: np.array, smb_anomaly: np.array, ts_anomaly: np.array, ocean_thermal_forcing: np.array, ocean_salinity: np.array, ocean_temperature: np.array, initial_year: int, numerics: str, stress_balance: str, resolution: int, init_method: str, melt_in_floating_cells: str, icefront_migration: str, ocean_forcing_type: str, ocean_sensitivity: str, ice_shelf_fracture: bool, open_melt_type: str=None, standard_melt_type: str=None, ): if year[0] == 2015: year = year - 2015 data = { "year": year, "pr_anomaly": pr_anomaly, "evspsbl_anomaly": evspsbl_anomaly, "mrro_anomaly": mrro_anomaly, "smb_anomaly": smb_anomaly, "ts_anomaly": ts_anomaly, "thermal_forcing": ocean_thermal_forcing, "salinity": ocean_salinity, "temperature": ocean_temperature, "initial_year": initial_year, "numerics": numerics, "stress_balance": stress_balance, "resolution": resolution, "init_method": init_method, "melt": melt_in_floating_cells, "ice_front": icefront_migration, "Ocean sensitivity": ocean_sensitivity, "Ice shelf fracture": ice_shelf_fracture, "Ocean forcing": ocean_forcing_type, "open_melt_param": open_melt_type, "standard_melt_param": standard_melt_type, } # map from accepted input to how the model expects variable names arg_map = { 'numerics': { 'fe': 'FE', 'fd': 'FD', 'fe/fv': 'FE/FV', }, 'stress_balance': { 'ho': 'HO', 'hybrid': 'Hybrid', 'l1l2': 'L1L2', 'sia+ssa': 'SIA_SSA', 'ssa': 'SSA', 'stokes': 'Stokes', }, "init_method": { 'da': 'DA', 'da*': 'DA_geom', 'da+': 'DA_relax', 'eq': 'Eq', 'sp': 'SP', 'sp+': 'SP_icethickness', }, "melt": { 'floating condition': 'Floating condition', 'sub-grid': 'Sub-grid', }, 'ice_front': { 'str': 'StR', 'fix': 'Fix', 'mh': 'MH', 'ro': 'RO', 'div': 'Div', }, 'Ocean forcing': { 'standard': 'Standard', 'open': 'Open', }, 'Ocean sensitivity': { 'low': 'Low', 'medium': 'Medium', 'high': 'High', 'pigl': 'PIGL', }, "open_melt_param": { 'lin': 'Lin', 'quad': 'Quad', 'nonlocal+slope': 'Nonlocal_Slope', 'pico': 'PICO', 'picop': 'PICOP', 'plume': 'Plume', }, "standard_melt_param": { 'local': 'Local', 'nonlocal': 'Nonlocal', 'local anom': 'Local anom', 'nonlocal anom': 'Nonlocal anom', } } mrro_means = np.array([3.61493220e-08, 2.77753815e-08, 5.50841177e-08, 4.17617754e-08, 5.58558082e-08, 5.74870861e-08, 1.07017988e-07, 7.72183085e-08, 6.44275121e-08, 2.10466987e-08, 5.36071770e-08, 8.32501757e-08, 9.31873131e-08, 7.84747761e-08, 8.41751157e-08, 8.56960829e-08, 7.81743956e-08, 9.74934761e-08, 6.04155892e-08, 8.31572351e-08, 1.16800344e-07, 9.96168899e-08, 1.41262144e-07, 8.76467771e-08, 1.03335698e-07, 1.23414214e-07, 9.29483909e-08, 1.95530928e-07, 1.18321950e-07, 1.68664275e-07, 1.56460562e-07, 1.40309916e-07, 1.08267844e-07, 1.85627395e-07, 1.29400203e-07, 1.98725020e-07, 1.39994753e-07, 1.86775688e-07, 1.68388442e-07, 2.04534154e-07, 1.49715175e-07, 1.50418319e-07, 1.44444531e-07, 1.67211070e-07, 1.83698063e-07, 2.05489898e-07, 2.42246565e-07, 1.98110423e-07, 2.40505470e-07, 2.37863389e-07, 2.55668987e-07, 2.93048624e-07, 2.57849749e-07, 2.72915753e-07, 2.82135517e-07, 2.27647208e-07, 2.21859448e-07, 2.07266200e-07, 2.42241281e-07, 2.55693726e-07, 2.52039399e-07, 2.82802604e-07, 2.94193847e-07, 3.00380753e-07, 3.60152406e-07, 3.47886784e-07, 3.58344925e-07, 3.84398045e-07, 4.41053179e-07, 3.84072892e-07, 4.42520286e-07, 4.30170222e-07, 4.34444387e-07, 4.77483307e-07, 3.52802246e-07, 4.96503280e-07, 5.22078462e-07, 4.78644041e-07, 4.86755806e-07, 5.04600526e-07, 4.80814514e-07, 5.38276914e-07, 5.91539053e-07, 5.84794672e-07, 5.33792907e-07, 5.37435986e-07]) # check inputs if not isinstance(initial_year, int): raise ValueError("initial_year must be an integer") if str(numerics).lower() not in ('fe', 'fd', 'fe/fv'): raise ValueError("numerics must be one of 'fe', 'fd', or 'fe/fv'") if str(stress_balance) not in ('ho', 'hybrid', "l1l2", 'sia+ssa', 'ssa', 'stokes'): raise ValueError("stress_balance must be one of 'ho', 'hybrid', 'l1l2', 'sia+ssa', 'ssa', or 'stokes'") if str(resolution) not in ('16', '20', '32', '4', '8', 'variable'): raise ValueError("resolution must be one of '16', '20', '32', '4', '8', or 'variable'") if str(init_method) not in ('da', 'da*', 'da+', 'eq', 'sp', 'sp+'): raise ValueError("init_method must be one of 'da', 'da*', 'da+', 'eq', 'sp', or 'sp+'") if str(melt_in_floating_cells) not in ('floating condition', 'sub-grid', 'None', 'False'): raise ValueError("melt_in_floating_cells must be one of 'floating condition', 'sub-grid', 'None', or 'False'") if str(icefront_migration) not in ('str', 'fix', 'mh', 'ro', 'div'): raise ValueError("icefront_migration must be one of 'str', 'fix', 'mh', 'ro', or 'div'") if str(ocean_forcing_type) not in ('standard', 'open'): raise ValueError("ocean_forcing_type must be one of 'standard' or 'open'") if str(ocean_forcing_type) == 'standard' and standard_melt_type is None: raise ValueError("standard_melt_type must be provided if ocean_forcing_type is 'standard'") elif str(ocean_forcing_type) == 'standard' and standard_melt_type not in ("local", "nonlocal", "local anom", "nonlocal anom", "None"): raise ValueError("standard_melt_type must be one of 'local', 'nonlocal', 'local anom', 'nonlocal anom', or None") if str(ocean_forcing_type) == 'open' and open_melt_type is None: raise ValueError("open_melt_type must be provided if ocean_forcing_type is 'open'") elif str(ocean_forcing_type) == 'open' and open_melt_type not in ("lin", "quad", "nonlocal+slope", "pico", "picop", "plume", "None"): raise ValueError("open_melt_type must be one of 'lin', 'quad', 'nonlocal+slope', 'pico', 'picop', 'plume', or None") if str(ocean_sensitivity) not in ('low', 'medium', 'high', 'pigl'): raise ValueError("ocean_sensitivity must be one of 'low', 'medium', 'high', or 'pigl'") if not isinstance(ice_shelf_fracture, bool): raise ValueError("ice_shelf_fracture must be a boolean") for key, value in data.items(): # make sure forcings are numpy arrays if key in ("year", "pr_anomaly", "evspsbl_anomaly", "mrro_anomaly", "smb_anomaly", "ts_anomaly", "thermal_forcing", "salinity", "temperature"): try: data[key] = np.array(value) except Exception as e: raise ValueError(f"Variable {key} must be a numpy array.") from e # remap args elif key in arg_map: if value in arg_map[key]: data[key] = arg_map[key][value] else: raise ValueError(f"Invalid value for {key}: {value}. Must be one of {list(arg_map[key].keys())}") data = pd.DataFrame(data) year_mean_map = {year: mean for year, mean in enumerate(mrro_means)} data["mrro_anomaly"] = data.apply( lambda row: year_mean_map[row["year"]] if pd.isna(row["mrro_anomaly"]) else row["mrro_anomaly"], axis=1 ) data = fe.add_lag_variables(data, lag=5, verbose=False) data = pd.get_dummies(data, columns=['numerics', 'stress_balance', 'resolution', 'init_method', 'melt', 'ice_front', 'Ocean forcing', 'Ocean sensitivity', 'open_melt_param', 'standard_melt_param']) # need to add other columns as zeros from get_dummies (all true) columns = ['year', 'sector', 'pr_anomaly', 'evspsbl_anomaly', 'mrro_anomaly', 'smb_anomaly', 'ts_anomaly', 'thermal_forcing', 'salinity', 'temperature', 'pr_anomaly.lag1', 'evspsbl_anomaly.lag1', 'mrro_anomaly.lag1', 'smb_anomaly.lag1', 'ts_anomaly.lag1', 'thermal_forcing.lag1', 'salinity.lag1', 'temperature.lag1', 'pr_anomaly.lag2', 'evspsbl_anomaly.lag2', 'mrro_anomaly.lag2', 'smb_anomaly.lag2', 'ts_anomaly.lag2', 'thermal_forcing.lag2', 'salinity.lag2', 'temperature.lag2', 'pr_anomaly.lag3', 'evspsbl_anomaly.lag3', 'mrro_anomaly.lag3', 'smb_anomaly.lag3', 'ts_anomaly.lag3', 'thermal_forcing.lag3', 'salinity.lag3', 'temperature.lag3', 'pr_anomaly.lag4', 'evspsbl_anomaly.lag4', 'mrro_anomaly.lag4', 'smb_anomaly.lag4', 'ts_anomaly.lag4', 'thermal_forcing.lag4', 'salinity.lag4', 'temperature.lag4', 'pr_anomaly.lag5', 'evspsbl_anomaly.lag5', 'mrro_anomaly.lag5', 'smb_anomaly.lag5', 'ts_anomaly.lag5', 'thermal_forcing.lag5', 'salinity.lag5', 'temperature.lag5', 'initial_year', 'numerics_FD', 'numerics_FE', 'numerics_FE/FV', 'stress_balance_HO', 'stress_balance_Hybrid', 'stress_balance_L1L2', 'stress_balance_SIA_SSA', 'stress_balance_SSA', 'stress_balance_Stokes', 'resolution_16', 'resolution_20', 'resolution_32', 'resolution_4', 'resolution_8', 'resolution_variable', 'init_method_DA', 'init_method_DA_geom', 'init_method_DA_relax', 'init_method_Eq', 'init_method_SP', 'init_method_SP_icethickness', 'melt_Floating_condition', 'melt_No', 'melt_Sub-grid', 'ice_front_Div', 'ice_front_Fix', 'ice_front_MH', 'ice_front_RO', 'ice_front_StR', 'open_melt_param_Lin', 'open_melt_param_Nonlocal_Slope', 'open_melt_param_PICO', 'open_melt_param_PICOP', 'open_melt_param_Plume', 'open_melt_param_Quad', 'standard_melt_param_Local', 'standard_melt_param_Local_anom', 'standard_melt_param_Nonlocal', 'standard_melt_param_Nonlocal_anom', 'Ocean forcing_Open', 'Ocean forcing_Standard', 'Ocean sensitivity_High', 'Ocean sensitivity_Low', 'Ocean sensitivity_Medium', 'Ocean sensitivity_PIGL', 'Ice shelf fracture_False', 'Ice shelf fracture_True'] for col in columns: if col not in data.columns: data[col] = 0 data = data[columns] data['outlier'] = False data = fe.scale_data(data, scaler_path=f"{ISEFlow_AIS_v1_0_0_path}/scaler_X.pkl") return data
[docs] def predict( self, year: np.array, pr_anomaly: np.array, evspsbl_anomaly: np.array, mrro_anomaly: np.array, smb_anomaly: np.array, ts_anomaly: np.array, ocean_thermal_forcing: np.array, ocean_salinity: np.array, ocean_temperature: np.array, initial_year: int, numerics: str, stress_balance: str, resolution: int, init_method: str, melt_in_floating_cells: str, icefront_migration: str, ocean_forcing_type: str, ocean_sensitivity: str, ice_shelf_fracture: bool, open_melt_type: str=None, standard_melt_type: str=None, ): data = self.process( year, pr_anomaly, evspsbl_anomaly, mrro_anomaly, smb_anomaly, ts_anomaly, ocean_thermal_forcing, ocean_salinity, ocean_temperature, initial_year, numerics, stress_balance, resolution, init_method, melt_in_floating_cells, icefront_migration, ocean_forcing_type, ocean_sensitivity, ice_shelf_fracture, open_melt_type, standard_melt_type ) X = data.values X = to_tensor(X).to(self.device) return super().predict(X, output_scaler=f"{ISEFlow_AIS_v1_0_0_path}/scaler_y.pkl")
[docs] class ISEFlow_GrIS(ISEFlow): def __init__(self,): self.ice_sheet = "GrIS" deep_ensemble = ISEFlow_GrIS_DE() normalizing_flow = ISEFlow_GrIS_NF() super(ISEFlow_GrIS, self).__init__(deep_ensemble, normalizing_flow)
[docs] @staticmethod def load(version="v1.0.0", model_dir=None, deep_ensemble_path=None, normalizing_flow_path=None,): if model_dir is None: if version == "v1.0.0": model_dir = ISEFlow_GrIS_v1_0_0_path else: raise NotImplementedError("Only version v1.0.0 is supported") return super(ISEFlow_GrIS, ISEFlow_GrIS).load(model_dir, deep_ensemble_path, normalizing_flow_path)