Coverage for src / autoencodix / base / _base_dataset.py: 67%
39 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
1import abc
2import scipy as sp
3from enum import Enum
4from typing import Any, Dict, List, Optional, Tuple, Union
6import pandas as pd
7import torch
8from torch.utils.data import Dataset
10from autoencodix.data._imgdataclass import ImgData
13class DataSetTypes(str, Enum):
14 NUM = "NUM"
15 IMG = "IMG"
18class BaseDataset(abc.ABC, Dataset):
19 """Interface to guide implementation for custom PyTorch datasets.
21 Attributes:
22 data: The dataset content (can be a torch.Tensor or other data structure).
23 config: Optional configuration object.
24 sample_ids: Optional list of identifiers for each sample.
25 feature_ids: Optional list of identifiers for each feature.
26 mytype: Enum indicating the dataset type (should be set in subclasses).
27 """
29 def __init__(
30 self,
31 data: Union[torch.Tensor, List[ImgData], sp.sparse.spmatrix],
32 config: Optional[Any] = None,
33 sample_ids: Optional[List[Any]] = None,
34 feature_ids: Optional[List[Any]] = None,
35 ):
36 """Initializes the dataset.
38 Args:
39 data: The data to be used by the dataset.
40 config: Optional configuration parameters.
41 sample_ids: Optional identifiers for each sample.
42 feature_ids: Optional identifiers for each feature.
43 mytype: Enum indicating the dataset type (should be set in subclasses).
44 """
45 self.data = data
46 self.raw_data = data # for child class ImageDataset
47 self.config = config
48 self.sample_ids = sample_ids
49 self.feature_ids = feature_ids
50 self.mytype: Enum # Should be set in subclasses to indicate the dataset type (e.g., DataSetTypes.NUM or DataSetTypes.IMG)
52 self.metadata: Optional[Union[pd.Series, pd.DataFrame]] = (None,)
53 self.datasets: Dict[str, BaseDataset] = {} # for xmodalix child
55 def __len__(self) -> int:
56 """Returns the number of samples in the dataset.
58 Returns:
59 The number of samples in the dataset.
60 """
61 if isinstance(self.data, list):
62 return len(self.data)
63 else:
64 return self.data.shape[0]
66 def get_input_dim(self) -> Union[int, Tuple[int, ...]]:
67 """Gets the input dimension of the dataset (n_features)
69 Returns:
70 The input dimension of the dataset's feature space.
71 """
72 if isinstance(self.data, (torch.Tensor, sp.sparse.spmatrix)):
73 return self.data.shape[1]
75 elif isinstance(self.data, list):
76 if len(self.data) == 0:
77 raise ValueError(
78 "Dataset is ImgData, and the list of ImgData is empty, cannot determine input dimension."
79 )
80 if isinstance(self.data[0], ImgData):
81 return self.data[0].img.shape[0]
82 else:
83 raise ValueError(
84 f"List data is not of type ImgData, got {type(self.data[0])}, cannot determine input dimension."
85 )
86 else:
87 raise ValueError("Unsupported data type for input dimension retrieval.")
89 def _to_df(self, modality: Optional[str] = None) -> pd.DataFrame:
90 """
91 Convert the dataset to a pandas DataFrame.
93 Returns:
94 DataFrame representation of the dataset
95 """
96 if isinstance(self.data, torch.Tensor):
97 return pd.DataFrame(
98 self.data.numpy(), columns=self.feature_ids, index=self.sample_ids
99 )
100 else:
101 raise TypeError(
102 "Data is not a torch.Tensor and cannot be converted to DataFrame."
103 )