Coverage for src / autoencodix / trainers / _general_trainer.py: 72%

199 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-21 10:28 +0200

1import torch 

2 

3import gc 

4import os 

5import numpy as np 

6from typing import Optional, Type, Union, Tuple, Any, Dict, List 

7from collections import defaultdict 

8from torch.utils.data import DataLoader 

9 

10from autoencodix.base._base_dataset import BaseDataset 

11from autoencodix.base._base_loss import BaseLoss 

12from autoencodix.base._base_autoencoder import BaseAutoencoder 

13from autoencodix.base._base_trainer import BaseTrainer 

14from autoencodix.utils._result import Result 

15from autoencodix.configs.default_config import DefaultConfig 

16from autoencodix.utils._model_output import ModelOutput 

17 

18 

19class GeneralTrainer(BaseTrainer): 

20 """Handles general training logic for autoencoder models. 

21 

22 Attributes: 

23 _trainset: The dataset used for training. 

24 _validset: The dataset used for validation, if provided. 

25 _result: An object to store and manage training results. 

26 _config: Configuration object containing training hyperparameters and settings. 

27 _model_type: The autoencoder model class to be trained. 

28 _loss_fn: Instantiated loss function specific to the model. 

29 _trainloader: DataLoader for the training dataset. 

30 _validloader: DataLoader for the validation dataset, if provided. 

31 _model: The instantiated model architecture. 

32 _optimizer: The optimizer used for training. 

33 _fabric: Lightning Fabric wrapper for device and precision management. 

34 n_train: Number of training samples. 

35 n_valid: Number of validation samples. 

36 n_test: Number of test samples (set during prediction). 

37 n_features: Number of input features. 

38 latent_dim: Dimensionality of the latent space. 

39 device: Device on which the model is located. 

40 _n_cpus: Number of CPU cores available. 

41 _reconstruction_buffer: Buffer to store reconstructions during training/validation/testing. 

42 _latentspace_buffer: Buffer to store latent representations during training/validation/testing. 

43 _mu_buffer: Buffer to store latent means (for VAE) during training/validation 

44 _sigma_buffer: Buffer to store latent log-variances (for VAE) during training/validation/testing. 

45 _sample_ids_buffer: Buffer to store sample IDs during training/validation/testing. 

46 

47 """ 

48 

49 def __init__( 

50 self, 

51 trainset: Optional[BaseDataset], 

52 validset: Optional[BaseDataset], 

53 result: Result, 

54 config: DefaultConfig, 

55 model_type: Type[BaseAutoencoder], 

56 loss_type: Type[BaseLoss], 

57 ontologies: Optional[Union[Tuple, List]] = None, 

58 **kwargs, 

59 ): 

60 """Initializes the GeneralTrainer. 

61 

62 Args: 

63 trainset: The dataset used for training. 

64 validset: The dataset used for validation, if provided. 

65 result: An object to store and manage training results. 

66 config: Configuration object containing training hyperparameters and settings. 

67 model_type: The autoencoder model class to be trained. 

68 loss_type: The loss function class to be used for training. 

69 ontologies: Ontology information, if provided for Ontix 

70 """ 

71 self._n_cpus = os.cpu_count() 

72 if self._n_cpus is None: 

73 self._n_cpus = 0 

74 

75 super().__init__( 

76 trainset, validset, result, config, model_type, loss_type, ontologies 

77 ) 

78 

79 self.latent_dim = config.latent_dim 

80 if ontologies is not None: 

81 if not hasattr(self._model, "latent_dim"): 

82 raise ValueError( 

83 "Model must have a 'latent_dim' attribute when ontologies are provided." 

84 ) 

85 self.latent_dim = self._model.latent_dim 

86 

87 # we will set this later, in the predict method 

88 self.n_test: Optional[int] = None 

89 self.n_train = len(trainset) if trainset else 0 

90 self.n_valid = len(validset) if validset else 0 

91 self.n_features = trainset.get_input_dim() if trainset else 0 

92 self.device = next(self._model.parameters()).device 

93 self._init_buffers() 

94 

95 def _init_buffers(self, input_data: Optional[BaseDataset] = None) -> None: 

96 if input_data: 

97 self.n_features = input_data.get_input_dim() 

98 

99 def make_tensor_buffer(size: int, dim: Union[int, Tuple[int, ...]]): 

100 if isinstance(dim, int): 

101 if self._config.save_vram: 

102 return torch.zeros( 

103 (size, dim), device="cpu" 

104 ) # Move to CPU to avoid GPU OOM, but will be slower 

105 else: 

106 return torch.zeros((size, dim), device=self.device) 

107 else: 

108 if self._config.save_vram: 

109 return torch.zeros( 

110 (size, *dim), device="cpu" 

111 ) # Move to CPU to avoid GPU OOM, but will be slower 

112 else: 

113 return torch.zeros((size, *dim), device=self.device) 

114 

115 def make_numpy_buffer(size: int): 

116 return np.empty((size,), dtype=object) 

117 

118 # Always create sample ID buffers 

119 self._sample_ids_buffer = { 

120 "train": make_numpy_buffer(self.n_train), 

121 "valid": make_numpy_buffer(self.n_valid) if self.n_valid else None, 

122 "test": make_numpy_buffer(self.n_test) if self.n_test else None, 

123 } 

124 

125 splits = ["test"] if self._config.save_memory else ["train", "valid", "test"] 

126 

127 def make_split_buffers(dim): 

128 return { 

129 split: ( 

130 make_tensor_buffer(getattr(self, f"n_{split}"), dim) 

131 if getattr(self, f"n_{split}", 0) and (split in splits) 

132 else None 

133 ) 

134 for split in ["train", "valid", "test"] 

135 } 

136 

137 self._latentspace_buffer = make_split_buffers(self.latent_dim) 

138 self._reconstruction_buffer = make_split_buffers(self.n_features) 

139 self._mu_buffer = make_split_buffers(self.latent_dim) 

140 self._sigma_buffer = make_split_buffers(self.latent_dim) 

141 

142 def _ontix_hook(self): 

143 pass 

144 

145 def train(self, epochs_overwrite=None) -> Result: 

146 """Orchestrates training over multiple epochs, including training and validation phases. 

147 

148 Args: 

149 epochs_overwrite: If provided, overrides the number of epochs specified in the config. 

150 This is only there so we can use the train method for pretraining.Any 

151 Returns: 

152 Result object containing training results and dynamics like latent spaces and reconstructions. 

153 """ 

154 epochs = self._config.epochs 

155 if epochs_overwrite: 

156 epochs = epochs_overwrite 

157 with self._fabric.autocast(): 

158 for epoch in range(epochs): 

159 self._init_buffers() 

160 should_checkpoint: bool = self._should_checkpoint(epoch) 

161 self._model.train() 

162 

163 epoch_loss, epoch_sub_losses = self._train_epoch( 

164 should_checkpoint=should_checkpoint, epoch=epoch 

165 ) 

166 self._log_losses(epoch, "train", epoch_loss, epoch_sub_losses) 

167 

168 if self._validset: 

169 valid_loss, valid_sub_losses = self._validate_epoch( 

170 should_checkpoint=should_checkpoint, epoch=epoch 

171 ) 

172 self._log_losses(epoch, "valid", valid_loss, valid_sub_losses) 

173 

174 if should_checkpoint: 

175 self._store_checkpoint(epoch) 

176 

177 self._result.model = next(self._model.children()) 

178 return self._result 

179 

180 def maskix_hook(self, X: torch.Tensor): 

181 """Only active when override from MaskixTrainer is used""" 

182 return X 

183 

184 def _train_epoch( 

185 self, should_checkpoint: bool, epoch: int 

186 ) -> Tuple[float, Dict[str, float]]: 

187 """Handles loss computation, backwards pass and checkpointing for one epoch. 

188 

189 Args: 

190 should_checkpoint: Whether to checkpoint this epoch. 

191 epoch: The current epoch number. 

192 Returns: 

193 total_loss: The total loss for the epoch. 

194 sub_losses: A dictionary of sub-losses accumulated over the epoch. 

195 """ 

196 total_loss = 0.0 

197 sub_losses: Dict[str, float] = defaultdict(float) 

198 current_batch = 0 

199 for indices, features, sample_ids in self._trainloader: 

200 current_batch += 1 

201 self._optimizer.zero_grad() 

202 X: torch.Tensor = self.maskix_hook(features) 

203 model_outputs = self._model(X) 

204 loss, batch_sub_losses = self._loss_fn( 

205 model_output=model_outputs, 

206 targets=features, 

207 corrupted_input=X, # only relevant for MaskixLoss 

208 epoch=epoch, 

209 n_samples=len( 

210 self._trainloader.dataset 

211 ), # Pass n_samples for disentangled loss calculations 

212 ) 

213 self._fabric.backward(loss) 

214 self._ontix_hook() 

215 self._optimizer.step() 

216 

217 total_loss += loss.item() 

218 for k, v in batch_sub_losses.items(): 

219 if "_factor" not in k: # Skip factor losses 

220 sub_losses[k] += v.item() 

221 else: 

222 sub_losses[k] = v.item() 

223 

224 if should_checkpoint: 

225 self._capture_dynamics(model_outputs, "train", indices, sample_ids) 

226 

227 for k, v in sub_losses.items(): 

228 if "_factor" not in k: 

229 sub_losses[k] = v / len( 

230 self._trainloader.dataset 

231 ) # Average over all samples 

232 total_loss = total_loss / len( 

233 self._trainloader.dataset 

234 ) # Average over all samples 

235 return total_loss, sub_losses 

236 

237 def _validate_epoch( 

238 self, should_checkpoint: bool, epoch: int 

239 ) -> Tuple[float, Dict[str, float]]: 

240 """Handles validation for one epoch. 

241 

242 Args: 

243 should_checkpoint: Whether to checkpoint this epoch. 

244 epoch: The current epoch number. 

245 Returns: 

246 total_loss: The total loss for the epoch. 

247 sub_losses: A dictionary of sub-losses accumulated over the epoch. 

248 """ 

249 total_loss = 0.0 

250 sub_losses: Dict[str, float] = defaultdict(float) 

251 self._model.eval() 

252 

253 with torch.no_grad(): 

254 for indices, features, sample_ids in self._validloader: 

255 X: torch.Tensor = self.maskix_hook(features) 

256 model_outputs = self._model(X) 

257 loss, batch_sub_losses = self._loss_fn( 

258 model_output=model_outputs, 

259 targets=features, 

260 corrupted_input=X, 

261 epoch=epoch, 

262 n_samples=len( 

263 self._validloader.dataset 

264 ), # Pass n_samples for disentangled loss calculations 

265 ) 

266 total_loss += loss.item() 

267 for k, v in batch_sub_losses.items(): 

268 if "_factor" not in k: # Skip factor losses 

269 sub_losses[k] += v.item() 

270 else: 

271 sub_losses[k] = v.item() 

272 if should_checkpoint: 

273 self._capture_dynamics(model_outputs, "valid", indices, sample_ids) 

274 

275 for k, v in sub_losses.items(): 

276 if "_factor" not in k: 

277 sub_losses[k] = v / len( 

278 self._validloader.dataset 

279 ) # Average over all samples 

280 total_loss = total_loss / len( 

281 self._validloader.dataset 

282 ) # Average over all samples 

283 return total_loss, sub_losses 

284 

285 def _log_losses( 

286 self, epoch: int, split: str, total_loss: float, sub_losses: Dict[str, float] 

287 ) -> None: 

288 """Logs the total and sub-losses for an epoch and stores them in the Result object. 

289 

290 Args: 

291 epoch: The current epoch number. 

292 split: The data split ("train" or "valid"). 

293 total_loss: The total loss for the epoch. 

294 sub_losses: A dictionary of sub-losses for the epoch. 

295 """ 

296 # dataset_len = len( 

297 # self._trainloader.dataset if split == "train" else self._validloader.dataset 

298 # ) 

299 self._result.losses.add(epoch=epoch, split=split, data=total_loss) 

300 self._result.sub_losses.add( 

301 epoch=epoch, 

302 split=split, 

303 data={k: v if "_factor" not in k else v for k, v in sub_losses.items()}, 

304 ) 

305 

306 self._fabric.print( 

307 f"Epoch {epoch + 1} - {split.capitalize()} Loss: {total_loss:.4f}" 

308 ) 

309 self._fabric.print( 

310 f"Sub-losses: {', '.join([f'{k}: {v:.4f}' for k, v in sub_losses.items()])}" 

311 ) 

312 

313 def _store_checkpoint(self, epoch: int) -> None: 

314 """Stores model checkpoints and training dynamics to result object. 

315 

316 Args: 

317 epoch: The current epoch number. 

318 """ 

319 self._result.model_checkpoints.add(epoch=epoch, data=self._model.state_dict()) 

320 if self._config.save_memory: 

321 return 

322 self._dynamics_to_result(epoch, "train") 

323 if self._validset: 

324 self._dynamics_to_result(epoch, "valid") 

325 

326 def _capture_dynamics( 

327 self, 

328 model_output: Union[ModelOutput, Any], 

329 split: str, 

330 indices: torch.Tensor, 

331 sample_ids: Any, 

332 **kwargs, 

333 ) -> None: 

334 """Writes model dynamics (latent space, reconstructions, etc.) to buffers. 

335 

336 Args: 

337 model_output: The output from the model containing dynamics information. 

338 split: The data split ("train", "valid", or "test"). 

339 indices: The indices of the samples in the current batch. 

340 sample_ids: The sample IDs corresponding to the current batch. 

341 **kwargs: Additional arguments (not used here). 

342 

343 """ 

344 indices_np = ( 

345 indices.cpu().numpy() 

346 if isinstance(indices, torch.Tensor) 

347 else np.array(indices) 

348 ) 

349 sample_ids = sample_ids.cpu().numpy() 

350 self._sample_ids_buffer[split][indices_np] = np.array(sample_ids) 

351 if self._config.save_memory and split != "test": 

352 return 

353 indices_np = ( 

354 indices.cpu().numpy() 

355 if isinstance(indices, torch.Tensor) 

356 else np.array(indices) 

357 ) 

358 

359 self._sample_ids_buffer[split][indices_np] = np.array(sample_ids) 

360 

361 if self._config.save_vram: # Move to CPU to avoid GPU OOM, but will be slower 

362 self._latentspace_buffer[split][ 

363 indices_np 

364 ] = model_output.latentspace.cpu().detach() 

365 self._reconstruction_buffer[split][ 

366 indices_np 

367 ] = model_output.reconstruction.cpu().detach() 

368 if model_output.latent_logvar is not None: 

369 self._sigma_buffer[split][ 

370 indices_np 

371 ] = model_output.latent_logvar.cpu().detach() 

372 if model_output.latent_mean is not None: 

373 self._mu_buffer[split][ 

374 indices_np 

375 ] = model_output.latent_mean.cpu().detach() 

376 else: 

377 self._latentspace_buffer[split][ 

378 indices_np 

379 ] = model_output.latentspace.detach() 

380 self._reconstruction_buffer[split][ 

381 indices_np 

382 ] = model_output.reconstruction.detach() 

383 if model_output.latent_logvar is not None: 

384 self._sigma_buffer[split][ 

385 indices_np 

386 ] = model_output.latent_logvar.detach() 

387 if model_output.latent_mean is not None: 

388 self._mu_buffer[split][indices_np] = model_output.latent_mean.detach() 

389 

390 def _dynamics_to_result(self, epoch: int, split: str) -> None: 

391 """Transfers buffered dynamics to the Result object. 

392 

393 Args: 

394 epoch: The current epoch number. 

395 split: The data split ("train", "valid", or "test"). 

396 """ 

397 

398 def maybe_add(buffer, target): 

399 if buffer[split] is not None and buffer[split].sum() != 0: 

400 target.add( 

401 epoch=epoch, split=split, data=buffer[split].cpu().detach().numpy() 

402 ) 

403 

404 self._result.latentspaces.add( 

405 epoch=epoch, 

406 split=split, 

407 data=self._latentspace_buffer[split].cpu().detach().numpy(), 

408 ) 

409 self._result.reconstructions.add( 

410 epoch=epoch, 

411 split=split, 

412 data=self._reconstruction_buffer[split].cpu().detach().numpy(), 

413 ) 

414 self._result.sample_ids.add( 

415 epoch=epoch, split=split, data=self._sample_ids_buffer[split] 

416 ) 

417 maybe_add(self._mu_buffer, self._result.mus) 

418 maybe_add(self._sigma_buffer, self._result.sigmas) 

419 

420 def predict( 

421 self, data: BaseDataset, model: Optional[torch.nn.Module] = None, **kwargs 

422 ) -> Result: 

423 """Decided to add predict method to the trainer class. 

424 

425 This violates SRP, but the trainer class has a lot of attributes and methods 

426 that are needed for prediction. So this way we don't need to write so much duplicate code 

427 

428 Args: 

429 data: BaseDataset unseen data to run inference on 

430 model: torch.nn.Module model to run inference with 

431 **kwargs: Additional arguments (not used here). 

432 

433 Returns: 

434 self._result: Result object containing the inference results 

435 

436 """ 

437 if model is None: 

438 model: torch.nn.Module = self._model 

439 import warnings 

440 

441 warnings.warn( 

442 "No model provided for prediction, using the trained model from the trainer." 

443 ) 

444 model.eval() 

445 inference_loader = DataLoader( 

446 data, 

447 batch_size=self._config.batch_size, 

448 shuffle=False, 

449 ) 

450 self.n_test = len(data) 

451 self._init_buffers(input_data=data) 

452 inference_loader = self._fabric.setup_dataloaders(inference_loader) # type: ignore 

453 with self._fabric.autocast(), torch.inference_mode(): 

454 processed_samples = 0 

455 for idx, data, sample_ids in inference_loader: 

456 model_output = model(data) 

457 processed_samples += len(data) 

458 print( 

459 f"Processed {processed_samples} / {self.n_test} samples", end="\r" 

460 ) 

461 self._capture_dynamics( 

462 model_output=model_output, 

463 split="test", 

464 sample_ids=sample_ids, 

465 indices=idx, 

466 ) 

467 self._dynamics_to_result(epoch=-1, split="test") 

468 

469 return self._result 

470 

471 def decode(self, x: torch.Tensor): 

472 """Decodes the input tensor x using the trained model. 

473 

474 Args: 

475 x: Input tensor to be decoded. 

476 Returns: 

477 Decoded tensor. 

478 """ 

479 with self._fabric.autocast(), torch.no_grad(): 

480 x = self._fabric.to_device(obj=x) 

481 if not isinstance(x, torch.Tensor): 

482 raise TypeError( 

483 f"Expected input to be a torch.Tensor, got {type(x)} instead." 

484 ) 

485 return self._model.decode(x=x) 

486 

487 def get_model(self) -> torch.nn.Module: 

488 """Getter method for the trained model. 

489 

490 Returns: 

491 

492 The trained model as a torch.nn.Module. 

493 """ 

494 return self._model 

495 

496 def purge(self) -> None: 

497 """Cleans up any resources used during training, such as cached data or large attributes.""" 

498 

499 attrs_to_delete = [ 

500 "_trainloader", 

501 "_validloader", 

502 "_model", 

503 "_optimizer", 

504 "_latentspace_buffer", 

505 "_reconstruction_buffer", 

506 "_mu_buffer", 

507 "_sigma_buffer", 

508 "_sample_ids_buffer", 

509 "_trainset", 

510 "_loss_fn", 

511 "_validset", 

512 ] 

513 

514 for attr in attrs_to_delete: 

515 if hasattr(self, attr): 

516 value = getattr(self, attr) 

517 # Optional: ensure non-None before deletion 

518 if value is not None: 

519 delattr(self, attr) 

520 

521 if torch.cuda.is_available(): 

522 torch.cuda.empty_cache() 

523 

524 gc.collect()