Coverage for src / autoencodix / disentanglix.py: 93%

29 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-21 10:09 +0200

1from typing import Dict, Optional, Type, Union 

2import torch 

3import numpy as np 

4 

5from autoencodix.base._base_dataset import BaseDataset 

6from autoencodix.utils._utils import config_method 

7from autoencodix.base._base_loss import BaseLoss 

8from autoencodix.base._base_pipeline import BasePipeline 

9from autoencodix.base._base_trainer import BaseTrainer 

10from autoencodix.base._base_visualizer import BaseVisualizer 

11from autoencodix.base._base_preprocessor import BasePreprocessor 

12from autoencodix.base._base_evaluator import BaseEvaluator 

13from autoencodix.base._base_autoencoder import BaseAutoencoder 

14from autoencodix.data._datasetcontainer import DatasetContainer 

15from autoencodix.data._datasplitter import DataSplitter 

16from autoencodix.data.datapackage import DataPackage 

17from autoencodix.data._numeric_dataset import NumericDataset 

18from autoencodix.data.general_preprocessor import GeneralPreprocessor 

19from autoencodix.evaluate._general_evaluator import GeneralEvaluator 

20from autoencodix.modeling._varix_architecture import VarixArchitecture 

21from autoencodix.trainers._general_trainer import GeneralTrainer 

22from autoencodix.utils._result import Result 

23from autoencodix.configs.default_config import DefaultConfig 

24from autoencodix.configs.disentanglix_config import DisentanglixConfig 

25from autoencodix.utils._losses import DisentanglixLoss 

26from autoencodix.visualize._general_visualizer import GeneralVisualizer 

27 

28 

29class Disentanglix(BasePipeline): 

30 """Disentanglix-specific version of the BasePipeline. 

31 

32 This class extends BasePipeline. See the parent class for a full list 

33 of attributes and methods. 

34 

35 Additional Attributes: 

36 _default_config: Is set to DisentanglixConfig here. 

37 """ 

38 

39 def __init__( 

40 self, 

41 data: Optional[Union[DataPackage, DatasetContainer]] = None, 

42 trainer_type: Type[BaseTrainer] = GeneralTrainer, 

43 dataset_type: Type[BaseDataset] = NumericDataset, 

44 model_type: Type[BaseAutoencoder] = VarixArchitecture, 

45 loss_type: Type[BaseLoss] = DisentanglixLoss, 

46 preprocessor_type: Type[BasePreprocessor] = GeneralPreprocessor, 

47 visualizer: Type[BaseVisualizer] = GeneralVisualizer, 

48 evaluator: Optional[Type[BaseEvaluator]] = GeneralEvaluator, 

49 result: Optional[Result] = None, 

50 datasplitter_type: Type[DataSplitter] = DataSplitter, 

51 custom_splits: Optional[Dict[str, np.ndarray]] = None, 

52 config: Optional[DefaultConfig] = None, 

53 ) -> None: 

54 """Initialize Varix pipeline with customizable components. 

55 

56 See the init method of parent class for a full list of Args. 

57 """ 

58 self._default_config = DisentanglixConfig() 

59 super().__init__( 

60 data=data, 

61 dataset_type=dataset_type, 

62 trainer_type=trainer_type, 

63 model_type=model_type, 

64 loss_type=loss_type, 

65 preprocessor_type=preprocessor_type, 

66 visualizer=visualizer, 

67 evaluator=evaluator, 

68 result=result, 

69 datasplitter_type=datasplitter_type, 

70 config=config, 

71 custom_split=custom_splits, 

72 )