Coverage for src / autoencodix / varix.py: 93%

29 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-21 10:09 +0200

1from typing import Dict, Optional, Type, Union, List 

2import torch 

3import numpy as np 

4 

5from autoencodix.base._base_dataset import BaseDataset 

6from autoencodix.utils._utils import config_method 

7from autoencodix.base._base_loss import BaseLoss 

8from autoencodix.base._base_pipeline import BasePipeline 

9from autoencodix.base._base_evaluator import BaseEvaluator 

10from autoencodix.base._base_trainer import BaseTrainer 

11from autoencodix.base._base_visualizer import BaseVisualizer 

12from autoencodix.base._base_preprocessor import BasePreprocessor 

13from autoencodix.base._base_autoencoder import BaseAutoencoder 

14from autoencodix.data._datasetcontainer import DatasetContainer 

15from autoencodix.data._datasplitter import DataSplitter 

16from autoencodix.data.datapackage import DataPackage 

17from autoencodix.data._numeric_dataset import NumericDataset 

18from autoencodix.data.general_preprocessor import GeneralPreprocessor 

19from autoencodix.evaluate._general_evaluator import GeneralEvaluator 

20from autoencodix.modeling._varix_architecture import VarixArchitecture 

21from autoencodix.trainers._general_trainer import GeneralTrainer 

22from autoencodix.utils._result import Result 

23from autoencodix.configs.default_config import DefaultConfig 

24from autoencodix.configs.varix_config import VarixConfig 

25from autoencodix.utils._losses import VarixLoss 

26from autoencodix.visualize._general_visualizer import GeneralVisualizer 

27 

28 

29class Varix(BasePipeline): 

30 """Varix specific version of the BasePipeline class. 

31 

32 Inherits preprocess, fit, predict, evaluate, and visualize methods from BasePipeline. 

33 This class extends BasePipeline. See the parent class for a full list 

34 of attributes and methods. 

35 

36 Additional Attributes: 

37 _default_config: Is set to VarixConfig here. 

38 

39 """ 

40 

41 def __init__( 

42 self, 

43 data: Optional[Union[DataPackage, DatasetContainer]] = None, 

44 trainer_type: Type[BaseTrainer] = GeneralTrainer, 

45 dataset_type: Type[BaseDataset] = NumericDataset, 

46 model_type: Type[BaseAutoencoder] = VarixArchitecture, 

47 loss_type: Type[BaseLoss] = VarixLoss, 

48 preprocessor_type: Type[BasePreprocessor] = GeneralPreprocessor, 

49 visualizer: Type[BaseVisualizer] = GeneralVisualizer, 

50 evaluator: Optional[Type[BaseEvaluator]] = GeneralEvaluator, 

51 result: Optional[Result] = None, 

52 datasplitter_type: Type[DataSplitter] = DataSplitter, 

53 custom_splits: Optional[Dict[str, np.ndarray]] = None, 

54 ontologies: Optional[Union[List, Dict]] = None, 

55 config: Optional[DefaultConfig] = None, 

56 ) -> None: 

57 """Initialize Varix pipeline with customizable components. 

58 

59 Some components are passed as types rather than instances because they require 

60 data that is only available after preprocessing. 

61 

62 See parent class for full list of Args. 

63 

64 """ 

65 self._default_config = VarixConfig() 

66 super().__init__( 

67 data=data, 

68 dataset_type=dataset_type, 

69 trainer_type=trainer_type, 

70 model_type=model_type, 

71 loss_type=loss_type, 

72 preprocessor_type=preprocessor_type, 

73 visualizer=visualizer, 

74 evaluator=evaluator, 

75 result=result, 

76 datasplitter_type=datasplitter_type, 

77 config=config, 

78 custom_split=custom_splits, 

79 ontologies=ontologies, 

80 )