Coverage for src / autoencodix / maskix.py: 76%
37 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
1from typing import Dict, Optional, Type, Union, Callable, Any
2import numpy as np
3import torch
5from autoencodix.base._base_dataset import BaseDataset
6from autoencodix.base._base_loss import BaseLoss
7from autoencodix.base._base_pipeline import BasePipeline
8from autoencodix.base._base_trainer import BaseTrainer
9from autoencodix.base._base_visualizer import BaseVisualizer
10from autoencodix.base._base_preprocessor import BasePreprocessor
11from autoencodix.base._base_autoencoder import BaseAutoencoder
12from autoencodix.base._base_evaluator import BaseEvaluator
13from autoencodix.data._datasetcontainer import DatasetContainer
14from autoencodix.data._datasplitter import DataSplitter
15from autoencodix.data.datapackage import DataPackage
16from autoencodix.data._numeric_dataset import NumericDataset
17from autoencodix.data.general_preprocessor import GeneralPreprocessor
18from autoencodix.evaluate._general_evaluator import GeneralEvaluator
20from autoencodix.utils._result import Result
21from autoencodix.configs.default_config import DefaultConfig
22from autoencodix.configs.maskix_config import MaskixConfig
23from autoencodix.utils._losses import MaskixLoss
24from autoencodix.modeling._maskix_architecture import MaskixArchitectureVanilla
25from autoencodix.trainers._maskix_trainer import MaskixTrainer
26from autoencodix.visualize._general_visualizer import GeneralVisualizer
27from autoencodix.utils._model_output import ModelOutput
30class Maskix(BasePipeline):
31 """Maskix specific version of the BasePipeline class.
33 Inherits preprocess, fit, predict, evaluate, and visualize methods from BasePipeline.
35 This class extends BasePipeline. See the parent class for a full list
36 of attributes and methods.
38 Additional Attributes:
39 _default_config: Is set to OntixConfig here.
41 """
43 def __init__(
44 self,
45 data: Optional[Union[DataPackage, DatasetContainer]] = None,
46 trainer_type: Type[BaseTrainer] = MaskixTrainer,
47 dataset_type: Type[BaseDataset] = NumericDataset,
48 model_type: Type[BaseAutoencoder] = MaskixArchitectureVanilla,
49 loss_type: Type[BaseLoss] = MaskixLoss,
50 preprocessor_type: Type[BasePreprocessor] = GeneralPreprocessor,
51 visualizer: Type[BaseVisualizer] = GeneralVisualizer,
52 evaluator: Optional[Type[BaseEvaluator]] = GeneralEvaluator,
53 result: Optional[Result] = None,
54 datasplitter_type: Type[DataSplitter] = DataSplitter,
55 custom_splits: Optional[Dict[str, np.ndarray]] = None,
56 config: Optional[DefaultConfig] = None,
57 masking_fn: Optional[Callable] = None,
58 masking_fn_kwargs: Dict[str, Any] = {},
59 **kwargs: dict,
60 ) -> None:
61 """Initialize Maskix pipeline with customizable components.
63 Some components are passed as types rather than instances because they require
64 data that is only available after preprocessing.
66 See parent class for full list of Arguments.
68 Raises:
69 TypeError: if ontologies are not a Tuple or List.
71 """
72 self._default_config = MaskixConfig()
74 super().__init__(
75 data=data,
76 dataset_type=dataset_type,
77 trainer_type=trainer_type,
78 model_type=model_type,
79 loss_type=loss_type,
80 preprocessor_type=preprocessor_type,
81 visualizer=visualizer,
82 evaluator=evaluator,
83 result=result,
84 datasplitter_type=datasplitter_type,
85 config=config,
86 custom_split=custom_splits,
87 masking_fn=masking_fn,
88 masking_fn_kwargs=masking_fn_kwargs,
89 **kwargs,
90 )
92 def impute(self, corrupted_tensor: torch.Tensor) -> ModelOutput:
93 """Impute missing values in the corrupted tensor using the trained Maskix model.
95 Args:
96 corrupted_tensor: A tensor with missing values to be imputed.
97 Returns:
98 model_output: A tensor with imputed values.
99 """
100 if self._trainer is None:
101 raise ValueError(
102 "Model needs to be trained before imputation., Run fit() first."
103 )
104 self._trainer._model.eval()
105 with torch.no_grad(), self._trainer._fabric.autocast():
106 corrupted_tensor = corrupted_tensor.to(self._trainer._model.device)
107 model_output = self._trainer._model(corrupted_tensor)
108 return model_output