Coverage for src / autoencodix / trainers / _ontix_trainer.py: 57%

23 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-21 10:09 +0200

1import torch 

2from typing import Optional, Type, Union, Tuple, Dict, Any 

3from autoencodix.trainers._general_trainer import GeneralTrainer 

4from autoencodix.base._base_dataset import BaseDataset 

5from autoencodix.base._base_loss import BaseLoss 

6from autoencodix.base._base_autoencoder import BaseAutoencoder 

7from autoencodix.utils._result import Result 

8from autoencodix.configs.default_config import DefaultConfig 

9 

10 

11class OntixTrainer(GeneralTrainer): 

12 """Specialized trainer for Ontix (ontology-based) autoencoders. 

13 

14 Handles ontology-specific weight masking and positive weight constraints. Uses most of the 

15 functionality from the GeneralTrainer class and ontology-specific functionality is added via a hook 

16 after the backward pass. 

17 

18 Attributes: 

19 Inherits all attributes from GeneralTrainer. 

20 """ 

21 

22 def __init__( 

23 self, 

24 trainset: Optional[BaseDataset], 

25 validset: Optional[BaseDataset], 

26 result: Result, 

27 config: DefaultConfig, 

28 model_type: Type[BaseAutoencoder], 

29 loss_type: Type[BaseLoss], 

30 ontologies: Optional[Union[Tuple, Dict[Any, Any]]], 

31 **kwargs, 

32 ): 

33 """Initializes the OntixTrainer with the given datasets, model, and configuration. 

34 

35 

36 Args: 

37 trainset: The dataset used for training. 

38 validset: The dataset used for validation, if provided. 

39 result: An object to store and manage training results. 

40 config: Configuration object containing training hyperparameters and settings. 

41 model_type: The autoencoder model class to be trained. 

42 loss_type: The loss function class specific to the model. 

43 ontologies: Ontology information required for Ontix. 

44 """ 

45 super().__init__( 

46 trainset=trainset, 

47 validset=validset, 

48 result=result, 

49 config=config, 

50 model_type=model_type, 

51 loss_type=loss_type, 

52 ontologies=ontologies, 

53 ) 

54 

55 def _ontix_hook(self): 

56 """Apply ontology-specific processing after backward pass.""" 

57 # Apply positive weight constraint to decoder 

58 self._model._decoder.apply(self._model._positive_dec) 

59 

60 # Apply ontology-based weight masking 

61 with torch.no_grad(): 

62 self._validate_masks() 

63 self._apply_weight_masks() 

64 

65 def _validate_masks(self): 

66 """Validate that number of masks matches decoder layers.""" 

67 if len(self._model.masks) != len(self._model._decoder): 

68 raise ValueError( 

69 f"Number of masks ({len(self._model.masks)}) does not match " 

70 f"number of decoder layers ({len(self._model._decoder)})" 

71 ) 

72 

73 def _apply_weight_masks(self): 

74 """Apply weight masks to decoder layers.""" 

75 for i, mask in enumerate(self._model.masks): 

76 mask = mask.to(self._fabric.device) 

77 self._model._decoder[i].weight.mul_(mask)