Coverage for src / autoencodix / data / _image_dataset.py: 39%

33 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-21 10:09 +0200

1import torch 

2import pandas as pd 

3from autoencodix.configs.default_config import DefaultConfig 

4from autoencodix.data._numeric_dataset import TensorAwareDataset 

5from typing import Any, List, Tuple, Optional, Dict 

6from autoencodix.data._imgdataclass import ImgData 

7from autoencodix.base._base_dataset import DataSetTypes 

8 

9 

10class ImageDataset(TensorAwareDataset): 

11 """ 

12 A custom PyTorch dataset that handles image data with proper dtype conversion. 

13 

14 

15 Attributes: 

16 raw_data: List of ImgData objects containing original image data and metadata. 

17 config: Configuration object for dataset settings. 

18 mytype: Enum indicating the dataset type (set to DataSetTypes.IMG). 

19 data: List of image tensors converted to the appropriate dtype. 

20 sample_ids: List of identifiers for each sample. 

21 split_indices: Optional numpy array of indices for splitting the dataset. 

22 feature_ids: Optional list of identifiers for each feature (set to None for images). 

23 metadata: Optional pandas DataFrame containing additional metadata. 

24 """ 

25 

26 def __init__( 

27 self, 

28 data: List[ImgData], 

29 config: DefaultConfig, 

30 split_indices: Optional[Dict[str, Any]] = None, 

31 metadata: Optional[pd.DataFrame] = None, 

32 ): 

33 """ 

34 Initialize the dataset 

35 Args: 

36 data: List of image data objects 

37 config: Configuration object 

38 """ 

39 self.raw_data = data # image data before conversion to keep original infos 

40 self.config = config 

41 self.mytype = DataSetTypes.IMG 

42 

43 if self.config is None: 

44 raise ValueError("config cannot be None") 

45 

46 # Convert all images to tensors with proper dtype once during initialization 

47 target_dtype = self._get_target_dtype() 

48 self.data = self._convert_all_images_to_tensors(target_dtype) 

49 

50 # Extract sample_ids for consistency 

51 self.sample_ids = [img_data.sample_id for img_data in data] 

52 

53 self.split_indices = split_indices 

54 self.feature_ids = None 

55 self.metadata = metadata 

56 

57 def _convert_all_images_to_tensors(self, dtype: torch.dtype) -> List[torch.Tensor]: 

58 """ 

59 Convert all images to tensors with specified dtype during initialization. 

60 

61 Args: 

62 dtype: Target dtype for the tensors 

63 

64 Returns: 

65 List of converted image tensors 

66 """ 

67 print(f"Converting {len(self.raw_data)} images to {dtype} tensors...") 

68 converted_data = [] 

69 

70 for img_data in self.raw_data: 

71 tensor = self._to_tensor(img_data.img, dtype) 

72 converted_data.append(tensor) 

73 

74 return converted_data 

75 

76 def __len__(self): 

77 return len(self.data) 

78 

79 def __getitem__(self, idx): 

80 """Get item at index - data is already converted to proper dtype 

81 Returns: 

82 Tuple of (index, image tensor, sample_id) 

83 """ 

84 return idx, self.data[idx], self.sample_ids[idx] 

85 

86 def get_input_dim(self) -> Tuple[int, ...]: 

87 """ 

88 Gets the input dimension of the dataset's feature space. 

89 

90 Returns: 

91 The input dimension of the dataset's feature space 

92 """ 

93 return self.data[0].shape # All images should have the same shape