Coverage for src / autoencodix / base / _base_pipeline.py: 39%
438 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
1import abc
2from typing import Dict, Optional, Tuple, Type, Union, Any, Literal, List, Callable
3import warnings
4import anndata as ad # type: ignore
5import numpy as np
6import pandas as pd
7import torch
8from mudata import MuData # type: ignore
9from torch.utils.data import Dataset
11# ML evaluation
12from sklearn import linear_model # type: ignore
13from sklearn.base import ClassifierMixin, RegressorMixin, is_classifier, is_regressor # type: ignore
16from autoencodix.data._datasetcontainer import DatasetContainer
17from autoencodix.data._datasplitter import DataSplitter
18from autoencodix.data.datapackage import DataPackage
20# from autoencodix.evaluate.evaluate import Evaluator
21from autoencodix.base._base_evaluator import BaseEvaluator
22from autoencodix.utils._result import Result
23from autoencodix.utils.prompts import PROMPT
24from autoencodix.utils._utils import Loader, Saver, get_dataset, preprocess_explanations
25from autoencodix.utils.adata_converter import AnnDataConverter
26from autoencodix.utils._explainer import FeatureImportanceExplainer
27from autoencodix.utils._llm_explainer import LLMExplainer
28from autoencodix.configs.default_config import DataCase, DataInfo, DefaultConfig
30from ._base_autoencoder import BaseAutoencoder
31from ._base_dataset import BaseDataset
32from ._base_loss import BaseLoss
33from ._base_preprocessor import BasePreprocessor
34from ._base_trainer import BaseTrainer
35from ._base_visualizer import BaseVisualizer
38class BasePipeline(abc.ABC):
39 """Provides a standardized interface for building model pipelines.
41 Implements methods for preprocessing data, training models, making predictions,
42 evaluating performance, and visualizing results. Subclasses customize behavior
43 by providing specific implementations for processing, training, evaluation,
44 and visualization. For example when using the Stackix Model, we would use
45 the StackixPreprocessor Type for preprocessing.
47 Attributes:
48 config: Configuration for the pipeline's components and behavior.
49 preprocessed_data: Pre-split and processed data that can be provided by user.
50 raw_user_data: Raw input data for processing (DataFrames, MuData, etc.).
51 result: Storage container for all pipeline outputs.
52 _preprocessor: Component that filters, scales, and cleans data.
53 _visualizer: Component that generates visual representations of results.
54 _dataset_type: Base class for dataset implementations.
55 _trainer_type: Base class for trainer implementations.
56 _model_type: Base class for model architecture implementations.
57 _loss_type: Base class for loss function implementations.
58 _datasets: Split datasets after preprocessing.
59 _evaluator: Component that assesses model performance. Not implemented yet
60 _data_splitter: Component that divides data into train/validation/test sets.
61 _ontologies: Tuple of dictionaries containing the ontologies to be used to construct sparse decoder layers.
62 If a list is provided, it is assumed to be a list of file paths to ontology files.
63 First item in list or tuple will be treated as first layer (after latent space) and so on.
65 """
67 def __init__(
68 self,
69 dataset_type: Type[BaseDataset],
70 trainer_type: Type[BaseTrainer],
71 model_type: Type[BaseAutoencoder],
72 loss_type: Type[BaseLoss],
73 datasplitter_type: Type[DataSplitter],
74 preprocessor_type: Type[BasePreprocessor],
75 data: Optional[
76 Union[DataPackage, DatasetContainer, ad.AnnData, MuData, pd.DataFrame, dict] # type: ignore[invalid-type-form]
77 ],
78 visualizer: Optional[Type[BaseVisualizer]] = None,
79 evaluator: Optional[Type[BaseEvaluator]] = None,
80 result: Optional[Result] = None,
81 config: Optional[DefaultConfig] = None,
82 custom_split: Optional[Dict[str, np.ndarray]] = None,
83 ontologies: Optional[Union[Tuple, Dict[Any, Any]]] = None,
84 masking_fn: Optional[Callable] = None,
85 masking_fn_kwargs: Dict[str, Any] = {},
86 **kwargs: dict,
87 ) -> None: # ty: ignore[call-non-callable]
88 """Initializes the pipeline with components and configuration.
90 Args:
91 dataset_type: Class for dataset implementations.
92 trainer_type: Class for model training implementations.
93 model_type: Class for model architecture implementations.
94 loss_type: Class for loss function implementations.
95 datasplitter_type: Class for data splitting implementation.
96 preprocessor_type: Class for data preprocessing implementation.
97 visualizer: Component for generating visualizations.
98 data: Input data to be processed or already processed data.
99 evaluator: Component for assessing model performance.
100 result: Storage container for pipeline outputs.
101 config: Configuration parameters for all pipeline components.
102 custom_split: User-provided data splits (train/validation/test).
103 **kwargs: Additional keyword arguments.
105 Raises:
106 TypeError: If inputs have incorrect types.
107 """
108 if not hasattr(self, "_default_config"):
109 raise ValueError("""
110 The _default_config attribute has not been specified in your pipeline class.
112 Example:
113 self._default_config = XModalixConfig()
115 This error typically occurs when a new architecture is added without setting the
116 _default_config in its corresponding pipeline class.
118 For more details, please refer to the 'how to add a new architecture' section in our documentation.
119 """)
120 self.model_map = kwargs.pop("model_map", None)
121 self._validate_config(config=config)
122 self._validate_user_input(data=data)
123 self.masking_fn = masking_fn
124 self.masking_fn_kwargs = masking_fn_kwargs
125 processed_data = data if isinstance(data, DatasetContainer) else None
126 raw_user_data = (
127 data
128 if isinstance(data, (DataPackage, ad.AnnData, MuData, pd.DataFrame, dict))
129 else None
130 )
131 if processed_data is not None and not isinstance(
132 processed_data, DatasetContainer
133 ):
134 raise TypeError(
135 f"Expected data type to be DatasetContainer, got {type(processed_data)}."
136 )
138 self.preprocessed_data: Optional[DatasetContainer] = processed_data
139 self.raw_user_data: Union[
140 DataPackage, ad.AnnData, MuData, pd.DataFrame, dict # type: ignore[invalid-type-form]
141 ] = raw_user_data
142 self._trainer_type = trainer_type
143 self._trainer: Optional[BaseTrainer] = None
144 self._model_type = model_type
145 self._loss_type = loss_type
146 self._preprocessor_type = preprocessor_type
147 if self.raw_user_data is not None:
148 self.raw_user_data, datacase = self._handle_direct_user_data(
149 data=self.raw_user_data,
150 )
151 self.config.data_case = datacase
152 self._fill_data_info()
154 self.ontologies = ontologies
155 self._preprocessor = self._preprocessor_type(
156 config=self.config, ontologies=self.ontologies
157 )
159 self.visualizer = (
160 visualizer() # ty: ignore[call-non-callable]
161 if visualizer is not None
162 else BaseVisualizer() # ty: ignore[call-non-callable]
163 ) # ty: ignore[call-non-callable]
164 self.evaluator = (
165 evaluator() # ty: ignore[call-non-callable]
166 if evaluator is not None
167 else BaseEvaluator() # ty: ignore[call-non-callable]
168 ) # ty: ignore[call-non-callable]
169 self.result = result if result is not None else Result()
170 self._dataset_type = dataset_type
171 self._data_splitter = datasplitter_type(
172 config=self.config, custom_splits=custom_split
173 )
175 self._datasets: Optional[DatasetContainer] = (
176 processed_data # None, or user input
177 )
179 def _validate_config(self, config: Any) -> None:
180 """Sets config to default if None, or validates its type.
181 Args:
182 config: Configuration object to validate or set to default.
183 Raises:
184 TypeError: If config is not of type DefaultConfig
185 """
186 if config is None:
187 self.config = self._default_config # type: ignore
188 else:
189 if not isinstance(config, DefaultConfig):
190 raise TypeError(
191 f"Expected config type to be DefaultConfig, got {type(config)}."
192 )
193 if not isinstance(config, type(self._default_config)): # type: ignore
194 warnings.warn(
195 f"Your config is of type: {type(config)}, for this pipeline the default params of: {type(self._default_config)} work best"
196 )
197 self.config = config
199 def _validate_user_input(self, data: Any) -> None:
200 """Ensures that user-provided data is of a valid type.
201 Args:
202 data: User-provided data to validate.
203 Raises:
204 TypeError: If data is not of a supported type.
205 """
206 if not isinstance(
207 data,
208 (
209 DataPackage,
210 ad.AnnData,
211 MuData,
212 pd.DataFrame,
213 dict,
214 type(None),
215 DatasetContainer,
216 ),
217 ):
218 raise TypeError(
219 f"Expected data type to be one of [DataPackage, AnnData, MuData, "
220 f"pd.DataFrame, dict, DatasetContainer], got {type(data)}."
221 )
223 def _handle_direct_user_data(
224 self,
225 data,
226 ) -> Tuple[DataPackage, DataCase]:
227 """Converts raw user data into a standardized DataPackage format.
229 Args:
230 data: Raw input data in various formats.
232 Returns:
233 DataPackage containing the standardized data
234 DataCase, muliti_single_cell or multi_bulk, etc.
236 Raises:
237 TypeError: If data format is not supported.
238 ValueError: If data doesn't meet format requirements or data_case
239 cannot be inferred.
240 """
241 print(f"in handle_direct_user_data with data: {type(data)}")
242 data_case = self.config.data_case
243 if isinstance(data, DataPackage):
244 data_package = data
245 data_case = self.config.data_case
246 elif isinstance(data, ad.AnnData):
247 mudata = MuData({"user-data": data})
248 data_package = DataPackage(multi_sc={"multi_sc": mudata})
249 if self.config.data_case is None:
250 data_case = DataCase.MULTI_SINGLE_CELL
251 elif isinstance(data, MuData):
252 data_package = DataPackage(multi_sc={"multi_sc": data})
253 if self.config.data_case is None:
254 data_case = DataCase.MULTI_SINGLE_CELL
255 elif isinstance(data, pd.DataFrame):
256 data_package = DataPackage(multi_bulk={"user-data": data})
257 if self.config.data_case is None:
258 data_case = DataCase.MULTI_BULK
259 elif isinstance(data, dict):
260 # Check if all values in the dictionary are pandas DataFrames
261 if all(isinstance(value, pd.DataFrame) for value in data.values()):
262 data_package = DataPackage(multi_bulk=data)
263 if self.config.data_case is None:
264 data_case = DataCase.MULTI_BULK
265 else:
266 raise ValueError(
267 "All values in the dictionary must be pandas DataFrames."
268 )
269 if data_case is None:
270 raise ValueError("data_case must be provided if it cannot be inferred.")
272 return data_package, data_case
274 def _validate_raw_user_data(self) -> None:
275 """Validates the format and content of user-provided raw data.
277 Ensures that raw_user_data is a valid DataPackage with properly formatted
278 attributes.
280 Raises:
281 TypeError: If raw_user_data is not a DataPackage.
282 ValueError: If DataPackage attributes aren't dictionaries or all are None.
283 """
284 if not isinstance(self.raw_user_data, DataPackage):
285 raise TypeError(
286 f"Expected raw_user_data to be of type DataPackage, got "
287 f"{type(self.raw_user_data)}."
288 )
290 all_none = True
291 for attr_name in self.raw_user_data.__annotations__:
292 attr_value = getattr(self.raw_user_data, attr_name)
293 if attr_value is not None:
294 all_none = False
295 if not isinstance(attr_value, dict):
296 raise ValueError(
297 f"Attribute '{attr_name}' of raw_user_data must be a dictionary, "
298 f"got {type(attr_value)}."
299 )
301 if all_none:
302 raise ValueError(
303 "All attributes of raw_user_data are None. At least one must be non-None."
304 )
306 def _fill_data_info(self) -> None:
307 """Populates the config's data_info with entries for all data keys.
309 Creates DataInfo objects for each data key found in raw_user_data
310 if they don't already exist in the configuration.
311 This method is needed, when the user provides data via the Pipeline and
312 not via the config.
313 """
314 all_keys = []
315 for k in self.raw_user_data.__annotations__:
316 attr_value = getattr(self.raw_user_data, k)
317 all_keys.append(k)
318 if isinstance(attr_value, dict):
319 all_keys.extend(attr_value.keys())
320 for k, v in attr_value.items():
321 if isinstance(v, MuData):
322 all_keys.extend(v.mod.keys())
323 for k in all_keys:
324 if self.config.data_config.data_info.get(k) is None:
325 self.config.data_config.data_info[k] = DataInfo()
327 def _validate_user_data(self):
328 """Validates user-provided data based on its source and format.
330 Performs different validation based on whether the user provided
331 preprocessed data, raw data, or a data configuration.
333 Raises:
334 Various exceptions depending on validation results.
335 """
336 if self.raw_user_data is None:
337 if self._datasets is not None: # case when user passes preprocessed data
338 self._validate_container()
339 else: # user passes data via config
340 self._validate_config_data()
341 else:
342 self._validate_raw_user_data()
344 def _validate_container(self):
345 """Validates that a DatasetContainer has at least one valid dataset.
347 Ensures the container has properly formatted datasets and at least
348 one split is present.
350 Raises:
351 ValueError: If container validation fails.
352 """
353 if self.preprocessed_data is None:
354 raise ValueError("DatasetContainer is None. Please provide valid datasets.")
355 none_count = 0
356 if not isinstance(self.preprocessed_data.train, Dataset):
357 if self.preprocessed_data.train is not None:
358 raise ValueError(
359 f"Train dataset has to be either None or Dataset, got "
360 f"{type(self.preprocessed_data.train)}"
361 )
362 none_count += 1
363 if not isinstance(self.preprocessed_data.test, Dataset):
364 if self.preprocessed_data.test is not None:
365 raise ValueError(
366 f"Test dataset has to be either None or Dataset, got "
367 f"{type(self.preprocessed_data.test)}"
368 )
369 none_count += 1
371 if not isinstance(self.preprocessed_data.valid, Dataset):
372 if self.preprocessed_data.valid is not None:
373 raise ValueError(
374 f"Valid dataset has to be either None or Dataset, got "
375 f"{type(self.preprocessed_data.valid)}"
376 )
377 none_count += 1
378 if none_count == 3:
379 raise ValueError("At least one split needs to be provided")
381 def _validate_config_data(self):
382 """Validates the data configuration provided via config.
384 Ensures the data configuration has the necessary components based
385 on the data types being processed.
387 Raises:
388 ValueError: If data configuration validation fails.
389 """
390 data_info_dict = self.config.data_config.data_info
391 if not data_info_dict:
392 raise ValueError("data_info dictionary is empty.")
394 # Check if there's at least one non-annotation file
395 non_annotation_files = {
396 key: info
397 for key, info in data_info_dict.items()
398 if info.data_type != "ANNOTATION"
399 }
401 if not non_annotation_files:
402 raise ValueError("At least one non-annotation file must be provided.")
404 # Check if there's any non-single-cell data
405 non_single_cell_data = {
406 key: info
407 for key, info in data_info_dict.items()
408 if not info.is_single_cell and info.data_type != "ANNOTATION"
409 }
411 # If there's non-single-cell data, check for annotation file
412 if non_single_cell_data:
413 annotation_files = {
414 key: info
415 for key, info in data_info_dict.items()
416 if info.data_type == "ANNOTATION"
417 }
419 if not annotation_files:
420 raise ValueError(
421 "When working with non-single-cell data, an annotation file must be "
422 "provided."
423 )
425 def preprocess(self, config: Optional[Union[None, DefaultConfig]] = None, **kwargs):
426 """Filters, normalizes and prepares data for model training.
428 Processes raw input data into the format required by the model and creates
429 train/validation/test splits as needed.
431 Args:
432 config: Optional custom configuration for preprocessing.
433 **kwargs: Additional configuration parameters as keyword arguments.
435 Raises:
436 NotImplementedError: If preprocessor is not initialized.
437 """
438 if self._preprocessor_type is None:
439 raise NotImplementedError("Preprocessor not initialized")
440 self._validate_user_data()
441 if self.preprocessed_data is None:
442 self.preprocessed_data = self._preprocessor.preprocess(
443 raw_user_data=self.raw_user_data, # type: ignore
444 )
445 self.result.datasets = self.preprocessed_data
446 self._datasets = self.preprocessed_data
447 else:
448 self._datasets = self.preprocessed_data
449 self.result.datasets = self.preprocessed_data
451 def fit(self, config: Optional[Union[None, DefaultConfig]] = None, **kwargs):
452 """Trains the model on preprocessed data.
454 Creates and configures a trainer instance, then executes the training
455 process using the preprocessed datasets.
457 Args:
458 config: Optional custom configuration for training.
459 **kwargs: Additional configuration parameters as keyword arguments.
461 Raises:
462 ValueError: If datasets aren't available for training.
463 """
464 if self._datasets is None:
465 raise ValueError(
466 "Datasets not built. Please run the preprocess method first."
467 )
469 self._trainer = self._trainer_type(
470 trainset=self._datasets.train,
471 validset=self._datasets.valid,
472 result=self.result,
473 config=self.config,
474 model_type=self._model_type,
475 loss_type=self._loss_type,
476 ontologies=self.ontologies, # Ontix
477 masking_fn=self.masking_fn if hasattr(self, "masking_fn") else None,
478 masking_fn_kwargs=(
479 self.masking_fn_kwargs if hasattr(self, "masking_fn_kwargs") else None
480 ),
481 model_map=self.model_map,
482 )
484 trainer_result: Result = self._trainer.train()
485 self.result.update(other=trainer_result)
487 def predict(
488 self,
489 data: Optional[
490 Union[
491 DataPackage,
492 DatasetContainer,
493 ad.AnnData,
494 MuData, # ty: ignore[invalid-type-form]
495 ] # ty: ignore[invalid-type-form]
496 ] = None, # ty: ignore[invalid-type-form]
497 config: Optional[Union[None, DefaultConfig]] = None,
498 from_key: Optional[str] = None,
499 to_key: Optional[str] = None,
500 **kwargs,
501 ):
502 """Generates predictions using the trained model.
504 Uses the trained model to make predictions on test data or new data
505 provided by the user. Processes the results and stores them in the
506 result container.
508 Args:
509 data: Optional new data for predictions.
510 config: Optional custom configuration for prediction.
511 **kwargs: Additional configuration parameters as keyword arguments.
513 Raises:
514 NotImplementedError: If required components aren't initialized.
515 ValueError: If no test data is available or data format is invalid.
516 """
517 self._validate_prediction_requirements()
518 if self._trainer is None:
519 raise ValueError(
520 "Trainer not initialized, call fit first. If you used .save and .load, then you shoul not call .fit, then this is a bug."
521 "In this case please submit an issue."
522 )
524 self._trainer.setup_trainer(old_model=self.result.model)
525 original_input = data
526 predict_data = self._prepare_prediction_data(data=data)
528 predictor_results = self._generate_predictions(
529 predict_data=predict_data,
530 )
532 self._process_latent_results(
533 predictor_results=predictor_results, predict_data=predict_data
534 )
535 self._postprocess_reconstruction(
536 predictor_results=predictor_results,
537 original_input=original_input,
538 predict_data=predict_data,
539 )
540 self.result.update(predictor_results)
541 return self.result
543 def _validate_prediction_requirements(self):
544 """Validate that required components are initialized."""
545 if self._preprocessor is None:
546 raise NotImplementedError("Preprocessor not initialized")
547 if self.result.model is None:
548 raise NotImplementedError(
549 "Model not trained. Please run the fit method first"
550 )
552 def _prepare_prediction_data(
553 self,
554 data: Optional[
555 Union[
556 DataPackage,
557 DatasetContainer,
558 ad.AnnData,
559 MuData, # ty: ignore[invalid-type-form]
560 ] # ty: ignore[invalid-type-form]
561 ] = None, # ty: ignore[invalid-type-form]
562 ) -> DatasetContainer:
563 """Prepare and validate input data for prediction.
564 Args:
565 data: Optional new data for predictions. If None, uses existing datasets.
566 Returns:
567 DatasetContainer: The prepared dataset container for predictions.
568 Raises:
569 ValueError: If data type is unsupported or no test data is available.
570 """
571 if data is None:
572 return self._get_existing_datasets()
573 elif isinstance(data, DatasetContainer):
574 return self._handle_dataset_container(data=data)
575 elif isinstance(data, (DataPackage, ad.AnnData, MuData, dict, pd.DataFrame)):
576 return self._handle_user_data(data=data)
577 else:
578 raise ValueError(f"Unsupported data type: {type(data)}")
580 def _get_existing_datasets(self) -> DatasetContainer:
581 """Get existing preprocessed datasets and validate them for prediction.
582 Returns:
583 DatasetContainer: The preprocessed datasets available for prediction.
584 Raises:
585 ValueError: If no datasets are available or no test data is present.
586 """
587 if self._datasets is None:
588 raise ValueError(
589 "No data provided for prediction and no preprocessed datasets "
590 "available. Please run the preprocess method first or provide "
591 "data for prediction."
592 )
593 if self._datasets.test is None:
594 raise ValueError("No test data available for prediction")
595 return self._datasets
597 def _handle_dataset_container(self, data: DatasetContainer) -> DatasetContainer:
598 """Handle DatasetContainer input for prediction.
599 Args:
600 data: DatasetContainer containing preprocessed datasets.
601 Returns:
602 DatasetContainer: The processed dataset container for predictions.
603 """
604 self.result.new_datasets = data
606 if hasattr(self._preprocessor, "_dataset_container"):
607 self._preprocessor._dataset_container = data
609 return data
611 def _handle_user_data(self, data: Any) -> DatasetContainer:
612 """Handle user-provided data (DataPackage, AnnData, etc.).
613 Args:
614 data: Raw user data in various formats (DataPackage, AnnData, etc.).
615 Returns:
616 DatasetContainer: The processed dataset container for predictions.
617 Raises:
618 ValueError: If data type is unsupported or no test data is available.
619 """
620 processed_data, _ = self._handle_direct_user_data(data=data)
621 predict_data = self._preprocessor.preprocess(
622 raw_user_data=processed_data, predict_new_data=True
623 )
624 self.result.new_datasets = predict_data
625 return predict_data
627 def _validate_prediction_data(self, predict_data: DatasetContainer):
628 """Validate that prediction data has required test split."""
629 if predict_data.test is None:
630 raise ValueError(
631 f"The data for prediction need to be a DatasetContainer with a test "
632 f"attribute, got: {predict_data}"
633 )
635 def _generate_predictions(
636 self,
637 predict_data: DatasetContainer,
638 ):
639 """Generate predictions using the trained model.
640 Args:
641 predict_data: DatasetContainer with preprocessed datasets for prediction.
642 Returns:
643 Predictor results containing latent spaces and reconstructions.
645 """
646 self._validate_prediction_data(predict_data=predict_data)
647 return self._trainer.predict(
648 data=predict_data.test,
649 model=self.result.model,
650 ) # type: ignore
652 def _process_latent_results(
653 self, predictor_results, predict_data: DatasetContainer
654 ):
655 """Process and store latent space results.
656 Args:
657 predictor_results: Results from the prediction step containing latents.
658 predict_data: DatasetContainer with preprocessed datasets for prediction.
659 """
660 latent = predictor_results.latentspaces.get(epoch=-1, split="test")
661 if isinstance(latent, dict):
662 print("Detected dictionary in latent results, extracting array...")
663 latent = next(iter(latent.values())) # TODO better adjust for xmodal
664 self.result.adata_latent = ad.AnnData(latent)
665 self.result.adata_latent.obs_names = predict_data.test.sample_ids # type: ignore
666 self.result.adata_latent.uns["var_names"] = predict_data.test.feature_ids # type: ignore
667 self.result.update(predictor_results)
669 def _postprocess_reconstruction(
670 self, predictor_results, original_input, predict_data: DatasetContainer
671 ):
672 """Postprocess reconstruction results based on input type.
674 This outpus the reconstruction in the same format as the original input data,
675 whether it is a DatasetContainer, DataPackage, AnnData, MuData, or other formats.
677 Args:
678 predictor_results: Results from the prediction step containing reconstructions.
679 original_input: Original input data format (if provided).
680 predict_data: DatasetContainer with preprocessed datasets for prediction.
681 Raises:
682 ValueError: If reconstruction fails or data types are incompatible.
683 """
684 raw_recon: Union[Dict, np.ndarray, torch.Tensor] = (
685 self.result.reconstructions.get(epoch=-1, split="test")
686 )
687 if isinstance(raw_recon, np.ndarray):
688 raw_recon = torch.from_numpy(raw_recon) # type: ignore
689 elif isinstance(raw_recon, dict):
690 raw_recon = raw_recon.get("translation") # type: ignore
691 if raw_recon is None:
692 raise ValueError(
693 f"Raw recon is dict, but has no translation key, this should not happen: {raw_recon}"
694 )
695 raw_recon = torch.from_numpy(raw_recon) # type: ignore
696 else:
697 raise ValueError(
698 f"type of raw_recon has to be 'dict' or 'np.ndarray', got: {type(raw_recon)}"
699 )
701 if original_input is None:
702 # Using existing datasets
703 self._handle_dataset_container_reconstruction(
704 raw_recon=raw_recon, # type: ignore
705 dataset_container=predict_data,
706 context="existing datasets",
707 )
708 elif isinstance(original_input, DatasetContainer):
709 self._handle_dataset_container_reconstruction(
710 raw_recon=raw_recon, # type: ignore
711 dataset_container=original_input,
712 context="provided DatasetContainer",
713 )
714 elif self.config.data_case == DataCase.MULTI_SINGLE_CELL:
715 self._handle_multi_single_cell_reconstruction(
716 raw_recon=raw_recon,
717 predictor_results=predictor_results, # type: ignore
718 )
719 elif isinstance(
720 original_input, (DataPackage, ad.AnnData, MuData, dict, pd.DataFrame)
721 ):
722 self._handle_user_data_reconstruction(
723 raw_recon=raw_recon, predictor_results=predictor_results
724 )
725 else:
726 self._handle_unsupported_reconstruction()
728 def _handle_dataset_container_reconstruction(
729 self,
730 raw_recon: torch.Tensor,
731 dataset_container: DatasetContainer,
732 context: str = "DatasetContainer",
733 ):
734 """Handle reconstruction for DatasetContainer input.
735 Args:
736 raw_recon: Raw reconstruction tensor from the model.
737 dataset_container: Original DatasetContainer provided by the user.
738 context: Description of the data context for error messages.
739 Raises:
740 ValueError: If no test data is available in the container.
741 """
743 # if dataset_container.test is None:
744 # raise ValueError(f"No test data available in {context} for reconstruction.")
745 # temp = copy.deepcopy(dataset_container.test)
746 # temp.data = raw_recon
747 # self.result.final_reconstruction = temp
748 pass
750 def _handle_multi_single_cell_reconstruction(
751 self, raw_recon: torch.Tensor, predictor_results: Result
752 ):
753 """Handle reconstruction for multi-single-cell data
754 Args:
755 raw_recon: Raw reconstruction tensor from the model.
756 predictor_results: Results from the prediction step containing reconstructions.
757 Raises:
758 ValueError: If reconstruction formatting fails or data types are incompatible.
759 """
760 pkg = self._preprocessor.format_reconstruction(
761 reconstruction=raw_recon, result=predictor_results
762 )
763 if not isinstance(pkg.multi_sc, dict):
764 raise ValueError(
765 "Expected pkg.multi_sc to be a dictionary, got "
766 f"{type(pkg.multi_sc)} instead."
767 )
768 self.result.final_reconstruction = pkg.multi_sc["multi_sc"]
770 def _handle_user_data_reconstruction(
771 self, raw_recon: torch.Tensor, predictor_results
772 ):
773 """Handle reconstruction for user-provided data formats.
774 Args:
775 raw_recon: Raw reconstruction tensor from the model.
776 predictor_results: Results from the prediction step containing reconstructions.
777 """
778 pkg = self._preprocessor.format_reconstruction(
779 reconstruction=raw_recon, result=predictor_results
780 )
781 self.result.final_reconstruction = pkg
783 def _handle_unsupported_reconstruction(self):
784 """Handle cases where reconstruction formatting is not available."""
785 print(
786 "Reconstruction Formatting (the process of using the reconstruction "
787 "output of the autoencoder models and combine it with metadata to get "
788 "the exact same data structure as the raw input data i.e, a DataPackage, "
789 "DatasetContainer, or AnnData) not available for this data type or case."
790 )
792 def decode(
793 self, latent: Union[torch.Tensor, ad.AnnData, pd.DataFrame]
794 ) -> Union[torch.Tensor, ad.AnnData, pd.DataFrame]:
795 """Transforms latent space representations back to input space.
797 Handles various input formats for the latent representation and
798 returns the decoded data in a matching format.
800 Args:
801 latent: Latent space representation to decode.
803 Returns:
804 Decoded data in a format matching the input.
806 Raises:
807 TypeError: If no model has been trained or input type is invalid.
808 ValueError: If latent dimensions are incompatible with the model.
809 """
810 if self.result.model is None:
811 raise TypeError("No model trained yet, use fit() or run() method first")
812 recons: torch.Tensor
813 if isinstance(latent, ad.AnnData):
814 latent_data = torch.tensor(
815 latent.X, dtype=torch.float32
816 ) # Ensure float for compatibility
818 expected_latent_dim = self.config.latent_dim
819 if not latent_data.shape[1] == expected_latent_dim:
820 raise ValueError(
821 f"Input AnnData's .X has shape {latent_data.shape}, but the model "
822 f"expects a latent vector of size {expected_latent_dim}. Consider "
823 f"projecting the AnnData to the correct latent space first."
824 )
825 latent_tensor = latent_data
827 recons = self._trainer.decode(x=latent_tensor)
828 if self._datasets is None:
829 raise ValueError(
830 "No datasets available in the DatasetContainer to reconstruct "
831 "AnnData objects. Please provide a valid DatasetContainer."
832 )
833 if self._datasets.train is None:
834 raise ValueError(
835 "The train dataset in the DatasetContainer is None. "
836 "Please provide a valid train dataset to reconstruct AnnData objects."
837 )
838 if not isinstance(self._datasets.train, BaseDataset):
839 raise TypeError(
840 "The train dataset in the DatasetContainer must be a BaseDataset "
841 "to reconstruct AnnData objects."
842 )
843 recons_adata = ad.AnnData(
844 X=recons.to("cpu").detach().numpy(),
845 obs=pd.DataFrame(index=latent.obs_names),
846 var=pd.DataFrame(index=self._datasets.train.feature_ids),
847 )
849 return recons_adata
850 elif isinstance(latent, pd.DataFrame):
851 latent_tensor = torch.tensor(latent.values, dtype=torch.float32)
852 recons = self._trainer.decode(x=latent_tensor)
853 return pd.DataFrame(
854 recons.to("cpu").detach().numpy(),
855 index=latent.index,
856 columns=latent.columns,
857 )
858 elif isinstance(latent, torch.Tensor):
859 # Check size compatibility
860 expected_latent_dim = self.config.latent_dim
861 if not latent.shape[1] == expected_latent_dim:
862 if self._trainer._model._mu.out_features == latent.shape[1]:
863 warnings.warn(
864 f"latent_prior has latent dimension {latent.shape[1]}, "
865 "which matches the input feature dimension of the model. Did you "
866 "mean to provide latent vectors of dimension "
867 "For Ontix this is the default behaviour and the warning can be ignored. "
868 f"{self.config.latent_dim}?"
869 )
870 else:
871 raise ValueError(
872 f"latent_prior has incompatible latent dimension {latent.shape[1]}, "
873 f"expected {self.config.latent_dim}. or {self._trainer._model._mu.out_features}."
874 )
876 latent_tensor = latent
877 else:
878 raise TypeError(
879 f"Input 'latent' must be either a torch.Tensor or an AnnData object, "
880 f"not {type(latent)}."
881 )
883 return self._trainer.decode(x=latent_tensor)
885 def evaluate(
886 self,
887 ml_model_class: ClassifierMixin = linear_model.LogisticRegression(),
888 ml_model_regression: RegressorMixin = linear_model.LinearRegression(),
889 params: Union[
890 list, str
891 ] = [], # Default empty list, to use all parameters use string "all"
892 metric_class: str = "roc_auc_ovo", # Default is 'roc_auc_ovo' via https://scikit-learn.org/stable/modules/model_evaluation.html#scoring-string-names
893 metric_regression: str = "r2", # Default is 'r2'
894 reference_methods: list = [], # Default [], Options are "PCA", "UMAP", "TSNE", "RandomFeature"
895 reference_reducer: dict = {}, # Option to provide pre-fitted reducer objects for PCA, UMAP or TSNE, e.g. {"PCA": pca_reducer, "UMAP": umap_reducer, "TSNE": tsne_reducer}
896 split_type: Literal[
897 "use-split", "CV-5", "LOOC"
898 ] = "use-split", # Default is "use-split", other options: "CV-5", ... "LOOCV"?
899 n_downsample: Optional[int] = 10000,
900 top_k_classes: Optional[int] = 20,
901 exclude_classes: Union[
902 list, None
903 ] = None, # Default is None, if provided exclude these classes from evaluation
904 ) -> Result:
905 """TODO"""
906 if self.evaluator is None:
907 raise NotImplementedError("Evaluator not initialized")
908 if self.result.model is None:
909 raise NotImplementedError(
910 "Model not trained. Please run the fit method first"
911 )
912 if not is_classifier(ml_model_class):
913 warnings.warn(
914 "The provided model is not a sklearn-type classifier. "
915 "Evaluation continues but may produce incorrect results or errors."
916 )
917 if not is_regressor(ml_model_regression):
918 warnings.warn(
919 "The provided model is not a sklearn-type regressor. "
920 "Evaluation continues but may produce incorrect results or errors."
921 )
923 if len(params) == 0:
924 if self.config.data_config.annotation_columns is None:
925 params = [] # type: ignore
926 else:
927 params = self.config.data_config.annotation_columns # type: ignore
929 if len(params) == 0:
930 raise ValueError(
931 "No parameters specified for evaluation. Please provide a list of "
932 "parameters or ensure that annotation_columns are set in the config."
933 )
935 if "RandomFeature" in reference_methods:
936 if self._datasets is None:
937 raise ValueError(
938 "Datasets not available for adding RandomFeature. Please keep "
939 "preprocessed data available before evaluation."
940 )
942 if len(self.result.latentspaces._data) == 0:
943 raise ValueError(
944 "No latent spaces found in results. Please run predict() to "
945 "calculate embeddings before evaluation."
946 )
948 self.result = self.evaluator.evaluate(
949 datasets=self._datasets,
950 result=self.result,
951 ml_model_class=ml_model_class,
952 ml_model_regression=ml_model_regression,
953 params=params,
954 metric_class=metric_class,
955 metric_regression=metric_regression,
956 reference_methods=reference_methods,
957 reference_reducer=reference_reducer,
958 split_type=split_type,
959 n_downsample=n_downsample,
960 top_k_classes=top_k_classes,
961 exclude_classes=exclude_classes,
962 )
964 _: Any = self.visualizer._plot_evaluation(result=self.result)
966 return self.result
968 def visualize(self, config: Optional[Union[None, DefaultConfig]] = None, **kwargs):
969 """Creates visualizations of model results and performance.
971 Args:
972 config: Optional custom configuration for visualization.
973 **kwargs: Additional configuration parameters.
975 Raises:
976 NotImplementedError: If visualizer is not initialized.
977 """
978 if self.visualizer is None:
979 raise NotImplementedError("Visualizer not initialized")
981 self.visualizer.visualize(result=self.result, config=self.config)
983 def show_result(self, split: str = "all", **kwargs):
984 """Displays key visualizations of model results.
986 This method generates the following visualizations:
987 1. Loss Curves: Displays the absolute loss curves to provide insights into
988 the model's training and validation performance over epochs.
989 2. Latent Space Ridgeline Plot: Visualizes the distribution of the latent
990 space representations across different dimensions, offering a high-level
991 overview of the learned embeddings.
992 3. Latent Space 2D Scatter Plot: Projects the latent space into two dimensions
993 for a detailed view of the clustering or separation of data points.
995 These visualizations help in understanding the model's performance and
996 the structure of the latent space representations.
997 """
998 print("Creating plots ...")
1000 params: Optional[Union[List[str], str]] = kwargs.pop("params", None)
1001 # Check if params are empty and annotation columns are available in config
1002 if params is None and self.config.data_config.annotation_columns:
1003 params = self.config.data_config.annotation_columns
1005 if len(self.result.losses._data) != 0:
1006 self.visualizer.show_loss(plot_type="absolute")
1007 else:
1008 warnings.warn(
1009 "No loss data found in results. Skipping loss curve visualization."
1010 )
1012 if len(self.result.latentspaces._data) != 0:
1013 self.visualizer.show_latent_space(
1014 result=self.result, plot_type="Ridgeline", split=split, param=params
1015 )
1016 self.visualizer.show_latent_space(
1017 result=self.result, plot_type="2D-scatter", split=split, param=params
1018 )
1019 else:
1020 warnings.warn(
1021 "No latent spaces found in results. Please run predict() to "
1022 "calculate embeddings."
1023 )
1025 def run(
1026 self, data: Optional[Union[DatasetContainer, DataPackage]] = None
1027 ) -> Result:
1028 """Executes the complete pipeline from preprocessing to visualization.
1030 Runs all pipeline steps in sequence and returns the result.
1032 Args:
1033 data: Optional data for prediction (overrides test data).
1035 Returns:
1036 Complete pipeline results.
1037 """
1038 self.preprocess()
1039 self.fit()
1040 self.predict(data=data)
1041 self.visualize()
1042 return self.result
1044 def save(self, file_path: str, save_all: bool = False):
1045 """Saves the pipeline to a file.
1047 Args:
1048 file_path: Path where the pipeline should be saved.
1049 """
1050 saver = Saver(file_path, save_all=save_all)
1051 saver.save(self)
1053 @classmethod
1054 def load(cls, file_path) -> Any:
1055 """Loads a pipeline from a file.
1057 Args:
1058 file_path: Path to the saved pipeline.
1060 Returns:
1061 The loaded pipeline instance.
1062 """
1063 loader = Loader(file_path)
1064 return loader.load()
1066 def sample_latent_space(
1067 self,
1068 n_samples: int,
1069 split: str = "test",
1070 epoch: int = -1,
1071 ) -> torch.Tensor:
1072 """Samples latent space points from the learned distribution.
1074 If `n_samples` is not provided, this method returns one latent point per
1075 sample in the specified split (legacy behavior). If `n_samples` is given,
1076 it draws samples from the aggregated posterior distribution of the split.
1078 Args:
1079 split: The split to sample from (train, valid, test), default is test.
1080 epoch: The epoch to sample from, default is the last epoch (-1).
1081 n_samples: Optional number of latent points to sample. If None,
1082 returns one latent point per available sample in the split.
1084 Returns:
1085 z: torch.Tensor - The sampled latent space points.
1087 Raises:
1088 ValueError: If the model has not been trained or latent statistics
1089 have not been computed.
1090 TypeError: If mu or logvar are not numpy arrays.
1091 """
1093 if not hasattr(self, "_trainer") or self._trainer is None:
1094 raise ValueError("Model is not trained yet. Please train the model first.")
1095 if self.result.mus is None or self.result.sigmas is None:
1096 raise ValueError("Model has not learned the latent space distribution yet.")
1097 if not isinstance(n_samples, int) or n_samples <= 0:
1098 raise ValueError("n_samples must be a positive integer.")
1100 mu = self.result.mus.get(split=split, epoch=epoch)
1101 logvar = self.result.sigmas.get(split=split, epoch=epoch)
1103 if not isinstance(mu, np.ndarray):
1104 raise TypeError(
1105 f"Expected value to be of type numpy.ndarray, got {type(mu)}."
1106 "This can happen if the model was not trained with VAE loss or if you forgot to run predict()"
1107 )
1108 if not isinstance(logvar, np.ndarray):
1109 raise TypeError(
1110 f"Expected value to be of type numpy.ndarray, got {type(logvar)}."
1111 )
1113 mu_t = torch.from_numpy(mu).to(
1114 device=self._trainer._model.device, dtype=self._trainer._model.dtype
1115 )
1116 logvar_t = torch.from_numpy(logvar).to(
1117 device=self._trainer._model.device, dtype=self._trainer._model.dtype
1118 )
1120 with torch.no_grad():
1121 global_mu = mu_t.mean(dim=0)
1122 global_logvar = logvar_t.mean(dim=0)
1124 mu_exp = global_mu.expand(n_samples, -1)
1125 logvar_exp = global_logvar.expand(n_samples, -1)
1127 z = self._trainer._model.reparameterize(mu_exp, logvar_exp)
1128 return z
1130 def generate(
1131 self,
1132 n_samples: Optional[int] = None,
1133 latent_prior: Optional[Union[np.ndarray, torch.Tensor]] = None,
1134 split: str = "test",
1135 epoch: int = -1,
1136 ) -> torch.Tensor:
1137 """Generates new samples from the model's latent space.
1139 This method allows for the generation of new data samples by sampling
1140 from the model's latent space. Users can either provide a custom latent
1141 prior or specify the number of samples to generate. If a custom latent
1142 prior is provided, its batch dimension must be compatible with n_samples.
1144 Args:
1145 n_samples: The number of samples to generate.
1146 latent_prior: Optional custom latent prior distribution. If provided,
1147 this will be used for sampling instead of the learned distribution.
1148 The prior must either be a single latent vector or a batch of
1149 latent vectors matching n_samples.
1150 split: The split to sample from (train, valid, test), default is test.
1151 epoch: The epoch to sample from, default is the last epoch (-1).
1153 Returns:
1154 torch.Tensor: The generated samples in the input space.
1156 Raises:
1157 ValueError: If n_samples is not a positive integer or if the latent
1158 prior has incompatible dimensions.
1159 TypeError: If latent_prior is not a numpy array or tensor.
1160 """
1161 self._trainer.setup_trainer(old_model=self.result.model)
1163 if not isinstance(n_samples, int) or n_samples <= 0:
1164 if latent_prior is None:
1165 raise ValueError(
1166 "n_samples must be a positive integer or latent_prior provided."
1167 )
1169 if latent_prior is None:
1170 latent_prior = self.sample_latent_space(
1171 n_samples=n_samples, split=split, epoch=epoch
1172 )
1174 if isinstance(latent_prior, np.ndarray):
1175 latent_prior = torch.from_numpy(latent_prior).to(
1176 device=self._trainer._model.device,
1177 dtype=self._trainer._model.dtype,
1178 )
1179 if not isinstance(latent_prior, torch.Tensor):
1180 raise TypeError(
1181 f"latent_prior must be numpy.ndarray or torch.Tensor, got {type(latent_prior)}."
1182 )
1183 if not latent_prior.shape[1] == self.config.latent_dim:
1184 if self._trainer._model._mu.out_features == latent_prior.shape[1]:
1185 warnings.warn(
1186 f"latent_prior has latent dimension {latent_prior.shape[1]}, "
1187 "which matches the input feature dimension of the model. Did you "
1188 "mean to provide latent vectors of dimension "
1189 "For Ontix this is the default behaviour and the warning can be ignored. "
1190 f"{self.config.latent_dim}?"
1191 )
1192 else:
1193 raise ValueError(
1194 f"latent_prior has incompatible latent dimension {latent_prior.shape[1]}, "
1195 f"expected {self.config.latent_dim}."
1196 )
1198 with torch.no_grad():
1199 generated = self.decode(latent=latent_prior)
1200 return generated
1202 def explain(
1203 self,
1204 method: Literal["DeepLiftShap", "IntegratedGradients"] = "DeepLiftShap",
1205 sel_latent_dim: Union[list, int, str, None] = None,
1206 input_type: Literal["random", "grouped"] = "grouped",
1207 input_group: Optional[str] = None,
1208 baseline_type: Literal["mean", "random", "grouped"] = "mean",
1209 baseline_group: str = "all",
1210 anno_col: Optional[str] = None,
1211 n_subset: int = 100,
1212 seed_int: int = 12,
1213 split: Literal["train", "test", "valid"] = "train",
1214 llm_explain: bool = False,
1215 llm_client: Literal["ollama", "mistral", "openrouter", "scads-llm"] = "mistral",
1216 llm_model: str = "mistral-large-latest",
1217 top_n_genes: int = 40,
1218 prompt: str = PROMPT,
1219 ) -> pd.DataFrame:
1220 """Runs the feature-importance explainer and returns gene-by-latent-dimension attribution scores.
1222 Args:
1223 method: Specifies which attribution algorithm to use for explaining the model.
1224 sel_latent_dim: Specifies which latent dimension(s) to explain. If None, all dimensions will be explained.
1225 input_type: Specifies whether the feature-importance algorithm should use 'random' samples or 'grouped' samples (see input_group) as input for the attribution computation.
1226 input_group: If input_type is 'grouped', this specifies which subset in anno_col to filter for to use as input for the attribution computation.
1227 baseline_type: Specifies whether the feature-importance algorithm should use the 'mean' baseline, a 'random'-sample baseline, or samples from a 'grouped' baseline (see baseline_group).
1228 baseline_group: Specifies whether the baseline is computed using all data (default) or which subset in anno_col to filter for.
1229 anno_col: If baseline_group is not 'all', this specifies the annotation column used to filter the baseline data.
1230 n_subset: Specifies the number of cells to use when subsampling for the attribution computation.
1231 seed_int: Defines the random seed used for reproducible subsampling and attribution calculations.
1232 split: The split to use for feature importance calculation (train, valid, test), default is train.
1233 llm_explain: Whether to use LLM explainers for feature importance calculation.
1234 llm_client: The LLM client to use for feature importance calculation.
1235 llm_model: The LLM model to use for feature importance calculation.
1236 top_n_genes: How many top (contribution to embedding) genes to consider for LLM explanation.
1238 Returns:
1239 pd.DataFrame: The generated samples in the input space.
1240 A DataFrame of attribution scores with genes as the index and latent dimensions as columns.
1241 Each entry represents the contribution of a given gene to a specific latent dimension.
1242 """
1243 my_converter = AnnDataConverter()
1244 dataset: Optional[DatasetContainer] = get_dataset(self.result)
1245 if dataset is None:
1246 raise ValueError(
1247 "No dataset available for explanation."
1248 "This happens if you used .save and .load, and did not run .predict before."
1249 "This can also happen if you run .explain before .preprocess or .fit."
1250 )
1251 adata: Optional[Dict[str, ad.AnnData]] = my_converter.dataset_to_adata(
1252 dataset, split=split
1253 )
1254 if adata is None:
1255 raise ValueError(
1256 "There has been an error converting the dataset to AnnData."
1257 )
1258 model = self.result.model
1259 if model is None:
1260 raise ValueError(
1261 "No model available for explanation."
1262 "This happens if you used .save and .load, and did not run .fit before."
1263 "This can also happen if you run .explain before .fit."
1264 )
1265 # Validate sel_latent_dim and convert from string to int if necessary
1266 if sel_latent_dim is not None:
1267 if isinstance(sel_latent_dim, str):
1268 if self.ontologies is None:
1269 raise ValueError(
1270 "sel_latent_dim is a string, but no ontologies are available in the pipeline. Please provide ontologies in the config or set sel_latent_dim to an int or list of ints."
1271 )
1272 if sel_latent_dim not in self.ontologies[0].keys():
1273 raise ValueError(
1274 f"sel_latent_dim is set to '{sel_latent_dim}', but this is not a key in the available ontologies: {list(self.ontologies[0].keys())}. Please provide a valid ontology key or set sel_latent_dim to an int or list of ints."
1275 )
1276 sel_latent_dim = list(self.ontologies[0].keys()).index(sel_latent_dim)
1277 elif isinstance(sel_latent_dim, int):
1278 if sel_latent_dim < 0 or sel_latent_dim >= self.config.latent_dim:
1279 raise ValueError(
1280 f"sel_latent_dim is set to {sel_latent_dim}, but this is out of bounds for the latent dimension size of {self.config.latent_dim}. Please provide a valid integer index or set sel_latent_dim to None to explain all dimensions."
1281 )
1282 elif isinstance(sel_latent_dim, list):
1283 # Check if entries are strings or integers and validate accordingly
1284 sel_latent_dim_validated = []
1285 for dim in sel_latent_dim:
1286 if isinstance(dim, str):
1287 if self.ontologies is None:
1288 raise ValueError(
1289 "sel_latent_dim is a list of strings, but no ontologies are available in the pipeline. Please provide ontologies in the config or set sel_latent_dim to a list of ints."
1290 )
1291 if dim not in self.ontologies[0].keys():
1292 raise ValueError(
1293 f"sel_latent_dim contains '{dim}', but this is not a key in the available ontologies: {list(self.ontologies[0].keys())}. Please provide valid ontology keys or set sel_latent_dim to a list of ints."
1294 )
1295 sel_latent_dim_validated.append(
1296 list(self.ontologies[0].keys()).index(dim)
1297 )
1298 elif isinstance(dim, int):
1299 if dim < 0 or dim >= self.config.latent_dim:
1300 raise ValueError(
1301 f"sel_latent_dim contains {dim}, but this is out of bounds for the latent dimension size of {self.config.latent_dim}. Please provide valid integer indices or set sel_latent_dim to None to explain all dimensions."
1302 )
1303 sel_latent_dim_validated.append(dim)
1304 else:
1305 raise ValueError(
1306 f"sel_latent_dim list contains entries of type {type(dim)}, but only str and int are allowed. Please ensure all entries are either strings corresponding to ontology keys or integers corresponding to latent dimension indices."
1307 )
1308 sel_latent_dim = sel_latent_dim_validated
1309 else:
1310 raise ValueError(
1311 f"sel_latent_dim must be either an int, a str, a list of ints or a list of strs, got {type(sel_latent_dim)}. Please provide a valid type for sel_latent_dim."
1312 )
1314 print("Start feature attribution calculation")
1315 explainer = FeatureImportanceExplainer(
1316 adata=adata["global"],
1317 model=model,
1318 n_subset=n_subset,
1319 method=method,
1320 sel_latent_dim=sel_latent_dim,
1321 input_type=input_type,
1322 input_group=input_group,
1323 baseline_type=baseline_type,
1324 baseline_group=baseline_group,
1325 anno_col=anno_col,
1326 seed_int=seed_int,
1327 )
1328 df_attributions = explainer.explain()
1330 if llm_explain:
1331 gene_attributions: Dict[str, List] = preprocess_explanations(
1332 df=df_attributions, n=top_n_genes, max_dims=8
1333 )
1334 if len(gene_attributions) > 8:
1335 warnings.warn(
1336 "You requested LLM-generated explanations for each latent dimension. "
1337 "However, more than eight latent dimensions were provided. To avoid excessively long "
1338 "and computationally intensive output, only the eight most informative latent "
1339 "dimensions will be analyzed."
1340 )
1342 llm_explainer = LLMExplainer(
1343 client_name=llm_client,
1344 model_name=llm_model,
1345 genes_to_latent=gene_attributions, # Example gene list
1346 prompt=prompt,
1347 )
1348 print("Start LLM explanation generation")
1349 explanation = llm_explainer.explain()
1350 self.result.embedding_explanations = explanation
1351 self.result.embedding_attributions = df_attributions
1352 return df_attributions
1354 def impute(self, corrupted_tensor: torch.Tensor):
1355 raise NotImplementedError("Impute method only implemented for Maskix pipeline.")