Coverage for src / autoencodix / data / general_preprocessor.py: 17%
158 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
1from typing import Any, Dict, List, Optional, Tuple, Union
3import mudata as md # type: ignore
4import numpy as np
5import pandas as pd
6import torch
7from anndata import AnnData # type: ignore
8from scipy.sparse import issparse # type: ignore
9import scipy as sp
11from autoencodix.base._base_dataset import BaseDataset
12from autoencodix.base._base_preprocessor import BasePreprocessor
13from autoencodix.data._datasetcontainer import DatasetContainer
14from autoencodix.data._numeric_dataset import NumericDataset
15from autoencodix.data.datapackage import DataPackage
16from autoencodix.utils._result import Result
17from autoencodix.configs.default_config import DataCase, DefaultConfig
20class GeneralPreprocessor(BasePreprocessor):
21 """Preprocessor for handling multi-modal data.
23 Attributes:
24 _datapackage_dict: Dictionary holding DataPackage objects for each data split.
25 _dataset_container: Container holding processed datasets for each split.
26 _reverse_mapping_multi_bulk: Reverse mapping for multi-bulk data reconstruction.
27 _reverse_mapping_multi_sc: Reverse mapping for multi-single-cell data reconstruction.
29 """
31 def __init__(
32 self, config: DefaultConfig, ontologies: Optional[Union[Tuple, Dict]] = None
33 ) -> None:
34 super().__init__(config=config, ontologies=ontologies)
35 self._datapackage_dict: Optional[Dict[str, Any]] = None
36 self._dataset_container: Optional[DatasetContainer] = None
37 # Reverse mappings for reconstruction
38 self._reverse_mapping_multi_bulk: Dict[
39 str, Dict[str, Tuple[List[int], List[str]]]
40 ] = {"train": {}, "test": {}, "valid": {}}
41 self._reverse_mapping_multi_sc: Dict[
42 str, Dict[str, Tuple[List[int], List[str]]]
43 ] = {"train": {}, "test": {}, "valid": {}}
45 def _combine_layers(
46 self, modality_name: str, modality_data: Any
47 ) -> List[np.ndarray]:
48 layer_list: List[np.ndarray] = []
49 selected_layers = self.config.data_config.data_info[
50 modality_name
51 ].selected_layers
52 for layer_name in selected_layers:
53 if layer_name == "X":
54 layer_list.append(modality_data.X)
55 elif layer_name in modality_data.layers:
56 layer_list.append(modality_data.layers[layer_name])
57 return layer_list
59 def _combine_modality_data(
60 self,
61 mudata: md.MuData, # ty: ignore[invalid-type-form]
62 ) -> Union[np.ndarray, sp.sparse.spmatrix]: # ty: ignore[invalid-type-form]
63 # Reset single-cell reverse mapping
64 modality_data_list: List[np.ndarray] = []
65 start_idx = 0
67 for modality_name, modality_data in mudata.mod.items():
68 self._reverse_mapping_multi_sc[self._split][modality_name] = {}
69 layers = self.config.data_config.data_info[modality_name].selected_layers
70 for layer_name in layers:
71 if layer_name == "X":
72 n_feats = modality_data.shape[1]
73 else:
74 n_feats = modality_data.layers[layer_name].shape[1]
76 end_idx = start_idx + n_feats
77 feature_ids = modality_data.var_names.tolist()
78 self._reverse_mapping_multi_sc[self._split][modality_name][
79 layer_name
80 ] = (
81 list(range(start_idx, end_idx)),
82 feature_ids,
83 )
84 start_idx = end_idx
86 combined_layers = self._combine_layers(
87 modality_name=modality_name, modality_data=modality_data
88 )
89 modality_data_list.extend(combined_layers)
90 all_sparse = all(issparse(arr) for arr in modality_data_list)
91 if all_sparse:
92 combined = sp.sparse.hstack(modality_data_list, format="csr")
93 else:
94 dense_layers = [
95 arr.toarray() if issparse(arr) else arr # ty: ignore
96 for arr in modality_data_list
97 ]
98 combined = np.concatenate(dense_layers, axis=1)
100 return combined
102 def _create_numeric_dataset(
103 self,
104 data: Union[np.ndarray, sp.sparse.spmatrix],
105 config: DefaultConfig,
106 split_ids: np.ndarray,
107 metadata: pd.DataFrame,
108 ids: List[str],
109 feature_ids: List[str],
110 ) -> NumericDataset:
111 # keep sparse data sparse until batch level in training for memory efficency
112 ds = NumericDataset(
113 data=data,
114 config=config,
115 split_indices=split_ids,
116 metadata=metadata,
117 sample_ids=ids,
118 feature_ids=feature_ids,
119 )
120 return ds
122 def _process_data_package(self, data_dict: Dict[str, Any]) -> BaseDataset:
123 data, split_ids = data_dict["data"], data_dict["indices"]
124 # MULTI-BULK
125 if data.multi_bulk is not None:
126 # reset bulk mapping
127 metadata = data.annotation
128 bulk_dict: Dict[str, pd.DataFrame] = data.multi_bulk
130 # Check if all DataFrames have the same number of samples
131 sample_counts = {}
132 for subkey, df in bulk_dict.items():
133 if not isinstance(df, pd.DataFrame):
134 raise ValueError(
135 f"Expected a DataFrame for '{subkey}', got {type(df)}"
136 )
137 sample_counts[subkey] = df.shape[0]
138 # print(f"cur shape: {subkey}: {df.shape}")
140 # Validate all modalities have the same number of samples
141 unique_sample_counts = set(sample_counts.values())
142 if len(unique_sample_counts) > 1:
143 sample_count_str = ", ".join(
144 [f"{k}: {v} samples" for k, v in sample_counts.items()]
145 )
146 raise NotImplementedError(
147 f"Different sample counts across modalities are not currently supported for Varix and Vanillix"
148 "Set requires_pared=True in config."
149 f"Found: {sample_count_str}. All modalities must have the same number of samples."
150 )
152 combined_cols: List[str] = []
153 start_idx = 0
154 for subkey, df in bulk_dict.items():
155 n_feats = df.shape[1]
156 end_idx = start_idx + n_feats
157 self._reverse_mapping_multi_bulk[self._split][subkey] = (
158 list(range(start_idx, end_idx)),
159 df.columns.tolist(),
160 )
161 combined_cols.extend(df.columns.tolist())
162 start_idx = end_idx
164 combined_df = pd.concat(bulk_dict.values(), axis=1)
165 return self._create_numeric_dataset(
166 data=combined_df.values,
167 config=self.config,
168 split_ids=split_ids,
169 metadata=metadata,
170 ids=combined_df.index.tolist(),
171 feature_ids=combined_cols,
172 )
173 # MULTI-SINGLE-CELL
174 elif data.multi_sc is not None:
175 # reset single-cell mapping
176 mudata: md.MuData = data.multi_sc.get( # ty: ignore[invalid-type-form]
177 "multi_sc", None
178 ) # ty: ignore[invalid-type-form]
179 first_mod = next(iter(mudata.mod.values()))
180 # for single cell we know, we have a shared metadata
181 # so we can use the first modality as reference
182 # otherwise when using .obs from mudata, we get
183 # unintutive column names with modality name prefix
185 if mudata is None:
186 raise NotImplementedError(
187 "Unpaired multi Single Cell case not implemented vor Varix and Vanillix, set requires_paired=True in config"
188 )
189 combined_data = self._combine_modality_data(mudata)
191 # collect feature IDs in concatenation order
192 feature_ids: List[str] = []
193 for layers in self._reverse_mapping_multi_sc[self._split].values():
194 for _, fids in layers.values():
195 feature_ids.extend(fids)
196 return self._create_numeric_dataset(
197 data=combined_data,
198 config=self.config,
199 split_ids=split_ids,
200 metadata=first_mod.obs,
201 ids=mudata.obs_names.tolist(),
202 feature_ids=feature_ids,
203 )
204 else:
205 raise NotImplementedError(
206 "GeneralPreprocessor only handles multi_bulk or multi_sc."
207 )
209 def preprocess(
210 self,
211 raw_user_data: Optional[DataPackage] = None,
212 predict_new_data: bool = False,
213 ) -> DatasetContainer:
214 # run common preprocessing
216 # self._reverse_mapping_multi_bulk.clear()
217 # self._reverse_mapping_multi_sc.clear()
219 self._datapackage_dict = self._general_preprocess(
220 raw_user_data=raw_user_data, predict_new_data=predict_new_data
221 )
222 if self._datapackage_dict is None:
223 raise TypeError("Datapackage cannot be None")
225 # prepare container
226 ds_container: DatasetContainer = DatasetContainer()
228 for split in ["train", "test", "valid"]:
229 split_data = self._datapackage_dict.get(split)
230 self._split = split
231 if not split_data or split_data["data"] is None:
232 ds_container[split] = None # type: ignore
233 continue
234 ds = self._process_data_package(split_data)
235 ds_container[split] = ds
236 self._dataset_container = ds_container
237 return ds_container
239 def format_reconstruction(
240 self, reconstruction: torch.Tensor, result: Optional[Result] = None
241 ) -> DataPackage:
242 self._split = self._match_split(n_samples=reconstruction.shape[0])
243 if self.config.data_case == DataCase.MULTI_BULK:
244 return self._reverse_multi_bulk(reconstruction)
245 elif self.config.data_case == DataCase.MULTI_SINGLE_CELL:
246 return self._reverse_multi_sc(reconstruction)
247 else:
248 raise NotImplementedError(
249 f"Reconstruction not implemented for {self.config.data_case}"
250 )
252 def _match_split(self, n_samples: int) -> str:
253 """Match the split based on the number of samples."""
254 print(f"n_samples in format recon: {n_samples}")
255 for split, split_data in self._datapackage_dict.items():
256 print(split)
257 data = split_data.get("data")
258 if data is None:
259 continue
260 ref_n = data.get_n_samples()["paired_count"]
261 print(f"n_samples from datatpackge: {ref_n}")
262 if n_samples == data.get_n_samples()["paired_count"]["paired_count"]:
263 return split
264 raise ValueError(
265 f"Cannot find matching split for {n_samples} samples in the dataset."
266 )
268 def _reverse_multi_bulk(
269 self, reconstruction: Union[np.ndarray, torch.Tensor]
270 ) -> DataPackage:
271 data_package = DataPackage(
272 multi_bulk={},
273 multi_sc=None,
274 annotation=None,
275 img=None,
276 from_modality=None,
277 to_modality=None,
278 )
279 # reconstruct each bulk subkey
280 dfs: Dict[str, pd.DataFrame] = {}
281 for subkey, (indices, fids) in self._reverse_mapping_multi_bulk[
282 self._split
283 ].items():
284 arr = self._slice_tensor(
285 reconstruction=reconstruction,
286 indices=indices,
287 )
288 dfs[subkey] = pd.DataFrame(
289 arr,
290 columns=fids,
291 index=self._dataset_container[self._split].sample_ids,
292 )
293 data_package.annotation = self._dataset_container[self._split].metadata
295 data_package.multi_bulk = dfs
296 return data_package
298 def _slice_tensor(
299 self, reconstruction: Union[np.ndarray, torch.Tensor], indices: List[int]
300 ) -> np.ndarray:
301 if isinstance(reconstruction, torch.Tensor):
302 arr = reconstruction[:, indices].detach().cpu().numpy()
303 elif isinstance(reconstruction, np.ndarray):
304 arr = reconstruction[:, indices]
305 else:
306 raise TypeError(
307 f"Expected reconstruction to be a torch.Tensor or np.ndarray, got {type(reconstruction)}"
308 )
309 return arr
311 def _reverse_multi_sc(self, reconstruction: torch.Tensor) -> DataPackage:
312 data_package = DataPackage(
313 multi_bulk=None,
314 multi_sc=None,
315 annotation=None,
316 img=None,
317 from_modality=None,
318 to_modality=None,
319 )
320 modalities: Dict[str, AnnData] = {}
322 for modality_name, layers in self._reverse_mapping_multi_sc[
323 self._split
324 ].items():
325 # rebuild each layer as DataFrame
326 layers_dict: Dict[str, pd.DataFrame] = {}
327 for layer_name, (indices, fids) in layers.items():
328 arr = self._slice_tensor(reconstruction=reconstruction, indices=indices)
329 layers_dict[layer_name] = pd.DataFrame(
330 arr,
331 columns=fids,
332 index=self._dataset_container[self._split].sample_ids,
333 )
335 # extract X layer for AnnData var
336 feature_ids = layers.get("X", (None, []))[1]
337 var = pd.DataFrame(index=feature_ids)
338 X_df = layers_dict.pop("X", None)
339 adata = AnnData(
340 X=X_df.values if X_df is not None else None,
341 obs=self._dataset_container[self._split].metadata,
342 var=var,
343 layers={k: v.values for k, v in layers_dict.items()},
344 )
345 modalities[modality_name] = adata
347 data_package.multi_sc = {"multi_sc": md.MuData(modalities)}
348 data_package.annotation = self._dataset_container[self._split].metadata
349 return data_package