Source code for locpix.img_processing.data_loading.dataset
"""Dataset module
This module defines the dataset class for SMLM image data"""
import torch
from torch.utils.data import Dataset
import numpy as np
import os
import tifffile
from . import transforms
import torchvision.transforms as T
[docs]
class ImgDataset(Dataset):
"""Pytorch dataset for the SMLM data represented as images
Attributes:
"""
[docs]
def __init__(self, input_root, files, transform, train=False, mean=0, std=1):
"""
Args:
input_root (string) : Directory containing the SMLM data and masks
files (list) : List of the files to include from
the directory in this dataset
transform (dictionary) : Transforms to apply to
the dataset
"""
self.input_data = [os.path.join(input_root, file + ".tif") for file in files]
self.label_data = [
os.path.join(input_root, file + "_masks.tif") for file in files
]
self.input_data, self.label_data = zip(
*sorted(zip(self.input_data, self.label_data))
)
if train:
# calculate mean and standard deviation
for index, file in enumerate(self.input_data):
image = tifffile.imread(file)
if index == 0:
output_image = image
else:
output_image = np.concatenate((output_image, image))
self.mean = np.mean(output_image, axis=(0, 1))
self.std = np.std(output_image, axis=(0, 1))
else:
self.mean = mean
self.std = std
# define transforms
output_transforms = []
# to tensor
output_transforms.append(T.ToTensor())
# random rotation
if "rotation" in transform.keys():
output_transforms.append(T.RandomRotation(transform["rotation"]))
# random horizontal flip
if "h_flip" in transform.keys():
output_transforms.append(T.RandomHorizontalFlip())
# random vertical flip
if "v_flip" in transform.keys():
output_transforms.append(T.RandomVerticalFlip())
# random erasing
if "erasing" in transform.keys():
output_transforms.append(T.RandomErasing())
# random perspective
if "perspective" in transform.keys():
output_transforms.append(T.RandomPerspective(transform["perspective"]))
# convert to float32
if "dtypeconv" in transform.keys():
self.transform = transforms.transform(
self.mean, self.std, output_transforms, dtypeconv=True
)
else:
self.transform = transforms.transform(
self.mean, self.std, output_transforms, dtypeconv=False
)
def __getitem__(self, idx):
"""Returns an item from the dataset, according to index idx
Args:
idx (int or other) : Index of the data to retrieve"""
if torch.is_tensor(idx):
idx = idx.tolist()
input_path = self.input_data[idx]
label_path = self.label_data[idx]
input = tifffile.imread(input_path)
label = tifffile.imread(label_path)
input, label = self.transform(input, label)
return input, label
def __len__(self):
"""Length of the dataset
Args:
None"""
return len(self.input_data)