Coverage for src / autoencodix / base / _base_trainer.py: 81%
114 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 10:09 +0200
1import abc
2from packaging.version import Version
3import random
4import numpy as np
5import os
6import warnings
7from typing import Optional, Type, List, cast, Tuple, Union
9import torch
10from lightning_fabric import Fabric
11from torch.utils.data import DataLoader
13from autoencodix.base._base_dataset import BaseDataset
14from autoencodix.base._base_loss import BaseLoss
15from autoencodix.base._base_autoencoder import BaseAutoencoder
16from autoencodix.utils._result import Result
17from autoencodix.configs.default_config import DefaultConfig
20class BaseTrainer(abc.ABC):
21 """General training logic for all autoencoder models.
23 This class sets up the model, optimizer, and data loaders. It also handles
24 reproducibility and model-specific configurations. Subclasses must implement
25 model training and prediction logic.
27 Attributes:
28 _trainset: The dataset used for training.
29 _validset: The dataset used for validation, if provided.
30 _result: An object to store and manage training results.
31 _config: Configuration object containing training hyperparameters and settings.
32 _model_type: The autoencoder model class to be trained.
33 _loss_fn: Instantiated loss function specific to the model.
34 _trainloader: DataLoader for the training dataset.
35 _validloader: DataLoader for the validation dataset, if provided.
36 _model: The instantiated model architecture.
37 _optimizer: The optimizer used for training.
38 _fabric: Lightning Fabric wrapper for device and precision management.
39 ontologies: Ontology information, if provided for Ontix
40 """
42 def __init__(
43 self,
44 trainset: Optional[BaseDataset],
45 validset: Optional[BaseDataset],
46 result: Result,
47 config: DefaultConfig,
48 model_type: Type[BaseAutoencoder],
49 loss_type: Type[BaseLoss],
50 ontologies: Optional[
51 Union[Tuple, List]
52 ] = None, # Addition to Varix, mandotory for Ontix
53 **kwargs,
54 ):
55 self._trainset = trainset
56 self._model_type = model_type
57 self._validset = validset
58 self._result = result
59 self._config = config
60 self._loss_type = loss_type
61 self.ontologies = ontologies
62 self.setup_trainer()
64 def setup_trainer(self, old_model=None):
65 if old_model is None:
66 self._input_validation()
67 self._init_loaders()
69 self._loss_fn = self._loss_type(config=self._config)
71 self._handle_reproducibility()
72 # Internal data handling
73 self._model: BaseAutoencoder
74 self._fabric = Fabric(
75 accelerator=self._config.device,
76 devices=self._config.n_gpus,
77 precision=self._config.float_precision,
78 strategy=self._config.gpu_strategy,
79 )
81 self._fabric.launch()
82 self._setup_fabric(old_model=old_model)
84 self._n_cpus = os.cpu_count()
85 if self._n_cpus is None:
86 self._n_cpus = 0
88 def _setup_fabric(self, old_model=None):
89 """
90 Sets up the model, optimizer, and data loaders with Lightning Fabric.
91 """
92 self._init_model_architecture(
93 ontologies=self.ontologies, old_model=old_model
94 ) # Ontix
96 self._optimizer = torch.optim.AdamW(
97 params=self._model.parameters(),
98 lr=self._config.learning_rate,
99 weight_decay=self._config.weight_decay,
100 )
101 if old_model is None:
102 self._trainloader = self._fabric.setup_dataloaders(self._trainloader) # type: ignore
103 if self._validloader is not None:
104 self._validloader = self._fabric.setup_dataloaders(self._validloader) # type: ignore
106 if (
107 Version(torch.__version__) >= Version("2.0")
108 and torch.cuda.is_available()
109 and old_model is None
110 and self._config.compile_model
111 ): # dont compile twice
112 self._model = torch.compile(self._model)
114 self._model, self._optimizer = self._fabric.setup(self._model, self._optimizer)
116 def _init_loaders(self):
117 """Initializes the DataLoaders for training and validation datasets."""
118 # g = torch.Generator()
119 # g.manual_seed(self._config.global_seed)
120 last_batch_is_one_sample = len(self._trainset) % self._config.batch_size == 1
121 corrected_bs = (
122 self._config.batch_size + 1
123 if last_batch_is_one_sample
124 else self._config.batch_size
125 )
126 if last_batch_is_one_sample:
127 warnings.warn(
128 f"increased batch_size to {corrected_bs} for trainset, to avoid dropping samples and having batches (makes trainingdynamics messy with missing samples per epoch) of size one (fails for Models with BachNorm)"
129 )
131 self._trainloader = DataLoader(
132 cast(BaseDataset, self._trainset),
133 shuffle=True,
134 batch_size=corrected_bs,
135 worker_init_fn=self._seed_worker,
136 pin_memory=self._config.pin_memory,
137 # generator=g,
138 )
139 if self._validset:
140 last_batch_is_one_sample = (
141 len(self._validset) % self._config.batch_size == 1
142 )
143 corrected_bs = (
144 self._config.batch_size + 1
145 if last_batch_is_one_sample
146 else self._config.batch_size
147 )
148 if last_batch_is_one_sample:
149 warnings.warn(
150 f"increased batch_size to {corrected_bs} for validset, to avoid dropping samples and having batches (makes trainingdynamics messy with missing samples per epoch) of size one (fails for Models with BachNorm)"
151 )
153 self._validloader = DataLoader(
154 dataset=self._validset,
155 batch_size=self._config.batch_size,
156 shuffle=False,
157 pin_memory=self._config.pin_memory,
158 )
159 else:
160 self._validloader = None # type: ignore
162 def _input_validation(self) -> None:
163 if self._trainset is None:
164 raise ValueError(
165 "Trainset cannot be None. Check the indices you provided with a custom split or be sure that the train_ratio attribute of the config is >0."
166 )
167 if not isinstance(self._trainset, BaseDataset):
168 raise TypeError(
169 f"Expected train type to be an instance of BaseDataset, got {type(self._trainset)}."
170 )
171 if self._validset is None:
172 print("training without validation")
173 elif not isinstance(self._validset, BaseDataset):
174 raise TypeError(
175 f"Expected valid type to be an instance of BaseDataset, got {type(self._validset)}."
176 )
177 if self._config is None:
178 raise ValueError("Config cannot be None.")
180 def _handle_reproducibility(self) -> None:
181 """Sets all relevant seeds for reproducibility
183 Raises:
184 NotImplementedError: If the device is set to "mps" (Apple Silicon).
185 """
186 if self._config.reproducible:
187 torch.use_deterministic_algorithms(True)
188 torch.manual_seed(seed=self._config.global_seed)
189 random.seed(self._config.global_seed)
190 np.random.seed(self._config.global_seed)
191 if torch.cuda.is_available():
192 torch.cuda.manual_seed(seed=self._config.global_seed)
193 torch.cuda.manual_seed_all(seed=self._config.global_seed)
194 torch.backends.cudnn.deterministic = True
195 torch.backends.cudnn.benchmark = False
196 if torch.backends.mps.is_available():
197 torch.mps.manual_seed(seed=self._config.global_seed)
199 # torch.use_deterministic_algorithms(True, warn_only=True)
201 print(
202 "Warning: MPS backend has limited support for deterministic algorithms. "
203 "Seeding is active, but full reproducibility is not guaranteed."
204 )
205 else:
206 print(
207 f"Reproducibility settings for device {self._config.device} are not implemented or necessary i.e. for cpu."
208 )
210 def _seed_worker(self, worker_id):
211 worker_seed = self._config.global_seed + worker_id
212 np.random.seed(worker_seed)
213 random.seed(worker_seed)
215 def _init_model_architecture(self, ontologies: tuple, old_model=None) -> None:
216 """Initializes the model architecture, based on the model type and input dimension."""
217 if old_model:
218 self._model = old_model
219 return
221 self._input_dim = cast(BaseDataset, self._trainset).get_input_dim()
222 self.feature_order = self._trainset.feature_ids
223 if ontologies is None:
224 self._model = self._model_type(
225 config=self._config, input_dim=self._input_dim
226 )
228 else:
229 self._model = self._model_type(
230 config=self._config,
231 input_dim=self._input_dim,
232 ontologies=ontologies,
233 feature_order=self._trainset.feature_ids, # type: ignore
234 )
236 def _should_checkpoint(self, epoch: int) -> bool:
237 return (
238 (epoch + 1) % self._config.checkpoint_interval == 0
239 or epoch == self._config.epochs - 1
240 )
242 @abc.abstractmethod
243 def train(self, epochs_overwrite: Optional[int] = None) -> Result:
244 pass
246 @abc.abstractmethod
247 def decode(self, x: torch.Tensor) -> torch.Tensor:
248 pass
250 @abc.abstractmethod
251 def predict(
252 self, data: BaseDataset, model: Optional[torch.nn.Module] = None, **kwargs
253 ) -> Result:
254 pass
256 @abc.abstractmethod
257 def purge(self) -> None:
258 """Cleans up any resources used during training, such as cached data or large attributes."""
259 pass