Coverage for src / autoencodix / data / _numeric_dataset.py: 72%

58 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-21 10:09 +0200

1from __future__ import annotations 

2import torch 

3 

4import scipy as sp 

5from scipy.sparse import issparse 

6import numpy as np 

7from typing import Optional, List, Union, Any, Dict, Tuple, no_type_check 

8import pandas as pd 

9from autoencodix.configs.default_config import DefaultConfig 

10from autoencodix.base._base_dataset import BaseDataset, DataSetTypes 

11 

12 

13class TensorAwareDataset(BaseDataset): 

14 """ 

15 Handles dtype mapping and tensor conversion logic. 

16 """ 

17 

18 @staticmethod 

19 def _to_tensor( 

20 data: Union[torch.Tensor, np.ndarray, Any], dtype: torch.dtype 

21 ) -> torch.Tensor: 

22 """ 

23 Convert data to tensor with specified dtype. 

24 

25 Args: 

26 data: Input data to convert 

27 dtype: Desired data type 

28 

29 Returns: 

30 Tensor with the specified dtype 

31 """ 

32 if isinstance(data, torch.Tensor): 

33 return data.clone().detach().to(dtype) 

34 else: 

35 return torch.tensor(data, dtype=dtype) 

36 

37 @staticmethod 

38 def _map_float_precision_to_dtype(float_precision: str) -> torch.dtype: 

39 """ 

40 Map fabric precision types to torch tensor dtypes. 

41 

42 Args: 

43 float_precision: Precision type (e.g., 'bf16-mixed', '16-mixed') 

44 

45 Returns: 

46 Corresponding torch dtype 

47 """ 

48 precision_mapping = { 

49 "transformer-engine": torch.float32, # Default for transformer-engine 

50 "transformer-engine-float16": torch.float16, 

51 "16-true": torch.float16, 

52 "16-mixed": torch.float16, 

53 "bf16-true": torch.bfloat16, 

54 "bf16-mixed": torch.bfloat16, 

55 "32-true": torch.float32, 

56 "64-true": torch.float64, 

57 "64": torch.float64, 

58 "32": torch.float32, 

59 "16": torch.float16, 

60 "bf16": torch.bfloat16, 

61 } 

62 # Default to torch.float32 if the precision is not recognized 

63 return precision_mapping.get(float_precision, torch.float32) 

64 

65 def _to_df(self) -> pd.DataFrame: 

66 """ 

67 Convert the dataset to a pandas DataFrame. 

68 

69 Returns: 

70 DataFrame representation of the dataset 

71 """ 

72 if isinstance(self.data, torch.Tensor): 

73 return pd.DataFrame( 

74 self.data.numpy(), columns=self.feature_ids, index=self.sample_ids 

75 ) 

76 elif issparse(self.data): 

77 return pd.DataFrame( 

78 self.data.toarray(), columns=self.feature_ids, index=self.sample_ids 

79 ) 

80 elif isinstance(self.data, list) and all( 

81 isinstance(item, torch.Tensor) for item in self.data 

82 ): 

83 # Handle image modality 

84 # Get the list of tensors 

85 tensor_list = self.data 

86 

87 # Flatten each tensor and collect as rows 

88 rows = [ 

89 ( 

90 t.flatten().cpu().numpy() 

91 if isinstance(t, torch.Tensor) 

92 else t.flatten() 

93 ) 

94 for t in tensor_list 

95 ] 

96 

97 df_flat = pd.DataFrame( 

98 rows, 

99 index=self.sample_ids, 

100 columns=["Pixel_" + str(i) for i in range(len(rows[0]))], 

101 ) 

102 return df_flat 

103 else: 

104 raise TypeError( 

105 "Data is not a torch.Tensor and cannot be converted to DataFrame." 

106 ) 

107 

108 def _get_target_dtype(self) -> torch.dtype: 

109 """Get the target dtype based on config, with MPS compatibility check.""" 

110 target_dtype = self._map_float_precision_to_dtype(self.config.float_precision) 

111 

112 # MPS doesn't support float64, so fallback to float32 

113 if target_dtype == torch.float64 and self.config.device == "mps": 

114 print("Warning: MPS doesn't support float64, using float32 instead") 

115 target_dtype = torch.float32 

116 

117 return target_dtype 

118 

119 

120class NumericDataset(TensorAwareDataset): 

121 """A custom PyTorch dataset that handles tensors. 

122 

123 

124 Attributes: 

125 data: The input features as a torch.Tensor. 

126 config: Configuration object containing settings for data processing. 

127 sample_ids: Optional list of sample identifiers. 

128 feature_ids: Optional list of feature identifiers. 

129 metadata: Optional pandas DataFrame containing metadata. 

130 split_indices: Optional numpy array for data splitting. 

131 mytype: Enum indicating the dataset type (set to DataSetTypes.NUM). 

132 """ 

133 

134 def __init__( 

135 self, 

136 data: Union[torch.Tensor, np.ndarray, sp.sparse.spmatrix], 

137 config: DefaultConfig, 

138 sample_ids: Union[None, List[Any]] = None, 

139 feature_ids: Union[None, List[Any]] = None, 

140 metadata: Optional[Union[pd.Series, pd.DataFrame]] = None, 

141 split_indices: Optional[Union[Dict[str, Any], List[Any], np.ndarray]] = None, 

142 ): 

143 """ 

144 Initialize the dataset 

145 

146 Args: 

147 data: Input features 

148 config: Configuration object 

149 sample_ids: Optional sample identifiers 

150 feature_ids: Optional feature identifiers 

151 metadata: Optional metadata 

152 split_indices: Optional split indices 

153 Optional split indices 

154 """ 

155 super().__init__( 

156 data=data, sample_ids=sample_ids, config=config, feature_ids=feature_ids 

157 ) 

158 

159 if self.config is None: 

160 raise ValueError("config cannot be None") 

161 

162 # Convert data to appropriate dtype once during initialization 

163 self.target_dtype = self._get_target_dtype() 

164 # keep data sparce if it is a scipy sparse matrix to be memory 

165 # efficient for large single cell data, convert at batch level to dense tensor 

166 if isinstance(self.data, (np.ndarray, torch.Tensor)): 

167 self.data = self._to_tensor(data, self.target_dtype) 

168 

169 self.metadata = metadata 

170 self.split_indices = split_indices 

171 self.mytype = DataSetTypes.NUM 

172 

173 @no_type_check 

174 def __getitem__(self, index: int) -> Union[ 

175 Tuple[ 

176 Union[torch.Tensor, int], 

177 Union[torch.Tensor, "ImgData"], # ty: ignore # noqa: F821 

178 Any, 

179 ], 

180 Dict[str, Tuple[Any, torch.Tensor, Any]], 

181 ]: 

182 """Retrieves a single sample and its corresponding label. 

183 

184 Args: 

185 index: Index of the sample to retrieve. 

186 

187 Returns: 

188 A tuple containing the index, the data sample and its label, or a dictionary 

189 mapping keys to such tuples in case we have multiple uncombined data at this step. 

190 """ 

191 

192 row = self.data[index] # idx: int, slice, or list 

193 if self.sample_ids is not None: 

194 label = self.sample_ids[index] 

195 else: 

196 label = index 

197 if issparse(row): 

198 # print("calling to array") 

199 

200 # print(f"Size of data sparse: {asizeof.asizeof(row)}") 

201 row = torch.tensor(row.toarray(), dtype=self.target_dtype).squeeze(0) 

202 

203 # print(f"Size of data dense: {asizeof.asizeof(row)}") 

204 

205 return index, row, label 

206 

207 def __len__(self) -> int: 

208 """Returns the number of samples (rows) in the dataset""" 

209 return self.data.shape[0]