Coverage for tests\unit\test_training_manager.py: 98%

216 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-01-17 02:23 -0700

1from typing import Dict, Any, Union, Callable 

2import time 

3from unittest.mock import MagicMock 

4from pathlib import Path 

5 

6import torch 

7import torch.nn as nn 

8import torch.optim as optim 

9from torch.utils.data import DataLoader, TensorDataset 

10import pytest 

11 

12from trnbl.training_interval import IntervalValueError 

13from trnbl.training_manager import TrainingManager, TrainingManagerInitError 

14from trnbl.loggers.local import LocalLogger 

15from trnbl.loggers.base import TrainingLoggerBase 

16 

17# Temporary directory for testing 

18TEMP_PATH: Path = Path("tests/_temp") 

19TEMP_PATH.mkdir(parents=True, exist_ok=True) 

20 

21 

22# Define a simple model 

23class SimpleModel(nn.Module): 

24 def __init__(self) -> None: 

25 super(SimpleModel, self).__init__() 

26 self.fc: nn.Linear = nn.Linear(10, 1) 

27 

28 def forward(self, x: torch.Tensor) -> torch.Tensor: 

29 return self.fc(x) 

30 

31 

32# Dummy dataset and dataloader 

33inputs: torch.Tensor = torch.randn(100, 10) 

34targets: torch.Tensor = torch.randn(100, 1) 

35dataset: TensorDataset = TensorDataset(inputs, targets) 

36dataloader: DataLoader = DataLoader(dataset, batch_size=10) 

37 

38# Logger configuration 

39logger_config: Dict[str, Any] = { 

40 "project": "test_project", 

41 "metric_names": ["loss"], 

42 "group": "test_run", 

43 "train_config": {"batch_size": 10, "learning_rate": 0.001, "epochs": 10}, 

44 "base_path": TEMP_PATH, 

45} 

46 

47# Initialize the model, criterion, and optimizer 

48model: nn.Module = SimpleModel() 

49criterion: nn.Module = nn.MSELoss() 

50optimizer: optim.Optimizer = optim.SGD(model.parameters(), lr=0.001) 

51 

52 

53def test_training_manager_initialization(): 

54 logger = LocalLogger(**logger_config) 

55 training_manager = TrainingManager( 

56 model=model, dataloader=dataloader, logger=logger, epochs_total=10 

57 ) 

58 assert training_manager.model == model 

59 assert training_manager.logger == logger 

60 assert training_manager.epochs_total == 10 

61 training_manager.logger.finish() 

62 

63 

64def test_training_manager_batch_update(): 

65 logger = LocalLogger(**logger_config) 

66 training_manager = TrainingManager( 

67 model=model, dataloader=dataloader, logger=logger, epochs_total=1 

68 ) 

69 training_manager._save_checkpoint = MagicMock() 

70 

71 # Simulate a training batch update 

72 inputs, targets = next(iter(dataloader)) 

73 outputs = model(inputs) 

74 loss = criterion(outputs, targets) 

75 training_manager.batch_update(samples=len(targets), train_loss=loss.item()) 

76 

77 # Check if metrics were logged 

78 assert len(training_manager.logger.metrics_list) > 0 

79 assert training_manager.logger.metrics_list[-1]["train_loss"] == loss.item() 

80 

81 # Check if a checkpoint was saved (based on interval) 

82 if training_manager.batches % training_manager.checkpoint_interval == 0: 

83 training_manager._save_checkpoint.assert_called() 

84 

85 training_manager.logger.finish() 

86 

87 

88def test_training_manager_epoch_update(): 

89 logger = LocalLogger(**logger_config) 

90 training_manager = TrainingManager( 

91 model=model, dataloader=dataloader, logger=logger, epochs_total=1 

92 ) 

93 

94 # Simulate an epoch update 

95 training_manager.epoch_update() 

96 assert training_manager.epochs == 1 

97 assert len(training_manager.logger.log_list) > 0 

98 assert training_manager.logger.log_list[-1]["message"] == "completed epoch 1/1" 

99 training_manager.logger.finish() 

100 

101 

102def test_training_manager_checkpoint_saving(): 

103 logger = LocalLogger(**logger_config) 

104 training_manager = TrainingManager( 

105 model=model, dataloader=dataloader, logger=logger, epochs_total=1 

106 ) 

107 training_manager._save_checkpoint(alias="test_checkpoint") 

108 

109 # Check if the checkpoint artifact was logged 

110 assert len(training_manager.logger.artifacts_list) > 0 

111 assert "test_checkpoint" in training_manager.logger.artifacts_list[-1]["aliases"] 

112 training_manager.logger.finish() 

113 

114 

115@pytest.fixture 

116def training_manager() -> TrainingManager: 

117 logger: LocalLogger = LocalLogger(**logger_config) 

118 return TrainingManager( 

119 model=model, dataloader=dataloader, logger=logger, epochs_total=1 

120 ) 

121 

122 

123def test_training_manager_initialization_comprehensive( 

124 training_manager: TrainingManager, 

125) -> None: 

126 assert isinstance(training_manager.model, nn.Module) 

127 assert isinstance(training_manager.logger, TrainingLoggerBase) 

128 assert training_manager.epochs_total == 1 

129 assert training_manager.epochs == 0 

130 assert training_manager.batches_per_epoch == len(dataloader) 

131 assert training_manager.batch_size == dataloader.batch_size 

132 assert training_manager.batches_total == len(dataloader) 

133 assert training_manager.batches == 0 

134 # TODO: why does it think a `Dataset[Any]` has no length? 

135 assert training_manager.samples_per_epoch == len(dataloader.dataset) # type: ignore[arg-type] 

136 assert training_manager.samples_total == len(dataloader.dataset) # type: ignore[arg-type] 

137 assert training_manager.samples == 0 

138 assert training_manager.checkpoints == 0 

139 

140 

141def test_training_manager_enter(training_manager: TrainingManager) -> None: 

142 with training_manager as tm: 

143 assert tm == training_manager 

144 

145 

146def test_training_manager_exit_normal(training_manager: TrainingManager) -> None: 

147 training_manager._save_checkpoint = MagicMock() # type: ignore[method-assign] 

148 training_manager.logger.finish = MagicMock() # type: ignore[method-assign] 

149 

150 with training_manager: 

151 pass 

152 

153 training_manager._save_checkpoint.assert_called_with(alias="final") 

154 training_manager.logger.finish.assert_called_once() 

155 

156 

157def test_training_manager_exit_exception(training_manager: TrainingManager) -> None: 

158 training_manager._save_checkpoint = MagicMock() # type: ignore[method-assign] 

159 training_manager.logger.error = MagicMock() # type: ignore[method-assign] 

160 training_manager.logger.finish = MagicMock() # type: ignore[method-assign] 

161 

162 with pytest.raises(ValueError): 

163 with training_manager: 

164 raise ValueError("Test exception") 

165 

166 training_manager._save_checkpoint.assert_called_with(alias="exception") 

167 training_manager.logger.error.assert_called_once() 

168 training_manager.logger.finish.assert_called_once() 

169 

170 

171def test_training_manager_get_elapsed_time(training_manager: TrainingManager) -> None: 

172 start_time: float = time.time() 

173 training_manager.start_time = start_time 

174 training_manager.get_elapsed_time() 

175 

176 

177def test_training_manager_training_status(training_manager: TrainingManager) -> None: 

178 status: Dict[str, Union[int, float]] = training_manager.training_status() 

179 assert all( 

180 key in status 

181 for key in [ 

182 "elapsed_time", 

183 "samples", 

184 "batches", 

185 "epochs", 

186 "latest_checkpoint", 

187 ] 

188 ) 

189 assert all(isinstance(value, (int, float)) for value in status.values()) 

190 

191 

192def test_training_manager_get_format_kwargs(training_manager: TrainingManager) -> None: 

193 kwargs: Dict[str, Union[str, int, float]] = training_manager._get_format_kwargs() 

194 assert all( 

195 key in kwargs 

196 for key in [ 

197 "run_path", 

198 "elapsed_time", 

199 "samples", 

200 "batches", 

201 "epochs", 

202 "latest_checkpoint", 

203 ] 

204 ) 

205 assert isinstance(kwargs["run_path"], str) 

206 assert all( 

207 isinstance(value, (int, float)) 

208 for key, value in kwargs.items() 

209 if key != "run_path" 

210 ) 

211 

212 

213def test_training_manager_batch_update_new(training_manager: TrainingManager) -> None: 

214 training_manager._save_checkpoint = MagicMock() # type: ignore[method-assign] 

215 

216 initial_samples: int = training_manager.samples 

217 initial_batches: int = training_manager.batches 

218 

219 training_manager.batch_update(samples=10, loss=0.5) 

220 

221 assert training_manager.samples == initial_samples + 10 

222 assert training_manager.batches == initial_batches + 1 

223 assert len(training_manager.logger.metrics_list) > 0 

224 assert training_manager.logger.metrics_list[-1]["loss"] == 0.5 

225 

226 # ok to ignore because the test will just fail? 

227 if training_manager.batches % training_manager.checkpoint_interval == 0: # type: ignore[operator] 

228 training_manager._save_checkpoint.assert_called_once() 

229 

230 

231def test_training_manager_epoch_update_new(training_manager: TrainingManager) -> None: 

232 initial_epochs: int = training_manager.epochs 

233 

234 training_manager.epoch_update() 

235 

236 assert training_manager.epochs == initial_epochs + 1 

237 assert len(training_manager.logger.log_list) > 0 

238 assert ( 

239 f"completed epoch {initial_epochs + 1}/{training_manager.epochs_total}" 

240 in training_manager.logger.log_list[-1]["message"] 

241 ) 

242 

243 

244def test_training_manager_save_checkpoint(training_manager: TrainingManager) -> None: 

245 initial_checkpoints: int = training_manager.checkpoints 

246 

247 training_manager._save_checkpoint(alias="test_checkpoint") 

248 

249 assert training_manager.checkpoints == initial_checkpoints + 1 

250 assert len(training_manager.logger.artifacts_list) > 0 

251 assert "test_checkpoint" in training_manager.logger.artifacts_list[-1]["aliases"] 

252 

253 

254def test_training_manager_full_training_loop() -> None: 

255 logger: LocalLogger = LocalLogger(**logger_config) 

256 training_manager: TrainingManager = TrainingManager( 

257 model=model, 

258 dataloader=dataloader, 

259 logger=logger, 

260 epochs_total=2, 

261 checkpoint_interval="1 epochs", 

262 evals=[ 

263 ("1 epochs", lambda m: {"eval_loss": criterion(m(inputs), targets).item()}) 

264 ], 

265 ) 

266 

267 with training_manager as tm: 

268 for epoch in range(2): 

269 for batch_inputs, batch_targets in dataloader: 

270 outputs: torch.Tensor = model(batch_inputs) 

271 loss: torch.Tensor = criterion(outputs, batch_targets) 

272 optimizer.zero_grad() 

273 loss.backward() 

274 optimizer.step() 

275 

276 tm.batch_update(samples=len(batch_targets), train_loss=loss.item()) 

277 

278 tm.epoch_update() 

279 

280 assert tm.epochs_total == 2 

281 assert tm.batches == 20 # 2 epochs * 10 batches per epoch 

282 assert tm.samples == 200 # 2 epochs * 100 samples per epoch 

283 assert tm.checkpoints == 3 # 1 checkpoint per epoch 

284 assert len(tm.logger.metrics_list) == 20 # 1 metric log per batch 

285 assert ( 

286 len(tm.logger.log_list) >= 4 

287 ) # At least 1 log for init, 2 for epoch completions, 1 for training complete 

288 assert ( 

289 len(tm.logger.artifacts_list) == 3 

290 ) # 2 epoch checkpoints + 1 final checkpoint 

291 

292 

293def test_training_manager_zero_epochs() -> None: 

294 logger: LocalLogger = LocalLogger(**logger_config) 

295 with pytest.warns(IntervalValueError): 

296 TrainingManager( 

297 model=model, dataloader=dataloader, logger=logger, epochs_total=0 

298 ) 

299 

300 

301def test_training_manager_negative_epochs() -> None: 

302 logger: LocalLogger = LocalLogger(**logger_config) 

303 with pytest.warns(IntervalValueError): 

304 TrainingManager( 

305 model=model, dataloader=dataloader, logger=logger, epochs_total=-1 

306 ) 

307 

308 

309def test_training_manager_custom_save_model() -> None: 

310 logger: LocalLogger = LocalLogger(**logger_config) 

311 custom_save_model: Callable[[nn.Module, Path], None] = MagicMock() 

312 training_manager: TrainingManager = TrainingManager( 

313 model=model, 

314 dataloader=dataloader, 

315 logger=logger, 

316 epochs_total=1, 

317 save_model=custom_save_model, 

318 ) 

319 training_manager._save_checkpoint() 

320 # ignoring here because custom_save_model is a mock 

321 custom_save_model.assert_called_once() # type: ignore[attr-defined] 

322 

323 

324def test_training_manager_custom_intervals() -> None: 

325 logger: LocalLogger = LocalLogger(**logger_config) 

326 training_manager: TrainingManager = TrainingManager( 

327 model=model, 

328 dataloader=dataloader, 

329 logger=logger, 

330 epochs_total=1, 

331 checkpoint_interval="0.5 epochs", 

332 print_metrics_interval="0.25 epochs", 

333 evals=[("0.1 epochs", lambda m: {"eval_loss": 0.5})], 

334 ) 

335 assert training_manager.checkpoint_interval == len(dataloader) // 2 

336 assert training_manager.print_metrics_interval == len(dataloader) // 4 

337 assert training_manager.evals[0][0] == len(dataloader) // 10 

338 

339 

340def test_training_manager_custom_model_save_paths() -> None: 

341 logger: LocalLogger = LocalLogger(**logger_config) 

342 custom_path: str = "{run_path}/custom_checkpoints/model-{latest_checkpoint}.pt" 

343 custom_special_path: str = "{run_path}/custom_special/model-{alias}.pt" 

344 training_manager: TrainingManager = TrainingManager( 

345 model=model, 

346 dataloader=dataloader, 

347 logger=logger, 

348 epochs_total=1, 

349 model_save_path=custom_path, 

350 model_save_path_special=custom_special_path, 

351 ) 

352 assert training_manager.model_save_path == custom_path 

353 assert training_manager.model_save_path_special == custom_special_path 

354 

355 

356def test_training_manager_batch_update_no_samples() -> None: 

357 logger: LocalLogger = LocalLogger(**logger_config) 

358 training_manager: TrainingManager = TrainingManager( 

359 model=model, dataloader=dataloader, logger=logger, epochs_total=1 

360 ) 

361 initial_samples: int = training_manager.samples 

362 training_manager.batch_update(samples=None, loss=0.5) 

363 # ok to ignore because the test will just fail? 

364 assert training_manager.samples == initial_samples + training_manager.batch_size # type: ignore[operator] 

365 

366 

367def test_training_manager_multiple_evals() -> None: 

368 logger: LocalLogger = LocalLogger(**logger_config) 

369 eval1: Callable[[nn.Module], Dict[str, float]] = lambda m: {"eval1": 0.5} # noqa: E731 

370 eval2: Callable[[nn.Module], Dict[str, float]] = lambda m: {"eval2": 0.7} # noqa: E731 

371 training_manager: TrainingManager = TrainingManager( 

372 model=model, 

373 dataloader=dataloader, 

374 logger=logger, 

375 epochs_total=1, 

376 evals=[("1 batch", eval1), ("2 batches", eval2)], 

377 ) 

378 training_manager.batch_update(samples=10, loss=0.3) 

379 assert "eval1" in training_manager.logger.metrics_list[-1] 

380 assert "eval2" not in training_manager.logger.metrics_list[-1] 

381 training_manager.batch_update(samples=10, loss=0.3) 

382 assert "eval1" in training_manager.logger.metrics_list[-1] 

383 assert "eval2" in training_manager.logger.metrics_list[-1] 

384 

385 

386@pytest.mark.parametrize( 

387 "interval, expected", 

388 [ 

389 ("1 epochs", len(dataloader)), 

390 ("0.5 epochs", len(dataloader) // 2), 

391 ("10 batches", 10), 

392 ("10 samples", 1), 

393 ], 

394) 

395def test_training_manager_interval_processing(interval: str, expected: int) -> None: 

396 logger: LocalLogger = LocalLogger(**logger_config) 

397 training_manager: TrainingManager = TrainingManager( 

398 model=model, 

399 dataloader=dataloader, 

400 logger=logger, 

401 epochs_total=1, 

402 checkpoint_interval=interval, 

403 ) 

404 assert training_manager.checkpoint_interval == expected 

405 

406 

407def test_training_manager_empty_dataloader() -> None: 

408 empty_dataloader: DataLoader = DataLoader( 

409 TensorDataset(torch.Tensor([]), torch.Tensor([])), batch_size=1 

410 ) 

411 logger: LocalLogger = LocalLogger(**logger_config) 

412 with pytest.raises(TrainingManagerInitError): 

413 TrainingManager( 

414 model=model, dataloader=empty_dataloader, logger=logger, epochs_total=1 

415 ) 

416 

417 

418def test_training_manager_0_batchsize() -> None: 

419 with pytest.raises(ValueError): 

420 empty_dataloader: DataLoader = DataLoader( 

421 TensorDataset(torch.Tensor([]), torch.Tensor([])), batch_size=0 

422 ) 

423 logger: LocalLogger = LocalLogger(**logger_config) 

424 TrainingManager( 

425 model=model, dataloader=empty_dataloader, logger=logger, epochs_total=1 

426 )