Coverage for src / autoencodix / data / _stackix_preprocessor.py: 15%

155 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-21 10:09 +0200

1from typing import Any, Dict, List, Optional, Tuple, Union, no_type_check 

2 

3import numpy as np 

4import pandas as pd 

5import torch 

6 

7from scipy.sparse import issparse # type: ignore 

8from autoencodix.base._base_preprocessor import BasePreprocessor 

9import anndata as ad # type: ignore 

10from autoencodix.data._numeric_dataset import NumericDataset 

11from autoencodix.data._multimodal_dataset import MultiModalDataset 

12from autoencodix.data.datapackage import DataPackage 

13from autoencodix.data._datasetcontainer import DatasetContainer 

14from autoencodix.configs.default_config import DefaultConfig, DataCase 

15from autoencodix.utils._result import Result 

16 

17 

18class StackixPreprocessor(BasePreprocessor): 

19 """Preprocessor for Stackix architecture, which handles multiple modalities separately. 

20 

21 Unlike GeneralPreprocessor which combines all modalities, StackixPreprocessor 

22 keeps modalities separate for individual VAE training in the Stackix architecture. 

23 

24 Attributes: 

25 config: Configuration parameters for preprocessing and model architecture 

26 _datapackage: Dictionary storing processed data splits 

27 _dataset_container:Container for processed datasets by split 

28 """ 

29 

30 def __init__( 

31 self, config: DefaultConfig, ontologies: Optional[Union[Tuple, Dict]] = None 

32 ) -> None: 

33 """Initialize the StackixPreprocessor with the given configuration. 

34 Args: 

35 config: Configuration parameters for preprocessing 

36 """ 

37 super().__init__(config=config) 

38 self._datapackage: Optional[Dict[str, Any]] = None 

39 self._dataset_container: Optional[DatasetContainer] = None 

40 

41 def preprocess( 

42 self, raw_user_data: Optional[DataPackage] = None, predict_new_data=False 

43 ) -> DatasetContainer: 

44 """Execute preprocessing steps for Stackix architecture. 

45 

46 Args 

47 raw_user_data: Raw user data to preprocess, or None to use self._datapackage 

48 

49 Returns: 

50 Container with MultiModalDataset for each split 

51 

52 Raises: 

53 TypeError: If datapackage is None after preprocessing 

54 """ 

55 self._datapackage = self._general_preprocess( 

56 raw_user_data, predict_new_data=predict_new_data 

57 ) 

58 self._dataset_container = DatasetContainer() 

59 

60 for split in ["train", "valid", "test"]: 

61 if ( 

62 split not in self._datapackage 

63 or self._datapackage[split].get("data") is None 

64 ): 

65 self._dataset_container[split] = None 

66 continue 

67 dataset_dict = self._build_dataset_dict( 

68 datapackage=self._datapackage[split]["data"], 

69 split_indices=self._datapackage[split]["indices"], 

70 ) 

71 stackix_ds = MultiModalDataset( 

72 datasets=dataset_dict, 

73 config=self.config, 

74 ) 

75 self._dataset_container[split] = stackix_ds 

76 return self._dataset_container 

77 

78 def _extract_primary_data(self, modality_data: Any) -> np.ndarray: 

79 primary_data = modality_data.X 

80 if issparse(primary_data): 

81 primary_data = primary_data.toarray() 

82 return primary_data 

83 

84 @no_type_check 

85 def _combine_layers( 

86 self, modality_name: str, modality_data: Any 

87 ) -> Tuple[np.ndarray, Dict[str, tuple[int]]]: 

88 """Combine layers from a modality and return the combined data and indices. 

89 

90 Args: 

91 modality_name: Name of the modality 

92 modality_data: Data for the modality 

93 

94 Returns: 

95 Combined data and list of (layer_name, start_idx, end_idx) tuples 

96 """ 

97 layer_list: List[np.ndarray] = [] 

98 layer_indices: Dict[str, Tuple[int]] = {} 

99 

100 selected_layers: List[str] = self.config.data_config.data_info[ 

101 modality_name 

102 ].selected_layers 

103 

104 start_idx = 0 

105 for layer_name in selected_layers: 

106 if layer_name == "X": 

107 data = self._extract_primary_data(modality_data) 

108 layer_list.append(data) 

109 end_idx = start_idx + data.shape[1] 

110 layer_indices[layer_name] = [start_idx, end_idx] # type: ignore 

111 start_idx += data.shape[1] 

112 continue 

113 elif layer_name in modality_data.layers: 

114 layer_data = modality_data.layers[layer_name] 

115 if issparse(layer_data): 

116 layer_data = layer_data.toarray() 

117 layer_list.append(layer_data) 

118 end_idx = start_idx + layer_data.shape[1] 

119 layer_indices[layer_name] = [start_idx, end_idx] # type: ignore 

120 start_idx += layer_data.shape[1] 

121 

122 combined_data: np.ndarray = ( 

123 np.concatenate(layer_list, axis=1) if layer_list else np.array([]) 

124 ) 

125 return combined_data, layer_indices 

126 

127 def _build_dataset_dict( 

128 self, datapackage: DataPackage, split_indices: np.ndarray 

129 ) -> Dict[str, NumericDataset]: 

130 """For each seperate entry in our datapackge we build a NumericDataset 

131 and store it in a dictionary with the modality as key. 

132 

133 Args: 

134 datapackage:DataPackage containing the data to be processed 

135 split_indices: List of indices for splitting the data 

136 Returns: 

137 Dictionary mapping modality names to NumericDataset objects 

138 

139 """ 

140 dataset_dict: Dict[str, NumericDataset] = {} 

141 layer_id_dict: Dict[str, Dict[str, List]] = {} 

142 for k, _ in datapackage: 

143 attr_name, dict_key = k.split( 

144 "." 

145 ) # see DataPackage __iter__ method for why this makes sense 

146 metadata = None 

147 if datapackage.annotation is not None: # prevents error in Single Cell case 

148 # case where each numeric data has it's own annotation/metadata 

149 metadata = datapackage.annotation.get(dict_key) 

150 if metadata is None: 

151 # case where there is one "paired" metadata for all numeric data 

152 metadata = datapackage.annotation.get("paired") 

153 # case where we have the unpaired case, but we have one metadata that included all samples across all numeric data 

154 if metadata is None: 

155 if not len(datapackage.annotation.keys()) == 1: 

156 raise ValueError( 

157 f"annotation key needs to be either 'paired' match a key of the numeric data or only one key exists that holds all unpaired data, please adjust config, got: {datapackage.annotation.keys()}" 

158 ) 

159 metadata_key = next(iter(datapackage.annotation.keys())) 

160 metadata = datapackage.annotation.get(metadata_key) 

161 

162 if attr_name == "multi_bulk": 

163 df = datapackage[attr_name][dict_key] 

164 ds = NumericDataset( 

165 data=df.values, 

166 config=self.config, 

167 sample_ids=df.index, 

168 feature_ids=df.columns, 

169 metadata=metadata, 

170 split_indices=split_indices, 

171 ) 

172 dataset_dict[dict_key] = ds 

173 elif attr_name == "multi_sc": 

174 mudata = datapackage["multi_sc"]["multi_sc"] 

175 if isinstance(mudata, ad.AnnData): 

176 raise TypeError( 

177 "Expected a MuData object, but got an AnnData object." 

178 ) 

179 

180 layer_list: List[Any] = [] 

181 print("building dataset_dict") 

182 for mod_name, mod_data in mudata.mod.items(): 

183 layers, indices = self._combine_layers( 

184 modality_name=mod_name, modality_data=mod_data 

185 ) 

186 layer_id_dict[mod_name] = indices 

187 layer_list.append(layers) 

188 mod_concat = np.concatenate(layer_list, axis=1) 

189 ds = NumericDataset( 

190 data=mod_concat, 

191 config=self.config, 

192 sample_ids=mudata.obs_names, 

193 feature_ids=mod_data.var_names * len(layer_list), 

194 metadata=mod_data.obs, 

195 split_indices=split_indices, 

196 ) 

197 dataset_dict[mod_name] = ds 

198 else: 

199 continue 

200 self._layer_indices = layer_id_dict 

201 return dataset_dict 

202 

203 def format_reconstruction( 

204 self, reconstruction: Any, result: Optional[Result] = None 

205 ) -> DataPackage: 

206 """Takes the reconstructed tensor and from which modality it comes and uses the dataset_dict 

207 to obtain the format of the original datapackage, but instead of the .data attribute 

208 we populate this attribute with the reconstructed tensor (as pd.DataFrame or MuData object) 

209 

210 Args: 

211 reconstruction: The reconstructed tensor 

212 result: Optional[Result] containing additional information 

213 Returns: 

214 DataPackage with reconstructed data in original format 

215 

216 """ 

217 

218 if result is None: 

219 raise ValueError( 

220 "Result object is not provided. This is needed for the StackixPreprocessor." 

221 ) 

222 reconstruction = result.sub_reconstructions 

223 if not isinstance(reconstruction, dict): 

224 raise TypeError( 

225 f"Expected value to be of type dict for Stackix, got {type(reconstruction)}." 

226 ) 

227 

228 if self.config.data_case == DataCase.MULTI_BULK: 

229 return self._format_multi_bulk(reconstructions=reconstruction) 

230 

231 elif self.config.data_case == DataCase.MULTI_SINGLE_CELL: 

232 return self._format_multi_sc(reconstructions=reconstruction) 

233 else: 

234 raise ValueError( 

235 f"Unsupported data_case {self.config.data_case} for StackixPreprocessor." 

236 ) 

237 

238 def _format_multi_bulk( 

239 self, reconstructions: Dict[str, torch.Tensor] 

240 ) -> DataPackage: 

241 multi_bulk_dict = {} 

242 annotation_dict = {} 

243 dp = DataPackage() 

244 for name, reconstruction in reconstructions.items(): 

245 if not isinstance(reconstruction, torch.Tensor): 

246 raise TypeError( 

247 f"Expected value to be of type torch.Tensor, got {type(reconstruction)}." 

248 ) 

249 if self._dataset_container is None: 

250 raise ValueError("Dataset container is not initialized.") 

251 stackix_ds = self._dataset_container["test"] 

252 if stackix_ds is None: 

253 raise ValueError("No dataset found for split: test") 

254 dataset_dict = stackix_ds.datasets 

255 df = pd.DataFrame( 

256 reconstruction.numpy(), 

257 index=dataset_dict[name].sample_ids, 

258 columns=dataset_dict[name].feature_ids, 

259 ) 

260 multi_bulk_dict[name] = df 

261 annotation_dict[name] = dataset_dict[name].metadata 

262 

263 dp["multi_bulk"] = multi_bulk_dict 

264 dp["annotation"] = annotation_dict 

265 return dp 

266 

267 def _format_multi_sc(self, reconstructions: Dict[str, torch.Tensor]) -> DataPackage: 

268 """Formats reconstructed tensors back into a MuData object for single-cell data. 

269 

270 This uses the stored layer indices to accurately split the reconstructed tensor 

271 back into the original layers. 

272 

273 Args: 

274 reconstruction: Dictionary of reconstructed tensors for each modality 

275 

276 Returns: 

277 DataPackage containing the reconstructed MuData object 

278 """ 

279 import mudata as md 

280 

281 dp = DataPackage() 

282 modalities = {} 

283 

284 if self._dataset_container is None: 

285 raise ValueError("Dataset container is not initialized.") 

286 if not hasattr(self, "_layer_indices"): 

287 raise ValueError( 

288 "Layer indices not found. Make sure _build_dataset_dict was called." 

289 ) 

290 

291 stackix_ds = self._dataset_container["test"] 

292 if stackix_ds is None: 

293 raise ValueError("No dataset found for split: test") 

294 

295 dataset_dict = stackix_ds.datasets 

296 

297 # Process each modality in the reconstruction 

298 for mod_name, recon_tensor in reconstructions.items(): 

299 if not isinstance(recon_tensor, torch.Tensor): 

300 raise TypeError( 

301 f"Expected value to be of type torch.Tensor, got {type(recon_tensor)}." 

302 ) 

303 if mod_name not in dataset_dict: 

304 raise ValueError(f"Modality {mod_name} not found in dataset dictionary") 

305 original_dataset = dataset_dict[mod_name] 

306 

307 layer_indices = self._layer_indices[mod_name] 

308 

309 start_idx, end_idx = layer_indices["X"] 

310 x_data = recon_tensor.numpy()[:, start_idx:end_idx] 

311 

312 var_names = original_dataset.feature_ids 

313 

314 mod_data = ad.AnnData( 

315 X=x_data, 

316 obs=original_dataset.metadata, 

317 var=pd.DataFrame(index=var_names[0 : x_data.shape[1]]), 

318 ) 

319 

320 # Add additional layers based on stored indices 

321 for layer_name, ids in layer_indices.items(): 

322 if layer_name == "X": 

323 continue # X is already set 

324 

325 layer_data = recon_tensor.numpy()[:, ids[0] : ids[1]] 

326 mod_data.layers[layer_name] = layer_data 

327 

328 modalities[mod_name] = mod_data 

329 

330 # Create MuData object from all modalities 

331 mdata = md.MuData(modalities) 

332 

333 # Create and return DataPackage 

334 dp["multi_sc"] = {"multi_sc": mdata} 

335 return dp