Coverage for src / autoencodix / trainers / _ontix_trainer.py: 57%
23 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
1import torch
2from typing import Optional, Type, Union, Tuple, Dict, Any
3from autoencodix.trainers._general_trainer import GeneralTrainer
4from autoencodix.base._base_dataset import BaseDataset
5from autoencodix.base._base_loss import BaseLoss
6from autoencodix.base._base_autoencoder import BaseAutoencoder
7from autoencodix.utils._result import Result
8from autoencodix.configs.default_config import DefaultConfig
11class OntixTrainer(GeneralTrainer):
12 """Specialized trainer for Ontix (ontology-based) autoencoders.
14 Handles ontology-specific weight masking and positive weight constraints. Uses most of the
15 functionality from the GeneralTrainer class and ontology-specific functionality is added via a hook
16 after the backward pass.
18 Attributes:
19 Inherits all attributes from GeneralTrainer.
20 """
22 def __init__(
23 self,
24 trainset: Optional[BaseDataset],
25 validset: Optional[BaseDataset],
26 result: Result,
27 config: DefaultConfig,
28 model_type: Type[BaseAutoencoder],
29 loss_type: Type[BaseLoss],
30 ontologies: Optional[Union[Tuple, Dict[Any, Any]]],
31 **kwargs,
32 ):
33 """Initializes the OntixTrainer with the given datasets, model, and configuration.
36 Args:
37 trainset: The dataset used for training.
38 validset: The dataset used for validation, if provided.
39 result: An object to store and manage training results.
40 config: Configuration object containing training hyperparameters and settings.
41 model_type: The autoencoder model class to be trained.
42 loss_type: The loss function class specific to the model.
43 ontologies: Ontology information required for Ontix.
44 """
45 super().__init__(
46 trainset=trainset,
47 validset=validset,
48 result=result,
49 config=config,
50 model_type=model_type,
51 loss_type=loss_type,
52 ontologies=ontologies,
53 )
55 def _ontix_hook(self):
56 """Apply ontology-specific processing after backward pass."""
57 # Apply positive weight constraint to decoder
58 self._model._decoder.apply(self._model._positive_dec)
60 # Apply ontology-based weight masking
61 with torch.no_grad():
62 self._validate_masks()
63 self._apply_weight_masks()
65 def _validate_masks(self):
66 """Validate that number of masks matches decoder layers."""
67 if len(self._model.masks) != len(self._model._decoder):
68 raise ValueError(
69 f"Number of masks ({len(self._model.masks)}) does not match "
70 f"number of decoder layers ({len(self._model._decoder)})"
71 )
73 def _apply_weight_masks(self):
74 """Apply weight masks to decoder layers."""
75 for i, mask in enumerate(self._model.masks):
76 mask = mask.to(self._fabric.device)
77 self._model._decoder[i].weight.mul_(mask)