Coverage for src / autoencodix / base / _base_trainer.py: 81%

114 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-21 10:09 +0200

1import abc 

2from packaging.version import Version 

3import random 

4import numpy as np 

5import os 

6import warnings 

7from typing import Optional, Type, List, cast, Tuple, Union 

8 

9import torch 

10from lightning_fabric import Fabric 

11from torch.utils.data import DataLoader 

12 

13from autoencodix.base._base_dataset import BaseDataset 

14from autoencodix.base._base_loss import BaseLoss 

15from autoencodix.base._base_autoencoder import BaseAutoencoder 

16from autoencodix.utils._result import Result 

17from autoencodix.configs.default_config import DefaultConfig 

18 

19 

20class BaseTrainer(abc.ABC): 

21 """General training logic for all autoencoder models. 

22 

23 This class sets up the model, optimizer, and data loaders. It also handles 

24 reproducibility and model-specific configurations. Subclasses must implement 

25 model training and prediction logic. 

26 

27 Attributes: 

28 _trainset: The dataset used for training. 

29 _validset: The dataset used for validation, if provided. 

30 _result: An object to store and manage training results. 

31 _config: Configuration object containing training hyperparameters and settings. 

32 _model_type: The autoencoder model class to be trained. 

33 _loss_fn: Instantiated loss function specific to the model. 

34 _trainloader: DataLoader for the training dataset. 

35 _validloader: DataLoader for the validation dataset, if provided. 

36 _model: The instantiated model architecture. 

37 _optimizer: The optimizer used for training. 

38 _fabric: Lightning Fabric wrapper for device and precision management. 

39 ontologies: Ontology information, if provided for Ontix 

40 """ 

41 

42 def __init__( 

43 self, 

44 trainset: Optional[BaseDataset], 

45 validset: Optional[BaseDataset], 

46 result: Result, 

47 config: DefaultConfig, 

48 model_type: Type[BaseAutoencoder], 

49 loss_type: Type[BaseLoss], 

50 ontologies: Optional[ 

51 Union[Tuple, List] 

52 ] = None, # Addition to Varix, mandotory for Ontix 

53 **kwargs, 

54 ): 

55 self._trainset = trainset 

56 self._model_type = model_type 

57 self._validset = validset 

58 self._result = result 

59 self._config = config 

60 self._loss_type = loss_type 

61 self.ontologies = ontologies 

62 self.setup_trainer() 

63 

64 def setup_trainer(self, old_model=None): 

65 if old_model is None: 

66 self._input_validation() 

67 self._init_loaders() 

68 

69 self._loss_fn = self._loss_type(config=self._config) 

70 

71 self._handle_reproducibility() 

72 # Internal data handling 

73 self._model: BaseAutoencoder 

74 self._fabric = Fabric( 

75 accelerator=self._config.device, 

76 devices=self._config.n_gpus, 

77 precision=self._config.float_precision, 

78 strategy=self._config.gpu_strategy, 

79 ) 

80 

81 self._fabric.launch() 

82 self._setup_fabric(old_model=old_model) 

83 

84 self._n_cpus = os.cpu_count() 

85 if self._n_cpus is None: 

86 self._n_cpus = 0 

87 

88 def _setup_fabric(self, old_model=None): 

89 """ 

90 Sets up the model, optimizer, and data loaders with Lightning Fabric. 

91 """ 

92 self._init_model_architecture( 

93 ontologies=self.ontologies, old_model=old_model 

94 ) # Ontix 

95 

96 self._optimizer = torch.optim.AdamW( 

97 params=self._model.parameters(), 

98 lr=self._config.learning_rate, 

99 weight_decay=self._config.weight_decay, 

100 ) 

101 if old_model is None: 

102 self._trainloader = self._fabric.setup_dataloaders(self._trainloader) # type: ignore 

103 if self._validloader is not None: 

104 self._validloader = self._fabric.setup_dataloaders(self._validloader) # type: ignore 

105 

106 if ( 

107 Version(torch.__version__) >= Version("2.0") 

108 and torch.cuda.is_available() 

109 and old_model is None 

110 and self._config.compile_model 

111 ): # dont compile twice 

112 self._model = torch.compile(self._model) 

113 

114 self._model, self._optimizer = self._fabric.setup(self._model, self._optimizer) 

115 

116 def _init_loaders(self): 

117 """Initializes the DataLoaders for training and validation datasets.""" 

118 # g = torch.Generator() 

119 # g.manual_seed(self._config.global_seed) 

120 last_batch_is_one_sample = len(self._trainset) % self._config.batch_size == 1 

121 corrected_bs = ( 

122 self._config.batch_size + 1 

123 if last_batch_is_one_sample 

124 else self._config.batch_size 

125 ) 

126 if last_batch_is_one_sample: 

127 warnings.warn( 

128 f"increased batch_size to {corrected_bs} for trainset, to avoid dropping samples and having batches (makes trainingdynamics messy with missing samples per epoch) of size one (fails for Models with BachNorm)" 

129 ) 

130 

131 self._trainloader = DataLoader( 

132 cast(BaseDataset, self._trainset), 

133 shuffle=True, 

134 batch_size=corrected_bs, 

135 worker_init_fn=self._seed_worker, 

136 pin_memory=self._config.pin_memory, 

137 # generator=g, 

138 ) 

139 if self._validset: 

140 last_batch_is_one_sample = ( 

141 len(self._validset) % self._config.batch_size == 1 

142 ) 

143 corrected_bs = ( 

144 self._config.batch_size + 1 

145 if last_batch_is_one_sample 

146 else self._config.batch_size 

147 ) 

148 if last_batch_is_one_sample: 

149 warnings.warn( 

150 f"increased batch_size to {corrected_bs} for validset, to avoid dropping samples and having batches (makes trainingdynamics messy with missing samples per epoch) of size one (fails for Models with BachNorm)" 

151 ) 

152 

153 self._validloader = DataLoader( 

154 dataset=self._validset, 

155 batch_size=self._config.batch_size, 

156 shuffle=False, 

157 pin_memory=self._config.pin_memory, 

158 ) 

159 else: 

160 self._validloader = None # type: ignore 

161 

162 def _input_validation(self) -> None: 

163 if self._trainset is None: 

164 raise ValueError( 

165 "Trainset cannot be None. Check the indices you provided with a custom split or be sure that the train_ratio attribute of the config is >0." 

166 ) 

167 if not isinstance(self._trainset, BaseDataset): 

168 raise TypeError( 

169 f"Expected train type to be an instance of BaseDataset, got {type(self._trainset)}." 

170 ) 

171 if self._validset is None: 

172 print("training without validation") 

173 elif not isinstance(self._validset, BaseDataset): 

174 raise TypeError( 

175 f"Expected valid type to be an instance of BaseDataset, got {type(self._validset)}." 

176 ) 

177 if self._config is None: 

178 raise ValueError("Config cannot be None.") 

179 

180 def _handle_reproducibility(self) -> None: 

181 """Sets all relevant seeds for reproducibility 

182 

183 Raises: 

184 NotImplementedError: If the device is set to "mps" (Apple Silicon). 

185 """ 

186 if self._config.reproducible: 

187 torch.use_deterministic_algorithms(True) 

188 torch.manual_seed(seed=self._config.global_seed) 

189 random.seed(self._config.global_seed) 

190 np.random.seed(self._config.global_seed) 

191 if torch.cuda.is_available(): 

192 torch.cuda.manual_seed(seed=self._config.global_seed) 

193 torch.cuda.manual_seed_all(seed=self._config.global_seed) 

194 torch.backends.cudnn.deterministic = True 

195 torch.backends.cudnn.benchmark = False 

196 if torch.backends.mps.is_available(): 

197 torch.mps.manual_seed(seed=self._config.global_seed) 

198 

199 # torch.use_deterministic_algorithms(True, warn_only=True) 

200 

201 print( 

202 "Warning: MPS backend has limited support for deterministic algorithms. " 

203 "Seeding is active, but full reproducibility is not guaranteed." 

204 ) 

205 else: 

206 print( 

207 f"Reproducibility settings for device {self._config.device} are not implemented or necessary i.e. for cpu." 

208 ) 

209 

210 def _seed_worker(self, worker_id): 

211 worker_seed = self._config.global_seed + worker_id 

212 np.random.seed(worker_seed) 

213 random.seed(worker_seed) 

214 

215 def _init_model_architecture(self, ontologies: tuple, old_model=None) -> None: 

216 """Initializes the model architecture, based on the model type and input dimension.""" 

217 if old_model: 

218 self._model = old_model 

219 return 

220 

221 self._input_dim = cast(BaseDataset, self._trainset).get_input_dim() 

222 self.feature_order = self._trainset.feature_ids 

223 if ontologies is None: 

224 self._model = self._model_type( 

225 config=self._config, input_dim=self._input_dim 

226 ) 

227 

228 else: 

229 self._model = self._model_type( 

230 config=self._config, 

231 input_dim=self._input_dim, 

232 ontologies=ontologies, 

233 feature_order=self._trainset.feature_ids, # type: ignore 

234 ) 

235 

236 def _should_checkpoint(self, epoch: int) -> bool: 

237 return ( 

238 (epoch + 1) % self._config.checkpoint_interval == 0 

239 or epoch == self._config.epochs - 1 

240 ) 

241 

242 @abc.abstractmethod 

243 def train(self, epochs_overwrite: Optional[int] = None) -> Result: 

244 pass 

245 

246 @abc.abstractmethod 

247 def decode(self, x: torch.Tensor) -> torch.Tensor: 

248 pass 

249 

250 @abc.abstractmethod 

251 def predict( 

252 self, data: BaseDataset, model: Optional[torch.nn.Module] = None, **kwargs 

253 ) -> Result: 

254 pass 

255 

256 @abc.abstractmethod 

257 def purge(self) -> None: 

258 """Cleans up any resources used during training, such as cached data or large attributes.""" 

259 pass