Coverage for src / autoencodix / disentanglix.py: 93%
29 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
1from typing import Dict, Optional, Type, Union
2import torch
3import numpy as np
5from autoencodix.base._base_dataset import BaseDataset
6from autoencodix.utils._utils import config_method
7from autoencodix.base._base_loss import BaseLoss
8from autoencodix.base._base_pipeline import BasePipeline
9from autoencodix.base._base_trainer import BaseTrainer
10from autoencodix.base._base_visualizer import BaseVisualizer
11from autoencodix.base._base_preprocessor import BasePreprocessor
12from autoencodix.base._base_evaluator import BaseEvaluator
13from autoencodix.base._base_autoencoder import BaseAutoencoder
14from autoencodix.data._datasetcontainer import DatasetContainer
15from autoencodix.data._datasplitter import DataSplitter
16from autoencodix.data.datapackage import DataPackage
17from autoencodix.data._numeric_dataset import NumericDataset
18from autoencodix.data.general_preprocessor import GeneralPreprocessor
19from autoencodix.evaluate._general_evaluator import GeneralEvaluator
20from autoencodix.modeling._varix_architecture import VarixArchitecture
21from autoencodix.trainers._general_trainer import GeneralTrainer
22from autoencodix.utils._result import Result
23from autoencodix.configs.default_config import DefaultConfig
24from autoencodix.configs.disentanglix_config import DisentanglixConfig
25from autoencodix.utils._losses import DisentanglixLoss
26from autoencodix.visualize._general_visualizer import GeneralVisualizer
29class Disentanglix(BasePipeline):
30 """Disentanglix-specific version of the BasePipeline.
32 This class extends BasePipeline. See the parent class for a full list
33 of attributes and methods.
35 Additional Attributes:
36 _default_config: Is set to DisentanglixConfig here.
37 """
39 def __init__(
40 self,
41 data: Optional[Union[DataPackage, DatasetContainer]] = None,
42 trainer_type: Type[BaseTrainer] = GeneralTrainer,
43 dataset_type: Type[BaseDataset] = NumericDataset,
44 model_type: Type[BaseAutoencoder] = VarixArchitecture,
45 loss_type: Type[BaseLoss] = DisentanglixLoss,
46 preprocessor_type: Type[BasePreprocessor] = GeneralPreprocessor,
47 visualizer: Type[BaseVisualizer] = GeneralVisualizer,
48 evaluator: Optional[Type[BaseEvaluator]] = GeneralEvaluator,
49 result: Optional[Result] = None,
50 datasplitter_type: Type[DataSplitter] = DataSplitter,
51 custom_splits: Optional[Dict[str, np.ndarray]] = None,
52 config: Optional[DefaultConfig] = None,
53 ) -> None:
54 """Initialize Varix pipeline with customizable components.
56 See the init method of parent class for a full list of Args.
57 """
58 self._default_config = DisentanglixConfig()
59 super().__init__(
60 data=data,
61 dataset_type=dataset_type,
62 trainer_type=trainer_type,
63 model_type=model_type,
64 loss_type=loss_type,
65 preprocessor_type=preprocessor_type,
66 visualizer=visualizer,
67 evaluator=evaluator,
68 result=result,
69 datasplitter_type=datasplitter_type,
70 config=config,
71 custom_split=custom_splits,
72 )