Coverage for src / autoencodix / base / _base_dataset.py: 67%

39 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-21 10:09 +0200

1import abc 

2import scipy as sp 

3from enum import Enum 

4from typing import Any, Dict, List, Optional, Tuple, Union 

5 

6import pandas as pd 

7import torch 

8from torch.utils.data import Dataset 

9 

10from autoencodix.data._imgdataclass import ImgData 

11 

12 

13class DataSetTypes(str, Enum): 

14 NUM = "NUM" 

15 IMG = "IMG" 

16 

17 

18class BaseDataset(abc.ABC, Dataset): 

19 """Interface to guide implementation for custom PyTorch datasets. 

20 

21 Attributes: 

22 data: The dataset content (can be a torch.Tensor or other data structure). 

23 config: Optional configuration object. 

24 sample_ids: Optional list of identifiers for each sample. 

25 feature_ids: Optional list of identifiers for each feature. 

26 mytype: Enum indicating the dataset type (should be set in subclasses). 

27 """ 

28 

29 def __init__( 

30 self, 

31 data: Union[torch.Tensor, List[ImgData], sp.sparse.spmatrix], 

32 config: Optional[Any] = None, 

33 sample_ids: Optional[List[Any]] = None, 

34 feature_ids: Optional[List[Any]] = None, 

35 ): 

36 """Initializes the dataset. 

37 

38 Args: 

39 data: The data to be used by the dataset. 

40 config: Optional configuration parameters. 

41 sample_ids: Optional identifiers for each sample. 

42 feature_ids: Optional identifiers for each feature. 

43 mytype: Enum indicating the dataset type (should be set in subclasses). 

44 """ 

45 self.data = data 

46 self.raw_data = data # for child class ImageDataset 

47 self.config = config 

48 self.sample_ids = sample_ids 

49 self.feature_ids = feature_ids 

50 self.mytype: Enum # Should be set in subclasses to indicate the dataset type (e.g., DataSetTypes.NUM or DataSetTypes.IMG) 

51 

52 self.metadata: Optional[Union[pd.Series, pd.DataFrame]] = (None,) 

53 self.datasets: Dict[str, BaseDataset] = {} # for xmodalix child 

54 

55 def __len__(self) -> int: 

56 """Returns the number of samples in the dataset. 

57 

58 Returns: 

59 The number of samples in the dataset. 

60 """ 

61 if isinstance(self.data, list): 

62 return len(self.data) 

63 else: 

64 return self.data.shape[0] 

65 

66 def get_input_dim(self) -> Union[int, Tuple[int, ...]]: 

67 """Gets the input dimension of the dataset (n_features) 

68 

69 Returns: 

70 The input dimension of the dataset's feature space. 

71 """ 

72 if isinstance(self.data, (torch.Tensor, sp.sparse.spmatrix)): 

73 return self.data.shape[1] 

74 

75 elif isinstance(self.data, list): 

76 if len(self.data) == 0: 

77 raise ValueError( 

78 "Dataset is ImgData, and the list of ImgData is empty, cannot determine input dimension." 

79 ) 

80 if isinstance(self.data[0], ImgData): 

81 return self.data[0].img.shape[0] 

82 else: 

83 raise ValueError( 

84 f"List data is not of type ImgData, got {type(self.data[0])}, cannot determine input dimension." 

85 ) 

86 else: 

87 raise ValueError("Unsupported data type for input dimension retrieval.") 

88 

89 def _to_df(self, modality: Optional[str] = None) -> pd.DataFrame: 

90 """ 

91 Convert the dataset to a pandas DataFrame. 

92 

93 Returns: 

94 DataFrame representation of the dataset 

95 """ 

96 if isinstance(self.data, torch.Tensor): 

97 return pd.DataFrame( 

98 self.data.numpy(), columns=self.feature_ids, index=self.sample_ids 

99 ) 

100 else: 

101 raise TypeError( 

102 "Data is not a torch.Tensor and cannot be converted to DataFrame." 

103 )