Coverage for src / autoencodix / data / _xmodal_preprocessor.py: 24%
85 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
1from typing import Any, Dict, Optional, Tuple, Union
3import mudata as md
5import anndata as ad
6import pandas as pd
7import torch
9from autoencodix.base._base_dataset import BaseDataset
10from autoencodix.data._datasetcontainer import DatasetContainer
11from autoencodix.data._image_dataset import ImageDataset
12from autoencodix.data._imgdataclass import ImgData
13from autoencodix.data._numeric_dataset import NumericDataset
14from autoencodix.data.datapackage import DataPackage
15from autoencodix.data.general_preprocessor import GeneralPreprocessor
16from autoencodix.configs.default_config import DefaultConfig
17from autoencodix.data._multimodal_dataset import MultiModalDataset
20class XModalPreprocessor(GeneralPreprocessor):
21 """Preprocessor for cross-modal data, handling multiple data types and their transformations.
24 Attributes:
25 data_config: Configuration specific to data handling.
26 dataset_dicts: Dictionary holding datasets for different splits (train, test, valid).
27 """
29 def __init__(
30 self, config: DefaultConfig, ontologies: Optional[Union[Tuple, Dict]] = None
31 ):
32 """Initializes the XModalPreprocessor
33 Args:
34 config: Configuration object for the preprocessor.
35 ontologies: Optional ontologies for data processing.
36 """
37 super().__init__(config=config, ontologies=ontologies)
38 self.data_config = config.data_config
40 def preprocess(
41 self,
42 raw_user_data: Optional[DataPackage] = None,
43 predict_new_data: bool = False,
44 ) -> DatasetContainer:
45 """Preprocess the data according to the configuration.
46 Args:
47 raw_user_data: Optional raw data provided by the user.
48 predict_new_data: Flag indicating if new data is being predicted.
49 """
50 self.dataset_dicts = self._general_preprocess(
51 raw_user_data=raw_user_data, predict_new_data=predict_new_data
52 )
53 datasets = {}
54 for split in ["train", "test", "valid"]:
55 cur_split = self.dataset_dicts.get(split)
56 if cur_split is None:
57 print(f"split is None: {split}")
58 continue
59 cur_data = cur_split.get("data")
60 if not isinstance(cur_data, DataPackage):
61 raise TypeError(
62 f"expected type of cur_data to be DataPackage, got {type(cur_data)}"
63 )
64 cur_indices = cur_split.get("indices")
65 datasets[split] = MultiModalDataset(
66 datasets=self._process_dp(dp=cur_data, indices=cur_indices),
67 config=self.config,
68 )
70 for k, v in self.dataset_dicts.items():
71 print(f"key: {k}, type: {type(v)}")
73 return DatasetContainer(
74 train=datasets["train"], test=datasets["test"], valid=datasets["valid"]
75 )
77 def format_reconstruction(self, reconstruction, result=None):
78 pass
80 def _process_dp(self, dp: DataPackage, indices: Dict[str, Any]):
81 """Processes a DataPackage into a dictionary of BaseDataset objects.
83 Args:
84 dp: The DataPackage to process.
85 indices: The indices for splitting the data.
86 Returns:
87 A dictionary mapping modality names to BaseDataset objects.
88 """
90 dataset_dict: Dict[str, BaseDataset] = {}
91 for k, v in dp:
92 dp_key, sub_key = k.split(".")
93 data = v
94 metadata = None
95 if dp.annotation is not None: # prevents error in SingleCell case
96 metadata = dp.annotation.get(sub_key)
97 if metadata is None:
98 metadata = dp.annotation.get("paired")
99 # case where we have the unpaired case, but we have one metadata that included all samples across all numeric data
100 if metadata is None:
101 if not len(dp.annotation.keys()) == 1:
102 raise ValueError(
103 f"annotation key needs to be either 'paired' match a key of the numeric data or only one key exists that holds all unpaired data, please adjust config, got: {dp.annotation.keys()}"
104 )
105 metadata_key = next(iter(dp.annotation.keys()))
106 metadata = dp.annotation.get(metadata_key)
108 if dp_key == "multi_bulk":
109 if not isinstance(data, pd.DataFrame):
110 raise ValueError(
111 f"Expected data for multi_bulk: {k}, {v} to be pd.DataFrame, got {type(data)}"
112 )
113 if metadata is None:
114 raise ValueError("metadata cannot be None")
115 metadata_num = metadata.loc[
116 data.index
117 ] # needed when we have only one annotation df containing metadata for all modalities
118 dataset_dict[k] = NumericDataset(
119 data=data.values,
120 config=self.config,
121 sample_ids=data.index,
122 feature_ids=data.columns,
123 split_indices=indices,
124 metadata=metadata_num,
125 )
126 elif dp_key == "img":
127 if not isinstance(data, list):
128 raise ValueError()
129 if not isinstance(data[0], ImgData):
130 raise ValueError()
131 dataset_dict[k] = ImageDataset(
132 data=data,
133 config=self.config,
134 split_indices=indices,
135 metadata=metadata,
136 )
137 elif dp_key == "multi_sc":
138 # unpaired multi_sc case with adata dicts
139 if isinstance(data, dict):
140 for adata_name, adata_v in data.items():
141 self._validate_layers(data_name=adata_name)
142 if not isinstance(adata_v, ad.AnnData):
143 raise TypeError(
144 f"Input data has unsupported data type: {type(data)}"
145 )
146 dataset_dict[k] = NumericDataset(
147 data=adata_v.X,
148 config=self.config,
149 sample_ids=adata_v.obs_names,
150 feature_ids=adata_v.var_names,
151 split_indices=indices,
152 metadata=adata_v.obs,
153 )
155 elif isinstance(data, md.MuData):
156 for mod_key, mod_data in data.mod.items():
157 self._validate_layers(data_name=mod_key)
158 dataset_dict[k] = NumericDataset(
159 data=mod_data.X,
160 config=self.config,
161 sample_ids=mod_data.obs_names,
162 feature_ids=mod_data.var_names,
163 split_indices=indices,
164 metadata=mod_data.obs,
165 )
166 else:
167 raise TypeError(
168 f"Input data has unsupported data type: {type(data)}"
169 )
171 elif dp_key == "annotation":
172 pass
174 else:
175 raise NotImplementedError(
176 f"Got datapackage attribute: {k}, probably you have added an attribute to the Datapackage class without adjusting this method. Only supports: ['multi_bulk', 'multi_sc', 'img' and 'annotation']"
177 )
178 return dataset_dict
180 def _validate_layers(self, data_name: str):
181 selected_layers = self.config.data_config.data_info[data_name].selected_layers
182 if not selected_layers[0] == "X" and len(selected_layers) != 1:
183 import warnings
185 warnings.warn(
186 "Xmodalix works only with X layer of single cell data as of now"
187 "Using X Layer, discarding selected layers"
188 )