Coverage for src / autoencodix / data / _numeric_dataset.py: 72%
58 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
1from __future__ import annotations
2import torch
4import scipy as sp
5from scipy.sparse import issparse
6import numpy as np
7from typing import Optional, List, Union, Any, Dict, Tuple, no_type_check
8import pandas as pd
9from autoencodix.configs.default_config import DefaultConfig
10from autoencodix.base._base_dataset import BaseDataset, DataSetTypes
13class TensorAwareDataset(BaseDataset):
14 """
15 Handles dtype mapping and tensor conversion logic.
16 """
18 @staticmethod
19 def _to_tensor(
20 data: Union[torch.Tensor, np.ndarray, Any], dtype: torch.dtype
21 ) -> torch.Tensor:
22 """
23 Convert data to tensor with specified dtype.
25 Args:
26 data: Input data to convert
27 dtype: Desired data type
29 Returns:
30 Tensor with the specified dtype
31 """
32 if isinstance(data, torch.Tensor):
33 return data.clone().detach().to(dtype)
34 else:
35 return torch.tensor(data, dtype=dtype)
37 @staticmethod
38 def _map_float_precision_to_dtype(float_precision: str) -> torch.dtype:
39 """
40 Map fabric precision types to torch tensor dtypes.
42 Args:
43 float_precision: Precision type (e.g., 'bf16-mixed', '16-mixed')
45 Returns:
46 Corresponding torch dtype
47 """
48 precision_mapping = {
49 "transformer-engine": torch.float32, # Default for transformer-engine
50 "transformer-engine-float16": torch.float16,
51 "16-true": torch.float16,
52 "16-mixed": torch.float16,
53 "bf16-true": torch.bfloat16,
54 "bf16-mixed": torch.bfloat16,
55 "32-true": torch.float32,
56 "64-true": torch.float64,
57 "64": torch.float64,
58 "32": torch.float32,
59 "16": torch.float16,
60 "bf16": torch.bfloat16,
61 }
62 # Default to torch.float32 if the precision is not recognized
63 return precision_mapping.get(float_precision, torch.float32)
65 def _to_df(self) -> pd.DataFrame:
66 """
67 Convert the dataset to a pandas DataFrame.
69 Returns:
70 DataFrame representation of the dataset
71 """
72 if isinstance(self.data, torch.Tensor):
73 return pd.DataFrame(
74 self.data.numpy(), columns=self.feature_ids, index=self.sample_ids
75 )
76 elif issparse(self.data):
77 return pd.DataFrame(
78 self.data.toarray(), columns=self.feature_ids, index=self.sample_ids
79 )
80 elif isinstance(self.data, list) and all(
81 isinstance(item, torch.Tensor) for item in self.data
82 ):
83 # Handle image modality
84 # Get the list of tensors
85 tensor_list = self.data
87 # Flatten each tensor and collect as rows
88 rows = [
89 (
90 t.flatten().cpu().numpy()
91 if isinstance(t, torch.Tensor)
92 else t.flatten()
93 )
94 for t in tensor_list
95 ]
97 df_flat = pd.DataFrame(
98 rows,
99 index=self.sample_ids,
100 columns=["Pixel_" + str(i) for i in range(len(rows[0]))],
101 )
102 return df_flat
103 else:
104 raise TypeError(
105 "Data is not a torch.Tensor and cannot be converted to DataFrame."
106 )
108 def _get_target_dtype(self) -> torch.dtype:
109 """Get the target dtype based on config, with MPS compatibility check."""
110 target_dtype = self._map_float_precision_to_dtype(self.config.float_precision)
112 # MPS doesn't support float64, so fallback to float32
113 if target_dtype == torch.float64 and self.config.device == "mps":
114 print("Warning: MPS doesn't support float64, using float32 instead")
115 target_dtype = torch.float32
117 return target_dtype
120class NumericDataset(TensorAwareDataset):
121 """A custom PyTorch dataset that handles tensors.
124 Attributes:
125 data: The input features as a torch.Tensor.
126 config: Configuration object containing settings for data processing.
127 sample_ids: Optional list of sample identifiers.
128 feature_ids: Optional list of feature identifiers.
129 metadata: Optional pandas DataFrame containing metadata.
130 split_indices: Optional numpy array for data splitting.
131 mytype: Enum indicating the dataset type (set to DataSetTypes.NUM).
132 """
134 def __init__(
135 self,
136 data: Union[torch.Tensor, np.ndarray, sp.sparse.spmatrix],
137 config: DefaultConfig,
138 sample_ids: Union[None, List[Any]] = None,
139 feature_ids: Union[None, List[Any]] = None,
140 metadata: Optional[Union[pd.Series, pd.DataFrame]] = None,
141 split_indices: Optional[Union[Dict[str, Any], List[Any], np.ndarray]] = None,
142 ):
143 """
144 Initialize the dataset
146 Args:
147 data: Input features
148 config: Configuration object
149 sample_ids: Optional sample identifiers
150 feature_ids: Optional feature identifiers
151 metadata: Optional metadata
152 split_indices: Optional split indices
153 Optional split indices
154 """
155 super().__init__(
156 data=data, sample_ids=sample_ids, config=config, feature_ids=feature_ids
157 )
159 if self.config is None:
160 raise ValueError("config cannot be None")
162 # Convert data to appropriate dtype once during initialization
163 self.target_dtype = self._get_target_dtype()
164 # keep data sparce if it is a scipy sparse matrix to be memory
165 # efficient for large single cell data, convert at batch level to dense tensor
166 if isinstance(self.data, (np.ndarray, torch.Tensor)):
167 self.data = self._to_tensor(data, self.target_dtype)
169 self.metadata = metadata
170 self.split_indices = split_indices
171 self.mytype = DataSetTypes.NUM
173 @no_type_check
174 def __getitem__(self, index: int) -> Union[
175 Tuple[
176 Union[torch.Tensor, int],
177 Union[torch.Tensor, "ImgData"], # ty: ignore # noqa: F821
178 Any,
179 ],
180 Dict[str, Tuple[Any, torch.Tensor, Any]],
181 ]:
182 """Retrieves a single sample and its corresponding label.
184 Args:
185 index: Index of the sample to retrieve.
187 Returns:
188 A tuple containing the index, the data sample and its label, or a dictionary
189 mapping keys to such tuples in case we have multiple uncombined data at this step.
190 """
192 row = self.data[index] # idx: int, slice, or list
193 if self.sample_ids is not None:
194 label = self.sample_ids[index]
195 else:
196 label = index
197 if issparse(row):
198 # print("calling to array")
200 # print(f"Size of data sparse: {asizeof.asizeof(row)}")
201 row = torch.tensor(row.toarray(), dtype=self.target_dtype).squeeze(0)
203 # print(f"Size of data dense: {asizeof.asizeof(row)}")
205 return index, row, label
207 def __len__(self) -> int:
208 """Returns the number of samples (rows) in the dataset"""
209 return self.data.shape[0]