Coverage for src / autoencodix / base / _base_preprocessor.py: 12%

380 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-21 10:09 +0200

1import abc 

2from abc import abstractmethod 

3from enum import Enum 

4from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union 

5 

6import numpy as np 

7import pandas as pd 

8import torch 

9 

10from autoencodix.data._datapackage_splitter import DataPackageSplitter 

11from autoencodix.data._datasetcontainer import DatasetContainer 

12from autoencodix.data._datasplitter import PairedUnpairedSplitter 

13from autoencodix.data._filter import DataFilter 

14from autoencodix.data._imgdataclass import ImgData 

15from autoencodix.data._nanremover import NaNRemover 

16from autoencodix.data._sc_filter import SingleCellFilter 

17from autoencodix.data.datapackage import DataPackage 

18from autoencodix.utils._bulkreader import BulkDataReader 

19from autoencodix.utils._imgreader import ImageDataReader, ImageNormalizer 

20from autoencodix.utils._screader import SingleCellDataReader 

21from autoencodix.configs.default_config import DataCase, DefaultConfig 

22from autoencodix.utils._result import Result 

23 

24if TYPE_CHECKING: 

25 import mudata as md # type: ignore 

26 

27 MuData = md.MuData.MuData 

28else: 

29 MuData = Any 

30 

31 

32class BasePreprocessor(abc.ABC): 

33 """Contains logic for data preprocessing in the Autoencodix framework. 

34 

35 This class defines the general preprocessing workflow and provides 

36 methods for handling different data modalities and data cases. 

37 Subclasses should implement the `preprocess` method to perform 

38 specific preprocessing steps. 

39 

40 Attributes: 

41 config: A DefaultConfig object containing preprocessing configurations. 

42 processed_data: A dictionary to store processed DataPackage objects for each data split. 

43 bulk_genes_to_keep: Optional list of genes to keep for bulk data. 

44 bulk_scalers: Optional dictionary of scalers for bulk data. 

45 sc_genes_to_keep: Optional dictionary mapping modality keys to lists of genes to keep for single-cell data. 

46 sc_scalers: Optional dictionary mapping modality keys to scalers for single-cell data. 

47 sc_general_genes_to_keep: Optional dictionary mapping modality keys to lists of genes to keep filtered by non-SC specific methods. 

48 data_readers: A dictionary mapping DataCase enum values to data reader instances for different modalities. 

49 _dataset_container: Optional DatasetContainer to hold the processed datasets. 

50 """ 

51 

52 def __init__( 

53 self, 

54 config: DefaultConfig, 

55 ontologies: Optional[Union[Tuple[Any, Any], Dict[Any, Any]]] = None, 

56 ): 

57 """Initializes the BasePreprocessor with a configuration object. 

58 

59 Args : 

60 config: A DefaultConfig object containing preprocessing configurations. 

61 ontologies: Ontology information, if provided for Ontix. 

62 """ 

63 self.config = config 

64 self._dataset_container: Optional[DatasetContainer] = None 

65 self.processed_data = Dict[str, Dict[str, Union[Any, DataPackage]]] 

66 self.bulk_genes_to_keep: Optional[Dict[str, List[str]]] = None 

67 self.bulk_scalers: Optional[Dict[str, Any]] = None 

68 self.sc_genes_to_keep: Optional[Dict[str, List[str]]] = None 

69 self.sc_scalers: Optional[Dict[str, Dict[str, Any]]] = None 

70 self.sc_general_genes_to_keep: Optional[Dict[str, List]] = None 

71 self._ontologies: Optional[Union[Tuple[Any, Any], Dict[Any, Any]]] = ontologies 

72 self.data_readers: Dict[Enum, Any] = { 

73 DataCase.MULTI_SINGLE_CELL: SingleCellDataReader(), 

74 DataCase.MULTI_BULK: BulkDataReader(config=self.config), 

75 DataCase.BULK_TO_BULK: BulkDataReader(config=self.config), 

76 DataCase.SINGLE_CELL_TO_SINGLE_CELL: SingleCellDataReader(), 

77 DataCase.IMG_TO_BULK: { 

78 "bulk": BulkDataReader(config=self.config), 

79 "img": ImageDataReader(config=self.config), 

80 }, 

81 DataCase.SINGLE_CELL_TO_IMG: { 

82 "sc": SingleCellDataReader(), 

83 "img": ImageDataReader(config=self.config), 

84 }, 

85 DataCase.IMG_TO_IMG: ImageDataReader(config=self.config), 

86 } 

87 

88 @abc.abstractmethod 

89 def preprocess( 

90 self, 

91 raw_user_data: Optional[DataPackage] = None, 

92 predict_new_data: bool = False, 

93 ) -> DatasetContainer: 

94 """To be implemented by subclasses for specific preprocessing steps. 

95 Args: 

96 raw_user_data: Users can provide raw data. This is an alternative way of 

97 providing data via filepaths in the config. If this param is passed, we skip the data reading step. 

98 predict_new_data: Indicates whether the user wants to predict with unseen data. 

99 If this is the case, we don't split the data and only prerpocess. 

100 """ 

101 pass 

102 

103 def _general_preprocess( 

104 self, 

105 raw_user_data: Optional[DataPackage] = None, 

106 predict_new_data: bool = False, 

107 ) -> Dict[str, Dict[str, Union[Any, DataPackage]]]: 

108 """Orchestrates the preprocessing steps. 

109 

110 This method determines the data case from the configuration and calls 

111 the appropriate processing function for that data case. 

112 

113 Args: 

114 raw_user_data: Optional DataPackage containing user-provided data. 

115 If provided, the data reading step is skipped. 

116 predict_new_data: Boolean indicating whether to preprocess new unseen data 

117 without splitting it into train/validation/test sets. 

118 

119 Returns: 

120 A dictionary containing processed DataPackage objects for each data split 

121 (e.g., 'train', 'validation', 'test'). 

122 

123 Raises: 

124 ValueError: If an unsupported data case is encountered. 

125 """ 

126 self.predict_new_data = predict_new_data 

127 datacase = self.config.data_case 

128 if datacase is None: 

129 raise TypeError( 

130 "datacase can't be None. Please ensure the configuration specifies a valid DataCase." 

131 ) 

132 if raw_user_data is None: 

133 self.from_key, self.to_key = self._get_translation_keys() 

134 else: 

135 self.from_key, self.to_key = self._get_user_translation_keys( 

136 raw_user_data=raw_user_data 

137 ) 

138 process_function = self._get_process_function(datacase=datacase) 

139 if process_function: 

140 return process_function(raw_user_data=raw_user_data) 

141 else: 

142 raise ValueError(f"Unsupported data case: {datacase}") 

143 

144 def _get_process_function(self, datacase: DataCase) -> Any: 

145 """Returns the appropriate processing function based on the data case. 

146 

147 Args: 

148 datacase: The DataCase enum value representing the current data case. 

149 

150 Returns: 

151 A callable function that performs the preprocessing for the given data case, 

152 or None if the data case is not supported. 

153 """ 

154 process_map = { 

155 DataCase.MULTI_SINGLE_CELL: self._process_multi_single_cell, 

156 DataCase.MULTI_BULK: self._process_multi_bulk_case, 

157 DataCase.BULK_TO_BULK: self._process_multi_bulk_case, 

158 DataCase.SINGLE_CELL_TO_SINGLE_CELL: self._process_multi_single_cell, 

159 DataCase.IMG_TO_BULK: self._process_img_to_bulk_case, 

160 DataCase.SINGLE_CELL_TO_IMG: self._process_sc_to_img_case, 

161 DataCase.IMG_TO_IMG: self._process_img_to_img_case, 

162 } 

163 return process_map.get(datacase) 

164 

165 def _process_data_case( 

166 self, data_package: DataPackage, modality_processors: Dict[Any, Any] 

167 ) -> Union[Dict[str, Dict[str, Union[Any, DataPackage]]], Dict[str, Any]]: 

168 """Processes the data package based on the provided modality processors. 

169 

170 This method handles the common preprocessing steps for different data cases, 

171 including splitting the data package, removing NaNs, and applying 

172 modality-specific processors. 

173 

174 Args:: 

175 data_package: The DataPackage object to be processed. 

176 modality_processors: A dictionary mapping modality keys (e.g., 'multi_sc', 'from_modality') 

177 to callable processor functions that will be applied to the corresponding modality data. 

178 

179 Returns: 

180 A dictionary containing processed DataPackage objects for each data split. 

181 """ 

182 if self.predict_new_data: 

183 # we get the data modality keys from this structure in postsplit processing 

184 # for predict_new data there do not exits real splits, because all is "test" data 

185 # but the preprocessing code expects this splits, so we mock them 

186 # use train, because processing logic expects train split 

187 mock_split: Dict[str, Dict[str, Union[Any, DataPackage]]] = { 

188 "test": { 

189 "data": data_package, 

190 "indices": {"paired": np.array([])}, 

191 }, 

192 "valid": {"data": None, "indices": {"paired": np.array([])}}, 

193 "train": {"data": data_package, "indices": {"paired": np.array([])}}, 

194 } 

195 if self.config.skip_preprocessing: 

196 return mock_split 

197 

198 clean_package = self._remove_nans(data_package=data_package) 

199 mock_split["test"]["data"] = clean_package 

200 for modality_key, ( 

201 presplit_processor, 

202 postsplit_processor, 

203 ) in modality_processors.items(): 

204 modality_data = clean_package[modality_key] 

205 if modality_data: 

206 processed_modality_data = presplit_processor(modality_data) 

207 # mock the split 

208 clean_package[modality_key] = processed_modality_data 

209 mock_split["test"]["data"] = clean_package 

210 mock_split = postsplit_processor(mock_split) 

211 return mock_split 

212 # normal case without new data ----------------------------------- 

213 if self.config.skip_preprocessing: 

214 split_packages, _ = self._split_data_package(data_package=data_package) 

215 return split_packages 

216 clean_package = self._remove_nans(data_package=data_package) 

217 for modality_key, (presplit_processor, _) in modality_processors.items(): 

218 modality_data = clean_package[modality_key] 

219 if modality_data: 

220 processed_modality_data = presplit_processor(modality_data) 

221 clean_package[modality_key] = processed_modality_data 

222 split_packages, indices = self._split_data_package(data_package=clean_package) 

223 processed_splits = {} 

224 for modality_key, (_, postsplit_processor) in modality_processors.items(): 

225 split_packages = postsplit_processor(split_packages) 

226 for split_name, split_package in split_packages.items(): 

227 split_indices = { 

228 name: { 

229 split: idx 

230 for split, idx in indices[name].items() 

231 if split == split_name 

232 } 

233 for name in indices.keys() 

234 } 

235 processed_splits[split_name] = { 

236 "data": split_package["data"], 

237 "indices": split_indices, 

238 } 

239 return processed_splits 

240 

241 def _process_multi_single_cell( 

242 self, raw_user_data: Optional[DataPackage] = None 

243 ) -> Dict[str, Dict[str, Union[Any, DataPackage]]]: 

244 """Process MULTI_SINGLE_CELL case 

245 

246 Reads multi-single-cell data, performs data splitting, NaN removal, 

247 and applies single-cell specific filtering. 

248 Args: 

249 raw_user_data: Optional DataPackage containing user-provided data. 

250 

251 Returns: 

252 A dictionary containing processed DataPackage objects for each data split. 

253 Raises: 

254 ValueError: If multi_sc in data_package is None. 

255 

256 """ 

257 if raw_user_data is None: 

258 screader = self.data_readers[DataCase.MULTI_SINGLE_CELL] # type: ignore 

259 

260 mudata = screader.read_data(config=self.config) 

261 data_package: DataPackage = DataPackage() 

262 data_package.multi_sc = mudata 

263 else: 

264 data_package = raw_user_data 

265 if self.config.requires_paired: 

266 common_ids = data_package.get_common_ids() 

267 if data_package.multi_sc is None: 

268 raise ValueError("multi_sc in data_package is None") 

269 data_package.multi_sc = { 

270 "multi_sc": data_package.multi_sc["multi_sc"][common_ids] 

271 } 

272 

273 def presplit_processor(modality_data: Any) -> Any: 

274 if modality_data is None: 

275 return modality_data 

276 sc_filter = SingleCellFilter( 

277 data_info=self.config.data_config.data_info, config=self.config 

278 ) 

279 return sc_filter.presplit_processing(multi_sc=modality_data) 

280 

281 def postsplit_processor( 

282 split_data: Dict[str, Dict[str, Any]], 

283 ) -> Dict[str, Dict[str, Any]]: 

284 return self._postsplit_multi_single_cell( 

285 split_data=split_data, datapackage_key="multi_sc" 

286 ) 

287 

288 return self._process_data_case( 

289 data_package, 

290 modality_processors={"multi_sc": (presplit_processor, postsplit_processor)}, 

291 ) 

292 

293 def _process_multi_bulk_case( 

294 self, 

295 raw_user_data: Optional[DataPackage] = None, 

296 ) -> Dict[str, Dict[str, Union[Any, DataPackage]]]: 

297 """ 

298 Process MULTI_BULK case. 

299 

300 Reads multi-bulk data, performs data splitting, NaN removal, 

301 and applies filtering and scaling to bulk dataframes. 

302 Args: 

303 raw_user_data: Optional DataPackage containing user-provided data. 

304 

305 Returns: 

306 A dictionary containing processed DataPackage objects for each data split. 

307 """ 

308 if raw_user_data is None: 

309 bulkreader = self.data_readers[DataCase.MULTI_BULK] 

310 bulk_dfs, annotation = bulkreader.read_data() 

311 

312 data_package = DataPackage(multi_bulk=bulk_dfs, annotation=annotation) 

313 else: 

314 data_package = raw_user_data 

315 if self.config.requires_paired: 

316 common_ids = data_package.get_common_ids() 

317 unpaired_data = data_package.multi_bulk 

318 unpaired_anno = data_package.annotation 

319 if unpaired_anno is None: 

320 raise ValueError("annotation attribute of datapackge cannot be None") 

321 if unpaired_data is None: 

322 raise ValueError("multi_bulk attribute of datapackge cannot be None") 

323 data_package.multi_bulk = { 

324 k: v.loc[common_ids] for k, v in unpaired_data.items() 

325 } 

326 

327 data_package.annotation = { 

328 k: v.loc[common_ids] # ty: ignore 

329 for k, v in unpaired_anno.items() # ty: ignore 

330 } 

331 

332 def presplit_processor( 

333 modality_data: Dict[str, Union[pd.DataFrame, None]], 

334 ) -> Dict[str, Union[pd.DataFrame, None]]: 

335 """For the multi_bulk modality we perform all operations after splitting at the moment.""" 

336 return modality_data 

337 

338 def postsplit_processor( 

339 split_data: Dict[str, Dict[str, Any]], 

340 ) -> Dict[str, Dict[str, Any]]: 

341 return self._postsplit_multi_bulk(split_data=split_data) 

342 

343 return self._process_data_case( 

344 data_package, 

345 modality_processors={ 

346 "multi_bulk": (presplit_processor, postsplit_processor) 

347 }, 

348 ) 

349 

350 def _calc_k_filter( 

351 self, i: int, remainder: int, base_features: int 

352 ) -> Optional[int]: 

353 if self.config.k_filter is None: 

354 return None 

355 extra = 1 if i < remainder else 0 

356 return base_features + extra 

357 

358 def _postsplit_multi_single_cell( 

359 self, 

360 split_data: Dict[str, Dict[str, Any]], 

361 datapackage_key: str = "multi_sc", 

362 modality_key: Optional[str] = None, 

363 ) -> Dict[str, Dict[str, Any]]: 

364 """Post-split processing for multi-single-cell data. 

365 This method applies filtering and scaling to the single-cell data after it has been split. 

366 Now supports multiple MuData objects in the input dictionary. 

367 

368 Args: 

369 split_data: A dictionary containing the split data for each data split. 

370 datapackage_key: The key in the DataPackage that contains the multi-single-cell data. 

371 modality_key: Optional specific modality key for backward compatibility. 

372 If provided, only processes that specific modality. 

373 If None, processes all modalities in the dictionary. 

374 

375 Returns: 

376 A dictionary containing processed DataPackage objects for each data split. 

377 

378 Raises: 

379 ValueError: If the train split data is None. 

380 """ 

381 processed_splits: Dict[str, Dict[str, Any]] = {} 

382 train_split: Optional[Dict[str, Any]] = split_data.get("train") 

383 

384 if train_split is None: 

385 raise ValueError( 

386 "Train split data is None. Ensure that the data package contains valid train data." 

387 ) 

388 

389 train_data: Optional[Any] = train_split.get("data") 

390 if train_data is None: 

391 raise ValueError( 

392 "Train split data is None. Ensure that the data package contains valid train data." 

393 ) 

394 

395 # Get all modality keys from the train data 

396 mudata_dict = train_data[datapackage_key] 

397 

398 if modality_key is not None: 

399 if modality_key not in mudata_dict: 

400 raise ValueError( 

401 f"Specified modality_key '{modality_key}' not found in {list(mudata_dict.keys())}" 

402 ) 

403 modality_keys = [modality_key] 

404 print( 

405 f"Processing single modality (backward compatibility): {modality_key}" 

406 ) 

407 else: 

408 modality_keys = list(mudata_dict.keys()) 

409 print(f"Processing {len(modality_keys)} MuData objects: {modality_keys}") 

410 

411 # Initialize storage for scalers and gene filters for each modality 

412 # if we do this for the first time, we need a train split and we dont 

413 # fitted any scalers or features to keep yet. 

414 # that's why in the predict_new case we can keep the mocksplit for train None 

415 # because we never get in this if 

416 if ( 

417 self.sc_scalers is None 

418 and self.sc_genes_to_keep is None 

419 and self.sc_general_genes_to_keep is None 

420 ) or ("modality" in datapackage_key): 

421 # Process each MuData object in the train split 

422 processed_mudata_dict = {} 

423 all_scalers = {} 

424 all_sc_genes_to_keep = {} 

425 all_general_genes_to_keep = {} 

426 

427 for current_modality_key in modality_keys: 

428 print(f"Processing train modality: {current_modality_key}") 

429 

430 sc_filter = SingleCellFilter( 

431 data_info=self.config.data_config.data_info, config=self.config 

432 ) 

433 

434 # Single-cell specific filtering 

435 filtered_train, sc_genes_to_keep = sc_filter.sc_postsplit_processing( 

436 mudata=mudata_dict[current_modality_key] 

437 ) 

438 

439 # General post-processing 

440 processed_train, general_genes_to_keep, scalers = ( 

441 sc_filter.general_postsplit_processing( 

442 mudata=filtered_train, scaler_map=None, gene_map=None 

443 ) 

444 ) 

445 

446 # Store processed data and filters for this modality 

447 processed_mudata_dict[current_modality_key] = processed_train 

448 all_scalers[current_modality_key] = scalers 

449 all_sc_genes_to_keep[current_modality_key] = sc_genes_to_keep 

450 all_general_genes_to_keep[current_modality_key] = general_genes_to_keep 

451 

452 # Store all scalers and gene filters 

453 self.sc_scalers = all_scalers 

454 self.sc_genes_to_keep = all_sc_genes_to_keep 

455 self.sc_general_genes_to_keep = all_general_genes_to_keep 

456 

457 # Update train data with processed MuData objects 

458 train_data[datapackage_key] = processed_mudata_dict 

459 

460 else: 

461 # Use existing scalers and gene filters 

462 all_scalers = self.sc_scalers # type: ignore 

463 all_sc_genes_to_keep = self.sc_genes_to_keep # type: ignore 

464 all_general_genes_to_keep = self.sc_general_genes_to_keep # type: ignore 

465 

466 # Store processed train split 

467 processed_splits["train"] = { 

468 "data": train_data, 

469 "indices": split_data["train"]["indices"], 

470 } 

471 

472 # Process other splits (val, test, etc.) 

473 for split, split_package in split_data.items(): 

474 if split == "train": 

475 continue 

476 

477 data_package = split_package["data"] 

478 if data_package is None: 

479 processed_splits[split] = split_package 

480 continue 

481 

482 print(f"Processing {split} split") 

483 processed_mudata_dict = {} 

484 

485 # Process each MuData object in this split 

486 for current_modality_key in modality_keys: 

487 print(f"Processing {split} modality: {current_modality_key}") 

488 

489 sc_filter = SingleCellFilter( 

490 data_info=self.config.data_config.data_info, config=self.config 

491 ) 

492 

493 # Apply single-cell filtering using train-derived gene map 

494 filtered_sc_data, _ = sc_filter.sc_postsplit_processing( 

495 mudata=data_package[datapackage_key][current_modality_key], 

496 gene_map=all_sc_genes_to_keep[current_modality_key], 

497 ) 

498 

499 # Apply general processing using train-derived scalers and gene map 

500 processed_general_data, _, _ = sc_filter.general_postsplit_processing( 

501 mudata=filtered_sc_data, 

502 gene_map=all_general_genes_to_keep[current_modality_key], 

503 scaler_map=all_scalers[current_modality_key], 

504 ) 

505 

506 processed_mudata_dict[current_modality_key] = processed_general_data 

507 

508 # Update data package with all processed MuData objects 

509 data_package[datapackage_key] = processed_mudata_dict 

510 

511 processed_splits[split] = { 

512 "data": data_package, 

513 "indices": split_package["indices"], 

514 } 

515 

516 return processed_splits 

517 

518 def _postsplit_multi_bulk( 

519 self, 

520 split_data: Dict[str, Dict[str, Any]], 

521 datapackage_key: str = "multi_bulk", 

522 ) -> Dict[str, Dict[str, Any]]: 

523 """Post-split processing for multi-bulk data. 

524 

525 This method applies filtering and scaling to the bulk dataframes after they have been split. 

526 

527 Args: 

528 split_data: A dictionary containing the split data for each data split. 

529 datapackage_key: The key in the DataPackage that contains the multi-bulk data. 

530 Returns: 

531 A dictionary containing processed DataPackage objects for each data split. 

532 Raises: 

533 ValueError: If the train split data is None. 

534 """ 

535 

536 train_split: Optional[Dict[str, Any]] = split_data.get("train") 

537 if train_split is None: 

538 raise ValueError( 

539 "Train split data is None. Ensure that the data package contains valid train data." 

540 ) 

541 train_data: Optional[Any] = train_split.get("data") 

542 genes_to_keep_map: Dict[str, List[str]] = {} 

543 scalers: Dict[str, Any] = {} 

544 processed_splits: Dict[str, Dict[str, Any]] = {} 

545 

546 if (self.bulk_scalers is None and self.bulk_genes_to_keep is None) or ( 

547 "modality" in datapackage_key 

548 ): 

549 if train_data is None: 

550 raise ValueError( 

551 "Train split data is None. Ensure that the data package contains valid train data." 

552 ) 

553 n_modalities: int = len(train_data[datapackage_key].keys()) 

554 remainder: int = 0 

555 base_features = 0 

556 if self.config.k_filter is not None: 

557 base_features = self.config.k_filter // n_modalities 

558 remainder = self.config.k_filter % n_modalities 

559 

560 # Get valid modality keys (those that are not None) 

561 modality_keys = [ 

562 k for k, v in train_data[datapackage_key].items() if v is not None 

563 ] 

564 

565 for i, k in enumerate(modality_keys): 

566 v = train_data[datapackage_key][k] 

567 cur_k_filter = self._calc_k_filter( 

568 i=i, base_features=base_features, remainder=remainder 

569 ) 

570 self.config.data_config.data_info[k].k_filter = cur_k_filter 

571 

572 data_processor = DataFilter( 

573 data_info=self.config.data_config.data_info[k], 

574 config=self.config, 

575 ontologies=self._ontologies, 

576 ) 

577 filtered_df, genes_to_keep = data_processor.filter(df=v) 

578 scaler = data_processor.fit_scaler(df=filtered_df) 

579 genes_to_keep_map[k] = genes_to_keep 

580 scalers[k] = scaler 

581 scaled_df = data_processor.scale(df=filtered_df, scaler=scaler) 

582 train_data[datapackage_key][k] = scaled_df 

583 # Check if indices stayed the same after filtering 

584 if not filtered_df.index.equals(v.index): 

585 mismatched_indices = filtered_df.index.symmetric_difference(v.index) 

586 raise ValueError( 

587 f"Indices mismatch after filtering for modality {k}. " 

588 f"Mismatched indices: {mismatched_indices}. " 

589 "Ensure filtering does not alter the indices." 

590 ) 

591 

592 self.bulk_scalers = scalers 

593 self.bulk_genes_to_keep = genes_to_keep_map # type: ignore 

594 else: 

595 scalers, genes_to_keep_map = self.bulk_scalers, self.bulk_genes_to_keep # type: ignore 

596 

597 processed_splits["train"] = { 

598 "data": train_data, 

599 "indices": split_data["train"]["indices"], 

600 } 

601 

602 for split_name, split_package in split_data.items(): 

603 if split_name == "train": 

604 continue 

605 if split_package["data"] is None: 

606 processed_splits[split_name] = split_data[split_name] 

607 continue 

608 

609 processed_package = split_package["data"] 

610 for k, v in processed_package[datapackage_key].items(): 

611 if v is None: 

612 continue 

613 data_processor = DataFilter( 

614 data_info=self.config.data_config.data_info[k], 

615 config=self.config, 

616 ontologies=self._ontologies, 

617 ) 

618 filtered_df, _ = data_processor.filter( 

619 df=v, genes_to_keep=genes_to_keep_map[k] 

620 ) 

621 scaled_df = data_processor.scale(df=filtered_df, scaler=scalers[k]) 

622 processed_package[datapackage_key][k] = scaled_df 

623 if not filtered_df.index.equals(v.index): 

624 raise ValueError( 

625 f"Indices mismatch after filtering for modality {k}. " 

626 "Ensure filtering does not alter the indices." 

627 ) 

628 

629 processed_splits[split_name] = { 

630 "data": processed_package, 

631 "indices": split_package["indices"], 

632 } 

633 

634 return processed_splits 

635 

636 def _process_img_to_bulk_case( 

637 self, raw_user_data: Optional[DataPackage] = None 

638 ) -> Dict[str, Dict[str, Union[Any, DataPackage]]]: 

639 """Process IMG_TO_BULK case 

640 

641 Reads image and bulk data, prepares from/to modalities (IMG->BULK or BULK->IMG), 

642 performs data splitting, NaN removal, and applies normalization to image data 

643 and filtering/scaling to bulk dataframes. 

644 Args: 

645 raw_user_data: Optional DataPackage containing user-provided data. 

646 If provided, the data reading step is skipped. 

647 Returns: 

648 A dictionary containing processed DataPackage objects for each data split. 

649 Raises: 

650 TypeError: If from_key or to_key is None, indicating that translation keys must be specified. 

651 """ 

652 

653 if raw_user_data is None: 

654 bulkreader = self.data_readers[DataCase.IMG_TO_BULK]["bulk"] 

655 imgreader = self.data_readers[DataCase.IMG_TO_BULK]["img"] 

656 

657 bulk_dfs, annotation_bulk = bulkreader.read_data() 

658 images, annotation_img = imgreader.read_data(config=self.config) 

659 

660 annotation = {**annotation_bulk, **annotation_img} 

661 

662 data_package = DataPackage( 

663 multi_bulk=bulk_dfs, img=images, annotation=annotation 

664 ) 

665 

666 else: 

667 data_package = raw_user_data 

668 

669 if self.config.requires_paired: 

670 common_ids = data_package.get_common_ids() 

671 

672 images = data_package.img 

673 if images is None: 

674 raise ValueError("Images cannot be None") 

675 data_package.img = { 

676 k: self.filter_imgdata_list(img_list=v, ids=common_ids) 

677 for k, v in images.items() 

678 } 

679 unpaired_data = data_package.multi_bulk 

680 unpaired_anno = data_package.annotation 

681 if unpaired_anno is None: 

682 raise ValueError("annotation attribute of datapackge cannot be None") 

683 if unpaired_data is None: 

684 raise ValueError("multi_bulk attribute of datapackge cannot be None") 

685 data_package.multi_bulk = { 

686 k: v.loc[common_ids] for k, v in unpaired_data.items() 

687 } 

688 

689 data_package.annotation = { 

690 k: v.loc[common_ids] # ty: ignore 

691 for k, v in unpaired_anno.items() # ty: ignore 

692 } 

693 

694 def presplit_processor( 

695 modality_data: Dict[str, Union[pd.DataFrame, List[ImgData]]], 

696 ) -> Dict[str, Union[pd.DataFrame, List[ImgData]]]: 

697 for modality_key, data in modality_data.items(): 

698 if self._is_image_data(data=data): 

699 modality_data[modality_key] = self._normalize_image_data( 

700 images=data, # type: ignore 

701 info_key=modality_key, # type: ignore 

702 ) 

703 # we don't need to filter bulk data here 

704 # because we do it in the postsplit step 

705 return modality_data 

706 

707 def postsplit_processor( 

708 split_data: Dict[str, Dict[str, Any]], datapackage_key: str 

709 ) -> Dict[str, Dict[str, Any]]: 

710 if datapackage_key == "multi_bulk": 

711 return self._postsplit_multi_bulk( 

712 split_data=split_data, datapackage_key=datapackage_key 

713 ) 

714 return split_data # for img data we don't need to do anything 

715 

716 return self._process_data_case( 

717 data_package, 

718 modality_processors={ 

719 "multi_bulk": ( # TODO change to multi_bulk and img for all translation cases and ajdust processors accordingly 

720 lambda data: presplit_processor(modality_data=data), 

721 lambda data: postsplit_processor( 

722 split_data=data, datapackage_key="multi_bulk" 

723 ), 

724 ), 

725 "img": ( 

726 lambda data: presplit_processor(modality_data=data), 

727 lambda data: postsplit_processor( 

728 split_data=data, datapackage_key="img" 

729 ), 

730 ), 

731 }, 

732 ) 

733 

734 def _process_sc_to_img_case( 

735 self, raw_user_data: Optional[DataPackage] = None 

736 ) -> Dict[str, Dict[str, Union[Any, DataPackage]]]: 

737 """Process SC_TO_IMG case. 

738 

739 Reads single-cell and image data, prepares from/to modalities (SC->IMG or IMG->SC), 

740 performs data splitting, NaN removal, and applies single-cell specific filtering 

741 to single-cell data and normalization to image data. 

742 

743 Args: 

744 raw_user_data: Optional DataPackage containing user-provided data. 

745 

746 Returns: 

747 A dictionary containing processed DataPackage objects for each data split. 

748 """ 

749 if raw_user_data is None: 

750 screader = self.data_readers[DataCase.SINGLE_CELL_TO_IMG]["sc"] 

751 imgreader = self.data_readers[DataCase.SINGLE_CELL_TO_IMG]["img"] 

752 

753 # only one mudata type in this case we know this 

754 mudata_dict = screader.read_data(config=self.config) 

755 images, annotation = imgreader.read_data(config=self.config) 

756 

757 data_package = DataPackage( 

758 multi_sc=mudata_dict, img=images, annotation=annotation 

759 ) 

760 else: 

761 data_package = raw_user_data 

762 if self.config.requires_paired: 

763 common_ids = data_package.get_common_ids() 

764 if data_package.multi_sc is None: 

765 raise ValueError("multi_sc in data_package is None") 

766 data_package.multi_sc = { 

767 "multi_sc": data_package.multi_sc["multi_sc"][common_ids] 

768 } 

769 images = data_package.img 

770 if images is None: 

771 raise ValueError("Images cannot be None") 

772 data_package.img = { 

773 k: self.filter_imgdata_list(img_list=v, ids=common_ids) 

774 for k, v in images.items() 

775 } 

776 

777 def presplit_processor( 

778 modality_data: Dict[str, Union[Any, List[ImgData]]], 

779 ) -> Dict[str, Union[Any, List[ImgData]]]: 

780 was_image = False 

781 for modality_key, data in modality_data.items(): 

782 if self._is_image_data(data=data): 

783 was_image = True 

784 modality_data[modality_key] = self._normalize_image_data( 

785 images=data, # type: ignore 

786 info_key=modality_key, # type: ignore 

787 ) 

788 

789 if was_image: 

790 return modality_data 

791 else: 

792 sc_filter = SingleCellFilter( 

793 data_info=self.config.data_config.data_info, config=self.config 

794 ) 

795 return sc_filter.presplit_processing(multi_sc=modality_data) 

796 

797 def postsplit_processor( 

798 split_data: Dict[str, Dict[str, Any]], datapackage_key: str 

799 ) -> Dict[str, Dict[str, Any]]: 

800 if datapackage_key == "multi_sc": 

801 return self._postsplit_multi_single_cell( 

802 split_data=split_data, datapackage_key=datapackage_key 

803 ) 

804 # No postsplit processing needed for image data 

805 return split_data 

806 

807 return self._process_data_case( 

808 data_package, 

809 modality_processors={ 

810 "multi_sc": ( 

811 lambda data: presplit_processor(modality_data=data), 

812 lambda data: postsplit_processor( 

813 split_data=data, datapackage_key="multi_sc" 

814 ), 

815 ), 

816 "img": ( 

817 lambda data: presplit_processor(modality_data=data), 

818 lambda data: postsplit_processor( 

819 split_data=data, datapackage_key="img" 

820 ), 

821 ), 

822 }, 

823 ) 

824 

825 def _process_img_to_img_case( 

826 self, raw_user_data: Optional[DataPackage] = None 

827 ) -> Dict[str, DataPackage]: 

828 """Process IMG_TO_IMG case. 

829 

830 Reads image data for from/to modalities, performs data splitting, 

831 NaN removal, and applies normalization to both from and to image data. 

832 

833 Args: 

834 raw_user_data: Optional DataPackage containing user-provided data. 

835 If provided, the data reading step is skipped. 

836 Returns: 

837 A dictionary containing processed DataPackage objects for each data split. 

838 Raises: 

839 TypeError: If from_key or to_key is None, indicating that translation keys must be specified. 

840 """ 

841 if raw_user_data is None: 

842 imgreader = self.data_readers[DataCase.IMG_TO_IMG] 

843 images, annotation = imgreader.read_data(config=self.config) 

844 

845 data_package = DataPackage(img=images, annotation=annotation) 

846 else: 

847 data_package = raw_user_data 

848 

849 if self.config.requires_paired: 

850 common_ids = data_package.get_common_ids() 

851 

852 images = data_package.img 

853 if images is None: 

854 raise ValueError("Images cannot be None") 

855 data_package.img = { 

856 k: self.filter_imgdata_list(img_list=v, ids=common_ids) 

857 for k, v in images.items() 

858 } 

859 

860 def presplit_processor(modality_data: Dict[str, List]) -> Dict[str, List]: 

861 """Processes img-to-img modality data with normalization for images.""" 

862 print("calling normalize image in _process_ing_to_img_case") 

863 return { 

864 k: self._normalize_image_data(v, k) for k, v in modality_data.items() 

865 } 

866 

867 def postsplit_processor( 

868 split_data: Dict[str, Dict[str, Any]], 

869 ) -> Dict[str, Dict[str, Any]]: 

870 """No postsplit processing needed for image data.""" 

871 return split_data 

872 

873 return self._process_data_case( 

874 data_package, 

875 modality_processors={ 

876 "img": ( 

877 lambda data: presplit_processor( 

878 data, 

879 ), 

880 postsplit_processor, 

881 ), 

882 }, 

883 ) 

884 

885 # This method would be inside your GeneralPreprocessor or a similar class 

886 def _split_data_package( 

887 self, data_package: DataPackage 

888 ) -> Tuple[Dict[str, Optional[Dict[str, Any]]], Dict[str, Any]]: 

889 """Splits a data package into train/validation/test sets. 

890 

891 This method first uses PairedUnpairedSplitter to generate a single, 

892 synchronized set of indices for all modalities. It then uses 

893 DataPackageSplitter to apply these indices to the data. 

894 

895 Args: 

896 data_package: The DataPackage to be split. 

897 

898 Returns: 

899 A tuple containing: 

900 1. A dictionary of the split DataPackages. 

901 2. A dictionary of the synchronized integer indices used for the split. 

902 """ 

903 pairing_splitter = PairedUnpairedSplitter( 

904 data_package=data_package, config=self.config 

905 ) 

906 split_indices_config = pairing_splitter.split() 

907 data_package_splitter = DataPackageSplitter( 

908 data_package=data_package, 

909 config=self.config, 

910 indices=split_indices_config, 

911 ) 

912 split_datasets = data_package_splitter.split() 

913 return split_datasets, split_indices_config 

914 

915 def _is_image_data(self, data: Any) -> bool: 

916 """Check if data is image data. 

917 

918 Determines if the provided data is a list of objects that are considered 

919 image data based on having an 'img' attribute. 

920 

921 Args: 

922 data: The data to check. 

923 

924 Returns: 

925 True if the data is image data, False otherwise. 

926 """ 

927 if data is None: 

928 return False 

929 if isinstance(data, list) and hasattr(data[0], "img"): 

930 return True 

931 return False 

932 

933 def _remove_nans(self, data_package: DataPackage) -> DataPackage: 

934 """Remove NaN values from the data package. 

935 

936 Utilizes NaNRemover to identify and remove rows containing NaN values 

937 in relevant annotation columns within the DataPackage. 

938 

939 Args: 

940 data_package: The DataPackage from which to remove NaNs. 

941 

942 Returns: 

943 The DataPackage with NaN values removed. 

944 """ 

945 nanremover = NaNRemover( 

946 config=self.config, 

947 ) 

948 return nanremover.remove_nan(data=data_package) 

949 

950 def _normalize_image_data(self, images: List, info_key: str) -> List: 

951 """Process images with normalization. 

952 

953 Normalizes a list of image data objects using ImageNormalizer based on 

954 the scaling method specified in the configuration for the given info_key. 

955 

956 Args: 

957 images: A list of image data objects (each having an 'img' attribute). 

958 info_key: The key referencing data information in the configuration to get the scaling method. 

959 

960 Returns: 

961 A list of processed image data objects with normalized image data. 

962 """ 

963 

964 scaling_method = self.config.data_config.data_info[info_key].scaling 

965 if scaling_method == "NOTSET": 

966 scaling_method = self.config.scaling 

967 processed_images = [] 

968 normalizer = ImageNormalizer() # Instance created once here 

969 

970 for img in images: 

971 img.img = normalizer.normalize_image( # Modify directly 

972 image=img.img, method=scaling_method 

973 ) 

974 processed_images.append(img) 

975 

976 return processed_images 

977 

978 def _get_translation_keys(self) -> Tuple[Optional[str], Optional[str]]: 

979 """ 

980 Extract from and to keys from config. 

981 

982 Retrieves the 'from' and 'to' modality keys from the data configuration 

983 based on the 'translate_direction' setting. 

984 

985 Returns: 

986 A tuple containing the from_key and to_key as strings, or None if not found. 

987 

988 Raises: 

989 ValueError: If neither 'from' nor 'to' keys are found in the data configuration. 

990 TypeError: If the translate_direction is not set for the data_info. 

991 """ 

992 from_key, to_key = None, None 

993 for k, v in self.config.data_config.data_info.items(): 

994 if v.translate_direction is None: 

995 continue 

996 if v.translate_direction == "from": 

997 from_key = k 

998 if v.translate_direction == "to": 

999 to_key = k 

1000 return from_key, to_key 

1001 

1002 def _get_user_translation_keys(self, raw_user_data: DataPackage): 

1003 if len(raw_user_data.from_modality) == 0: # type: ignore 

1004 return None, None 

1005 elif len(raw_user_data.to_modality) == 0: # type: ignore 

1006 return None, None 

1007 else: 

1008 if raw_user_data.from_modality is None or raw_user_data.to_modality is None: 

1009 raise TypeError( 

1010 "from_modality and to_modality cannot be None for Translation" 

1011 ) 

1012 try: 

1013 return next(iter(raw_user_data.from_modality.keys())), next( 

1014 iter(raw_user_data.to_modality.keys()) 

1015 ) 

1016 except Exception as e: 

1017 print("error getting from or to keys") 

1018 print(e) 

1019 print("returning None") 

1020 return None, None 

1021 

1022 @abstractmethod 

1023 def format_reconstruction( 

1024 self, reconstruction: Dict[str, torch.Tensor], result: Optional[Result] = None 

1025 ) -> DataPackage: 

1026 pass 

1027 

1028 def filter_imgdata_list(self, img_list, ids): 

1029 filtered = [] 

1030 for imgdata in img_list: 

1031 if imgdata.sample_id in ids: 

1032 filtered.append(imgdata) 

1033 return filtered