Coverage for src / autoencodix / data / general_preprocessor.py: 17%

158 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-21 10:09 +0200

1from typing import Any, Dict, List, Optional, Tuple, Union 

2 

3import mudata as md # type: ignore 

4import numpy as np 

5import pandas as pd 

6import torch 

7from anndata import AnnData # type: ignore 

8from scipy.sparse import issparse # type: ignore 

9import scipy as sp 

10 

11from autoencodix.base._base_dataset import BaseDataset 

12from autoencodix.base._base_preprocessor import BasePreprocessor 

13from autoencodix.data._datasetcontainer import DatasetContainer 

14from autoencodix.data._numeric_dataset import NumericDataset 

15from autoencodix.data.datapackage import DataPackage 

16from autoencodix.utils._result import Result 

17from autoencodix.configs.default_config import DataCase, DefaultConfig 

18 

19 

20class GeneralPreprocessor(BasePreprocessor): 

21 """Preprocessor for handling multi-modal data. 

22 

23 Attributes: 

24 _datapackage_dict: Dictionary holding DataPackage objects for each data split. 

25 _dataset_container: Container holding processed datasets for each split. 

26 _reverse_mapping_multi_bulk: Reverse mapping for multi-bulk data reconstruction. 

27 _reverse_mapping_multi_sc: Reverse mapping for multi-single-cell data reconstruction. 

28 

29 """ 

30 

31 def __init__( 

32 self, config: DefaultConfig, ontologies: Optional[Union[Tuple, Dict]] = None 

33 ) -> None: 

34 super().__init__(config=config, ontologies=ontologies) 

35 self._datapackage_dict: Optional[Dict[str, Any]] = None 

36 self._dataset_container: Optional[DatasetContainer] = None 

37 # Reverse mappings for reconstruction 

38 self._reverse_mapping_multi_bulk: Dict[ 

39 str, Dict[str, Tuple[List[int], List[str]]] 

40 ] = {"train": {}, "test": {}, "valid": {}} 

41 self._reverse_mapping_multi_sc: Dict[ 

42 str, Dict[str, Tuple[List[int], List[str]]] 

43 ] = {"train": {}, "test": {}, "valid": {}} 

44 

45 def _combine_layers( 

46 self, modality_name: str, modality_data: Any 

47 ) -> List[np.ndarray]: 

48 layer_list: List[np.ndarray] = [] 

49 selected_layers = self.config.data_config.data_info[ 

50 modality_name 

51 ].selected_layers 

52 for layer_name in selected_layers: 

53 if layer_name == "X": 

54 layer_list.append(modality_data.X) 

55 elif layer_name in modality_data.layers: 

56 layer_list.append(modality_data.layers[layer_name]) 

57 return layer_list 

58 

59 def _combine_modality_data( 

60 self, 

61 mudata: md.MuData, # ty: ignore[invalid-type-form] 

62 ) -> Union[np.ndarray, sp.sparse.spmatrix]: # ty: ignore[invalid-type-form] 

63 # Reset single-cell reverse mapping 

64 modality_data_list: List[np.ndarray] = [] 

65 start_idx = 0 

66 

67 for modality_name, modality_data in mudata.mod.items(): 

68 self._reverse_mapping_multi_sc[self._split][modality_name] = {} 

69 layers = self.config.data_config.data_info[modality_name].selected_layers 

70 for layer_name in layers: 

71 if layer_name == "X": 

72 n_feats = modality_data.shape[1] 

73 else: 

74 n_feats = modality_data.layers[layer_name].shape[1] 

75 

76 end_idx = start_idx + n_feats 

77 feature_ids = modality_data.var_names.tolist() 

78 self._reverse_mapping_multi_sc[self._split][modality_name][ 

79 layer_name 

80 ] = ( 

81 list(range(start_idx, end_idx)), 

82 feature_ids, 

83 ) 

84 start_idx = end_idx 

85 

86 combined_layers = self._combine_layers( 

87 modality_name=modality_name, modality_data=modality_data 

88 ) 

89 modality_data_list.extend(combined_layers) 

90 all_sparse = all(issparse(arr) for arr in modality_data_list) 

91 if all_sparse: 

92 combined = sp.sparse.hstack(modality_data_list, format="csr") 

93 else: 

94 dense_layers = [ 

95 arr.toarray() if issparse(arr) else arr # ty: ignore 

96 for arr in modality_data_list 

97 ] 

98 combined = np.concatenate(dense_layers, axis=1) 

99 

100 return combined 

101 

102 def _create_numeric_dataset( 

103 self, 

104 data: Union[np.ndarray, sp.sparse.spmatrix], 

105 config: DefaultConfig, 

106 split_ids: np.ndarray, 

107 metadata: pd.DataFrame, 

108 ids: List[str], 

109 feature_ids: List[str], 

110 ) -> NumericDataset: 

111 # keep sparse data sparse until batch level in training for memory efficency 

112 ds = NumericDataset( 

113 data=data, 

114 config=config, 

115 split_indices=split_ids, 

116 metadata=metadata, 

117 sample_ids=ids, 

118 feature_ids=feature_ids, 

119 ) 

120 return ds 

121 

122 def _process_data_package(self, data_dict: Dict[str, Any]) -> BaseDataset: 

123 data, split_ids = data_dict["data"], data_dict["indices"] 

124 # MULTI-BULK 

125 if data.multi_bulk is not None: 

126 # reset bulk mapping 

127 metadata = data.annotation 

128 bulk_dict: Dict[str, pd.DataFrame] = data.multi_bulk 

129 

130 # Check if all DataFrames have the same number of samples 

131 sample_counts = {} 

132 for subkey, df in bulk_dict.items(): 

133 if not isinstance(df, pd.DataFrame): 

134 raise ValueError( 

135 f"Expected a DataFrame for '{subkey}', got {type(df)}" 

136 ) 

137 sample_counts[subkey] = df.shape[0] 

138 # print(f"cur shape: {subkey}: {df.shape}") 

139 

140 # Validate all modalities have the same number of samples 

141 unique_sample_counts = set(sample_counts.values()) 

142 if len(unique_sample_counts) > 1: 

143 sample_count_str = ", ".join( 

144 [f"{k}: {v} samples" for k, v in sample_counts.items()] 

145 ) 

146 raise NotImplementedError( 

147 f"Different sample counts across modalities are not currently supported for Varix and Vanillix" 

148 "Set requires_pared=True in config." 

149 f"Found: {sample_count_str}. All modalities must have the same number of samples." 

150 ) 

151 

152 combined_cols: List[str] = [] 

153 start_idx = 0 

154 for subkey, df in bulk_dict.items(): 

155 n_feats = df.shape[1] 

156 end_idx = start_idx + n_feats 

157 self._reverse_mapping_multi_bulk[self._split][subkey] = ( 

158 list(range(start_idx, end_idx)), 

159 df.columns.tolist(), 

160 ) 

161 combined_cols.extend(df.columns.tolist()) 

162 start_idx = end_idx 

163 

164 combined_df = pd.concat(bulk_dict.values(), axis=1) 

165 return self._create_numeric_dataset( 

166 data=combined_df.values, 

167 config=self.config, 

168 split_ids=split_ids, 

169 metadata=metadata, 

170 ids=combined_df.index.tolist(), 

171 feature_ids=combined_cols, 

172 ) 

173 # MULTI-SINGLE-CELL 

174 elif data.multi_sc is not None: 

175 # reset single-cell mapping 

176 mudata: md.MuData = data.multi_sc.get( # ty: ignore[invalid-type-form] 

177 "multi_sc", None 

178 ) # ty: ignore[invalid-type-form] 

179 first_mod = next(iter(mudata.mod.values())) 

180 # for single cell we know, we have a shared metadata 

181 # so we can use the first modality as reference 

182 # otherwise when using .obs from mudata, we get 

183 # unintutive column names with modality name prefix 

184 

185 if mudata is None: 

186 raise NotImplementedError( 

187 "Unpaired multi Single Cell case not implemented vor Varix and Vanillix, set requires_paired=True in config" 

188 ) 

189 combined_data = self._combine_modality_data(mudata) 

190 

191 # collect feature IDs in concatenation order 

192 feature_ids: List[str] = [] 

193 for layers in self._reverse_mapping_multi_sc[self._split].values(): 

194 for _, fids in layers.values(): 

195 feature_ids.extend(fids) 

196 return self._create_numeric_dataset( 

197 data=combined_data, 

198 config=self.config, 

199 split_ids=split_ids, 

200 metadata=first_mod.obs, 

201 ids=mudata.obs_names.tolist(), 

202 feature_ids=feature_ids, 

203 ) 

204 else: 

205 raise NotImplementedError( 

206 "GeneralPreprocessor only handles multi_bulk or multi_sc." 

207 ) 

208 

209 def preprocess( 

210 self, 

211 raw_user_data: Optional[DataPackage] = None, 

212 predict_new_data: bool = False, 

213 ) -> DatasetContainer: 

214 # run common preprocessing 

215 

216 # self._reverse_mapping_multi_bulk.clear() 

217 # self._reverse_mapping_multi_sc.clear() 

218 

219 self._datapackage_dict = self._general_preprocess( 

220 raw_user_data=raw_user_data, predict_new_data=predict_new_data 

221 ) 

222 if self._datapackage_dict is None: 

223 raise TypeError("Datapackage cannot be None") 

224 

225 # prepare container 

226 ds_container: DatasetContainer = DatasetContainer() 

227 

228 for split in ["train", "test", "valid"]: 

229 split_data = self._datapackage_dict.get(split) 

230 self._split = split 

231 if not split_data or split_data["data"] is None: 

232 ds_container[split] = None # type: ignore 

233 continue 

234 ds = self._process_data_package(split_data) 

235 ds_container[split] = ds 

236 self._dataset_container = ds_container 

237 return ds_container 

238 

239 def format_reconstruction( 

240 self, reconstruction: torch.Tensor, result: Optional[Result] = None 

241 ) -> DataPackage: 

242 self._split = self._match_split(n_samples=reconstruction.shape[0]) 

243 if self.config.data_case == DataCase.MULTI_BULK: 

244 return self._reverse_multi_bulk(reconstruction) 

245 elif self.config.data_case == DataCase.MULTI_SINGLE_CELL: 

246 return self._reverse_multi_sc(reconstruction) 

247 else: 

248 raise NotImplementedError( 

249 f"Reconstruction not implemented for {self.config.data_case}" 

250 ) 

251 

252 def _match_split(self, n_samples: int) -> str: 

253 """Match the split based on the number of samples.""" 

254 print(f"n_samples in format recon: {n_samples}") 

255 for split, split_data in self._datapackage_dict.items(): 

256 print(split) 

257 data = split_data.get("data") 

258 if data is None: 

259 continue 

260 ref_n = data.get_n_samples()["paired_count"] 

261 print(f"n_samples from datatpackge: {ref_n}") 

262 if n_samples == data.get_n_samples()["paired_count"]["paired_count"]: 

263 return split 

264 raise ValueError( 

265 f"Cannot find matching split for {n_samples} samples in the dataset." 

266 ) 

267 

268 def _reverse_multi_bulk( 

269 self, reconstruction: Union[np.ndarray, torch.Tensor] 

270 ) -> DataPackage: 

271 data_package = DataPackage( 

272 multi_bulk={}, 

273 multi_sc=None, 

274 annotation=None, 

275 img=None, 

276 from_modality=None, 

277 to_modality=None, 

278 ) 

279 # reconstruct each bulk subkey 

280 dfs: Dict[str, pd.DataFrame] = {} 

281 for subkey, (indices, fids) in self._reverse_mapping_multi_bulk[ 

282 self._split 

283 ].items(): 

284 arr = self._slice_tensor( 

285 reconstruction=reconstruction, 

286 indices=indices, 

287 ) 

288 dfs[subkey] = pd.DataFrame( 

289 arr, 

290 columns=fids, 

291 index=self._dataset_container[self._split].sample_ids, 

292 ) 

293 data_package.annotation = self._dataset_container[self._split].metadata 

294 

295 data_package.multi_bulk = dfs 

296 return data_package 

297 

298 def _slice_tensor( 

299 self, reconstruction: Union[np.ndarray, torch.Tensor], indices: List[int] 

300 ) -> np.ndarray: 

301 if isinstance(reconstruction, torch.Tensor): 

302 arr = reconstruction[:, indices].detach().cpu().numpy() 

303 elif isinstance(reconstruction, np.ndarray): 

304 arr = reconstruction[:, indices] 

305 else: 

306 raise TypeError( 

307 f"Expected reconstruction to be a torch.Tensor or np.ndarray, got {type(reconstruction)}" 

308 ) 

309 return arr 

310 

311 def _reverse_multi_sc(self, reconstruction: torch.Tensor) -> DataPackage: 

312 data_package = DataPackage( 

313 multi_bulk=None, 

314 multi_sc=None, 

315 annotation=None, 

316 img=None, 

317 from_modality=None, 

318 to_modality=None, 

319 ) 

320 modalities: Dict[str, AnnData] = {} 

321 

322 for modality_name, layers in self._reverse_mapping_multi_sc[ 

323 self._split 

324 ].items(): 

325 # rebuild each layer as DataFrame 

326 layers_dict: Dict[str, pd.DataFrame] = {} 

327 for layer_name, (indices, fids) in layers.items(): 

328 arr = self._slice_tensor(reconstruction=reconstruction, indices=indices) 

329 layers_dict[layer_name] = pd.DataFrame( 

330 arr, 

331 columns=fids, 

332 index=self._dataset_container[self._split].sample_ids, 

333 ) 

334 

335 # extract X layer for AnnData var 

336 feature_ids = layers.get("X", (None, []))[1] 

337 var = pd.DataFrame(index=feature_ids) 

338 X_df = layers_dict.pop("X", None) 

339 adata = AnnData( 

340 X=X_df.values if X_df is not None else None, 

341 obs=self._dataset_container[self._split].metadata, 

342 var=var, 

343 layers={k: v.values for k, v in layers_dict.items()}, 

344 ) 

345 modalities[modality_name] = adata 

346 

347 data_package.multi_sc = {"multi_sc": md.MuData(modalities)} 

348 data_package.annotation = self._dataset_container[self._split].metadata 

349 return data_package