Coverage for src / autoencodix / data / _xmodal_preprocessor.py: 24%

85 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-21 10:09 +0200

1from typing import Any, Dict, Optional, Tuple, Union 

2 

3import mudata as md 

4 

5import anndata as ad 

6import pandas as pd 

7import torch 

8 

9from autoencodix.base._base_dataset import BaseDataset 

10from autoencodix.data._datasetcontainer import DatasetContainer 

11from autoencodix.data._image_dataset import ImageDataset 

12from autoencodix.data._imgdataclass import ImgData 

13from autoencodix.data._numeric_dataset import NumericDataset 

14from autoencodix.data.datapackage import DataPackage 

15from autoencodix.data.general_preprocessor import GeneralPreprocessor 

16from autoencodix.configs.default_config import DefaultConfig 

17from autoencodix.data._multimodal_dataset import MultiModalDataset 

18 

19 

20class XModalPreprocessor(GeneralPreprocessor): 

21 """Preprocessor for cross-modal data, handling multiple data types and their transformations. 

22 

23 

24 Attributes: 

25 data_config: Configuration specific to data handling. 

26 dataset_dicts: Dictionary holding datasets for different splits (train, test, valid). 

27 """ 

28 

29 def __init__( 

30 self, config: DefaultConfig, ontologies: Optional[Union[Tuple, Dict]] = None 

31 ): 

32 """Initializes the XModalPreprocessor 

33 Args: 

34 config: Configuration object for the preprocessor. 

35 ontologies: Optional ontologies for data processing. 

36 """ 

37 super().__init__(config=config, ontologies=ontologies) 

38 self.data_config = config.data_config 

39 

40 def preprocess( 

41 self, 

42 raw_user_data: Optional[DataPackage] = None, 

43 predict_new_data: bool = False, 

44 ) -> DatasetContainer: 

45 """Preprocess the data according to the configuration. 

46 Args: 

47 raw_user_data: Optional raw data provided by the user. 

48 predict_new_data: Flag indicating if new data is being predicted. 

49 """ 

50 self.dataset_dicts = self._general_preprocess( 

51 raw_user_data=raw_user_data, predict_new_data=predict_new_data 

52 ) 

53 datasets = {} 

54 for split in ["train", "test", "valid"]: 

55 cur_split = self.dataset_dicts.get(split) 

56 if cur_split is None: 

57 print(f"split is None: {split}") 

58 continue 

59 cur_data = cur_split.get("data") 

60 if not isinstance(cur_data, DataPackage): 

61 raise TypeError( 

62 f"expected type of cur_data to be DataPackage, got {type(cur_data)}" 

63 ) 

64 cur_indices = cur_split.get("indices") 

65 datasets[split] = MultiModalDataset( 

66 datasets=self._process_dp(dp=cur_data, indices=cur_indices), 

67 config=self.config, 

68 ) 

69 

70 for k, v in self.dataset_dicts.items(): 

71 print(f"key: {k}, type: {type(v)}") 

72 

73 return DatasetContainer( 

74 train=datasets["train"], test=datasets["test"], valid=datasets["valid"] 

75 ) 

76 

77 def format_reconstruction(self, reconstruction, result=None): 

78 pass 

79 

80 def _process_dp(self, dp: DataPackage, indices: Dict[str, Any]): 

81 """Processes a DataPackage into a dictionary of BaseDataset objects. 

82 

83 Args: 

84 dp: The DataPackage to process. 

85 indices: The indices for splitting the data. 

86 Returns: 

87 A dictionary mapping modality names to BaseDataset objects. 

88 """ 

89 

90 dataset_dict: Dict[str, BaseDataset] = {} 

91 for k, v in dp: 

92 dp_key, sub_key = k.split(".") 

93 data = v 

94 metadata = None 

95 if dp.annotation is not None: # prevents error in SingleCell case 

96 metadata = dp.annotation.get(sub_key) 

97 if metadata is None: 

98 metadata = dp.annotation.get("paired") 

99 # case where we have the unpaired case, but we have one metadata that included all samples across all numeric data 

100 if metadata is None: 

101 if not len(dp.annotation.keys()) == 1: 

102 raise ValueError( 

103 f"annotation key needs to be either 'paired' match a key of the numeric data or only one key exists that holds all unpaired data, please adjust config, got: {dp.annotation.keys()}" 

104 ) 

105 metadata_key = next(iter(dp.annotation.keys())) 

106 metadata = dp.annotation.get(metadata_key) 

107 

108 if dp_key == "multi_bulk": 

109 if not isinstance(data, pd.DataFrame): 

110 raise ValueError( 

111 f"Expected data for multi_bulk: {k}, {v} to be pd.DataFrame, got {type(data)}" 

112 ) 

113 if metadata is None: 

114 raise ValueError("metadata cannot be None") 

115 metadata_num = metadata.loc[ 

116 data.index 

117 ] # needed when we have only one annotation df containing metadata for all modalities 

118 dataset_dict[k] = NumericDataset( 

119 data=data.values, 

120 config=self.config, 

121 sample_ids=data.index, 

122 feature_ids=data.columns, 

123 split_indices=indices, 

124 metadata=metadata_num, 

125 ) 

126 elif dp_key == "img": 

127 if not isinstance(data, list): 

128 raise ValueError() 

129 if not isinstance(data[0], ImgData): 

130 raise ValueError() 

131 dataset_dict[k] = ImageDataset( 

132 data=data, 

133 config=self.config, 

134 split_indices=indices, 

135 metadata=metadata, 

136 ) 

137 elif dp_key == "multi_sc": 

138 # unpaired multi_sc case with adata dicts 

139 if isinstance(data, dict): 

140 for adata_name, adata_v in data.items(): 

141 self._validate_layers(data_name=adata_name) 

142 if not isinstance(adata_v, ad.AnnData): 

143 raise TypeError( 

144 f"Input data has unsupported data type: {type(data)}" 

145 ) 

146 dataset_dict[k] = NumericDataset( 

147 data=adata_v.X, 

148 config=self.config, 

149 sample_ids=adata_v.obs_names, 

150 feature_ids=adata_v.var_names, 

151 split_indices=indices, 

152 metadata=adata_v.obs, 

153 ) 

154 

155 elif isinstance(data, md.MuData): 

156 for mod_key, mod_data in data.mod.items(): 

157 self._validate_layers(data_name=mod_key) 

158 dataset_dict[k] = NumericDataset( 

159 data=mod_data.X, 

160 config=self.config, 

161 sample_ids=mod_data.obs_names, 

162 feature_ids=mod_data.var_names, 

163 split_indices=indices, 

164 metadata=mod_data.obs, 

165 ) 

166 else: 

167 raise TypeError( 

168 f"Input data has unsupported data type: {type(data)}" 

169 ) 

170 

171 elif dp_key == "annotation": 

172 pass 

173 

174 else: 

175 raise NotImplementedError( 

176 f"Got datapackage attribute: {k}, probably you have added an attribute to the Datapackage class without adjusting this method. Only supports: ['multi_bulk', 'multi_sc', 'img' and 'annotation']" 

177 ) 

178 return dataset_dict 

179 

180 def _validate_layers(self, data_name: str): 

181 selected_layers = self.config.data_config.data_info[data_name].selected_layers 

182 if not selected_layers[0] == "X" and len(selected_layers) != 1: 

183 import warnings 

184 

185 warnings.warn( 

186 "Xmodalix works only with X layer of single cell data as of now" 

187 "Using X Layer, discarding selected layers" 

188 )