Coverage for src / autoencodix / base / _base_preprocessor.py: 12%
380 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
1import abc
2from abc import abstractmethod
3from enum import Enum
4from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
6import numpy as np
7import pandas as pd
8import torch
10from autoencodix.data._datapackage_splitter import DataPackageSplitter
11from autoencodix.data._datasetcontainer import DatasetContainer
12from autoencodix.data._datasplitter import PairedUnpairedSplitter
13from autoencodix.data._filter import DataFilter
14from autoencodix.data._imgdataclass import ImgData
15from autoencodix.data._nanremover import NaNRemover
16from autoencodix.data._sc_filter import SingleCellFilter
17from autoencodix.data.datapackage import DataPackage
18from autoencodix.utils._bulkreader import BulkDataReader
19from autoencodix.utils._imgreader import ImageDataReader, ImageNormalizer
20from autoencodix.utils._screader import SingleCellDataReader
21from autoencodix.configs.default_config import DataCase, DefaultConfig
22from autoencodix.utils._result import Result
24if TYPE_CHECKING:
25 import mudata as md # type: ignore
27 MuData = md.MuData.MuData
28else:
29 MuData = Any
32class BasePreprocessor(abc.ABC):
33 """Contains logic for data preprocessing in the Autoencodix framework.
35 This class defines the general preprocessing workflow and provides
36 methods for handling different data modalities and data cases.
37 Subclasses should implement the `preprocess` method to perform
38 specific preprocessing steps.
40 Attributes:
41 config: A DefaultConfig object containing preprocessing configurations.
42 processed_data: A dictionary to store processed DataPackage objects for each data split.
43 bulk_genes_to_keep: Optional list of genes to keep for bulk data.
44 bulk_scalers: Optional dictionary of scalers for bulk data.
45 sc_genes_to_keep: Optional dictionary mapping modality keys to lists of genes to keep for single-cell data.
46 sc_scalers: Optional dictionary mapping modality keys to scalers for single-cell data.
47 sc_general_genes_to_keep: Optional dictionary mapping modality keys to lists of genes to keep filtered by non-SC specific methods.
48 data_readers: A dictionary mapping DataCase enum values to data reader instances for different modalities.
49 _dataset_container: Optional DatasetContainer to hold the processed datasets.
50 """
52 def __init__(
53 self,
54 config: DefaultConfig,
55 ontologies: Optional[Union[Tuple[Any, Any], Dict[Any, Any]]] = None,
56 ):
57 """Initializes the BasePreprocessor with a configuration object.
59 Args :
60 config: A DefaultConfig object containing preprocessing configurations.
61 ontologies: Ontology information, if provided for Ontix.
62 """
63 self.config = config
64 self._dataset_container: Optional[DatasetContainer] = None
65 self.processed_data = Dict[str, Dict[str, Union[Any, DataPackage]]]
66 self.bulk_genes_to_keep: Optional[Dict[str, List[str]]] = None
67 self.bulk_scalers: Optional[Dict[str, Any]] = None
68 self.sc_genes_to_keep: Optional[Dict[str, List[str]]] = None
69 self.sc_scalers: Optional[Dict[str, Dict[str, Any]]] = None
70 self.sc_general_genes_to_keep: Optional[Dict[str, List]] = None
71 self._ontologies: Optional[Union[Tuple[Any, Any], Dict[Any, Any]]] = ontologies
72 self.data_readers: Dict[Enum, Any] = {
73 DataCase.MULTI_SINGLE_CELL: SingleCellDataReader(),
74 DataCase.MULTI_BULK: BulkDataReader(config=self.config),
75 DataCase.BULK_TO_BULK: BulkDataReader(config=self.config),
76 DataCase.SINGLE_CELL_TO_SINGLE_CELL: SingleCellDataReader(),
77 DataCase.IMG_TO_BULK: {
78 "bulk": BulkDataReader(config=self.config),
79 "img": ImageDataReader(config=self.config),
80 },
81 DataCase.SINGLE_CELL_TO_IMG: {
82 "sc": SingleCellDataReader(),
83 "img": ImageDataReader(config=self.config),
84 },
85 DataCase.IMG_TO_IMG: ImageDataReader(config=self.config),
86 }
88 @abc.abstractmethod
89 def preprocess(
90 self,
91 raw_user_data: Optional[DataPackage] = None,
92 predict_new_data: bool = False,
93 ) -> DatasetContainer:
94 """To be implemented by subclasses for specific preprocessing steps.
95 Args:
96 raw_user_data: Users can provide raw data. This is an alternative way of
97 providing data via filepaths in the config. If this param is passed, we skip the data reading step.
98 predict_new_data: Indicates whether the user wants to predict with unseen data.
99 If this is the case, we don't split the data and only prerpocess.
100 """
101 pass
103 def _general_preprocess(
104 self,
105 raw_user_data: Optional[DataPackage] = None,
106 predict_new_data: bool = False,
107 ) -> Dict[str, Dict[str, Union[Any, DataPackage]]]:
108 """Orchestrates the preprocessing steps.
110 This method determines the data case from the configuration and calls
111 the appropriate processing function for that data case.
113 Args:
114 raw_user_data: Optional DataPackage containing user-provided data.
115 If provided, the data reading step is skipped.
116 predict_new_data: Boolean indicating whether to preprocess new unseen data
117 without splitting it into train/validation/test sets.
119 Returns:
120 A dictionary containing processed DataPackage objects for each data split
121 (e.g., 'train', 'validation', 'test').
123 Raises:
124 ValueError: If an unsupported data case is encountered.
125 """
126 self.predict_new_data = predict_new_data
127 datacase = self.config.data_case
128 if datacase is None:
129 raise TypeError(
130 "datacase can't be None. Please ensure the configuration specifies a valid DataCase."
131 )
132 if raw_user_data is None:
133 self.from_key, self.to_key = self._get_translation_keys()
134 else:
135 self.from_key, self.to_key = self._get_user_translation_keys(
136 raw_user_data=raw_user_data
137 )
138 process_function = self._get_process_function(datacase=datacase)
139 if process_function:
140 return process_function(raw_user_data=raw_user_data)
141 else:
142 raise ValueError(f"Unsupported data case: {datacase}")
144 def _get_process_function(self, datacase: DataCase) -> Any:
145 """Returns the appropriate processing function based on the data case.
147 Args:
148 datacase: The DataCase enum value representing the current data case.
150 Returns:
151 A callable function that performs the preprocessing for the given data case,
152 or None if the data case is not supported.
153 """
154 process_map = {
155 DataCase.MULTI_SINGLE_CELL: self._process_multi_single_cell,
156 DataCase.MULTI_BULK: self._process_multi_bulk_case,
157 DataCase.BULK_TO_BULK: self._process_multi_bulk_case,
158 DataCase.SINGLE_CELL_TO_SINGLE_CELL: self._process_multi_single_cell,
159 DataCase.IMG_TO_BULK: self._process_img_to_bulk_case,
160 DataCase.SINGLE_CELL_TO_IMG: self._process_sc_to_img_case,
161 DataCase.IMG_TO_IMG: self._process_img_to_img_case,
162 }
163 return process_map.get(datacase)
165 def _process_data_case(
166 self, data_package: DataPackage, modality_processors: Dict[Any, Any]
167 ) -> Union[Dict[str, Dict[str, Union[Any, DataPackage]]], Dict[str, Any]]:
168 """Processes the data package based on the provided modality processors.
170 This method handles the common preprocessing steps for different data cases,
171 including splitting the data package, removing NaNs, and applying
172 modality-specific processors.
174 Args::
175 data_package: The DataPackage object to be processed.
176 modality_processors: A dictionary mapping modality keys (e.g., 'multi_sc', 'from_modality')
177 to callable processor functions that will be applied to the corresponding modality data.
179 Returns:
180 A dictionary containing processed DataPackage objects for each data split.
181 """
182 if self.predict_new_data:
183 # we get the data modality keys from this structure in postsplit processing
184 # for predict_new data there do not exits real splits, because all is "test" data
185 # but the preprocessing code expects this splits, so we mock them
186 # use train, because processing logic expects train split
187 mock_split: Dict[str, Dict[str, Union[Any, DataPackage]]] = {
188 "test": {
189 "data": data_package,
190 "indices": {"paired": np.array([])},
191 },
192 "valid": {"data": None, "indices": {"paired": np.array([])}},
193 "train": {"data": data_package, "indices": {"paired": np.array([])}},
194 }
195 if self.config.skip_preprocessing:
196 return mock_split
198 clean_package = self._remove_nans(data_package=data_package)
199 mock_split["test"]["data"] = clean_package
200 for modality_key, (
201 presplit_processor,
202 postsplit_processor,
203 ) in modality_processors.items():
204 modality_data = clean_package[modality_key]
205 if modality_data:
206 processed_modality_data = presplit_processor(modality_data)
207 # mock the split
208 clean_package[modality_key] = processed_modality_data
209 mock_split["test"]["data"] = clean_package
210 mock_split = postsplit_processor(mock_split)
211 return mock_split
212 # normal case without new data -----------------------------------
213 if self.config.skip_preprocessing:
214 split_packages, _ = self._split_data_package(data_package=data_package)
215 return split_packages
216 clean_package = self._remove_nans(data_package=data_package)
217 for modality_key, (presplit_processor, _) in modality_processors.items():
218 modality_data = clean_package[modality_key]
219 if modality_data:
220 processed_modality_data = presplit_processor(modality_data)
221 clean_package[modality_key] = processed_modality_data
222 split_packages, indices = self._split_data_package(data_package=clean_package)
223 processed_splits = {}
224 for modality_key, (_, postsplit_processor) in modality_processors.items():
225 split_packages = postsplit_processor(split_packages)
226 for split_name, split_package in split_packages.items():
227 split_indices = {
228 name: {
229 split: idx
230 for split, idx in indices[name].items()
231 if split == split_name
232 }
233 for name in indices.keys()
234 }
235 processed_splits[split_name] = {
236 "data": split_package["data"],
237 "indices": split_indices,
238 }
239 return processed_splits
241 def _process_multi_single_cell(
242 self, raw_user_data: Optional[DataPackage] = None
243 ) -> Dict[str, Dict[str, Union[Any, DataPackage]]]:
244 """Process MULTI_SINGLE_CELL case
246 Reads multi-single-cell data, performs data splitting, NaN removal,
247 and applies single-cell specific filtering.
248 Args:
249 raw_user_data: Optional DataPackage containing user-provided data.
251 Returns:
252 A dictionary containing processed DataPackage objects for each data split.
253 Raises:
254 ValueError: If multi_sc in data_package is None.
256 """
257 if raw_user_data is None:
258 screader = self.data_readers[DataCase.MULTI_SINGLE_CELL] # type: ignore
260 mudata = screader.read_data(config=self.config)
261 data_package: DataPackage = DataPackage()
262 data_package.multi_sc = mudata
263 else:
264 data_package = raw_user_data
265 if self.config.requires_paired:
266 common_ids = data_package.get_common_ids()
267 if data_package.multi_sc is None:
268 raise ValueError("multi_sc in data_package is None")
269 data_package.multi_sc = {
270 "multi_sc": data_package.multi_sc["multi_sc"][common_ids]
271 }
273 def presplit_processor(modality_data: Any) -> Any:
274 if modality_data is None:
275 return modality_data
276 sc_filter = SingleCellFilter(
277 data_info=self.config.data_config.data_info, config=self.config
278 )
279 return sc_filter.presplit_processing(multi_sc=modality_data)
281 def postsplit_processor(
282 split_data: Dict[str, Dict[str, Any]],
283 ) -> Dict[str, Dict[str, Any]]:
284 return self._postsplit_multi_single_cell(
285 split_data=split_data, datapackage_key="multi_sc"
286 )
288 return self._process_data_case(
289 data_package,
290 modality_processors={"multi_sc": (presplit_processor, postsplit_processor)},
291 )
293 def _process_multi_bulk_case(
294 self,
295 raw_user_data: Optional[DataPackage] = None,
296 ) -> Dict[str, Dict[str, Union[Any, DataPackage]]]:
297 """
298 Process MULTI_BULK case.
300 Reads multi-bulk data, performs data splitting, NaN removal,
301 and applies filtering and scaling to bulk dataframes.
302 Args:
303 raw_user_data: Optional DataPackage containing user-provided data.
305 Returns:
306 A dictionary containing processed DataPackage objects for each data split.
307 """
308 if raw_user_data is None:
309 bulkreader = self.data_readers[DataCase.MULTI_BULK]
310 bulk_dfs, annotation = bulkreader.read_data()
312 data_package = DataPackage(multi_bulk=bulk_dfs, annotation=annotation)
313 else:
314 data_package = raw_user_data
315 if self.config.requires_paired:
316 common_ids = data_package.get_common_ids()
317 unpaired_data = data_package.multi_bulk
318 unpaired_anno = data_package.annotation
319 if unpaired_anno is None:
320 raise ValueError("annotation attribute of datapackge cannot be None")
321 if unpaired_data is None:
322 raise ValueError("multi_bulk attribute of datapackge cannot be None")
323 data_package.multi_bulk = {
324 k: v.loc[common_ids] for k, v in unpaired_data.items()
325 }
327 data_package.annotation = {
328 k: v.loc[common_ids] # ty: ignore
329 for k, v in unpaired_anno.items() # ty: ignore
330 }
332 def presplit_processor(
333 modality_data: Dict[str, Union[pd.DataFrame, None]],
334 ) -> Dict[str, Union[pd.DataFrame, None]]:
335 """For the multi_bulk modality we perform all operations after splitting at the moment."""
336 return modality_data
338 def postsplit_processor(
339 split_data: Dict[str, Dict[str, Any]],
340 ) -> Dict[str, Dict[str, Any]]:
341 return self._postsplit_multi_bulk(split_data=split_data)
343 return self._process_data_case(
344 data_package,
345 modality_processors={
346 "multi_bulk": (presplit_processor, postsplit_processor)
347 },
348 )
350 def _calc_k_filter(
351 self, i: int, remainder: int, base_features: int
352 ) -> Optional[int]:
353 if self.config.k_filter is None:
354 return None
355 extra = 1 if i < remainder else 0
356 return base_features + extra
358 def _postsplit_multi_single_cell(
359 self,
360 split_data: Dict[str, Dict[str, Any]],
361 datapackage_key: str = "multi_sc",
362 modality_key: Optional[str] = None,
363 ) -> Dict[str, Dict[str, Any]]:
364 """Post-split processing for multi-single-cell data.
365 This method applies filtering and scaling to the single-cell data after it has been split.
366 Now supports multiple MuData objects in the input dictionary.
368 Args:
369 split_data: A dictionary containing the split data for each data split.
370 datapackage_key: The key in the DataPackage that contains the multi-single-cell data.
371 modality_key: Optional specific modality key for backward compatibility.
372 If provided, only processes that specific modality.
373 If None, processes all modalities in the dictionary.
375 Returns:
376 A dictionary containing processed DataPackage objects for each data split.
378 Raises:
379 ValueError: If the train split data is None.
380 """
381 processed_splits: Dict[str, Dict[str, Any]] = {}
382 train_split: Optional[Dict[str, Any]] = split_data.get("train")
384 if train_split is None:
385 raise ValueError(
386 "Train split data is None. Ensure that the data package contains valid train data."
387 )
389 train_data: Optional[Any] = train_split.get("data")
390 if train_data is None:
391 raise ValueError(
392 "Train split data is None. Ensure that the data package contains valid train data."
393 )
395 # Get all modality keys from the train data
396 mudata_dict = train_data[datapackage_key]
398 if modality_key is not None:
399 if modality_key not in mudata_dict:
400 raise ValueError(
401 f"Specified modality_key '{modality_key}' not found in {list(mudata_dict.keys())}"
402 )
403 modality_keys = [modality_key]
404 print(
405 f"Processing single modality (backward compatibility): {modality_key}"
406 )
407 else:
408 modality_keys = list(mudata_dict.keys())
409 print(f"Processing {len(modality_keys)} MuData objects: {modality_keys}")
411 # Initialize storage for scalers and gene filters for each modality
412 # if we do this for the first time, we need a train split and we dont
413 # fitted any scalers or features to keep yet.
414 # that's why in the predict_new case we can keep the mocksplit for train None
415 # because we never get in this if
416 if (
417 self.sc_scalers is None
418 and self.sc_genes_to_keep is None
419 and self.sc_general_genes_to_keep is None
420 ) or ("modality" in datapackage_key):
421 # Process each MuData object in the train split
422 processed_mudata_dict = {}
423 all_scalers = {}
424 all_sc_genes_to_keep = {}
425 all_general_genes_to_keep = {}
427 for current_modality_key in modality_keys:
428 print(f"Processing train modality: {current_modality_key}")
430 sc_filter = SingleCellFilter(
431 data_info=self.config.data_config.data_info, config=self.config
432 )
434 # Single-cell specific filtering
435 filtered_train, sc_genes_to_keep = sc_filter.sc_postsplit_processing(
436 mudata=mudata_dict[current_modality_key]
437 )
439 # General post-processing
440 processed_train, general_genes_to_keep, scalers = (
441 sc_filter.general_postsplit_processing(
442 mudata=filtered_train, scaler_map=None, gene_map=None
443 )
444 )
446 # Store processed data and filters for this modality
447 processed_mudata_dict[current_modality_key] = processed_train
448 all_scalers[current_modality_key] = scalers
449 all_sc_genes_to_keep[current_modality_key] = sc_genes_to_keep
450 all_general_genes_to_keep[current_modality_key] = general_genes_to_keep
452 # Store all scalers and gene filters
453 self.sc_scalers = all_scalers
454 self.sc_genes_to_keep = all_sc_genes_to_keep
455 self.sc_general_genes_to_keep = all_general_genes_to_keep
457 # Update train data with processed MuData objects
458 train_data[datapackage_key] = processed_mudata_dict
460 else:
461 # Use existing scalers and gene filters
462 all_scalers = self.sc_scalers # type: ignore
463 all_sc_genes_to_keep = self.sc_genes_to_keep # type: ignore
464 all_general_genes_to_keep = self.sc_general_genes_to_keep # type: ignore
466 # Store processed train split
467 processed_splits["train"] = {
468 "data": train_data,
469 "indices": split_data["train"]["indices"],
470 }
472 # Process other splits (val, test, etc.)
473 for split, split_package in split_data.items():
474 if split == "train":
475 continue
477 data_package = split_package["data"]
478 if data_package is None:
479 processed_splits[split] = split_package
480 continue
482 print(f"Processing {split} split")
483 processed_mudata_dict = {}
485 # Process each MuData object in this split
486 for current_modality_key in modality_keys:
487 print(f"Processing {split} modality: {current_modality_key}")
489 sc_filter = SingleCellFilter(
490 data_info=self.config.data_config.data_info, config=self.config
491 )
493 # Apply single-cell filtering using train-derived gene map
494 filtered_sc_data, _ = sc_filter.sc_postsplit_processing(
495 mudata=data_package[datapackage_key][current_modality_key],
496 gene_map=all_sc_genes_to_keep[current_modality_key],
497 )
499 # Apply general processing using train-derived scalers and gene map
500 processed_general_data, _, _ = sc_filter.general_postsplit_processing(
501 mudata=filtered_sc_data,
502 gene_map=all_general_genes_to_keep[current_modality_key],
503 scaler_map=all_scalers[current_modality_key],
504 )
506 processed_mudata_dict[current_modality_key] = processed_general_data
508 # Update data package with all processed MuData objects
509 data_package[datapackage_key] = processed_mudata_dict
511 processed_splits[split] = {
512 "data": data_package,
513 "indices": split_package["indices"],
514 }
516 return processed_splits
518 def _postsplit_multi_bulk(
519 self,
520 split_data: Dict[str, Dict[str, Any]],
521 datapackage_key: str = "multi_bulk",
522 ) -> Dict[str, Dict[str, Any]]:
523 """Post-split processing for multi-bulk data.
525 This method applies filtering and scaling to the bulk dataframes after they have been split.
527 Args:
528 split_data: A dictionary containing the split data for each data split.
529 datapackage_key: The key in the DataPackage that contains the multi-bulk data.
530 Returns:
531 A dictionary containing processed DataPackage objects for each data split.
532 Raises:
533 ValueError: If the train split data is None.
534 """
536 train_split: Optional[Dict[str, Any]] = split_data.get("train")
537 if train_split is None:
538 raise ValueError(
539 "Train split data is None. Ensure that the data package contains valid train data."
540 )
541 train_data: Optional[Any] = train_split.get("data")
542 genes_to_keep_map: Dict[str, List[str]] = {}
543 scalers: Dict[str, Any] = {}
544 processed_splits: Dict[str, Dict[str, Any]] = {}
546 if (self.bulk_scalers is None and self.bulk_genes_to_keep is None) or (
547 "modality" in datapackage_key
548 ):
549 if train_data is None:
550 raise ValueError(
551 "Train split data is None. Ensure that the data package contains valid train data."
552 )
553 n_modalities: int = len(train_data[datapackage_key].keys())
554 remainder: int = 0
555 base_features = 0
556 if self.config.k_filter is not None:
557 base_features = self.config.k_filter // n_modalities
558 remainder = self.config.k_filter % n_modalities
560 # Get valid modality keys (those that are not None)
561 modality_keys = [
562 k for k, v in train_data[datapackage_key].items() if v is not None
563 ]
565 for i, k in enumerate(modality_keys):
566 v = train_data[datapackage_key][k]
567 cur_k_filter = self._calc_k_filter(
568 i=i, base_features=base_features, remainder=remainder
569 )
570 self.config.data_config.data_info[k].k_filter = cur_k_filter
572 data_processor = DataFilter(
573 data_info=self.config.data_config.data_info[k],
574 config=self.config,
575 ontologies=self._ontologies,
576 )
577 filtered_df, genes_to_keep = data_processor.filter(df=v)
578 scaler = data_processor.fit_scaler(df=filtered_df)
579 genes_to_keep_map[k] = genes_to_keep
580 scalers[k] = scaler
581 scaled_df = data_processor.scale(df=filtered_df, scaler=scaler)
582 train_data[datapackage_key][k] = scaled_df
583 # Check if indices stayed the same after filtering
584 if not filtered_df.index.equals(v.index):
585 mismatched_indices = filtered_df.index.symmetric_difference(v.index)
586 raise ValueError(
587 f"Indices mismatch after filtering for modality {k}. "
588 f"Mismatched indices: {mismatched_indices}. "
589 "Ensure filtering does not alter the indices."
590 )
592 self.bulk_scalers = scalers
593 self.bulk_genes_to_keep = genes_to_keep_map # type: ignore
594 else:
595 scalers, genes_to_keep_map = self.bulk_scalers, self.bulk_genes_to_keep # type: ignore
597 processed_splits["train"] = {
598 "data": train_data,
599 "indices": split_data["train"]["indices"],
600 }
602 for split_name, split_package in split_data.items():
603 if split_name == "train":
604 continue
605 if split_package["data"] is None:
606 processed_splits[split_name] = split_data[split_name]
607 continue
609 processed_package = split_package["data"]
610 for k, v in processed_package[datapackage_key].items():
611 if v is None:
612 continue
613 data_processor = DataFilter(
614 data_info=self.config.data_config.data_info[k],
615 config=self.config,
616 ontologies=self._ontologies,
617 )
618 filtered_df, _ = data_processor.filter(
619 df=v, genes_to_keep=genes_to_keep_map[k]
620 )
621 scaled_df = data_processor.scale(df=filtered_df, scaler=scalers[k])
622 processed_package[datapackage_key][k] = scaled_df
623 if not filtered_df.index.equals(v.index):
624 raise ValueError(
625 f"Indices mismatch after filtering for modality {k}. "
626 "Ensure filtering does not alter the indices."
627 )
629 processed_splits[split_name] = {
630 "data": processed_package,
631 "indices": split_package["indices"],
632 }
634 return processed_splits
636 def _process_img_to_bulk_case(
637 self, raw_user_data: Optional[DataPackage] = None
638 ) -> Dict[str, Dict[str, Union[Any, DataPackage]]]:
639 """Process IMG_TO_BULK case
641 Reads image and bulk data, prepares from/to modalities (IMG->BULK or BULK->IMG),
642 performs data splitting, NaN removal, and applies normalization to image data
643 and filtering/scaling to bulk dataframes.
644 Args:
645 raw_user_data: Optional DataPackage containing user-provided data.
646 If provided, the data reading step is skipped.
647 Returns:
648 A dictionary containing processed DataPackage objects for each data split.
649 Raises:
650 TypeError: If from_key or to_key is None, indicating that translation keys must be specified.
651 """
653 if raw_user_data is None:
654 bulkreader = self.data_readers[DataCase.IMG_TO_BULK]["bulk"]
655 imgreader = self.data_readers[DataCase.IMG_TO_BULK]["img"]
657 bulk_dfs, annotation_bulk = bulkreader.read_data()
658 images, annotation_img = imgreader.read_data(config=self.config)
660 annotation = {**annotation_bulk, **annotation_img}
662 data_package = DataPackage(
663 multi_bulk=bulk_dfs, img=images, annotation=annotation
664 )
666 else:
667 data_package = raw_user_data
669 if self.config.requires_paired:
670 common_ids = data_package.get_common_ids()
672 images = data_package.img
673 if images is None:
674 raise ValueError("Images cannot be None")
675 data_package.img = {
676 k: self.filter_imgdata_list(img_list=v, ids=common_ids)
677 for k, v in images.items()
678 }
679 unpaired_data = data_package.multi_bulk
680 unpaired_anno = data_package.annotation
681 if unpaired_anno is None:
682 raise ValueError("annotation attribute of datapackge cannot be None")
683 if unpaired_data is None:
684 raise ValueError("multi_bulk attribute of datapackge cannot be None")
685 data_package.multi_bulk = {
686 k: v.loc[common_ids] for k, v in unpaired_data.items()
687 }
689 data_package.annotation = {
690 k: v.loc[common_ids] # ty: ignore
691 for k, v in unpaired_anno.items() # ty: ignore
692 }
694 def presplit_processor(
695 modality_data: Dict[str, Union[pd.DataFrame, List[ImgData]]],
696 ) -> Dict[str, Union[pd.DataFrame, List[ImgData]]]:
697 for modality_key, data in modality_data.items():
698 if self._is_image_data(data=data):
699 modality_data[modality_key] = self._normalize_image_data(
700 images=data, # type: ignore
701 info_key=modality_key, # type: ignore
702 )
703 # we don't need to filter bulk data here
704 # because we do it in the postsplit step
705 return modality_data
707 def postsplit_processor(
708 split_data: Dict[str, Dict[str, Any]], datapackage_key: str
709 ) -> Dict[str, Dict[str, Any]]:
710 if datapackage_key == "multi_bulk":
711 return self._postsplit_multi_bulk(
712 split_data=split_data, datapackage_key=datapackage_key
713 )
714 return split_data # for img data we don't need to do anything
716 return self._process_data_case(
717 data_package,
718 modality_processors={
719 "multi_bulk": ( # TODO change to multi_bulk and img for all translation cases and ajdust processors accordingly
720 lambda data: presplit_processor(modality_data=data),
721 lambda data: postsplit_processor(
722 split_data=data, datapackage_key="multi_bulk"
723 ),
724 ),
725 "img": (
726 lambda data: presplit_processor(modality_data=data),
727 lambda data: postsplit_processor(
728 split_data=data, datapackage_key="img"
729 ),
730 ),
731 },
732 )
734 def _process_sc_to_img_case(
735 self, raw_user_data: Optional[DataPackage] = None
736 ) -> Dict[str, Dict[str, Union[Any, DataPackage]]]:
737 """Process SC_TO_IMG case.
739 Reads single-cell and image data, prepares from/to modalities (SC->IMG or IMG->SC),
740 performs data splitting, NaN removal, and applies single-cell specific filtering
741 to single-cell data and normalization to image data.
743 Args:
744 raw_user_data: Optional DataPackage containing user-provided data.
746 Returns:
747 A dictionary containing processed DataPackage objects for each data split.
748 """
749 if raw_user_data is None:
750 screader = self.data_readers[DataCase.SINGLE_CELL_TO_IMG]["sc"]
751 imgreader = self.data_readers[DataCase.SINGLE_CELL_TO_IMG]["img"]
753 # only one mudata type in this case we know this
754 mudata_dict = screader.read_data(config=self.config)
755 images, annotation = imgreader.read_data(config=self.config)
757 data_package = DataPackage(
758 multi_sc=mudata_dict, img=images, annotation=annotation
759 )
760 else:
761 data_package = raw_user_data
762 if self.config.requires_paired:
763 common_ids = data_package.get_common_ids()
764 if data_package.multi_sc is None:
765 raise ValueError("multi_sc in data_package is None")
766 data_package.multi_sc = {
767 "multi_sc": data_package.multi_sc["multi_sc"][common_ids]
768 }
769 images = data_package.img
770 if images is None:
771 raise ValueError("Images cannot be None")
772 data_package.img = {
773 k: self.filter_imgdata_list(img_list=v, ids=common_ids)
774 for k, v in images.items()
775 }
777 def presplit_processor(
778 modality_data: Dict[str, Union[Any, List[ImgData]]],
779 ) -> Dict[str, Union[Any, List[ImgData]]]:
780 was_image = False
781 for modality_key, data in modality_data.items():
782 if self._is_image_data(data=data):
783 was_image = True
784 modality_data[modality_key] = self._normalize_image_data(
785 images=data, # type: ignore
786 info_key=modality_key, # type: ignore
787 )
789 if was_image:
790 return modality_data
791 else:
792 sc_filter = SingleCellFilter(
793 data_info=self.config.data_config.data_info, config=self.config
794 )
795 return sc_filter.presplit_processing(multi_sc=modality_data)
797 def postsplit_processor(
798 split_data: Dict[str, Dict[str, Any]], datapackage_key: str
799 ) -> Dict[str, Dict[str, Any]]:
800 if datapackage_key == "multi_sc":
801 return self._postsplit_multi_single_cell(
802 split_data=split_data, datapackage_key=datapackage_key
803 )
804 # No postsplit processing needed for image data
805 return split_data
807 return self._process_data_case(
808 data_package,
809 modality_processors={
810 "multi_sc": (
811 lambda data: presplit_processor(modality_data=data),
812 lambda data: postsplit_processor(
813 split_data=data, datapackage_key="multi_sc"
814 ),
815 ),
816 "img": (
817 lambda data: presplit_processor(modality_data=data),
818 lambda data: postsplit_processor(
819 split_data=data, datapackage_key="img"
820 ),
821 ),
822 },
823 )
825 def _process_img_to_img_case(
826 self, raw_user_data: Optional[DataPackage] = None
827 ) -> Dict[str, DataPackage]:
828 """Process IMG_TO_IMG case.
830 Reads image data for from/to modalities, performs data splitting,
831 NaN removal, and applies normalization to both from and to image data.
833 Args:
834 raw_user_data: Optional DataPackage containing user-provided data.
835 If provided, the data reading step is skipped.
836 Returns:
837 A dictionary containing processed DataPackage objects for each data split.
838 Raises:
839 TypeError: If from_key or to_key is None, indicating that translation keys must be specified.
840 """
841 if raw_user_data is None:
842 imgreader = self.data_readers[DataCase.IMG_TO_IMG]
843 images, annotation = imgreader.read_data(config=self.config)
845 data_package = DataPackage(img=images, annotation=annotation)
846 else:
847 data_package = raw_user_data
849 if self.config.requires_paired:
850 common_ids = data_package.get_common_ids()
852 images = data_package.img
853 if images is None:
854 raise ValueError("Images cannot be None")
855 data_package.img = {
856 k: self.filter_imgdata_list(img_list=v, ids=common_ids)
857 for k, v in images.items()
858 }
860 def presplit_processor(modality_data: Dict[str, List]) -> Dict[str, List]:
861 """Processes img-to-img modality data with normalization for images."""
862 print("calling normalize image in _process_ing_to_img_case")
863 return {
864 k: self._normalize_image_data(v, k) for k, v in modality_data.items()
865 }
867 def postsplit_processor(
868 split_data: Dict[str, Dict[str, Any]],
869 ) -> Dict[str, Dict[str, Any]]:
870 """No postsplit processing needed for image data."""
871 return split_data
873 return self._process_data_case(
874 data_package,
875 modality_processors={
876 "img": (
877 lambda data: presplit_processor(
878 data,
879 ),
880 postsplit_processor,
881 ),
882 },
883 )
885 # This method would be inside your GeneralPreprocessor or a similar class
886 def _split_data_package(
887 self, data_package: DataPackage
888 ) -> Tuple[Dict[str, Optional[Dict[str, Any]]], Dict[str, Any]]:
889 """Splits a data package into train/validation/test sets.
891 This method first uses PairedUnpairedSplitter to generate a single,
892 synchronized set of indices for all modalities. It then uses
893 DataPackageSplitter to apply these indices to the data.
895 Args:
896 data_package: The DataPackage to be split.
898 Returns:
899 A tuple containing:
900 1. A dictionary of the split DataPackages.
901 2. A dictionary of the synchronized integer indices used for the split.
902 """
903 pairing_splitter = PairedUnpairedSplitter(
904 data_package=data_package, config=self.config
905 )
906 split_indices_config = pairing_splitter.split()
907 data_package_splitter = DataPackageSplitter(
908 data_package=data_package,
909 config=self.config,
910 indices=split_indices_config,
911 )
912 split_datasets = data_package_splitter.split()
913 return split_datasets, split_indices_config
915 def _is_image_data(self, data: Any) -> bool:
916 """Check if data is image data.
918 Determines if the provided data is a list of objects that are considered
919 image data based on having an 'img' attribute.
921 Args:
922 data: The data to check.
924 Returns:
925 True if the data is image data, False otherwise.
926 """
927 if data is None:
928 return False
929 if isinstance(data, list) and hasattr(data[0], "img"):
930 return True
931 return False
933 def _remove_nans(self, data_package: DataPackage) -> DataPackage:
934 """Remove NaN values from the data package.
936 Utilizes NaNRemover to identify and remove rows containing NaN values
937 in relevant annotation columns within the DataPackage.
939 Args:
940 data_package: The DataPackage from which to remove NaNs.
942 Returns:
943 The DataPackage with NaN values removed.
944 """
945 nanremover = NaNRemover(
946 config=self.config,
947 )
948 return nanremover.remove_nan(data=data_package)
950 def _normalize_image_data(self, images: List, info_key: str) -> List:
951 """Process images with normalization.
953 Normalizes a list of image data objects using ImageNormalizer based on
954 the scaling method specified in the configuration for the given info_key.
956 Args:
957 images: A list of image data objects (each having an 'img' attribute).
958 info_key: The key referencing data information in the configuration to get the scaling method.
960 Returns:
961 A list of processed image data objects with normalized image data.
962 """
964 scaling_method = self.config.data_config.data_info[info_key].scaling
965 if scaling_method == "NOTSET":
966 scaling_method = self.config.scaling
967 processed_images = []
968 normalizer = ImageNormalizer() # Instance created once here
970 for img in images:
971 img.img = normalizer.normalize_image( # Modify directly
972 image=img.img, method=scaling_method
973 )
974 processed_images.append(img)
976 return processed_images
978 def _get_translation_keys(self) -> Tuple[Optional[str], Optional[str]]:
979 """
980 Extract from and to keys from config.
982 Retrieves the 'from' and 'to' modality keys from the data configuration
983 based on the 'translate_direction' setting.
985 Returns:
986 A tuple containing the from_key and to_key as strings, or None if not found.
988 Raises:
989 ValueError: If neither 'from' nor 'to' keys are found in the data configuration.
990 TypeError: If the translate_direction is not set for the data_info.
991 """
992 from_key, to_key = None, None
993 for k, v in self.config.data_config.data_info.items():
994 if v.translate_direction is None:
995 continue
996 if v.translate_direction == "from":
997 from_key = k
998 if v.translate_direction == "to":
999 to_key = k
1000 return from_key, to_key
1002 def _get_user_translation_keys(self, raw_user_data: DataPackage):
1003 if len(raw_user_data.from_modality) == 0: # type: ignore
1004 return None, None
1005 elif len(raw_user_data.to_modality) == 0: # type: ignore
1006 return None, None
1007 else:
1008 if raw_user_data.from_modality is None or raw_user_data.to_modality is None:
1009 raise TypeError(
1010 "from_modality and to_modality cannot be None for Translation"
1011 )
1012 try:
1013 return next(iter(raw_user_data.from_modality.keys())), next(
1014 iter(raw_user_data.to_modality.keys())
1015 )
1016 except Exception as e:
1017 print("error getting from or to keys")
1018 print(e)
1019 print("returning None")
1020 return None, None
1022 @abstractmethod
1023 def format_reconstruction(
1024 self, reconstruction: Dict[str, torch.Tensor], result: Optional[Result] = None
1025 ) -> DataPackage:
1026 pass
1028 def filter_imgdata_list(self, img_list, ids):
1029 filtered = []
1030 for imgdata in img_list:
1031 if imgdata.sample_id in ids:
1032 filtered.append(imgdata)
1033 return filtered