Coverage for src / autoencodix / vanillix.py: 64%
45 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
1from typing import Dict, Optional, Type, Union, List
3import numpy as np
4import torch
6from autoencodix.base._base_dataset import BaseDataset
7from autoencodix.base._base_preprocessor import BasePreprocessor
8from autoencodix.base._base_loss import BaseLoss
9from autoencodix.data.datapackage import DataPackage
10from autoencodix.base._base_pipeline import BasePipeline
11from autoencodix.base._base_trainer import BaseTrainer
12from autoencodix.base._base_evaluator import BaseEvaluator
13from autoencodix.base._base_visualizer import BaseVisualizer
14from autoencodix.base._base_autoencoder import BaseAutoencoder
15from autoencodix.data._datasetcontainer import DatasetContainer
16from autoencodix.data._datasplitter import DataSplitter
17from autoencodix.data._numeric_dataset import NumericDataset
18from autoencodix.data.general_preprocessor import GeneralPreprocessor
19from autoencodix.evaluate._general_evaluator import GeneralEvaluator
20from autoencodix.modeling._vanillix_architecture import VanillixArchitecture
21from autoencodix.trainers._general_trainer import GeneralTrainer
22from autoencodix.utils._result import Result
23from autoencodix.configs.default_config import DefaultConfig
24from autoencodix.configs.vanillix_config import VanillixConfig
25from autoencodix.utils._losses import VanillixLoss
26from autoencodix.visualize._general_visualizer import GeneralVisualizer
29class Vanillix(BasePipeline):
30 """Vanillix specific version of the BasePipeline class.
32 Inherits preprocess, fit, predict, evaluate, and visualize methods from BasePipeline.
33 This class extends BasePipeline. See the parent class for a full list
34 of attributes and methods.
36 Additional Attributes:
37 _default_config: Is set to VanillixConfig here.
39 """
41 def __init__(
42 self,
43 data: Optional[Union[DataPackage, DatasetContainer]] = None,
44 trainer_type: Type[BaseTrainer] = GeneralTrainer,
45 dataset_type: Type[BaseDataset] = NumericDataset,
46 model_type: Type[BaseAutoencoder] = VanillixArchitecture,
47 loss_type: Type[BaseLoss] = VanillixLoss,
48 preprocessor_type: Type[BasePreprocessor] = GeneralPreprocessor,
49 visualizer: Type[BaseVisualizer] = GeneralVisualizer,
50 evaluator: Optional[Type[BaseEvaluator]] = GeneralEvaluator,
51 result: Optional[Result] = None,
52 datasplitter_type: Type[DataSplitter] = DataSplitter,
53 custom_splits: Optional[Dict[str, np.ndarray]] = None,
54 config: Optional[DefaultConfig] = None,
55 ontologies: Optional[Union[List, Dict]] = None,
56 ) -> None:
57 """Initialize Vanillix pipeline with customizable components.
59 Some components are passed as types rather than instances because they require
60 data that is only available after preprocessing.
62 See implementation of parent class for list of full Args.
63 """
64 self._default_config = VanillixConfig()
65 super().__init__(
66 data=data,
67 dataset_type=dataset_type,
68 trainer_type=trainer_type,
69 model_type=model_type,
70 loss_type=loss_type,
71 preprocessor_type=preprocessor_type,
72 visualizer=visualizer,
73 evaluator=evaluator,
74 result=result,
75 datasplitter_type=datasplitter_type,
76 config=config,
77 custom_split=custom_splits,
78 ontologies=ontologies,
79 )
81 def sample_latent_space(
82 self,
83 n_samples: int,
84 split: str = "test",
85 epoch: int = -1,
86 ) -> torch.Tensor:
87 """Samples latent space points from the empirical latent distribution.
89 This method draws new latent points by fitting a diagonal Gaussian
90 distribution to the latent codes of the specified split and epoch, and
91 sampling from it. This enables approximate generative sampling for
92 autoencoders that do not model uncertainty explicitly.
94 Args:
95 n_samples: The number of latent points to sample. Must be a positive
96 integer.
97 split: The split to sample from (train, valid, test), default is test.
98 epoch: The epoch to sample from, default is the last epoch (-1).
100 Returns:
101 z: torch.Tensor - The sampled latent space points.
103 Raises:
104 ValueError: If the model has not been trained, latent codes have not
105 been computed, or n_samples is not a positive integer.
106 TypeError: If the stored latent codes are not numpy arrays.
107 """
109 if not hasattr(self, "_trainer") or self._trainer is None:
110 raise ValueError("Model is not trained yet. Please train the model first.")
111 if self.result.latentspaces is None:
112 raise ValueError("Model has no stored latent codes for sampling.")
113 if not isinstance(n_samples, int) or n_samples <= 0:
114 raise ValueError("n_samples must be a positive integer.")
116 Z = self.result.latentspaces.get(split=split, epoch=epoch)
118 if not isinstance(Z, np.ndarray):
119 raise TypeError(
120 f"Expected latent codes to be of type numpy.ndarray, got {type(Z)}."
121 )
123 Z_t = torch.from_numpy(Z).to(
124 device=self._trainer._model.device,
125 dtype=self._trainer._model.dtype,
126 )
128 with torch.no_grad():
129 # Fit empirical diagonal Gaussian
130 global_mu = Z_t.mean(dim=0)
131 global_std = Z_t.std(dim=0)
133 eps = torch.randn(
134 n_samples,
135 Z_t.shape[1],
136 device=Z_t.device,
137 dtype=Z_t.dtype,
138 )
140 z = global_mu + eps * global_std
141 return z