Coverage for src / autoencodix / vanillix.py: 64%

45 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-21 10:09 +0200

1from typing import Dict, Optional, Type, Union, List 

2 

3import numpy as np 

4import torch 

5 

6from autoencodix.base._base_dataset import BaseDataset 

7from autoencodix.base._base_preprocessor import BasePreprocessor 

8from autoencodix.base._base_loss import BaseLoss 

9from autoencodix.data.datapackage import DataPackage 

10from autoencodix.base._base_pipeline import BasePipeline 

11from autoencodix.base._base_trainer import BaseTrainer 

12from autoencodix.base._base_evaluator import BaseEvaluator 

13from autoencodix.base._base_visualizer import BaseVisualizer 

14from autoencodix.base._base_autoencoder import BaseAutoencoder 

15from autoencodix.data._datasetcontainer import DatasetContainer 

16from autoencodix.data._datasplitter import DataSplitter 

17from autoencodix.data._numeric_dataset import NumericDataset 

18from autoencodix.data.general_preprocessor import GeneralPreprocessor 

19from autoencodix.evaluate._general_evaluator import GeneralEvaluator 

20from autoencodix.modeling._vanillix_architecture import VanillixArchitecture 

21from autoencodix.trainers._general_trainer import GeneralTrainer 

22from autoencodix.utils._result import Result 

23from autoencodix.configs.default_config import DefaultConfig 

24from autoencodix.configs.vanillix_config import VanillixConfig 

25from autoencodix.utils._losses import VanillixLoss 

26from autoencodix.visualize._general_visualizer import GeneralVisualizer 

27 

28 

29class Vanillix(BasePipeline): 

30 """Vanillix specific version of the BasePipeline class. 

31 

32 Inherits preprocess, fit, predict, evaluate, and visualize methods from BasePipeline. 

33 This class extends BasePipeline. See the parent class for a full list 

34 of attributes and methods. 

35 

36 Additional Attributes: 

37 _default_config: Is set to VanillixConfig here. 

38 

39 """ 

40 

41 def __init__( 

42 self, 

43 data: Optional[Union[DataPackage, DatasetContainer]] = None, 

44 trainer_type: Type[BaseTrainer] = GeneralTrainer, 

45 dataset_type: Type[BaseDataset] = NumericDataset, 

46 model_type: Type[BaseAutoencoder] = VanillixArchitecture, 

47 loss_type: Type[BaseLoss] = VanillixLoss, 

48 preprocessor_type: Type[BasePreprocessor] = GeneralPreprocessor, 

49 visualizer: Type[BaseVisualizer] = GeneralVisualizer, 

50 evaluator: Optional[Type[BaseEvaluator]] = GeneralEvaluator, 

51 result: Optional[Result] = None, 

52 datasplitter_type: Type[DataSplitter] = DataSplitter, 

53 custom_splits: Optional[Dict[str, np.ndarray]] = None, 

54 config: Optional[DefaultConfig] = None, 

55 ontologies: Optional[Union[List, Dict]] = None, 

56 ) -> None: 

57 """Initialize Vanillix pipeline with customizable components. 

58 

59 Some components are passed as types rather than instances because they require 

60 data that is only available after preprocessing. 

61 

62 See implementation of parent class for list of full Args. 

63 """ 

64 self._default_config = VanillixConfig() 

65 super().__init__( 

66 data=data, 

67 dataset_type=dataset_type, 

68 trainer_type=trainer_type, 

69 model_type=model_type, 

70 loss_type=loss_type, 

71 preprocessor_type=preprocessor_type, 

72 visualizer=visualizer, 

73 evaluator=evaluator, 

74 result=result, 

75 datasplitter_type=datasplitter_type, 

76 config=config, 

77 custom_split=custom_splits, 

78 ontologies=ontologies, 

79 ) 

80 

81 def sample_latent_space( 

82 self, 

83 n_samples: int, 

84 split: str = "test", 

85 epoch: int = -1, 

86 ) -> torch.Tensor: 

87 """Samples latent space points from the empirical latent distribution. 

88 

89 This method draws new latent points by fitting a diagonal Gaussian 

90 distribution to the latent codes of the specified split and epoch, and 

91 sampling from it. This enables approximate generative sampling for 

92 autoencoders that do not model uncertainty explicitly. 

93 

94 Args: 

95 n_samples: The number of latent points to sample. Must be a positive 

96 integer. 

97 split: The split to sample from (train, valid, test), default is test. 

98 epoch: The epoch to sample from, default is the last epoch (-1). 

99 

100 Returns: 

101 z: torch.Tensor - The sampled latent space points. 

102 

103 Raises: 

104 ValueError: If the model has not been trained, latent codes have not 

105 been computed, or n_samples is not a positive integer. 

106 TypeError: If the stored latent codes are not numpy arrays. 

107 """ 

108 

109 if not hasattr(self, "_trainer") or self._trainer is None: 

110 raise ValueError("Model is not trained yet. Please train the model first.") 

111 if self.result.latentspaces is None: 

112 raise ValueError("Model has no stored latent codes for sampling.") 

113 if not isinstance(n_samples, int) or n_samples <= 0: 

114 raise ValueError("n_samples must be a positive integer.") 

115 

116 Z = self.result.latentspaces.get(split=split, epoch=epoch) 

117 

118 if not isinstance(Z, np.ndarray): 

119 raise TypeError( 

120 f"Expected latent codes to be of type numpy.ndarray, got {type(Z)}." 

121 ) 

122 

123 Z_t = torch.from_numpy(Z).to( 

124 device=self._trainer._model.device, 

125 dtype=self._trainer._model.dtype, 

126 ) 

127 

128 with torch.no_grad(): 

129 # Fit empirical diagonal Gaussian 

130 global_mu = Z_t.mean(dim=0) 

131 global_std = Z_t.std(dim=0) 

132 

133 eps = torch.randn( 

134 n_samples, 

135 Z_t.shape[1], 

136 device=Z_t.device, 

137 dtype=Z_t.dtype, 

138 ) 

139 

140 z = global_mu + eps * global_std 

141 return z