import torch
from torch import nn, optim
import torch.nn.functional as F
import warnings
import pandas as pd
import numpy as np
import os
from ise.utils.functions import to_tensor
from ise.data.dataclasses import EmulatorDataset
from ise.utils.training import CheckpointSaver, EarlyStoppingCheckpointer
[docs]
class LSTM(nn.Module):
def __init__(
self,
lstm_num_layers,
lstm_hidden_size,
input_size=83,
output_size=1,
criterion=torch.nn.MSELoss(),
output_sequence_length=86,
optimizer=optim.Adam
):
super(LSTM, self).__init__()
# Initialize attributes
self.lstm_num_layers = int(lstm_num_layers)
self.lstm_num_hidden = int(lstm_hidden_size)
self.input_size = input_size
self.output_size = output_size
self.output_sequence_length = output_sequence_length
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.to(self.device)
# Initialize model layers
self.lstm = nn.LSTM(
input_size=input_size,
hidden_size=int(lstm_hidden_size),
batch_first=True,
num_layers=lstm_num_layers,
)
self.relu = nn.ReLU()
self.linear1 = nn.Linear(in_features=lstm_hidden_size, out_features=32)
self.linear_out = nn.Linear(in_features=32, out_features=output_size)
# Initialize optimizer and other components
self.optimizer = optimizer(self.parameters())
self.dropout = nn.Dropout(p=0.2)
self.criterion = criterion
self.trained = False
[docs]
def forward(self, x):
batch_size = x.shape[0]
h0 = (
torch.zeros(self.lstm_num_layers, batch_size, self.lstm_num_hidden)
.requires_grad_()
.to(self.device)
)
c0 = (
torch.zeros(self.lstm_num_layers, batch_size, self.lstm_num_hidden)
.requires_grad_()
.to(self.device)
)
_, (hn, _) = self.lstm(x, (h0, c0))
x = hn[-1, :, :]
# Perform linear layer operations
x = self.linear1(x)
x = self.relu(x)
x = self.linear_out(x)
return x
[docs]
def fit(
self, X, y, epochs=100, sequence_length=5, batch_size=64, criterion=None, X_val=None, y_val=None, save_checkpoints=True, checkpoint_path='checkpoint.pt', early_stopping=False,
patience=10, verbose=True, dataclass=EmulatorDataset,
):
X, y = to_tensor(X).to(self.device), to_tensor(y).to(self.device)
if y.ndimension() == 1:
y = y.unsqueeze(1)
# Check if a checkpoint exists and load it
start_epoch = 1
best_loss = float("inf")
self.checkpoint_path = checkpoint_path
if os.path.exists(checkpoint_path):
checkpoint = torch.load(checkpoint_path)
self.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch'] + 1
best_loss = checkpoint.get('best_loss', float("inf"))
if verbose:
print(f"Resuming from checkpoint at epoch {start_epoch} with validation loss {best_loss:.6f}")
# Check if validation data is provided
if X_val is not None and y_val is not None:
validate = True
if not early_stopping:
warnings.warn(
"Validation data provided but early_stopping is False. Early stopping is recommended for validation data."
)
X_val, y_val = to_tensor(X_val).to(self.device), to_tensor(y_val).to(self.device)
else:
validate = False
# Set loss criterion
if criterion is not None:
self.criterion = criterion.to(self.device)
elif criterion is None and self.criterion is None:
raise ValueError("loss must be provided if criterion is None.")
self.criterion = self.criterion.to(self.device)
# Convert data to numpy arrays if pandas DataFrames
if isinstance(X, pd.DataFrame):
X = X.values
if isinstance(y, pd.DataFrame):
y = y.values
# Create dataset and data loader
dataset = dataclass(X, y, sequence_length=sequence_length, projection_length=self.output_sequence_length)
data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
# Set model to training mode
self.train()
self.to(self.device)
# Initialize early stopping
if save_checkpoints:
if early_stopping:
checkpointer = EarlyStoppingCheckpointer(self, self.optimizer, checkpoint_path, patience, verbose)
else:
checkpointer = CheckpointSaver(self, self.optimizer, checkpoint_path, verbose)
checkpointer.best_loss = best_loss
else:
checkpointer = None
# Training loop
if start_epoch < epochs:
for epoch in range(start_epoch, epochs + 1):
self.train()
batch_losses = []
for i, (x, y) in enumerate(data_loader):
x = x.to(self.device)
y = y.to(self.device)
self.optimizer.zero_grad()
y_pred = self.forward(x)
loss = self.criterion(y_pred, y) # Renamed to 'loss' for clarity
loss.backward()
self.optimizer.step()
batch_losses.append(loss.item())
# Print average batch loss and validation loss (if provided)
if validate:
val_preds = self.predict(
X_val, sequence_length=sequence_length, batch_size=batch_size
).to(self.device)
val_loss = F.mse_loss(val_preds.squeeze(), y_val.squeeze())
if save_checkpoints:
checkpointer(val_loss, epoch)
if hasattr(checkpointer, "early_stop") and checkpointer.early_stop:
if verbose:
print("Early stopping")
break
if verbose:
print(f"[epoch/total]: [{epoch}/{epochs}], train loss: {sum(batch_losses) / len(batch_losses)}, val mse: {val_loss:.6f} -- {getattr(checkpointer, 'log', '') if checkpointer is not None else ''}")
else:
average_batch_loss = sum(batch_losses) / len(batch_losses)
if verbose:
print(f"[epoch/total]: [{epoch}/{epochs}], train loss: {average_batch_loss}")
else:
if verbose:
print(f"Training already completed ({epochs}/{epochs}).")
self.trained = True
# loads best model
if save_checkpoints:
checkpoint = torch.load(checkpoint_path)
if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint.keys():
self.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
self.best_loss = checkpoint['best_loss']
self.epochs_trained = checkpoint['epoch']
else:
self.load_state_dict(checkpoint)
# os.remove(checkpoint_path)
[docs]
def predict(self, X, sequence_length=5, batch_size=64, dataclass=EmulatorDataset):
self.eval()
self.to(self.device)
# Convert data to numpy array if pandas DataFrame
if isinstance(X, pd.DataFrame):
X = X.values
# Create dataset and data loader
dataset = dataclass(X, y=None, sequence_length=sequence_length, projection_length=self.output_sequence_length)
data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False)
preds = torch.tensor([]).to(self.device)
for X_test_batch in data_loader:
self.eval()
X_test_batch = X_test_batch.to(self.device)
y_pred = self.forward(X_test_batch)
preds = torch.cat((preds, y_pred), 0)
return preds