## PSA: This script is not ready for general use, as it still hardcoded to the specific use case of the project it was developed for.
# Standard library imports
import os
import time
import warnings
from queue import Queue
from threading import Thread
from typing import Union, Optional, Tuple, List, Callable
# Dependency imports
import torch
from torch.utils.data import DataLoader, IterableDataset#, TensorDataset
from torchvision.io import read_image
from torchvision.io.image import ImageReadMode
# Backend imports
from .implicit_mount import RemotePathIterator
[docs]
class RemotePathDataset(IterableDataset):
'''
Creates a :py:class:`torch.utils.data.IterableDataset` from a :py:class:`pyremotedata.implicit_mount.RemotePathIterator`.
By default the dataset will return the image as a tensor and the remote path as a string.
**Hierarchical mode**
If `hierarchical` >= 1, the dataset is in "Hierarchical mode" and will return the image as a tensor and the label as a list of integers (class indices for each level in the hierarchy).
The `class_handles` property can be used to get the class-idx mappings for the dataset.
By default the dataset will use a parser which assumes that the hierarchical levels are encoded in the remote path as directories like so:
`.../level_n/.../level_1/level_0/image.jpg`
Where `n = (hierarchical - 1)` and `level_0` is the leaf level.
Args:
remote_path_iterator (RemotePathIterator): The :py:class:`pyremotedata.implicit_mount.RemotePathIterator` to create the dataset from.
prefetch (int): The number of items to prefetch from the :py:class:`pyremotedata.implicit_mount.RemotePathIterator`.
transform (callable, optional): A function/transform that takes in an image as a :py:class:`torch.Tensor` and returns a transformed version.
target_transform (callable, optional): A function/transform that takes in the label (after potential parsing by `parse_hierarchical`) and transforms it.
device (torch.device, optional): The device to move the tensors to.
dtype (torch.dtype, optional): The data type to convert the tensors to.
hierarchical (int, optional): The number of hierarchical levels to use for the labels. Default: 0, i.e. no hierarchy.
hierarchy_parser (callable, optional): A function to parse the hierarchical levels from the remote path. Default: None, i.e. use the default parser.
return_remote_path (bool, optional): Whether to return the remote path. Default: False.
return_local_path (bool, optional): Whether to return the local path. Default: False.
verbose (bool, optional): Whether to print verbose output. Default: False.
Yields:
(tuple): A tuple containing the following elements:
- (torch.Tensor): The image as a tensor.
- (Union[str, List[int]]): The label as the remote path or as a list of class indices.
- (Optional[str]): The local path, if `return_local_path` is True.
- (Optional[str]): The remote path, if `return_remote_path` is True.
'''
def __init__(
self,
remote_path_iterator : "RemotePathIterator",
prefetch: int=64,
transform : Optional[Callable]=None,
target_transform : Optional[Callable]=None,
device: Union["torch.device", None]=None,
dtype: Union[torch.dtype, None]=None,
hierarchical: int=0,
hierarchy_parser: Optional[Callable]=None,
shuffle: bool=False,
return_remote_path: bool=False,
return_local_path: bool=False,
verbose: bool=False
):
# Check if remote_path_iterator is of type RemotePathIterator
if not isinstance(remote_path_iterator, RemotePathIterator):
raise ValueError("Argument remote_path_iterator must be of type pyremotedata.implicit_mount.RemotePathIterator.")
# Check if prefetch is an integer
if not isinstance(prefetch, int):
raise ValueError("Argument prefetch must be an integer.")
# Check if prefetch is greater than 0
if prefetch < 1:
raise ValueError("Argument prefetch must be greater than 0.")
## General parameters
assert isinstance(verbose, bool), ValueError("Argument verbose must be a boolean.")
self.verbose : bool = verbose
## PyTorch specific parameters
# Get the classes and their indices
self.hierarchical : int = hierarchical
if self.hierarchical < 1:
self.hierarchical = False
def error_hierarchy(*args, **kwargs):
raise ValueError("Hierarchical mode disabled (`hierarchical` < 1), but `pyremotedata.dataloader.RemotePathDataset.parse_hierarchy` function was called.")
self.parse_hierarchy = error_hierarchy
else:
if hierarchy_parser is None:
self.parse_hierarchy = lambda path: path.split('/')[-(1 + self.hierarchical):-1]
elif callable(hierarchy_parser):
self.parse_hierarchy = hierarchy_parser
else:
raise ValueError("Argument `hierarchy_parser` must be a callable or None.")
self.classes = [[] for _ in range(self.hierarchical)]
self.n_classes = [0 for _ in range(self.hierarchical)]
self.class_to_idx = [{} for _ in range(self.hierarchical)]
self.idx_to_class = [{} for _ in range(self.hierarchical)]
for path in remote_path_iterator.remote_paths:
for level, cls in enumerate(reversed(self.parse_hierarchy(path))):
if cls in self.classes[level]:
continue
if level >= self.hierarchical:
raise ValueError(f"Error parsing class from {path}. Got {cls} at level {level}, but the number of specified hierarchical levels is {self.hierarchical} with levels 0-{self.hierarchical-1}.")
self.classes[level].append(cls)
for level in range(self.hierarchical):
# self.classes[level] = sorted(list(set([path.split('/')[-2-level] for path in remote_path_iterator.remote_paths])))
self.classes[level] = sorted(self.classes[level])
self.n_classes[level] = len(self.classes[level])
self.class_to_idx[level] = {self.classes[level][i]: i for i in range(len(self.classes[level]))}
self.idx_to_class[level] = {i: self.classes[level][i] for i in range(len(self.classes[level]))}
# Set the transforms
self.transform = transform
self.target_transform = target_transform
# Set the device and dtype
self.device = device
self.dtype = dtype
## Backend specific parameters
# Store the remote_path_iterator backend
self.remote_path_iterator = remote_path_iterator
self.return_remote_path = return_remote_path
self.return_local_path = return_local_path
self.shuffle = shuffle
## Multi-threading parameters
# Set the number of workers (threads) for (multi)processing
self.num_workers = 1
# We don't want to start the buffer filling thread until the dataloader is called for iteration
self.producer_thread = None
# We need to keep track of whether the buffer filling thread has been initiated or not
self.thread_initiated = False
# Initialize the worker threads
self.consumer_threads = []
self.consumers = 0
self.stop_consumer_threads = True
# Set the buffer filling parameters (Watermark Buffering)
self.buffer_minfill, self.buffer_maxfill = 0.4, 0.6
# Initialize the buffers
self.buffer = Queue(maxsize=prefetch) # Tune maxsize as needed
self.processed_buffer = Queue(maxsize=prefetch) # Tune maxsize as needed
@property
def class_handles(self):
return {
'classes': self.classes,
'n_classes': self.n_classes,
'class_to_idx': self.class_to_idx,
'idx_to_class': self.idx_to_class,
'hierarchical': self.hierarchical
}
@class_handles.setter
def class_handles(self, value : dict):
if not isinstance(value, dict):
raise ValueError("Argument value must be a dictionary.")
if value["hierarchical"]:
assert isinstance(value['classes'], list), ValueError("Argument value['classes'] must be a list, when hierarchical is True.")
else:
assert not isinstance(value['classes'], list), ValueError("Argument value['classes'] must not be a list, when hierarchical is False.")
self.classes = value['classes']
self.n_classes = value['n_classes']
self.class_to_idx = value['class_to_idx']
self.idx_to_class = value['idx_to_class']
self.hierarchical = value['hierarchical']
def _shuffle(self):
if not self.shuffle:
raise RuntimeError("Shuffle called, but shuffle is set to False.")
if self.thread_initiated:
raise RuntimeError("Shuffle called, but buffer filling thread is still active.")
self.remote_path_iterator.shuffle()
def _shutdown_and_reset(self):
# Handle shutdown logic (Should probably be moved to a dedicated reset function, that is called on StopIteration instead or perhaps in __iter__)
self.stop_consumer_threads = True # Signal the consumer threads to stop
for i, consumer in enumerate(self.consumer_threads):
if consumer is None:
continue
if not consumer.is_alive():
continue
print(f"Waiting for worker {i} to finish.")
consumer.join(timeout = 1 / self.num_workers) # Wait for the consumer thread to finish
if self.producer_thread is not None:
self.producer_thread.join(timeout=1) # Wait for the producer thread to finish
self.producer_thread = None # Reset the producer thread
self.consumer_threads = [] # Reset the consumer threads
self.consumers = 0 # Reset the number of consumers
self.thread_initiated = False # Reset the thread initiated flag
self.buffer.queue.clear() # Clear the buffer
self.processed_buffer.queue.clear() # Clear the processed buffer
self.remote_path_iterator.__del__(force=True) # Close the remote_path_iterator and clean the temporary directory
assert self.buffer.qsize() == 0, RuntimeError("Buffer not empty after iterator end.")
assert self.processed_buffer.qsize() == 0, RuntimeError("Processed buffer not empty after iterator end.")
def _init_buffer(self):
# Check if the buffer filling thread has been initiated
if not self.thread_initiated:
# Start the buffer filling thread
self.producer_thread = Thread(target=self._fill_buffer)
self.producer_thread.daemon = True
self.producer_thread.start()
# Set the flag to indicate that the thread has been initiated
self.thread_initiated = True
else:
# Raise an error if the buffer filling thread has already been initiated
raise RuntimeError("Buffer filling thread already initiated.")
def _fill_buffer(self):
# Calculate the min and max fill values for the buffer
min_fill = int(self.buffer.maxsize * self.buffer_minfill)
max_fill = int(self.buffer.maxsize * self.buffer_maxfill)
if self.verbose:
print(f"Buffer min fill: {min_fill}, max fill: {max_fill}")
print("Producer thread started.")
wait_for_min_fill = False
for item in self.remote_path_iterator:
# Get the current size of the buffer
current_buffer_size = self.buffer.qsize()
# Decide whether to fill the buffer based on its current size
if not wait_for_min_fill:
wait_for_min_fill = (current_buffer_size >= max_fill)
# Sleep logic which ensures that the buffer doesn't switch between filling and not filling too often (Watermark Buffering)
while wait_for_min_fill:
time.sleep(0.1)
current_buffer_size = self.buffer.qsize() # Update the current buffer size
wait_for_min_fill = current_buffer_size >= min_fill # Wait until the buffer drops below min_fill
# Fill the buffer
self.buffer.put(item)
if self.verbose:
print("Producer signalling end of iterator.")
# Signal the end of the iterator to the consumers by putting None in the buffer until all consumer threads have finished
while self.consumers > 0:
time.sleep(0.01)
self.buffer.put(None)
if self.verbose:
print("Producer emptying buffer.")
# Wait for the consumer threads to finish then clear the buffer
while self.buffer.qsize() > 0:
self.buffer.get()
self.processed_buffer.put(None) # Signal the end of the processed buffer to the main thread
if self.verbose:
print("Producer thread finished.")
def _process_buffer(self):
if self.verbose:
print("Consumer thread started.")
self.consumers += 1
while True:
qsize = self.buffer.qsize()
while qsize < (self.num_workers * 2):
time.sleep(0.05 / self.num_workers)
qsize = self.buffer.qsize()
if self.producer_thread is None or not self.producer_thread.is_alive():
break
# Get the next item from the buffer
item = self.buffer.get()
if self.verbose:
print("Consumer thread got item")
# Check if the buffer is empty, signaling the end of the iterator
if item is None or self.stop_consumer_threads:
break # Close the thread
# Preprocess the item (e.g. read image, apply transforms, etc.) and put it in the processed buffer
processed_item = self.parse_item(*item) if item is not None else None
if processed_item is not None:
self.processed_buffer.put(processed_item)
if self.verbose:
print("Consumer thread processed item")
self.consumers -= 1
if self.verbose:
print("Consumer thread finished.")
def __iter__(self):
# Check if the buffer filling thread has been initiated
# If it has, reset the dataloader state and close all threads
# (Only one iteration is allowed per dataloader instance)
if self.thread_initiated:
warnings.warn("Iterator called, but buffer filling thread is still active. Resetting the dataloader state.")
self._shutdown_and_reset()
# If shuffle is set to True, shuffle the remote_path_iterator
if self.shuffle:
self._shuffle()
# Initialize the buffer filling thread
self._init_buffer()
# Check number of workers
if self.num_workers == 0:
self.num_workers = 1
if self.num_workers < 1:
raise ValueError("Number of workers must be greater than 0.")
self.stop_consumer_threads = False
# Start consumer threads for processing
for _ in range(self.num_workers):
consumer_thread = Thread(target=self._process_buffer)
consumer_thread.daemon = True
consumer_thread.start()
self.consumer_threads.append(consumer_thread)
return self
def __next__(self):
# Fetch from processed_buffer instead
processed_item = self.processed_buffer.get()
# Check if the processed buffer is empty, signaling the end of the iterator
if processed_item is None:
self._shutdown_and_reset()
raise StopIteration
# Restart crashed consumer threads
for thread in self.consumer_threads:
if not thread.is_alive():
self.consumer_threads.remove(thread)
new_consumer_thread = Thread(target=self._process_buffer)
new_consumer_thread.daemon = True
self.consumer_threads.append(new_consumer_thread)
# Otherwise, return the processed item
return processed_item
def __len__(self):
return len(self.remote_path_iterator)
def parse_item(self, local_path : str, remote_path : str) -> Union[Tuple[torch.Tensor, Union[str, List[int]]], Tuple[torch.Tensor, Union[str, List[int]], str], Tuple[torch.Tensor, Union[str, List[int]], str, str]]:
## Image processing
# Check if image format is supported (jpeg/jpg/png)
image_type = os.path.splitext(local_path)[-1]
if image_type not in ['.JPG', '.JPEG', '.PNG', '.jpg', '.jpeg', '.png']:
# raise ValueError(f"Image format of {remote_path} ({image_type}) is not supported.")
# Instead of raising an error, we can skip the image instead
return None
try:
image = read_image(local_path, mode=ImageReadMode.RGB)
except Exception as e:
print(f"Error reading image {remote_path} ({e}).")
return None
# Remove the alpha channel if present
if image.shape[0] == 4:
image = image[:3]
if image.shape[0] != 3:
print(f"Error reading image {remote_path}.")
return None
# Apply transforms (preprocessing)
if self.transform:
image = self.transform(image)
if self.dtype is not None:
image = image.to(dtype=self.dtype)
if self.device is not None:
image = image.to(device=self.device)
if self.hierarchical:
## Label processing
# Get the label by parsing the remote path
hierarchy = self.parse_hierarchy(remote_path)
# Transform the species name to the label index
label = [self.class_to_idx[level][cls] for level, cls in enumerate(hierarchy)]
if len(label) == 0:
raise ValueError(f"Error parsing label from {remote_path}.")
else:
label = remote_path
# Apply label transforms
if self.target_transform:
label = self.target_transform(label)
# TODO: Does the label need to be converted to a tensor and moved to the device? (Probably not) If so, does it need to be a long tensor?
# if self.device is not None:
# label = label.to(device=self.device)
## Return the image and label (and optionally the local and remote paths)
if self.return_local_path and self.return_remote_path:
return image, label, local_path, remote_path
elif self.return_local_path:
return image, label, local_path
elif self.return_remote_path:
return image, label, remote_path
else:
return image, label
[docs]
class RemotePathDataLoader(DataLoader):
"""
A custom :py:class:`torch.utils.data.DataLoader` for :py:class:`pyremotedata.dataloader.RemotePathDataset`.
This DataLoader is designed to work with :py:class:`pyremotedata.dataloader.RemotePathDataset` and does not support all the arguments of the standard :py:class:`torch.utils.data.DataLoader`.
Unsupported arguments:
- sampler
- batch_sampler
Args:
dataset (RemotePathDataset): The :py:class:`pyremotedata.dataloader.RemotePathDataset` dataset to load from.
num_workers (int, optional): The number of worker threads to use for loading. Default: 0. Must be greater than 0.
shuffle (bool, optional): Whether to shuffle the dataset between epochs. Default: False.
"""
def __init__(self, dataset: "RemotePathDataset", num_workers : int=0, shuffle : bool=False, *args, **kwargs):
# Snipe arguments from the user which would break the custom dataloader (e.g. sampler, shuffle, etc.)
unsupported_kwargs = ['sampler', 'batch_sampler']
for unzkw in unsupported_kwargs:
value = kwargs.pop(unzkw, None)
if value is not None:
warnings.warn(f"Argument {unzkw} is not supported in `pyremotedata.dataloader.RemotePathDataLoader`. {unzkw}={value} will be ignored.")
# Override the num_workers argument handling (default is 0)
dataset.num_workers = num_workers
# Override the shuffle argument handling (default is False)
dataset.shuffle = shuffle
if not isinstance(dataset, RemotePathDataset):
raise ValueError("Argument dataset must be of type `pyremotedata.dataloader.RemotePathDataset`.")
# Initialize the dataloader
super(RemotePathDataLoader, self).__init__(
shuffle=False,
sampler=None,
batch_sampler=None,
num_workers=0,
dataset=dataset,
*args,
**kwargs
)
# TODO: This cannot be set before calling super().__init__(), so perhaps it can be overridden after initialization instead
# def __setattr__(self, name, value):
# if name in ['batch_sampler', 'sampler']:
# raise (f"Changing {name} is not allowed in this custom DataLoader.")
# super(RemotePathDataLoader, self).__setattr__(name, value)