Coverage for src / autoencodix / stackix.py: 67%
45 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
1from typing import Dict, Optional, Type, Union, List
2import torch
3import numpy as np
5import anndata as ad # type: ignore
6from autoencodix.base._base_dataset import BaseDataset
7from autoencodix.utils._utils import config_method
8from autoencodix.base._base_loss import BaseLoss
9from autoencodix.base._base_pipeline import BasePipeline
10from autoencodix.base._base_trainer import BaseTrainer
11from autoencodix.base._base_visualizer import BaseVisualizer
12from autoencodix.base._base_preprocessor import BasePreprocessor
13from autoencodix.base._base_autoencoder import BaseAutoencoder
14from autoencodix.base._base_evaluator import BaseEvaluator
15from autoencodix.data._datasetcontainer import DatasetContainer
16from autoencodix.data._datasplitter import DataSplitter
17from autoencodix.data.datapackage import DataPackage
18from autoencodix.data._numeric_dataset import NumericDataset
19from autoencodix.evaluate._general_evaluator import GeneralEvaluator
20from autoencodix.modeling._varix_architecture import VarixArchitecture
21from autoencodix.utils._result import Result
22from autoencodix.configs.default_config import DefaultConfig
24from autoencodix.configs.stackix_config import StackixConfig
25from autoencodix.utils._losses import VarixLoss
26from autoencodix.data._stackix_preprocessor import StackixPreprocessor
27from autoencodix.data._stackix_dataset import StackixDataset
28from autoencodix.trainers._stackix_trainer import StackixTrainer
29from autoencodix.visualize._general_visualizer import GeneralVisualizer
32class Stackix(BasePipeline):
33 """Stackix pipeline for training multiple VAEs on different modalities and stacking their latent spaces.
35 This pipeline uses:
36 1. StackixPreprocessor to prepare data for multi-modality training
37 2. StackixTrainer to train individual VAEs, extract latent spaces, and train the final stacked model
39 Like other pipelines, it follows the standard BasePipeline interface and workflow.
41 Additional Attributes:
42 _default_config: Is set to StackixConfig here.
44 """
46 def __init__(
47 self,
48 data: Optional[Union[DataPackage, DatasetContainer]] = None,
49 trainer_type: Type[BaseTrainer] = StackixTrainer,
50 dataset_type: Type[BaseDataset] = StackixDataset,
51 model_type: Type[BaseAutoencoder] = VarixArchitecture,
52 loss_type: Type[BaseLoss] = VarixLoss,
53 preprocessor_type: Type[BasePreprocessor] = StackixPreprocessor,
54 visualizer: Type[BaseVisualizer] = GeneralVisualizer,
55 evaluator: Optional[Type[BaseEvaluator]] = GeneralEvaluator,
56 result: Optional[Result] = None,
57 datasplitter_type: Type[DataSplitter] = DataSplitter,
58 custom_splits: Optional[Dict[str, np.ndarray]] = None,
59 config: Optional[DefaultConfig] = None,
60 ontologies: Optional[Union[List, Dict]] = None,
61 ) -> None:
62 """Initialize the Stackix pipeline.
64 See parent class for full list of Args.
65 """
66 self._default_config = StackixConfig()
67 super().__init__(
68 data=data,
69 dataset_type=dataset_type
70 or NumericDataset, # Fallback, but not directly used
71 trainer_type=trainer_type,
72 model_type=model_type,
73 loss_type=loss_type,
74 preprocessor_type=preprocessor_type,
75 visualizer=visualizer,
76 evaluator=evaluator,
77 result=result,
78 datasplitter_type=datasplitter_type,
79 config=config,
80 custom_split=custom_splits,
81 ontologies=ontologies,
82 )
83 if not isinstance(self.config, StackixConfig):
84 raise TypeError(
85 f"For Stackix Pipeline, we only allow StackixConfig as type for config, got {type(self.config)}"
86 )
88 def _process_latent_results(
89 self, predictor_results: Result, predict_data: DatasetContainer
90 ):
91 """Processes the latent spaces from the StackixTrainer prediction results.
93 Creates a correctly annotated AnnData object.
94 This method overrides the BasePipeline implementation to specifically handle
95 the aligned latent space from the unpaired/stacked workflow.
98 Args:
99 predictor_results: Result object after predict step
100 predict_data: not used here, only to keep interface structure
102 """
103 latent = predictor_results.latentspaces.get(epoch=-1, split="test")
104 sample_ids = predictor_results.sample_ids.get(epoch=-1, split="test")
105 if latent is None:
106 import warnings
108 warnings.warn(
109 "No latent space found in predictor results. Cannot create AnnData object."
110 )
111 return
113 self.result.adata_latent = ad.AnnData(X=latent)
114 self.result.adata_latent.obs_names = sample_ids
115 self.result.adata_latent.var_names = [
116 f"Latent_{i}" for i in range(latent.shape[1]) # ty: ignore
117 ]
119 # 4. Update the main result object with the rest of the prediction results.
120 self.result.update(predictor_results)
122 print("Successfully created annotated latent space object (adata_latent).")