import torch
from torch import nn
import numpy as np
import warnings
import os
import json
from ise.models.predictors.lstm import LSTM
[docs]
class DeepEnsemble(nn.Module):
def __init__(self, ensemble_members=None, input_size=83, output_size=1, num_ensemble_members=3, output_sequence_length=86, latent_dim=1):
super(DeepEnsemble, self).__init__()
self.input_size = input_size + latent_dim
self.output_size = output_size
self.output_sequence_length = output_sequence_length
self.loss_choices = [torch.nn.MSELoss(), torch.nn.L1Loss(), torch.nn.HuberLoss()]
# Initialize ensemble members
if not ensemble_members:
self.ensemble_members = [
LSTM(
lstm_num_layers=np.random.randint(1, 3),
lstm_hidden_size=np.random.choice([512, 256, 128, 64]),
criterion=np.random.choice(self.loss_choices),
input_size=self.input_size,
output_size=self.output_size,
output_sequence_length=self.output_sequence_length,
)
for _ in range(num_ensemble_members)
]
elif isinstance(ensemble_members, list) and all(isinstance(m, LSTM) for m in ensemble_members):
self.ensemble_members = ensemble_members
else:
raise ValueError("ensemble_members must be a list of LSTM instances")
# Check if all ensemble members are trained
self.trained = all([member.trained for member in self.ensemble_members])
[docs]
def forward(self, x):
"""
Performs a forward pass through the ensemble.
Args:
- x: Input data.
Returns:
- mean_prediction: Mean prediction across ensemble members.
- epistemic_uncertainty: Standard deviation across ensemble predictions.
"""
if not self.trained:
warnings.warn("This model has not been trained. Predictions may be inaccurate.")
preds = torch.cat([member.predict(x).unsqueeze(1) for member in self.ensemble_members], dim=1)
mean_prediction = preds.mean(dim=1).squeeze()
epistemic_uncertainty = preds.std(dim=1).squeeze()
return mean_prediction, epistemic_uncertainty
[docs]
def predict(self, x):
"""
Makes predictions using the ensemble.
Args:
- x: Input data.
Returns:
- Tuple[Tensor, Tensor]: Mean predictions and uncertainty estimates.
"""
self.eval()
return self.forward(x)
[docs]
def fit(self, X, y, X_val=None, y_val=None, save_checkpoints=True, checkpoint_path='checkpoint_ensemble', early_stopping=False, epochs=100, batch_size=128, sequence_length=5, patience=10, verbose=True):
"""
Trains the ensemble with optional early stopping.
Args:
- X, y: Training data.
- early_stopping (bool): Use early stopping. Defaults to False.
"""
if self.trained:
warnings.warn("Model already trained. Proceeding to train again.")
for i, member in enumerate(self.ensemble_members):
if verbose:
print(f"Training Ensemble Member {i+1} of {len(self.ensemble_members)}:")
member.fit(X, y, X_val=X_val, y_val=y_val, epochs=epochs, batch_size=batch_size, sequence_length=sequence_length, save_checkpoints=save_checkpoints, checkpoint_path=f'{checkpoint_path}_member{i+1}.pth', early_stopping=early_stopping, patience=patience, verbose=verbose)
print("")
self.trained = True
[docs]
def save(self, model_path):
if not self.trained:
raise ValueError("Train the model before saving.")
# Ensure the save directory is based on model_path
model_dir = os.path.dirname(model_path)
os.makedirs(model_dir, exist_ok=True)
ensemble_dir = os.path.join(model_dir, "ensemble_members")
os.makedirs(ensemble_dir, exist_ok=True)
# Prepare metadata for each ensemble member with paths relative to the model directory
metadata = {
"model_type": self.__class__.__name__,
"version": "1.0",
"device": "cuda" if torch.cuda.is_available() else "cpu",
"ensemble_members": [
{
"lstm_num_layers": member.lstm_num_layers,
"lstm_num_hidden": member.lstm_num_hidden,
"criterion": member.criterion.__class__.__name__,
"input_size": member.input_size,
"output_size": member.output_size,
"trained": member.trained,
"path": os.path.join("ensemble_members", f"member_{i+1}.pth"),
"best_loss": float(member.best_loss),
"epochs_trained": int(member.epochs_trained),
}
for i, member in enumerate(self.ensemble_members)
],
}
# Save metadata file in the same directory as the model
metadata_path = model_path.replace(".pth", "_metadata.json")
with open(metadata_path, "w") as file:
json.dump(metadata, file, indent=4)
print(f"Model metadata saved to {metadata_path}")
# Save the state dictionary of the ensemble model
torch.save(self.state_dict(), model_path)
print(f"Model parameters saved to {model_path}")
# Save each ensemble member’s state dict in the ensemble directory
for i, member in enumerate(self.ensemble_members):
member_path = os.path.join(ensemble_dir, f"member_{i+1}.pth")
torch.save(member.state_dict(), member_path)
print(f"Ensemble Member {i+1} saved to {member_path}")
print('Removing checkpoints after saving to model directory...')
[os.remove(member.checkpoint_path) for member in self.ensemble_members if hasattr(member, "checkpoint_path")]
[docs]
@classmethod
def load(cls, model_path):
metadata_path = model_path.replace(".pth", "_metadata.json")
model_dir = os.path.dirname(model_path)
with open(metadata_path, "r") as file:
metadata = json.load(file)
if cls.__name__ != metadata["model_type"]:
raise ValueError(f"Metadata type {metadata['model_type']} does not match {cls.__name__}")
loss_lookup = {"MSELoss": torch.nn.MSELoss(), "L1Loss": torch.nn.L1Loss(), "HuberLoss": torch.nn.HuberLoss()}
ensemble_members = []
# Load each ensemble member from the same directory
for member_metadata in metadata["ensemble_members"]:
member_path = os.path.join(model_dir, member_metadata["path"])
if not os.path.isfile(member_path):
raise FileNotFoundError(f"Ensemble member file not found: {member_path}")
criterion = loss_lookup[member_metadata["criterion"]]
member = LSTM(
lstm_num_layers=member_metadata["lstm_num_layers"],
lstm_hidden_size=member_metadata["lstm_num_hidden"],
input_size=member_metadata["input_size"],
output_size=member_metadata["output_size"],
criterion=criterion,
)
state_dict = torch.load(member_path, map_location="cpu" if not torch.cuda.is_available() else None)
member.load_state_dict(state_dict)
member.trained = True
member.eval()
ensemble_members.append(member)
model = cls(ensemble_members=ensemble_members)
ensemble_state_dict = torch.load(model_path, map_location="cpu" if not torch.cuda.is_available() else None)
model.load_state_dict(ensemble_state_dict, strict=False)
model.eval()
return model