Source code for locpix.img_processing.training.train
"""Training module.
This module contains definitions relevant to
training the model.
"""
import torch
import wandb
[docs]
def train_loop(
epochs, model, optimiser, train_loader, val_loader, loss_fn, device, model_path
):
"""Function defining the training loop
Args:
epochs (int): Number of epochs to train for
optimiser (torch.optim optimiser): Optimiser
that controls training of the model
model (torch geometric model): Model that
is going to be trained
train_loader (torch dataloader): Dataloader for the
training dataset
val_loader (torch dataloader): Dataloader for the
validation dataset
loss_fn (loss function): Loss function calculate loss
between train and output
device (gpu or cpu): Device to train the model
on
model_path (string) : Where to save model to"""
model.to(device)
scaler = torch.cuda.amp.GradScaler()
best_loss = 1e12
save_epoch = 0
for epoch in range(epochs):
print("Epoch: ", epoch)
# TODO : autocast look at gradient accumulation and take care with multiple gpus
# training data
running_train_loss = 0
train_items = 0
for index, data in enumerate(train_loader):
model.train()
img, label = data
# note set to none is meant to have less memory footprint
optimiser.zero_grad(set_to_none=True)
# move data to device
img = img.to(device)
label = label.to(device)
# make label same shape
label = torch.unsqueeze(label, 1)
# forward pass - with autocasting
with torch.autocast(device_type="cuda", dtype=torch.float32):
# data.x = data.x.float()
output = model(img)
loss = loss_fn(output, label)
running_train_loss += loss
train_items += img.shape[0]
# scales loss - calls backward on scaled loss creating scaled gradients
scaler.scale(loss).backward()
# unscales the gradients of optimiser then optimiser.step is called
scaler.step(optimiser)
# update scale for next iteration
scaler.update()
# val data
# TODO: make sure torch.no_grad() somewhere
running_val_loss = 0
val_items = 0
for index, data in enumerate(val_loader):
with torch.no_grad():
# make sure model in eval mode
model.eval()
img, label = data
# note set to none is meant to have less memory footprint
optimiser.zero_grad(set_to_none=True)
# move data to device
img = img.to(device)
label = label.to(device)
# make label same shape
label = torch.unsqueeze(label, 1)
# forward pass - with autocasting
with torch.autocast(device_type="cuda"):
output = model(img)
loss = loss_fn(output, label)
running_val_loss += loss
val_items += img.shape[0]
running_val_loss /= val_items
wandb.log({"train loss": running_train_loss, "val loss": running_val_loss})
if running_val_loss < best_loss:
print("Saving best model")
best_loss = running_val_loss
torch.save(model.state_dict(), model_path)
save_epoch = epoch
# load in best model and return
model.load_state_dict(model.state_dict())
wandb.log({"save epoch": save_epoch})
return model