Coverage for src / autoencodix / maskix.py: 76%

37 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-21 10:09 +0200

1from typing import Dict, Optional, Type, Union, Callable, Any 

2import numpy as np 

3import torch 

4 

5from autoencodix.base._base_dataset import BaseDataset 

6from autoencodix.base._base_loss import BaseLoss 

7from autoencodix.base._base_pipeline import BasePipeline 

8from autoencodix.base._base_trainer import BaseTrainer 

9from autoencodix.base._base_visualizer import BaseVisualizer 

10from autoencodix.base._base_preprocessor import BasePreprocessor 

11from autoencodix.base._base_autoencoder import BaseAutoencoder 

12from autoencodix.base._base_evaluator import BaseEvaluator 

13from autoencodix.data._datasetcontainer import DatasetContainer 

14from autoencodix.data._datasplitter import DataSplitter 

15from autoencodix.data.datapackage import DataPackage 

16from autoencodix.data._numeric_dataset import NumericDataset 

17from autoencodix.data.general_preprocessor import GeneralPreprocessor 

18from autoencodix.evaluate._general_evaluator import GeneralEvaluator 

19 

20from autoencodix.utils._result import Result 

21from autoencodix.configs.default_config import DefaultConfig 

22from autoencodix.configs.maskix_config import MaskixConfig 

23from autoencodix.utils._losses import MaskixLoss 

24from autoencodix.modeling._maskix_architecture import MaskixArchitectureVanilla 

25from autoencodix.trainers._maskix_trainer import MaskixTrainer 

26from autoencodix.visualize._general_visualizer import GeneralVisualizer 

27from autoencodix.utils._model_output import ModelOutput 

28 

29 

30class Maskix(BasePipeline): 

31 """Maskix specific version of the BasePipeline class. 

32 

33 Inherits preprocess, fit, predict, evaluate, and visualize methods from BasePipeline. 

34 

35 This class extends BasePipeline. See the parent class for a full list 

36 of attributes and methods. 

37 

38 Additional Attributes: 

39 _default_config: Is set to OntixConfig here. 

40 

41 """ 

42 

43 def __init__( 

44 self, 

45 data: Optional[Union[DataPackage, DatasetContainer]] = None, 

46 trainer_type: Type[BaseTrainer] = MaskixTrainer, 

47 dataset_type: Type[BaseDataset] = NumericDataset, 

48 model_type: Type[BaseAutoencoder] = MaskixArchitectureVanilla, 

49 loss_type: Type[BaseLoss] = MaskixLoss, 

50 preprocessor_type: Type[BasePreprocessor] = GeneralPreprocessor, 

51 visualizer: Type[BaseVisualizer] = GeneralVisualizer, 

52 evaluator: Optional[Type[BaseEvaluator]] = GeneralEvaluator, 

53 result: Optional[Result] = None, 

54 datasplitter_type: Type[DataSplitter] = DataSplitter, 

55 custom_splits: Optional[Dict[str, np.ndarray]] = None, 

56 config: Optional[DefaultConfig] = None, 

57 masking_fn: Optional[Callable] = None, 

58 masking_fn_kwargs: Dict[str, Any] = {}, 

59 **kwargs: dict, 

60 ) -> None: 

61 """Initialize Maskix pipeline with customizable components. 

62 

63 Some components are passed as types rather than instances because they require 

64 data that is only available after preprocessing. 

65 

66 See parent class for full list of Arguments. 

67 

68 Raises: 

69 TypeError: if ontologies are not a Tuple or List. 

70 

71 """ 

72 self._default_config = MaskixConfig() 

73 

74 super().__init__( 

75 data=data, 

76 dataset_type=dataset_type, 

77 trainer_type=trainer_type, 

78 model_type=model_type, 

79 loss_type=loss_type, 

80 preprocessor_type=preprocessor_type, 

81 visualizer=visualizer, 

82 evaluator=evaluator, 

83 result=result, 

84 datasplitter_type=datasplitter_type, 

85 config=config, 

86 custom_split=custom_splits, 

87 masking_fn=masking_fn, 

88 masking_fn_kwargs=masking_fn_kwargs, 

89 **kwargs, 

90 ) 

91 

92 def impute(self, corrupted_tensor: torch.Tensor) -> ModelOutput: 

93 """Impute missing values in the corrupted tensor using the trained Maskix model. 

94 

95 Args: 

96 corrupted_tensor: A tensor with missing values to be imputed. 

97 Returns: 

98 model_output: A tensor with imputed values. 

99 """ 

100 if self._trainer is None: 

101 raise ValueError( 

102 "Model needs to be trained before imputation., Run fit() first." 

103 ) 

104 self._trainer._model.eval() 

105 with torch.no_grad(), self._trainer._fabric.autocast(): 

106 corrupted_tensor = corrupted_tensor.to(self._trainer._model.device) 

107 model_output = self._trainer._model(corrupted_tensor) 

108 return model_output