Coverage for src / autoencodix / data / _image_dataset.py: 39%
33 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
1import torch
2import pandas as pd
3from autoencodix.configs.default_config import DefaultConfig
4from autoencodix.data._numeric_dataset import TensorAwareDataset
5from typing import Any, List, Tuple, Optional, Dict
6from autoencodix.data._imgdataclass import ImgData
7from autoencodix.base._base_dataset import DataSetTypes
10class ImageDataset(TensorAwareDataset):
11 """
12 A custom PyTorch dataset that handles image data with proper dtype conversion.
15 Attributes:
16 raw_data: List of ImgData objects containing original image data and metadata.
17 config: Configuration object for dataset settings.
18 mytype: Enum indicating the dataset type (set to DataSetTypes.IMG).
19 data: List of image tensors converted to the appropriate dtype.
20 sample_ids: List of identifiers for each sample.
21 split_indices: Optional numpy array of indices for splitting the dataset.
22 feature_ids: Optional list of identifiers for each feature (set to None for images).
23 metadata: Optional pandas DataFrame containing additional metadata.
24 """
26 def __init__(
27 self,
28 data: List[ImgData],
29 config: DefaultConfig,
30 split_indices: Optional[Dict[str, Any]] = None,
31 metadata: Optional[pd.DataFrame] = None,
32 ):
33 """
34 Initialize the dataset
35 Args:
36 data: List of image data objects
37 config: Configuration object
38 """
39 self.raw_data = data # image data before conversion to keep original infos
40 self.config = config
41 self.mytype = DataSetTypes.IMG
43 if self.config is None:
44 raise ValueError("config cannot be None")
46 # Convert all images to tensors with proper dtype once during initialization
47 target_dtype = self._get_target_dtype()
48 self.data = self._convert_all_images_to_tensors(target_dtype)
50 # Extract sample_ids for consistency
51 self.sample_ids = [img_data.sample_id for img_data in data]
53 self.split_indices = split_indices
54 self.feature_ids = None
55 self.metadata = metadata
57 def _convert_all_images_to_tensors(self, dtype: torch.dtype) -> List[torch.Tensor]:
58 """
59 Convert all images to tensors with specified dtype during initialization.
61 Args:
62 dtype: Target dtype for the tensors
64 Returns:
65 List of converted image tensors
66 """
67 print(f"Converting {len(self.raw_data)} images to {dtype} tensors...")
68 converted_data = []
70 for img_data in self.raw_data:
71 tensor = self._to_tensor(img_data.img, dtype)
72 converted_data.append(tensor)
74 return converted_data
76 def __len__(self):
77 return len(self.data)
79 def __getitem__(self, idx):
80 """Get item at index - data is already converted to proper dtype
81 Returns:
82 Tuple of (index, image tensor, sample_id)
83 """
84 return idx, self.data[idx], self.sample_ids[idx]
86 def get_input_dim(self) -> Tuple[int, ...]:
87 """
88 Gets the input dimension of the dataset's feature space.
90 Returns:
91 The input dimension of the dataset's feature space
92 """
93 return self.data[0].shape # All images should have the same shape