Coverage for src / autoencodix / base / _base_pipeline.py: 39%

438 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-21 10:09 +0200

1import abc 

2from typing import Dict, Optional, Tuple, Type, Union, Any, Literal, List, Callable 

3import warnings 

4import anndata as ad # type: ignore 

5import numpy as np 

6import pandas as pd 

7import torch 

8from mudata import MuData # type: ignore 

9from torch.utils.data import Dataset 

10 

11# ML evaluation 

12from sklearn import linear_model # type: ignore 

13from sklearn.base import ClassifierMixin, RegressorMixin, is_classifier, is_regressor # type: ignore 

14 

15 

16from autoencodix.data._datasetcontainer import DatasetContainer 

17from autoencodix.data._datasplitter import DataSplitter 

18from autoencodix.data.datapackage import DataPackage 

19 

20# from autoencodix.evaluate.evaluate import Evaluator 

21from autoencodix.base._base_evaluator import BaseEvaluator 

22from autoencodix.utils._result import Result 

23from autoencodix.utils.prompts import PROMPT 

24from autoencodix.utils._utils import Loader, Saver, get_dataset, preprocess_explanations 

25from autoencodix.utils.adata_converter import AnnDataConverter 

26from autoencodix.utils._explainer import FeatureImportanceExplainer 

27from autoencodix.utils._llm_explainer import LLMExplainer 

28from autoencodix.configs.default_config import DataCase, DataInfo, DefaultConfig 

29 

30from ._base_autoencoder import BaseAutoencoder 

31from ._base_dataset import BaseDataset 

32from ._base_loss import BaseLoss 

33from ._base_preprocessor import BasePreprocessor 

34from ._base_trainer import BaseTrainer 

35from ._base_visualizer import BaseVisualizer 

36 

37 

38class BasePipeline(abc.ABC): 

39 """Provides a standardized interface for building model pipelines. 

40 

41 Implements methods for preprocessing data, training models, making predictions, 

42 evaluating performance, and visualizing results. Subclasses customize behavior 

43 by providing specific implementations for processing, training, evaluation, 

44 and visualization. For example when using the Stackix Model, we would use 

45 the StackixPreprocessor Type for preprocessing. 

46 

47 Attributes: 

48 config: Configuration for the pipeline's components and behavior. 

49 preprocessed_data: Pre-split and processed data that can be provided by user. 

50 raw_user_data: Raw input data for processing (DataFrames, MuData, etc.). 

51 result: Storage container for all pipeline outputs. 

52 _preprocessor: Component that filters, scales, and cleans data. 

53 _visualizer: Component that generates visual representations of results. 

54 _dataset_type: Base class for dataset implementations. 

55 _trainer_type: Base class for trainer implementations. 

56 _model_type: Base class for model architecture implementations. 

57 _loss_type: Base class for loss function implementations. 

58 _datasets: Split datasets after preprocessing. 

59 _evaluator: Component that assesses model performance. Not implemented yet 

60 _data_splitter: Component that divides data into train/validation/test sets. 

61 _ontologies: Tuple of dictionaries containing the ontologies to be used to construct sparse decoder layers. 

62 If a list is provided, it is assumed to be a list of file paths to ontology files. 

63 First item in list or tuple will be treated as first layer (after latent space) and so on. 

64 

65 """ 

66 

67 def __init__( 

68 self, 

69 dataset_type: Type[BaseDataset], 

70 trainer_type: Type[BaseTrainer], 

71 model_type: Type[BaseAutoencoder], 

72 loss_type: Type[BaseLoss], 

73 datasplitter_type: Type[DataSplitter], 

74 preprocessor_type: Type[BasePreprocessor], 

75 data: Optional[ 

76 Union[DataPackage, DatasetContainer, ad.AnnData, MuData, pd.DataFrame, dict] # type: ignore[invalid-type-form] 

77 ], 

78 visualizer: Optional[Type[BaseVisualizer]] = None, 

79 evaluator: Optional[Type[BaseEvaluator]] = None, 

80 result: Optional[Result] = None, 

81 config: Optional[DefaultConfig] = None, 

82 custom_split: Optional[Dict[str, np.ndarray]] = None, 

83 ontologies: Optional[Union[Tuple, Dict[Any, Any]]] = None, 

84 masking_fn: Optional[Callable] = None, 

85 masking_fn_kwargs: Dict[str, Any] = {}, 

86 **kwargs: dict, 

87 ) -> None: # ty: ignore[call-non-callable] 

88 """Initializes the pipeline with components and configuration. 

89 

90 Args: 

91 dataset_type: Class for dataset implementations. 

92 trainer_type: Class for model training implementations. 

93 model_type: Class for model architecture implementations. 

94 loss_type: Class for loss function implementations. 

95 datasplitter_type: Class for data splitting implementation. 

96 preprocessor_type: Class for data preprocessing implementation. 

97 visualizer: Component for generating visualizations. 

98 data: Input data to be processed or already processed data. 

99 evaluator: Component for assessing model performance. 

100 result: Storage container for pipeline outputs. 

101 config: Configuration parameters for all pipeline components. 

102 custom_split: User-provided data splits (train/validation/test). 

103 **kwargs: Additional keyword arguments. 

104 

105 Raises: 

106 TypeError: If inputs have incorrect types. 

107 """ 

108 if not hasattr(self, "_default_config"): 

109 raise ValueError(""" 

110 The _default_config attribute has not been specified in your pipeline class. 

111 

112 Example: 

113 self._default_config = XModalixConfig() 

114 

115 This error typically occurs when a new architecture is added without setting the 

116 _default_config in its corresponding pipeline class. 

117 

118 For more details, please refer to the 'how to add a new architecture' section in our documentation. 

119 """) 

120 self.model_map = kwargs.pop("model_map", None) 

121 self._validate_config(config=config) 

122 self._validate_user_input(data=data) 

123 self.masking_fn = masking_fn 

124 self.masking_fn_kwargs = masking_fn_kwargs 

125 processed_data = data if isinstance(data, DatasetContainer) else None 

126 raw_user_data = ( 

127 data 

128 if isinstance(data, (DataPackage, ad.AnnData, MuData, pd.DataFrame, dict)) 

129 else None 

130 ) 

131 if processed_data is not None and not isinstance( 

132 processed_data, DatasetContainer 

133 ): 

134 raise TypeError( 

135 f"Expected data type to be DatasetContainer, got {type(processed_data)}." 

136 ) 

137 

138 self.preprocessed_data: Optional[DatasetContainer] = processed_data 

139 self.raw_user_data: Union[ 

140 DataPackage, ad.AnnData, MuData, pd.DataFrame, dict # type: ignore[invalid-type-form] 

141 ] = raw_user_data 

142 self._trainer_type = trainer_type 

143 self._trainer: Optional[BaseTrainer] = None 

144 self._model_type = model_type 

145 self._loss_type = loss_type 

146 self._preprocessor_type = preprocessor_type 

147 if self.raw_user_data is not None: 

148 self.raw_user_data, datacase = self._handle_direct_user_data( 

149 data=self.raw_user_data, 

150 ) 

151 self.config.data_case = datacase 

152 self._fill_data_info() 

153 

154 self.ontologies = ontologies 

155 self._preprocessor = self._preprocessor_type( 

156 config=self.config, ontologies=self.ontologies 

157 ) 

158 

159 self.visualizer = ( 

160 visualizer() # ty: ignore[call-non-callable] 

161 if visualizer is not None 

162 else BaseVisualizer() # ty: ignore[call-non-callable] 

163 ) # ty: ignore[call-non-callable] 

164 self.evaluator = ( 

165 evaluator() # ty: ignore[call-non-callable] 

166 if evaluator is not None 

167 else BaseEvaluator() # ty: ignore[call-non-callable] 

168 ) # ty: ignore[call-non-callable] 

169 self.result = result if result is not None else Result() 

170 self._dataset_type = dataset_type 

171 self._data_splitter = datasplitter_type( 

172 config=self.config, custom_splits=custom_split 

173 ) 

174 

175 self._datasets: Optional[DatasetContainer] = ( 

176 processed_data # None, or user input 

177 ) 

178 

179 def _validate_config(self, config: Any) -> None: 

180 """Sets config to default if None, or validates its type. 

181 Args: 

182 config: Configuration object to validate or set to default. 

183 Raises: 

184 TypeError: If config is not of type DefaultConfig 

185 """ 

186 if config is None: 

187 self.config = self._default_config # type: ignore 

188 else: 

189 if not isinstance(config, DefaultConfig): 

190 raise TypeError( 

191 f"Expected config type to be DefaultConfig, got {type(config)}." 

192 ) 

193 if not isinstance(config, type(self._default_config)): # type: ignore 

194 warnings.warn( 

195 f"Your config is of type: {type(config)}, for this pipeline the default params of: {type(self._default_config)} work best" 

196 ) 

197 self.config = config 

198 

199 def _validate_user_input(self, data: Any) -> None: 

200 """Ensures that user-provided data is of a valid type. 

201 Args: 

202 data: User-provided data to validate. 

203 Raises: 

204 TypeError: If data is not of a supported type. 

205 """ 

206 if not isinstance( 

207 data, 

208 ( 

209 DataPackage, 

210 ad.AnnData, 

211 MuData, 

212 pd.DataFrame, 

213 dict, 

214 type(None), 

215 DatasetContainer, 

216 ), 

217 ): 

218 raise TypeError( 

219 f"Expected data type to be one of [DataPackage, AnnData, MuData, " 

220 f"pd.DataFrame, dict, DatasetContainer], got {type(data)}." 

221 ) 

222 

223 def _handle_direct_user_data( 

224 self, 

225 data, 

226 ) -> Tuple[DataPackage, DataCase]: 

227 """Converts raw user data into a standardized DataPackage format. 

228 

229 Args: 

230 data: Raw input data in various formats. 

231 

232 Returns: 

233 DataPackage containing the standardized data 

234 DataCase, muliti_single_cell or multi_bulk, etc. 

235 

236 Raises: 

237 TypeError: If data format is not supported. 

238 ValueError: If data doesn't meet format requirements or data_case 

239 cannot be inferred. 

240 """ 

241 print(f"in handle_direct_user_data with data: {type(data)}") 

242 data_case = self.config.data_case 

243 if isinstance(data, DataPackage): 

244 data_package = data 

245 data_case = self.config.data_case 

246 elif isinstance(data, ad.AnnData): 

247 mudata = MuData({"user-data": data}) 

248 data_package = DataPackage(multi_sc={"multi_sc": mudata}) 

249 if self.config.data_case is None: 

250 data_case = DataCase.MULTI_SINGLE_CELL 

251 elif isinstance(data, MuData): 

252 data_package = DataPackage(multi_sc={"multi_sc": data}) 

253 if self.config.data_case is None: 

254 data_case = DataCase.MULTI_SINGLE_CELL 

255 elif isinstance(data, pd.DataFrame): 

256 data_package = DataPackage(multi_bulk={"user-data": data}) 

257 if self.config.data_case is None: 

258 data_case = DataCase.MULTI_BULK 

259 elif isinstance(data, dict): 

260 # Check if all values in the dictionary are pandas DataFrames 

261 if all(isinstance(value, pd.DataFrame) for value in data.values()): 

262 data_package = DataPackage(multi_bulk=data) 

263 if self.config.data_case is None: 

264 data_case = DataCase.MULTI_BULK 

265 else: 

266 raise ValueError( 

267 "All values in the dictionary must be pandas DataFrames." 

268 ) 

269 if data_case is None: 

270 raise ValueError("data_case must be provided if it cannot be inferred.") 

271 

272 return data_package, data_case 

273 

274 def _validate_raw_user_data(self) -> None: 

275 """Validates the format and content of user-provided raw data. 

276 

277 Ensures that raw_user_data is a valid DataPackage with properly formatted 

278 attributes. 

279 

280 Raises: 

281 TypeError: If raw_user_data is not a DataPackage. 

282 ValueError: If DataPackage attributes aren't dictionaries or all are None. 

283 """ 

284 if not isinstance(self.raw_user_data, DataPackage): 

285 raise TypeError( 

286 f"Expected raw_user_data to be of type DataPackage, got " 

287 f"{type(self.raw_user_data)}." 

288 ) 

289 

290 all_none = True 

291 for attr_name in self.raw_user_data.__annotations__: 

292 attr_value = getattr(self.raw_user_data, attr_name) 

293 if attr_value is not None: 

294 all_none = False 

295 if not isinstance(attr_value, dict): 

296 raise ValueError( 

297 f"Attribute '{attr_name}' of raw_user_data must be a dictionary, " 

298 f"got {type(attr_value)}." 

299 ) 

300 

301 if all_none: 

302 raise ValueError( 

303 "All attributes of raw_user_data are None. At least one must be non-None." 

304 ) 

305 

306 def _fill_data_info(self) -> None: 

307 """Populates the config's data_info with entries for all data keys. 

308 

309 Creates DataInfo objects for each data key found in raw_user_data 

310 if they don't already exist in the configuration. 

311 This method is needed, when the user provides data via the Pipeline and 

312 not via the config. 

313 """ 

314 all_keys = [] 

315 for k in self.raw_user_data.__annotations__: 

316 attr_value = getattr(self.raw_user_data, k) 

317 all_keys.append(k) 

318 if isinstance(attr_value, dict): 

319 all_keys.extend(attr_value.keys()) 

320 for k, v in attr_value.items(): 

321 if isinstance(v, MuData): 

322 all_keys.extend(v.mod.keys()) 

323 for k in all_keys: 

324 if self.config.data_config.data_info.get(k) is None: 

325 self.config.data_config.data_info[k] = DataInfo() 

326 

327 def _validate_user_data(self): 

328 """Validates user-provided data based on its source and format. 

329 

330 Performs different validation based on whether the user provided 

331 preprocessed data, raw data, or a data configuration. 

332 

333 Raises: 

334 Various exceptions depending on validation results. 

335 """ 

336 if self.raw_user_data is None: 

337 if self._datasets is not None: # case when user passes preprocessed data 

338 self._validate_container() 

339 else: # user passes data via config 

340 self._validate_config_data() 

341 else: 

342 self._validate_raw_user_data() 

343 

344 def _validate_container(self): 

345 """Validates that a DatasetContainer has at least one valid dataset. 

346 

347 Ensures the container has properly formatted datasets and at least 

348 one split is present. 

349 

350 Raises: 

351 ValueError: If container validation fails. 

352 """ 

353 if self.preprocessed_data is None: 

354 raise ValueError("DatasetContainer is None. Please provide valid datasets.") 

355 none_count = 0 

356 if not isinstance(self.preprocessed_data.train, Dataset): 

357 if self.preprocessed_data.train is not None: 

358 raise ValueError( 

359 f"Train dataset has to be either None or Dataset, got " 

360 f"{type(self.preprocessed_data.train)}" 

361 ) 

362 none_count += 1 

363 if not isinstance(self.preprocessed_data.test, Dataset): 

364 if self.preprocessed_data.test is not None: 

365 raise ValueError( 

366 f"Test dataset has to be either None or Dataset, got " 

367 f"{type(self.preprocessed_data.test)}" 

368 ) 

369 none_count += 1 

370 

371 if not isinstance(self.preprocessed_data.valid, Dataset): 

372 if self.preprocessed_data.valid is not None: 

373 raise ValueError( 

374 f"Valid dataset has to be either None or Dataset, got " 

375 f"{type(self.preprocessed_data.valid)}" 

376 ) 

377 none_count += 1 

378 if none_count == 3: 

379 raise ValueError("At least one split needs to be provided") 

380 

381 def _validate_config_data(self): 

382 """Validates the data configuration provided via config. 

383 

384 Ensures the data configuration has the necessary components based 

385 on the data types being processed. 

386 

387 Raises: 

388 ValueError: If data configuration validation fails. 

389 """ 

390 data_info_dict = self.config.data_config.data_info 

391 if not data_info_dict: 

392 raise ValueError("data_info dictionary is empty.") 

393 

394 # Check if there's at least one non-annotation file 

395 non_annotation_files = { 

396 key: info 

397 for key, info in data_info_dict.items() 

398 if info.data_type != "ANNOTATION" 

399 } 

400 

401 if not non_annotation_files: 

402 raise ValueError("At least one non-annotation file must be provided.") 

403 

404 # Check if there's any non-single-cell data 

405 non_single_cell_data = { 

406 key: info 

407 for key, info in data_info_dict.items() 

408 if not info.is_single_cell and info.data_type != "ANNOTATION" 

409 } 

410 

411 # If there's non-single-cell data, check for annotation file 

412 if non_single_cell_data: 

413 annotation_files = { 

414 key: info 

415 for key, info in data_info_dict.items() 

416 if info.data_type == "ANNOTATION" 

417 } 

418 

419 if not annotation_files: 

420 raise ValueError( 

421 "When working with non-single-cell data, an annotation file must be " 

422 "provided." 

423 ) 

424 

425 def preprocess(self, config: Optional[Union[None, DefaultConfig]] = None, **kwargs): 

426 """Filters, normalizes and prepares data for model training. 

427 

428 Processes raw input data into the format required by the model and creates 

429 train/validation/test splits as needed. 

430 

431 Args: 

432 config: Optional custom configuration for preprocessing. 

433 **kwargs: Additional configuration parameters as keyword arguments. 

434 

435 Raises: 

436 NotImplementedError: If preprocessor is not initialized. 

437 """ 

438 if self._preprocessor_type is None: 

439 raise NotImplementedError("Preprocessor not initialized") 

440 self._validate_user_data() 

441 if self.preprocessed_data is None: 

442 self.preprocessed_data = self._preprocessor.preprocess( 

443 raw_user_data=self.raw_user_data, # type: ignore 

444 ) 

445 self.result.datasets = self.preprocessed_data 

446 self._datasets = self.preprocessed_data 

447 else: 

448 self._datasets = self.preprocessed_data 

449 self.result.datasets = self.preprocessed_data 

450 

451 def fit(self, config: Optional[Union[None, DefaultConfig]] = None, **kwargs): 

452 """Trains the model on preprocessed data. 

453 

454 Creates and configures a trainer instance, then executes the training 

455 process using the preprocessed datasets. 

456 

457 Args: 

458 config: Optional custom configuration for training. 

459 **kwargs: Additional configuration parameters as keyword arguments. 

460 

461 Raises: 

462 ValueError: If datasets aren't available for training. 

463 """ 

464 if self._datasets is None: 

465 raise ValueError( 

466 "Datasets not built. Please run the preprocess method first." 

467 ) 

468 

469 self._trainer = self._trainer_type( 

470 trainset=self._datasets.train, 

471 validset=self._datasets.valid, 

472 result=self.result, 

473 config=self.config, 

474 model_type=self._model_type, 

475 loss_type=self._loss_type, 

476 ontologies=self.ontologies, # Ontix 

477 masking_fn=self.masking_fn if hasattr(self, "masking_fn") else None, 

478 masking_fn_kwargs=( 

479 self.masking_fn_kwargs if hasattr(self, "masking_fn_kwargs") else None 

480 ), 

481 model_map=self.model_map, 

482 ) 

483 

484 trainer_result: Result = self._trainer.train() 

485 self.result.update(other=trainer_result) 

486 

487 def predict( 

488 self, 

489 data: Optional[ 

490 Union[ 

491 DataPackage, 

492 DatasetContainer, 

493 ad.AnnData, 

494 MuData, # ty: ignore[invalid-type-form] 

495 ] # ty: ignore[invalid-type-form] 

496 ] = None, # ty: ignore[invalid-type-form] 

497 config: Optional[Union[None, DefaultConfig]] = None, 

498 from_key: Optional[str] = None, 

499 to_key: Optional[str] = None, 

500 **kwargs, 

501 ): 

502 """Generates predictions using the trained model. 

503 

504 Uses the trained model to make predictions on test data or new data 

505 provided by the user. Processes the results and stores them in the 

506 result container. 

507 

508 Args: 

509 data: Optional new data for predictions. 

510 config: Optional custom configuration for prediction. 

511 **kwargs: Additional configuration parameters as keyword arguments. 

512 

513 Raises: 

514 NotImplementedError: If required components aren't initialized. 

515 ValueError: If no test data is available or data format is invalid. 

516 """ 

517 self._validate_prediction_requirements() 

518 if self._trainer is None: 

519 raise ValueError( 

520 "Trainer not initialized, call fit first. If you used .save and .load, then you shoul not call .fit, then this is a bug." 

521 "In this case please submit an issue." 

522 ) 

523 

524 self._trainer.setup_trainer(old_model=self.result.model) 

525 original_input = data 

526 predict_data = self._prepare_prediction_data(data=data) 

527 

528 predictor_results = self._generate_predictions( 

529 predict_data=predict_data, 

530 ) 

531 

532 self._process_latent_results( 

533 predictor_results=predictor_results, predict_data=predict_data 

534 ) 

535 self._postprocess_reconstruction( 

536 predictor_results=predictor_results, 

537 original_input=original_input, 

538 predict_data=predict_data, 

539 ) 

540 self.result.update(predictor_results) 

541 return self.result 

542 

543 def _validate_prediction_requirements(self): 

544 """Validate that required components are initialized.""" 

545 if self._preprocessor is None: 

546 raise NotImplementedError("Preprocessor not initialized") 

547 if self.result.model is None: 

548 raise NotImplementedError( 

549 "Model not trained. Please run the fit method first" 

550 ) 

551 

552 def _prepare_prediction_data( 

553 self, 

554 data: Optional[ 

555 Union[ 

556 DataPackage, 

557 DatasetContainer, 

558 ad.AnnData, 

559 MuData, # ty: ignore[invalid-type-form] 

560 ] # ty: ignore[invalid-type-form] 

561 ] = None, # ty: ignore[invalid-type-form] 

562 ) -> DatasetContainer: 

563 """Prepare and validate input data for prediction. 

564 Args: 

565 data: Optional new data for predictions. If None, uses existing datasets. 

566 Returns: 

567 DatasetContainer: The prepared dataset container for predictions. 

568 Raises: 

569 ValueError: If data type is unsupported or no test data is available. 

570 """ 

571 if data is None: 

572 return self._get_existing_datasets() 

573 elif isinstance(data, DatasetContainer): 

574 return self._handle_dataset_container(data=data) 

575 elif isinstance(data, (DataPackage, ad.AnnData, MuData, dict, pd.DataFrame)): 

576 return self._handle_user_data(data=data) 

577 else: 

578 raise ValueError(f"Unsupported data type: {type(data)}") 

579 

580 def _get_existing_datasets(self) -> DatasetContainer: 

581 """Get existing preprocessed datasets and validate them for prediction. 

582 Returns: 

583 DatasetContainer: The preprocessed datasets available for prediction. 

584 Raises: 

585 ValueError: If no datasets are available or no test data is present. 

586 """ 

587 if self._datasets is None: 

588 raise ValueError( 

589 "No data provided for prediction and no preprocessed datasets " 

590 "available. Please run the preprocess method first or provide " 

591 "data for prediction." 

592 ) 

593 if self._datasets.test is None: 

594 raise ValueError("No test data available for prediction") 

595 return self._datasets 

596 

597 def _handle_dataset_container(self, data: DatasetContainer) -> DatasetContainer: 

598 """Handle DatasetContainer input for prediction. 

599 Args: 

600 data: DatasetContainer containing preprocessed datasets. 

601 Returns: 

602 DatasetContainer: The processed dataset container for predictions. 

603 """ 

604 self.result.new_datasets = data 

605 

606 if hasattr(self._preprocessor, "_dataset_container"): 

607 self._preprocessor._dataset_container = data 

608 

609 return data 

610 

611 def _handle_user_data(self, data: Any) -> DatasetContainer: 

612 """Handle user-provided data (DataPackage, AnnData, etc.). 

613 Args: 

614 data: Raw user data in various formats (DataPackage, AnnData, etc.). 

615 Returns: 

616 DatasetContainer: The processed dataset container for predictions. 

617 Raises: 

618 ValueError: If data type is unsupported or no test data is available. 

619 """ 

620 processed_data, _ = self._handle_direct_user_data(data=data) 

621 predict_data = self._preprocessor.preprocess( 

622 raw_user_data=processed_data, predict_new_data=True 

623 ) 

624 self.result.new_datasets = predict_data 

625 return predict_data 

626 

627 def _validate_prediction_data(self, predict_data: DatasetContainer): 

628 """Validate that prediction data has required test split.""" 

629 if predict_data.test is None: 

630 raise ValueError( 

631 f"The data for prediction need to be a DatasetContainer with a test " 

632 f"attribute, got: {predict_data}" 

633 ) 

634 

635 def _generate_predictions( 

636 self, 

637 predict_data: DatasetContainer, 

638 ): 

639 """Generate predictions using the trained model. 

640 Args: 

641 predict_data: DatasetContainer with preprocessed datasets for prediction. 

642 Returns: 

643 Predictor results containing latent spaces and reconstructions. 

644 

645 """ 

646 self._validate_prediction_data(predict_data=predict_data) 

647 return self._trainer.predict( 

648 data=predict_data.test, 

649 model=self.result.model, 

650 ) # type: ignore 

651 

652 def _process_latent_results( 

653 self, predictor_results, predict_data: DatasetContainer 

654 ): 

655 """Process and store latent space results. 

656 Args: 

657 predictor_results: Results from the prediction step containing latents. 

658 predict_data: DatasetContainer with preprocessed datasets for prediction. 

659 """ 

660 latent = predictor_results.latentspaces.get(epoch=-1, split="test") 

661 if isinstance(latent, dict): 

662 print("Detected dictionary in latent results, extracting array...") 

663 latent = next(iter(latent.values())) # TODO better adjust for xmodal 

664 self.result.adata_latent = ad.AnnData(latent) 

665 self.result.adata_latent.obs_names = predict_data.test.sample_ids # type: ignore 

666 self.result.adata_latent.uns["var_names"] = predict_data.test.feature_ids # type: ignore 

667 self.result.update(predictor_results) 

668 

669 def _postprocess_reconstruction( 

670 self, predictor_results, original_input, predict_data: DatasetContainer 

671 ): 

672 """Postprocess reconstruction results based on input type. 

673 

674 This outpus the reconstruction in the same format as the original input data, 

675 whether it is a DatasetContainer, DataPackage, AnnData, MuData, or other formats. 

676 

677 Args: 

678 predictor_results: Results from the prediction step containing reconstructions. 

679 original_input: Original input data format (if provided). 

680 predict_data: DatasetContainer with preprocessed datasets for prediction. 

681 Raises: 

682 ValueError: If reconstruction fails or data types are incompatible. 

683 """ 

684 raw_recon: Union[Dict, np.ndarray, torch.Tensor] = ( 

685 self.result.reconstructions.get(epoch=-1, split="test") 

686 ) 

687 if isinstance(raw_recon, np.ndarray): 

688 raw_recon = torch.from_numpy(raw_recon) # type: ignore 

689 elif isinstance(raw_recon, dict): 

690 raw_recon = raw_recon.get("translation") # type: ignore 

691 if raw_recon is None: 

692 raise ValueError( 

693 f"Raw recon is dict, but has no translation key, this should not happen: {raw_recon}" 

694 ) 

695 raw_recon = torch.from_numpy(raw_recon) # type: ignore 

696 else: 

697 raise ValueError( 

698 f"type of raw_recon has to be 'dict' or 'np.ndarray', got: {type(raw_recon)}" 

699 ) 

700 

701 if original_input is None: 

702 # Using existing datasets 

703 self._handle_dataset_container_reconstruction( 

704 raw_recon=raw_recon, # type: ignore 

705 dataset_container=predict_data, 

706 context="existing datasets", 

707 ) 

708 elif isinstance(original_input, DatasetContainer): 

709 self._handle_dataset_container_reconstruction( 

710 raw_recon=raw_recon, # type: ignore 

711 dataset_container=original_input, 

712 context="provided DatasetContainer", 

713 ) 

714 elif self.config.data_case == DataCase.MULTI_SINGLE_CELL: 

715 self._handle_multi_single_cell_reconstruction( 

716 raw_recon=raw_recon, 

717 predictor_results=predictor_results, # type: ignore 

718 ) 

719 elif isinstance( 

720 original_input, (DataPackage, ad.AnnData, MuData, dict, pd.DataFrame) 

721 ): 

722 self._handle_user_data_reconstruction( 

723 raw_recon=raw_recon, predictor_results=predictor_results 

724 ) 

725 else: 

726 self._handle_unsupported_reconstruction() 

727 

728 def _handle_dataset_container_reconstruction( 

729 self, 

730 raw_recon: torch.Tensor, 

731 dataset_container: DatasetContainer, 

732 context: str = "DatasetContainer", 

733 ): 

734 """Handle reconstruction for DatasetContainer input. 

735 Args: 

736 raw_recon: Raw reconstruction tensor from the model. 

737 dataset_container: Original DatasetContainer provided by the user. 

738 context: Description of the data context for error messages. 

739 Raises: 

740 ValueError: If no test data is available in the container. 

741 """ 

742 

743 # if dataset_container.test is None: 

744 # raise ValueError(f"No test data available in {context} for reconstruction.") 

745 # temp = copy.deepcopy(dataset_container.test) 

746 # temp.data = raw_recon 

747 # self.result.final_reconstruction = temp 

748 pass 

749 

750 def _handle_multi_single_cell_reconstruction( 

751 self, raw_recon: torch.Tensor, predictor_results: Result 

752 ): 

753 """Handle reconstruction for multi-single-cell data 

754 Args: 

755 raw_recon: Raw reconstruction tensor from the model. 

756 predictor_results: Results from the prediction step containing reconstructions. 

757 Raises: 

758 ValueError: If reconstruction formatting fails or data types are incompatible. 

759 """ 

760 pkg = self._preprocessor.format_reconstruction( 

761 reconstruction=raw_recon, result=predictor_results 

762 ) 

763 if not isinstance(pkg.multi_sc, dict): 

764 raise ValueError( 

765 "Expected pkg.multi_sc to be a dictionary, got " 

766 f"{type(pkg.multi_sc)} instead." 

767 ) 

768 self.result.final_reconstruction = pkg.multi_sc["multi_sc"] 

769 

770 def _handle_user_data_reconstruction( 

771 self, raw_recon: torch.Tensor, predictor_results 

772 ): 

773 """Handle reconstruction for user-provided data formats. 

774 Args: 

775 raw_recon: Raw reconstruction tensor from the model. 

776 predictor_results: Results from the prediction step containing reconstructions. 

777 """ 

778 pkg = self._preprocessor.format_reconstruction( 

779 reconstruction=raw_recon, result=predictor_results 

780 ) 

781 self.result.final_reconstruction = pkg 

782 

783 def _handle_unsupported_reconstruction(self): 

784 """Handle cases where reconstruction formatting is not available.""" 

785 print( 

786 "Reconstruction Formatting (the process of using the reconstruction " 

787 "output of the autoencoder models and combine it with metadata to get " 

788 "the exact same data structure as the raw input data i.e, a DataPackage, " 

789 "DatasetContainer, or AnnData) not available for this data type or case." 

790 ) 

791 

792 def decode( 

793 self, latent: Union[torch.Tensor, ad.AnnData, pd.DataFrame] 

794 ) -> Union[torch.Tensor, ad.AnnData, pd.DataFrame]: 

795 """Transforms latent space representations back to input space. 

796 

797 Handles various input formats for the latent representation and 

798 returns the decoded data in a matching format. 

799 

800 Args: 

801 latent: Latent space representation to decode. 

802 

803 Returns: 

804 Decoded data in a format matching the input. 

805 

806 Raises: 

807 TypeError: If no model has been trained or input type is invalid. 

808 ValueError: If latent dimensions are incompatible with the model. 

809 """ 

810 if self.result.model is None: 

811 raise TypeError("No model trained yet, use fit() or run() method first") 

812 recons: torch.Tensor 

813 if isinstance(latent, ad.AnnData): 

814 latent_data = torch.tensor( 

815 latent.X, dtype=torch.float32 

816 ) # Ensure float for compatibility 

817 

818 expected_latent_dim = self.config.latent_dim 

819 if not latent_data.shape[1] == expected_latent_dim: 

820 raise ValueError( 

821 f"Input AnnData's .X has shape {latent_data.shape}, but the model " 

822 f"expects a latent vector of size {expected_latent_dim}. Consider " 

823 f"projecting the AnnData to the correct latent space first." 

824 ) 

825 latent_tensor = latent_data 

826 

827 recons = self._trainer.decode(x=latent_tensor) 

828 if self._datasets is None: 

829 raise ValueError( 

830 "No datasets available in the DatasetContainer to reconstruct " 

831 "AnnData objects. Please provide a valid DatasetContainer." 

832 ) 

833 if self._datasets.train is None: 

834 raise ValueError( 

835 "The train dataset in the DatasetContainer is None. " 

836 "Please provide a valid train dataset to reconstruct AnnData objects." 

837 ) 

838 if not isinstance(self._datasets.train, BaseDataset): 

839 raise TypeError( 

840 "The train dataset in the DatasetContainer must be a BaseDataset " 

841 "to reconstruct AnnData objects." 

842 ) 

843 recons_adata = ad.AnnData( 

844 X=recons.to("cpu").detach().numpy(), 

845 obs=pd.DataFrame(index=latent.obs_names), 

846 var=pd.DataFrame(index=self._datasets.train.feature_ids), 

847 ) 

848 

849 return recons_adata 

850 elif isinstance(latent, pd.DataFrame): 

851 latent_tensor = torch.tensor(latent.values, dtype=torch.float32) 

852 recons = self._trainer.decode(x=latent_tensor) 

853 return pd.DataFrame( 

854 recons.to("cpu").detach().numpy(), 

855 index=latent.index, 

856 columns=latent.columns, 

857 ) 

858 elif isinstance(latent, torch.Tensor): 

859 # Check size compatibility 

860 expected_latent_dim = self.config.latent_dim 

861 if not latent.shape[1] == expected_latent_dim: 

862 if self._trainer._model._mu.out_features == latent.shape[1]: 

863 warnings.warn( 

864 f"latent_prior has latent dimension {latent.shape[1]}, " 

865 "which matches the input feature dimension of the model. Did you " 

866 "mean to provide latent vectors of dimension " 

867 "For Ontix this is the default behaviour and the warning can be ignored. " 

868 f"{self.config.latent_dim}?" 

869 ) 

870 else: 

871 raise ValueError( 

872 f"latent_prior has incompatible latent dimension {latent.shape[1]}, " 

873 f"expected {self.config.latent_dim}. or {self._trainer._model._mu.out_features}." 

874 ) 

875 

876 latent_tensor = latent 

877 else: 

878 raise TypeError( 

879 f"Input 'latent' must be either a torch.Tensor or an AnnData object, " 

880 f"not {type(latent)}." 

881 ) 

882 

883 return self._trainer.decode(x=latent_tensor) 

884 

885 def evaluate( 

886 self, 

887 ml_model_class: ClassifierMixin = linear_model.LogisticRegression(), 

888 ml_model_regression: RegressorMixin = linear_model.LinearRegression(), 

889 params: Union[ 

890 list, str 

891 ] = [], # Default empty list, to use all parameters use string "all" 

892 metric_class: str = "roc_auc_ovo", # Default is 'roc_auc_ovo' via https://scikit-learn.org/stable/modules/model_evaluation.html#scoring-string-names 

893 metric_regression: str = "r2", # Default is 'r2' 

894 reference_methods: list = [], # Default [], Options are "PCA", "UMAP", "TSNE", "RandomFeature" 

895 reference_reducer: dict = {}, # Option to provide pre-fitted reducer objects for PCA, UMAP or TSNE, e.g. {"PCA": pca_reducer, "UMAP": umap_reducer, "TSNE": tsne_reducer} 

896 split_type: Literal[ 

897 "use-split", "CV-5", "LOOC" 

898 ] = "use-split", # Default is "use-split", other options: "CV-5", ... "LOOCV"? 

899 n_downsample: Optional[int] = 10000, 

900 top_k_classes: Optional[int] = 20, 

901 exclude_classes: Union[ 

902 list, None 

903 ] = None, # Default is None, if provided exclude these classes from evaluation 

904 ) -> Result: 

905 """TODO""" 

906 if self.evaluator is None: 

907 raise NotImplementedError("Evaluator not initialized") 

908 if self.result.model is None: 

909 raise NotImplementedError( 

910 "Model not trained. Please run the fit method first" 

911 ) 

912 if not is_classifier(ml_model_class): 

913 warnings.warn( 

914 "The provided model is not a sklearn-type classifier. " 

915 "Evaluation continues but may produce incorrect results or errors." 

916 ) 

917 if not is_regressor(ml_model_regression): 

918 warnings.warn( 

919 "The provided model is not a sklearn-type regressor. " 

920 "Evaluation continues but may produce incorrect results or errors." 

921 ) 

922 

923 if len(params) == 0: 

924 if self.config.data_config.annotation_columns is None: 

925 params = [] # type: ignore 

926 else: 

927 params = self.config.data_config.annotation_columns # type: ignore 

928 

929 if len(params) == 0: 

930 raise ValueError( 

931 "No parameters specified for evaluation. Please provide a list of " 

932 "parameters or ensure that annotation_columns are set in the config." 

933 ) 

934 

935 if "RandomFeature" in reference_methods: 

936 if self._datasets is None: 

937 raise ValueError( 

938 "Datasets not available for adding RandomFeature. Please keep " 

939 "preprocessed data available before evaluation." 

940 ) 

941 

942 if len(self.result.latentspaces._data) == 0: 

943 raise ValueError( 

944 "No latent spaces found in results. Please run predict() to " 

945 "calculate embeddings before evaluation." 

946 ) 

947 

948 self.result = self.evaluator.evaluate( 

949 datasets=self._datasets, 

950 result=self.result, 

951 ml_model_class=ml_model_class, 

952 ml_model_regression=ml_model_regression, 

953 params=params, 

954 metric_class=metric_class, 

955 metric_regression=metric_regression, 

956 reference_methods=reference_methods, 

957 reference_reducer=reference_reducer, 

958 split_type=split_type, 

959 n_downsample=n_downsample, 

960 top_k_classes=top_k_classes, 

961 exclude_classes=exclude_classes, 

962 ) 

963 

964 _: Any = self.visualizer._plot_evaluation(result=self.result) 

965 

966 return self.result 

967 

968 def visualize(self, config: Optional[Union[None, DefaultConfig]] = None, **kwargs): 

969 """Creates visualizations of model results and performance. 

970 

971 Args: 

972 config: Optional custom configuration for visualization. 

973 **kwargs: Additional configuration parameters. 

974 

975 Raises: 

976 NotImplementedError: If visualizer is not initialized. 

977 """ 

978 if self.visualizer is None: 

979 raise NotImplementedError("Visualizer not initialized") 

980 

981 self.visualizer.visualize(result=self.result, config=self.config) 

982 

983 def show_result(self, split: str = "all", **kwargs): 

984 """Displays key visualizations of model results. 

985 

986 This method generates the following visualizations: 

987 1. Loss Curves: Displays the absolute loss curves to provide insights into 

988 the model's training and validation performance over epochs. 

989 2. Latent Space Ridgeline Plot: Visualizes the distribution of the latent 

990 space representations across different dimensions, offering a high-level 

991 overview of the learned embeddings. 

992 3. Latent Space 2D Scatter Plot: Projects the latent space into two dimensions 

993 for a detailed view of the clustering or separation of data points. 

994 

995 These visualizations help in understanding the model's performance and 

996 the structure of the latent space representations. 

997 """ 

998 print("Creating plots ...") 

999 

1000 params: Optional[Union[List[str], str]] = kwargs.pop("params", None) 

1001 # Check if params are empty and annotation columns are available in config 

1002 if params is None and self.config.data_config.annotation_columns: 

1003 params = self.config.data_config.annotation_columns 

1004 

1005 if len(self.result.losses._data) != 0: 

1006 self.visualizer.show_loss(plot_type="absolute") 

1007 else: 

1008 warnings.warn( 

1009 "No loss data found in results. Skipping loss curve visualization." 

1010 ) 

1011 

1012 if len(self.result.latentspaces._data) != 0: 

1013 self.visualizer.show_latent_space( 

1014 result=self.result, plot_type="Ridgeline", split=split, param=params 

1015 ) 

1016 self.visualizer.show_latent_space( 

1017 result=self.result, plot_type="2D-scatter", split=split, param=params 

1018 ) 

1019 else: 

1020 warnings.warn( 

1021 "No latent spaces found in results. Please run predict() to " 

1022 "calculate embeddings." 

1023 ) 

1024 

1025 def run( 

1026 self, data: Optional[Union[DatasetContainer, DataPackage]] = None 

1027 ) -> Result: 

1028 """Executes the complete pipeline from preprocessing to visualization. 

1029 

1030 Runs all pipeline steps in sequence and returns the result. 

1031 

1032 Args: 

1033 data: Optional data for prediction (overrides test data). 

1034 

1035 Returns: 

1036 Complete pipeline results. 

1037 """ 

1038 self.preprocess() 

1039 self.fit() 

1040 self.predict(data=data) 

1041 self.visualize() 

1042 return self.result 

1043 

1044 def save(self, file_path: str, save_all: bool = False): 

1045 """Saves the pipeline to a file. 

1046 

1047 Args: 

1048 file_path: Path where the pipeline should be saved. 

1049 """ 

1050 saver = Saver(file_path, save_all=save_all) 

1051 saver.save(self) 

1052 

1053 @classmethod 

1054 def load(cls, file_path) -> Any: 

1055 """Loads a pipeline from a file. 

1056 

1057 Args: 

1058 file_path: Path to the saved pipeline. 

1059 

1060 Returns: 

1061 The loaded pipeline instance. 

1062 """ 

1063 loader = Loader(file_path) 

1064 return loader.load() 

1065 

1066 def sample_latent_space( 

1067 self, 

1068 n_samples: int, 

1069 split: str = "test", 

1070 epoch: int = -1, 

1071 ) -> torch.Tensor: 

1072 """Samples latent space points from the learned distribution. 

1073 

1074 If `n_samples` is not provided, this method returns one latent point per 

1075 sample in the specified split (legacy behavior). If `n_samples` is given, 

1076 it draws samples from the aggregated posterior distribution of the split. 

1077 

1078 Args: 

1079 split: The split to sample from (train, valid, test), default is test. 

1080 epoch: The epoch to sample from, default is the last epoch (-1). 

1081 n_samples: Optional number of latent points to sample. If None, 

1082 returns one latent point per available sample in the split. 

1083 

1084 Returns: 

1085 z: torch.Tensor - The sampled latent space points. 

1086 

1087 Raises: 

1088 ValueError: If the model has not been trained or latent statistics 

1089 have not been computed. 

1090 TypeError: If mu or logvar are not numpy arrays. 

1091 """ 

1092 

1093 if not hasattr(self, "_trainer") or self._trainer is None: 

1094 raise ValueError("Model is not trained yet. Please train the model first.") 

1095 if self.result.mus is None or self.result.sigmas is None: 

1096 raise ValueError("Model has not learned the latent space distribution yet.") 

1097 if not isinstance(n_samples, int) or n_samples <= 0: 

1098 raise ValueError("n_samples must be a positive integer.") 

1099 

1100 mu = self.result.mus.get(split=split, epoch=epoch) 

1101 logvar = self.result.sigmas.get(split=split, epoch=epoch) 

1102 

1103 if not isinstance(mu, np.ndarray): 

1104 raise TypeError( 

1105 f"Expected value to be of type numpy.ndarray, got {type(mu)}." 

1106 "This can happen if the model was not trained with VAE loss or if you forgot to run predict()" 

1107 ) 

1108 if not isinstance(logvar, np.ndarray): 

1109 raise TypeError( 

1110 f"Expected value to be of type numpy.ndarray, got {type(logvar)}." 

1111 ) 

1112 

1113 mu_t = torch.from_numpy(mu).to( 

1114 device=self._trainer._model.device, dtype=self._trainer._model.dtype 

1115 ) 

1116 logvar_t = torch.from_numpy(logvar).to( 

1117 device=self._trainer._model.device, dtype=self._trainer._model.dtype 

1118 ) 

1119 

1120 with torch.no_grad(): 

1121 global_mu = mu_t.mean(dim=0) 

1122 global_logvar = logvar_t.mean(dim=0) 

1123 

1124 mu_exp = global_mu.expand(n_samples, -1) 

1125 logvar_exp = global_logvar.expand(n_samples, -1) 

1126 

1127 z = self._trainer._model.reparameterize(mu_exp, logvar_exp) 

1128 return z 

1129 

1130 def generate( 

1131 self, 

1132 n_samples: Optional[int] = None, 

1133 latent_prior: Optional[Union[np.ndarray, torch.Tensor]] = None, 

1134 split: str = "test", 

1135 epoch: int = -1, 

1136 ) -> torch.Tensor: 

1137 """Generates new samples from the model's latent space. 

1138 

1139 This method allows for the generation of new data samples by sampling 

1140 from the model's latent space. Users can either provide a custom latent 

1141 prior or specify the number of samples to generate. If a custom latent 

1142 prior is provided, its batch dimension must be compatible with n_samples. 

1143 

1144 Args: 

1145 n_samples: The number of samples to generate. 

1146 latent_prior: Optional custom latent prior distribution. If provided, 

1147 this will be used for sampling instead of the learned distribution. 

1148 The prior must either be a single latent vector or a batch of 

1149 latent vectors matching n_samples. 

1150 split: The split to sample from (train, valid, test), default is test. 

1151 epoch: The epoch to sample from, default is the last epoch (-1). 

1152 

1153 Returns: 

1154 torch.Tensor: The generated samples in the input space. 

1155 

1156 Raises: 

1157 ValueError: If n_samples is not a positive integer or if the latent 

1158 prior has incompatible dimensions. 

1159 TypeError: If latent_prior is not a numpy array or tensor. 

1160 """ 

1161 self._trainer.setup_trainer(old_model=self.result.model) 

1162 

1163 if not isinstance(n_samples, int) or n_samples <= 0: 

1164 if latent_prior is None: 

1165 raise ValueError( 

1166 "n_samples must be a positive integer or latent_prior provided." 

1167 ) 

1168 

1169 if latent_prior is None: 

1170 latent_prior = self.sample_latent_space( 

1171 n_samples=n_samples, split=split, epoch=epoch 

1172 ) 

1173 

1174 if isinstance(latent_prior, np.ndarray): 

1175 latent_prior = torch.from_numpy(latent_prior).to( 

1176 device=self._trainer._model.device, 

1177 dtype=self._trainer._model.dtype, 

1178 ) 

1179 if not isinstance(latent_prior, torch.Tensor): 

1180 raise TypeError( 

1181 f"latent_prior must be numpy.ndarray or torch.Tensor, got {type(latent_prior)}." 

1182 ) 

1183 if not latent_prior.shape[1] == self.config.latent_dim: 

1184 if self._trainer._model._mu.out_features == latent_prior.shape[1]: 

1185 warnings.warn( 

1186 f"latent_prior has latent dimension {latent_prior.shape[1]}, " 

1187 "which matches the input feature dimension of the model. Did you " 

1188 "mean to provide latent vectors of dimension " 

1189 "For Ontix this is the default behaviour and the warning can be ignored. " 

1190 f"{self.config.latent_dim}?" 

1191 ) 

1192 else: 

1193 raise ValueError( 

1194 f"latent_prior has incompatible latent dimension {latent_prior.shape[1]}, " 

1195 f"expected {self.config.latent_dim}." 

1196 ) 

1197 

1198 with torch.no_grad(): 

1199 generated = self.decode(latent=latent_prior) 

1200 return generated 

1201 

1202 def explain( 

1203 self, 

1204 method: Literal["DeepLiftShap", "IntegratedGradients"] = "DeepLiftShap", 

1205 sel_latent_dim: Union[list, int, str, None] = None, 

1206 input_type: Literal["random", "grouped"] = "grouped", 

1207 input_group: Optional[str] = None, 

1208 baseline_type: Literal["mean", "random", "grouped"] = "mean", 

1209 baseline_group: str = "all", 

1210 anno_col: Optional[str] = None, 

1211 n_subset: int = 100, 

1212 seed_int: int = 12, 

1213 split: Literal["train", "test", "valid"] = "train", 

1214 llm_explain: bool = False, 

1215 llm_client: Literal["ollama", "mistral", "openrouter", "scads-llm"] = "mistral", 

1216 llm_model: str = "mistral-large-latest", 

1217 top_n_genes: int = 40, 

1218 prompt: str = PROMPT, 

1219 ) -> pd.DataFrame: 

1220 """Runs the feature-importance explainer and returns gene-by-latent-dimension attribution scores. 

1221 

1222 Args: 

1223 method: Specifies which attribution algorithm to use for explaining the model. 

1224 sel_latent_dim: Specifies which latent dimension(s) to explain. If None, all dimensions will be explained. 

1225 input_type: Specifies whether the feature-importance algorithm should use 'random' samples or 'grouped' samples (see input_group) as input for the attribution computation. 

1226 input_group: If input_type is 'grouped', this specifies which subset in anno_col to filter for to use as input for the attribution computation. 

1227 baseline_type: Specifies whether the feature-importance algorithm should use the 'mean' baseline, a 'random'-sample baseline, or samples from a 'grouped' baseline (see baseline_group). 

1228 baseline_group: Specifies whether the baseline is computed using all data (default) or which subset in anno_col to filter for. 

1229 anno_col: If baseline_group is not 'all', this specifies the annotation column used to filter the baseline data. 

1230 n_subset: Specifies the number of cells to use when subsampling for the attribution computation. 

1231 seed_int: Defines the random seed used for reproducible subsampling and attribution calculations. 

1232 split: The split to use for feature importance calculation (train, valid, test), default is train. 

1233 llm_explain: Whether to use LLM explainers for feature importance calculation. 

1234 llm_client: The LLM client to use for feature importance calculation. 

1235 llm_model: The LLM model to use for feature importance calculation. 

1236 top_n_genes: How many top (contribution to embedding) genes to consider for LLM explanation. 

1237 

1238 Returns: 

1239 pd.DataFrame: The generated samples in the input space. 

1240 A DataFrame of attribution scores with genes as the index and latent dimensions as columns. 

1241 Each entry represents the contribution of a given gene to a specific latent dimension. 

1242 """ 

1243 my_converter = AnnDataConverter() 

1244 dataset: Optional[DatasetContainer] = get_dataset(self.result) 

1245 if dataset is None: 

1246 raise ValueError( 

1247 "No dataset available for explanation." 

1248 "This happens if you used .save and .load, and did not run .predict before." 

1249 "This can also happen if you run .explain before .preprocess or .fit." 

1250 ) 

1251 adata: Optional[Dict[str, ad.AnnData]] = my_converter.dataset_to_adata( 

1252 dataset, split=split 

1253 ) 

1254 if adata is None: 

1255 raise ValueError( 

1256 "There has been an error converting the dataset to AnnData." 

1257 ) 

1258 model = self.result.model 

1259 if model is None: 

1260 raise ValueError( 

1261 "No model available for explanation." 

1262 "This happens if you used .save and .load, and did not run .fit before." 

1263 "This can also happen if you run .explain before .fit." 

1264 ) 

1265 # Validate sel_latent_dim and convert from string to int if necessary 

1266 if sel_latent_dim is not None: 

1267 if isinstance(sel_latent_dim, str): 

1268 if self.ontologies is None: 

1269 raise ValueError( 

1270 "sel_latent_dim is a string, but no ontologies are available in the pipeline. Please provide ontologies in the config or set sel_latent_dim to an int or list of ints." 

1271 ) 

1272 if sel_latent_dim not in self.ontologies[0].keys(): 

1273 raise ValueError( 

1274 f"sel_latent_dim is set to '{sel_latent_dim}', but this is not a key in the available ontologies: {list(self.ontologies[0].keys())}. Please provide a valid ontology key or set sel_latent_dim to an int or list of ints." 

1275 ) 

1276 sel_latent_dim = list(self.ontologies[0].keys()).index(sel_latent_dim) 

1277 elif isinstance(sel_latent_dim, int): 

1278 if sel_latent_dim < 0 or sel_latent_dim >= self.config.latent_dim: 

1279 raise ValueError( 

1280 f"sel_latent_dim is set to {sel_latent_dim}, but this is out of bounds for the latent dimension size of {self.config.latent_dim}. Please provide a valid integer index or set sel_latent_dim to None to explain all dimensions." 

1281 ) 

1282 elif isinstance(sel_latent_dim, list): 

1283 # Check if entries are strings or integers and validate accordingly 

1284 sel_latent_dim_validated = [] 

1285 for dim in sel_latent_dim: 

1286 if isinstance(dim, str): 

1287 if self.ontologies is None: 

1288 raise ValueError( 

1289 "sel_latent_dim is a list of strings, but no ontologies are available in the pipeline. Please provide ontologies in the config or set sel_latent_dim to a list of ints." 

1290 ) 

1291 if dim not in self.ontologies[0].keys(): 

1292 raise ValueError( 

1293 f"sel_latent_dim contains '{dim}', but this is not a key in the available ontologies: {list(self.ontologies[0].keys())}. Please provide valid ontology keys or set sel_latent_dim to a list of ints." 

1294 ) 

1295 sel_latent_dim_validated.append( 

1296 list(self.ontologies[0].keys()).index(dim) 

1297 ) 

1298 elif isinstance(dim, int): 

1299 if dim < 0 or dim >= self.config.latent_dim: 

1300 raise ValueError( 

1301 f"sel_latent_dim contains {dim}, but this is out of bounds for the latent dimension size of {self.config.latent_dim}. Please provide valid integer indices or set sel_latent_dim to None to explain all dimensions." 

1302 ) 

1303 sel_latent_dim_validated.append(dim) 

1304 else: 

1305 raise ValueError( 

1306 f"sel_latent_dim list contains entries of type {type(dim)}, but only str and int are allowed. Please ensure all entries are either strings corresponding to ontology keys or integers corresponding to latent dimension indices." 

1307 ) 

1308 sel_latent_dim = sel_latent_dim_validated 

1309 else: 

1310 raise ValueError( 

1311 f"sel_latent_dim must be either an int, a str, a list of ints or a list of strs, got {type(sel_latent_dim)}. Please provide a valid type for sel_latent_dim." 

1312 ) 

1313 

1314 print("Start feature attribution calculation") 

1315 explainer = FeatureImportanceExplainer( 

1316 adata=adata["global"], 

1317 model=model, 

1318 n_subset=n_subset, 

1319 method=method, 

1320 sel_latent_dim=sel_latent_dim, 

1321 input_type=input_type, 

1322 input_group=input_group, 

1323 baseline_type=baseline_type, 

1324 baseline_group=baseline_group, 

1325 anno_col=anno_col, 

1326 seed_int=seed_int, 

1327 ) 

1328 df_attributions = explainer.explain() 

1329 

1330 if llm_explain: 

1331 gene_attributions: Dict[str, List] = preprocess_explanations( 

1332 df=df_attributions, n=top_n_genes, max_dims=8 

1333 ) 

1334 if len(gene_attributions) > 8: 

1335 warnings.warn( 

1336 "You requested LLM-generated explanations for each latent dimension. " 

1337 "However, more than eight latent dimensions were provided. To avoid excessively long " 

1338 "and computationally intensive output, only the eight most informative latent " 

1339 "dimensions will be analyzed." 

1340 ) 

1341 

1342 llm_explainer = LLMExplainer( 

1343 client_name=llm_client, 

1344 model_name=llm_model, 

1345 genes_to_latent=gene_attributions, # Example gene list 

1346 prompt=prompt, 

1347 ) 

1348 print("Start LLM explanation generation") 

1349 explanation = llm_explainer.explain() 

1350 self.result.embedding_explanations = explanation 

1351 self.result.embedding_attributions = df_attributions 

1352 return df_attributions 

1353 

1354 def impute(self, corrupted_tensor: torch.Tensor): 

1355 raise NotImplementedError("Impute method only implemented for Maskix pipeline.")