Coverage for src / autoencodix / data / _image_processor.py: 25%
40 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
1from typing import Any, Dict, Optional, Tuple, Union
4from autoencodix.data._datasetcontainer import DatasetContainer
5from autoencodix.data._image_dataset import ImageDataset
6from autoencodix.data.datapackage import DataPackage
7from autoencodix.data.general_preprocessor import GeneralPreprocessor
8from autoencodix.configs.default_config import DefaultConfig
11class ImagePreprocessor(GeneralPreprocessor):
12 """
13 Preprocessor for cross-modal data, handling multiple data types and their transformations.
16 Attributes:
17 data_config: Configuration specific to data handling and preprocessing.
18 dataset_dicts: Dictionary holding datasets for different splits (train/test/valid).
19 """
21 def __init__(
22 self, config: DefaultConfig, ontologies: Optional[Union[Tuple, Dict]] = None
23 ):
24 super().__init__(config=config, ontologies=ontologies)
25 self.data_config = config.data_config
27 def preprocess(
28 self,
29 raw_user_data: Optional[DataPackage] = None,
30 predict_new_data: bool = False,
31 ) -> DatasetContainer:
32 """
33 Preprocess the data according to the configuration.
35 Args:
36 raw_user_data: The raw data package provided by the user.
37 predict_new_data: Flag indicating if new data is being predicted.
38 Returns:
39 A DatasetContainer with processed training, validation, and test datasets.
40 """
41 self.dataset_dicts = self._general_preprocess(
42 raw_user_data=raw_user_data, predict_new_data=predict_new_data
43 )
44 datasets = {}
45 for split in ["train", "test", "valid"]:
46 cur_split = self.dataset_dicts.get(split)
47 if cur_split is None:
48 print(f"split is None: {split}")
49 continue
50 cur_data = cur_split.get("data")
51 if not isinstance(cur_data, DataPackage):
52 raise TypeError(
53 f"expected type of cur_data to be DataPackage, got {type(cur_data)}"
54 )
55 cur_indices = cur_split.get("indices")
56 datasets[split] = self._process_dp(dp=cur_data, indices=cur_indices)
58 return DatasetContainer(
59 train=datasets["train"], test=datasets["test"], valid=datasets["valid"]
60 )
62 def _process_dp(self, dp: DataPackage, indices: Dict[str, Any]) -> ImageDataset:
63 if dp.img is None:
64 raise ValueError("no img attribute found in datapackage")
65 first_key = next(iter(list(dp.img.keys())))
66 if not isinstance(dp.img, dict):
67 raise TypeError(
68 f"Expected `img` attribute of DataPackage to be `dict`, got {type(dp.img)}"
69 )
70 if len(dp.img.keys()) > 1:
71 import warnings
73 warnings.warn(
74 f"got multiple image datasets for Imagix: {dp.img.keys()},\
75 we only support a single image dataset in this case, using: {first_key}"
76 )
77 if dp.annotation is None:
78 metadata = None
79 else:
80 metadata = dp.annotation.get(first_key)
81 if metadata is None:
82 metadata = dp.annotation.get("paired")
83 data = dp.img[first_key]
84 return ImageDataset(
85 data=data,
86 config=self.config,
87 split_indices=indices,
88 metadata=metadata,
89 )