Coverage for src / autoencodix / data / _stackix_preprocessor.py: 15%
155 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
1from typing import Any, Dict, List, Optional, Tuple, Union, no_type_check
3import numpy as np
4import pandas as pd
5import torch
7from scipy.sparse import issparse # type: ignore
8from autoencodix.base._base_preprocessor import BasePreprocessor
9import anndata as ad # type: ignore
10from autoencodix.data._numeric_dataset import NumericDataset
11from autoencodix.data._multimodal_dataset import MultiModalDataset
12from autoencodix.data.datapackage import DataPackage
13from autoencodix.data._datasetcontainer import DatasetContainer
14from autoencodix.configs.default_config import DefaultConfig, DataCase
15from autoencodix.utils._result import Result
18class StackixPreprocessor(BasePreprocessor):
19 """Preprocessor for Stackix architecture, which handles multiple modalities separately.
21 Unlike GeneralPreprocessor which combines all modalities, StackixPreprocessor
22 keeps modalities separate for individual VAE training in the Stackix architecture.
24 Attributes:
25 config: Configuration parameters for preprocessing and model architecture
26 _datapackage: Dictionary storing processed data splits
27 _dataset_container:Container for processed datasets by split
28 """
30 def __init__(
31 self, config: DefaultConfig, ontologies: Optional[Union[Tuple, Dict]] = None
32 ) -> None:
33 """Initialize the StackixPreprocessor with the given configuration.
34 Args:
35 config: Configuration parameters for preprocessing
36 """
37 super().__init__(config=config)
38 self._datapackage: Optional[Dict[str, Any]] = None
39 self._dataset_container: Optional[DatasetContainer] = None
41 def preprocess(
42 self, raw_user_data: Optional[DataPackage] = None, predict_new_data=False
43 ) -> DatasetContainer:
44 """Execute preprocessing steps for Stackix architecture.
46 Args
47 raw_user_data: Raw user data to preprocess, or None to use self._datapackage
49 Returns:
50 Container with MultiModalDataset for each split
52 Raises:
53 TypeError: If datapackage is None after preprocessing
54 """
55 self._datapackage = self._general_preprocess(
56 raw_user_data, predict_new_data=predict_new_data
57 )
58 self._dataset_container = DatasetContainer()
60 for split in ["train", "valid", "test"]:
61 if (
62 split not in self._datapackage
63 or self._datapackage[split].get("data") is None
64 ):
65 self._dataset_container[split] = None
66 continue
67 dataset_dict = self._build_dataset_dict(
68 datapackage=self._datapackage[split]["data"],
69 split_indices=self._datapackage[split]["indices"],
70 )
71 stackix_ds = MultiModalDataset(
72 datasets=dataset_dict,
73 config=self.config,
74 )
75 self._dataset_container[split] = stackix_ds
76 return self._dataset_container
78 def _extract_primary_data(self, modality_data: Any) -> np.ndarray:
79 primary_data = modality_data.X
80 if issparse(primary_data):
81 primary_data = primary_data.toarray()
82 return primary_data
84 @no_type_check
85 def _combine_layers(
86 self, modality_name: str, modality_data: Any
87 ) -> Tuple[np.ndarray, Dict[str, tuple[int]]]:
88 """Combine layers from a modality and return the combined data and indices.
90 Args:
91 modality_name: Name of the modality
92 modality_data: Data for the modality
94 Returns:
95 Combined data and list of (layer_name, start_idx, end_idx) tuples
96 """
97 layer_list: List[np.ndarray] = []
98 layer_indices: Dict[str, Tuple[int]] = {}
100 selected_layers: List[str] = self.config.data_config.data_info[
101 modality_name
102 ].selected_layers
104 start_idx = 0
105 for layer_name in selected_layers:
106 if layer_name == "X":
107 data = self._extract_primary_data(modality_data)
108 layer_list.append(data)
109 end_idx = start_idx + data.shape[1]
110 layer_indices[layer_name] = [start_idx, end_idx] # type: ignore
111 start_idx += data.shape[1]
112 continue
113 elif layer_name in modality_data.layers:
114 layer_data = modality_data.layers[layer_name]
115 if issparse(layer_data):
116 layer_data = layer_data.toarray()
117 layer_list.append(layer_data)
118 end_idx = start_idx + layer_data.shape[1]
119 layer_indices[layer_name] = [start_idx, end_idx] # type: ignore
120 start_idx += layer_data.shape[1]
122 combined_data: np.ndarray = (
123 np.concatenate(layer_list, axis=1) if layer_list else np.array([])
124 )
125 return combined_data, layer_indices
127 def _build_dataset_dict(
128 self, datapackage: DataPackage, split_indices: np.ndarray
129 ) -> Dict[str, NumericDataset]:
130 """For each seperate entry in our datapackge we build a NumericDataset
131 and store it in a dictionary with the modality as key.
133 Args:
134 datapackage:DataPackage containing the data to be processed
135 split_indices: List of indices for splitting the data
136 Returns:
137 Dictionary mapping modality names to NumericDataset objects
139 """
140 dataset_dict: Dict[str, NumericDataset] = {}
141 layer_id_dict: Dict[str, Dict[str, List]] = {}
142 for k, _ in datapackage:
143 attr_name, dict_key = k.split(
144 "."
145 ) # see DataPackage __iter__ method for why this makes sense
146 metadata = None
147 if datapackage.annotation is not None: # prevents error in Single Cell case
148 # case where each numeric data has it's own annotation/metadata
149 metadata = datapackage.annotation.get(dict_key)
150 if metadata is None:
151 # case where there is one "paired" metadata for all numeric data
152 metadata = datapackage.annotation.get("paired")
153 # case where we have the unpaired case, but we have one metadata that included all samples across all numeric data
154 if metadata is None:
155 if not len(datapackage.annotation.keys()) == 1:
156 raise ValueError(
157 f"annotation key needs to be either 'paired' match a key of the numeric data or only one key exists that holds all unpaired data, please adjust config, got: {datapackage.annotation.keys()}"
158 )
159 metadata_key = next(iter(datapackage.annotation.keys()))
160 metadata = datapackage.annotation.get(metadata_key)
162 if attr_name == "multi_bulk":
163 df = datapackage[attr_name][dict_key]
164 ds = NumericDataset(
165 data=df.values,
166 config=self.config,
167 sample_ids=df.index,
168 feature_ids=df.columns,
169 metadata=metadata,
170 split_indices=split_indices,
171 )
172 dataset_dict[dict_key] = ds
173 elif attr_name == "multi_sc":
174 mudata = datapackage["multi_sc"]["multi_sc"]
175 if isinstance(mudata, ad.AnnData):
176 raise TypeError(
177 "Expected a MuData object, but got an AnnData object."
178 )
180 layer_list: List[Any] = []
181 print("building dataset_dict")
182 for mod_name, mod_data in mudata.mod.items():
183 layers, indices = self._combine_layers(
184 modality_name=mod_name, modality_data=mod_data
185 )
186 layer_id_dict[mod_name] = indices
187 layer_list.append(layers)
188 mod_concat = np.concatenate(layer_list, axis=1)
189 ds = NumericDataset(
190 data=mod_concat,
191 config=self.config,
192 sample_ids=mudata.obs_names,
193 feature_ids=mod_data.var_names * len(layer_list),
194 metadata=mod_data.obs,
195 split_indices=split_indices,
196 )
197 dataset_dict[mod_name] = ds
198 else:
199 continue
200 self._layer_indices = layer_id_dict
201 return dataset_dict
203 def format_reconstruction(
204 self, reconstruction: Any, result: Optional[Result] = None
205 ) -> DataPackage:
206 """Takes the reconstructed tensor and from which modality it comes and uses the dataset_dict
207 to obtain the format of the original datapackage, but instead of the .data attribute
208 we populate this attribute with the reconstructed tensor (as pd.DataFrame or MuData object)
210 Args:
211 reconstruction: The reconstructed tensor
212 result: Optional[Result] containing additional information
213 Returns:
214 DataPackage with reconstructed data in original format
216 """
218 if result is None:
219 raise ValueError(
220 "Result object is not provided. This is needed for the StackixPreprocessor."
221 )
222 reconstruction = result.sub_reconstructions
223 if not isinstance(reconstruction, dict):
224 raise TypeError(
225 f"Expected value to be of type dict for Stackix, got {type(reconstruction)}."
226 )
228 if self.config.data_case == DataCase.MULTI_BULK:
229 return self._format_multi_bulk(reconstructions=reconstruction)
231 elif self.config.data_case == DataCase.MULTI_SINGLE_CELL:
232 return self._format_multi_sc(reconstructions=reconstruction)
233 else:
234 raise ValueError(
235 f"Unsupported data_case {self.config.data_case} for StackixPreprocessor."
236 )
238 def _format_multi_bulk(
239 self, reconstructions: Dict[str, torch.Tensor]
240 ) -> DataPackage:
241 multi_bulk_dict = {}
242 annotation_dict = {}
243 dp = DataPackage()
244 for name, reconstruction in reconstructions.items():
245 if not isinstance(reconstruction, torch.Tensor):
246 raise TypeError(
247 f"Expected value to be of type torch.Tensor, got {type(reconstruction)}."
248 )
249 if self._dataset_container is None:
250 raise ValueError("Dataset container is not initialized.")
251 stackix_ds = self._dataset_container["test"]
252 if stackix_ds is None:
253 raise ValueError("No dataset found for split: test")
254 dataset_dict = stackix_ds.datasets
255 df = pd.DataFrame(
256 reconstruction.numpy(),
257 index=dataset_dict[name].sample_ids,
258 columns=dataset_dict[name].feature_ids,
259 )
260 multi_bulk_dict[name] = df
261 annotation_dict[name] = dataset_dict[name].metadata
263 dp["multi_bulk"] = multi_bulk_dict
264 dp["annotation"] = annotation_dict
265 return dp
267 def _format_multi_sc(self, reconstructions: Dict[str, torch.Tensor]) -> DataPackage:
268 """Formats reconstructed tensors back into a MuData object for single-cell data.
270 This uses the stored layer indices to accurately split the reconstructed tensor
271 back into the original layers.
273 Args:
274 reconstruction: Dictionary of reconstructed tensors for each modality
276 Returns:
277 DataPackage containing the reconstructed MuData object
278 """
279 import mudata as md
281 dp = DataPackage()
282 modalities = {}
284 if self._dataset_container is None:
285 raise ValueError("Dataset container is not initialized.")
286 if not hasattr(self, "_layer_indices"):
287 raise ValueError(
288 "Layer indices not found. Make sure _build_dataset_dict was called."
289 )
291 stackix_ds = self._dataset_container["test"]
292 if stackix_ds is None:
293 raise ValueError("No dataset found for split: test")
295 dataset_dict = stackix_ds.datasets
297 # Process each modality in the reconstruction
298 for mod_name, recon_tensor in reconstructions.items():
299 if not isinstance(recon_tensor, torch.Tensor):
300 raise TypeError(
301 f"Expected value to be of type torch.Tensor, got {type(recon_tensor)}."
302 )
303 if mod_name not in dataset_dict:
304 raise ValueError(f"Modality {mod_name} not found in dataset dictionary")
305 original_dataset = dataset_dict[mod_name]
307 layer_indices = self._layer_indices[mod_name]
309 start_idx, end_idx = layer_indices["X"]
310 x_data = recon_tensor.numpy()[:, start_idx:end_idx]
312 var_names = original_dataset.feature_ids
314 mod_data = ad.AnnData(
315 X=x_data,
316 obs=original_dataset.metadata,
317 var=pd.DataFrame(index=var_names[0 : x_data.shape[1]]),
318 )
320 # Add additional layers based on stored indices
321 for layer_name, ids in layer_indices.items():
322 if layer_name == "X":
323 continue # X is already set
325 layer_data = recon_tensor.numpy()[:, ids[0] : ids[1]]
326 mod_data.layers[layer_name] = layer_data
328 modalities[mod_name] = mod_data
330 # Create MuData object from all modalities
331 mdata = md.MuData(modalities)
333 # Create and return DataPackage
334 dp["multi_sc"] = {"multi_sc": mdata}
335 return dp