Source code for locpix.img_processing.data_loading.transforms

"""Module defining transformations to apply to data"""

from torchvision import transforms
import numpy as np
import torch


[docs] class transform: """Wrapper for transforms to allow input and label to be transformed together"""
[docs] def __init__(self, mean, std, transform_list, dtypeconv=False): """Args: mean (float) : Mean for normalisation of image std (float) : Std for normalisation of image transform_list (list) : List of transforms to be applied dtypeconv (string) : Whether to convert image""" self.mean = mean self.std = std self.transform = transforms.Compose(transform_list) self.dtypeconv = dtypeconv
def __call__(self, input, label): """Args: input (numpy array) : Input histogram label (numpy array) : Histogram with labels""" input = (input - self.mean) / (self.std) data = np.stack((input, label), axis=-1) data = self.transform(data) if self.dtypeconv is True: data = data.to(torch.float32) input = data[:-1] label = data[-1] return input, label