Coverage for trnbl\training_manager.py: 92%

146 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-01-17 02:23 -0700

1import time 

2from types import TracebackType 

3from typing import Any, Callable, Generic, Iterable, Type, TypeVar, Generator, Sequence 

4from pathlib import Path 

5import warnings 

6 

7import tqdm # type: ignore[import-untyped] 

8 

9# torch 

10try: 

11 import torch 

12except ImportError: 

13 warnings.warn("PyTorch not found, this might break things!") 

14 

15# trnbl 

16from trnbl.loggers.base import TrainingLoggerBase 

17from trnbl.training_interval import TrainingInterval, CastableToTrainingInterval 

18 

19# evaluation function should take a model and return some metrics 

20EvalFunction = Callable[["torch.nn.Module"], dict] 

21 

22 

23class TrainingManagerInitError(Exception): 

24 pass 

25 

26 

27T = TypeVar("T") 

28 

29 

30def wrapped_iterable( 

31 sequence: Sequence[T], 

32 manager: "TrainingManager", 

33 is_epoch: bool = False, 

34 use_tqdm: bool | None = None, 

35 tqdm_kwargs: dict[str, Any] | None = None, 

36) -> Generator[T, None, None]: 

37 length: int = len(sequence) 

38 

39 # update the manager if it's not fully initialized 

40 # ------------------------------------------------------------ 

41 if not manager.init_complete: 

42 if is_epoch: 

43 # if epoch loop, set the total epochs 

44 manager.epochs_total = length 

45 else: 

46 # if batch loop, set other things 

47 manager.batches_per_epoch = length 

48 try: 

49 manager.batch_size = sequence.batch_size # type: ignore[attr-defined] 

50 manager.samples_per_epoch = len(sequence.dataset) # type: ignore[attr-defined] 

51 except AttributeError as e: 

52 raise TrainingManagerInitError( 

53 "could not get the batch size or dataset size from the dataloader passed to `TrainingManager().batch_loop()`. ", 

54 "pass either a `torch.utils.data.DataLoader` ", 

55 "or an iterable with a `batch_size: int` attribute and a `dataset: Iterable` attribute.", 

56 ) from e 

57 

58 # try to compute counters and finish init of TrainingManager 

59 manager.try_compute_counters() 

60 

61 # set up progress bar with tqdm 

62 # ------------------------------------------------------------ 

63 use_tqdm = ( 

64 use_tqdm 

65 if use_tqdm is not None # do what the user says 

66 else is_epoch # otherwise, use tqdm if we are the epoch loop 

67 ) 

68 

69 if use_tqdm: 

70 # tqdm kwargs with defaults 

71 _tqdm_kwargs: dict[str, Any] = dict( 

72 desc="training run" 

73 if is_epoch 

74 else f"epoch {manager.epochs + 1}/{manager.epochs_total}", 

75 unit=" epochs" if is_epoch else " batches", 

76 total=length, 

77 ) 

78 if tqdm_kwargs is not None: 

79 _tqdm_kwargs.update(tqdm_kwargs) 

80 

81 # wrap with tqdm 

82 sequence = tqdm.tqdm(sequence, **_tqdm_kwargs) 

83 

84 # yield the items, and update the manager 

85 # ------------------------------------------------------------ 

86 for item in sequence: 

87 yield item 

88 if is_epoch: 

89 manager.epoch_update() 

90 # no need to call batch_update, since the user has to call batch_update to log metrics 

91 

92 

93TLogger = TypeVar("TLogger", bound=TrainingLoggerBase) 

94 

95 

96class TrainingManager(Generic[TLogger]): 

97 """context manager for training a model, with logging, evals, and checkpoints 

98 

99 # Parameters: 

100 - `model : torch.nn.Module` 

101 ref to model being trained - used for saving checkpoints 

102 - `dataloader : torch.utils.data.DataLoader` 

103 ref to dataloader being used - used for calculating training progress 

104 - `logger : TrainingLoggerBase` 

105 logger, which can be local or interface with wandb. 

106 - `epochs : int` 

107 number of epochs to train for 

108 (defaults to `1`) 

109 - `evals : Iterable[tuple[TrainingInterval | str, EvalFunction]] | None` 

110 list of pairs of (interval, eval_fn) to run evals on the model. See `TrainingInterval` for interval options. 

111 (defaults to `None`) 

112 - `checkpoint_interval : TrainingInterval | str` 

113 interval at which to save model checkpoints 

114 (defaults to `TrainingInterval(1, "epochs")`) 

115 - `print_metrics_interval : TrainingInterval | str` 

116 interval at which to print metrics 

117 (defaults to `TrainingInterval(0.1, "runs")`) 

118 - `save_model : Callable[[torch.nn.Module, Path], None]` 

119 function to save the model (defaults to `torch.save`) 

120 (defaults to `torch.save`) 

121 - `model_save_path : str` 

122 format string for saving model checkpoints. uses `_get_format_kwargs` for formatting, along with an `alias` kwarg 

123 (defaults to `"{run_path}/checkpoints/model.checkpoint-{latest_checkpoint}.pt"`) 

124 - `model_save_path_special : str` 

125 format string for saving special model checkpoints (final, exception, etc.). uses `_get_format_kwargs` for formatting, along with an `alias` kwarg 

126 (defaults to `"{run_path}/model.{alias}.pt"`) 

127 

128 # Usage: 

129 ```python 

130 with TrainingManager( 

131 model=model, dataloader=TRAIN_LOADER, logger=logger, epochs=500, 

132 evals={ 

133 "1 epochs": eval_func, 

134 "0.1 runs": lambda model: logger.get_mem_usage(), 

135 }.items(), 

136 checkpoint_interval="50 epochs", 

137 ) as tp: 

138 

139 # Training loop 

140 model.train() 

141 for epoch in range(epochs): 

142 for inputs, targets in TRAIN_LOADER: 

143 # the usual 

144 optimizer.zero_grad() 

145 outputs = model(inputs) 

146 loss = criterion(outputs, targets) 

147 loss.backward() 

148 optimizer.step() 

149 

150 # compute accuracy 

151 accuracy = torch.sum(torch.argmax(outputs, dim=1) == targets).item() / len(targets) 

152 

153 # log metrics 

154 tp.batch_update( 

155 # pass in number of samples in your batch (or it will be inferred from the batch size) 

156 samples=len(targets), 

157 # any other metrics you compute every loop 

158 **{"train/loss": loss.item(), "train/acc": accuracy}, 

159 ) 

160 # batch_update will automatically run evals and save checkpoints as needed 

161 

162 tp.epoch_update() 

163 ``` 

164 

165 """ 

166 

167 def __init__( 

168 self, 

169 model: "torch.nn.Module", 

170 logger: TLogger, 

171 # required if you don't wrap the loops 

172 dataloader: "torch.utils.data.DataLoader|None" = None, 

173 epochs_total: int | None = None, 

174 save_model: Callable[["torch.nn.Module", Path], None] = torch.save, 

175 # everything with intervals 

176 evals: Iterable[tuple[CastableToTrainingInterval, EvalFunction]] | None = None, 

177 checkpoint_interval: CastableToTrainingInterval = TrainingInterval(1, "epochs"), 

178 print_metrics_interval: CastableToTrainingInterval = TrainingInterval( 

179 0.1, "runs" 

180 ), 

181 # everything with paths 

182 model_save_path: str = "{run_path}/checkpoints/model.checkpoint-{latest_checkpoint}.pt", 

183 model_save_path_special: str = "{run_path}/model.{alias}.pt", 

184 ): 

185 # save start time 

186 self.start_time: float = time.time() 

187 # non path and non-interval args get copied over directly 

188 self.model: "torch.nn.Module" = model 

189 self.logger: TLogger = logger 

190 self.save_model: Callable[["torch.nn.Module", Path], None] = save_model 

191 

192 self.logger.message("starting training manager initialization") 

193 

194 # model save paths 

195 self.model_save_path: str = model_save_path 

196 self.model_save_path_special: str = model_save_path_special 

197 

198 # temp intervals for processing later in `try_compute_counters` 

199 self._evals: Iterable[tuple[TrainingInterval, EvalFunction]] 

200 if evals is None: 

201 self._evals = [] 

202 else: 

203 self._evals = [ 

204 (TrainingInterval.from_any(interval), eval_fn) 

205 for interval, eval_fn in evals 

206 ] 

207 self._checkpoint_interval: TrainingInterval = TrainingInterval.from_any( 

208 checkpoint_interval 

209 ) 

210 self._print_metrics_interval: TrainingInterval = TrainingInterval.from_any( 

211 print_metrics_interval 

212 ) 

213 

214 self.evals: list[tuple[int, EvalFunction]] = list() 

215 self.checkpoint_interval: int | None = None 

216 self.print_metrics_interval: int | None = None 

217 

218 # counters for epochs, batches, samples, and checkpoints 

219 self.epochs: int = 0 

220 self.batches: int = 0 

221 self.samples: int = 0 

222 self.checkpoints: int = 0 

223 

224 # total numbers of epochs, batches, and samples 

225 # pass via init kwarg or wrapped epochs loop 

226 self.epochs_total: int | None = epochs_total 

227 # from dataloader or dataloader in wrapped loop 

228 self.batches_per_epoch: int | None = None 

229 self.batch_size: int | None = None 

230 self.samples_per_epoch: int | None = None 

231 # computed dynamically from the above 

232 self.batches_total: int | None = None 

233 self.samples_total: int | None = None 

234 

235 # whether the init is finished 

236 self.init_complete: bool = False 

237 

238 # if we have a dataloader, we can compute some of the above 

239 if dataloader is not None: 

240 self.batches_per_epoch = len(dataloader) 

241 self.batch_size = dataloader.batch_size 

242 self.samples_per_epoch = len(dataloader.dataset) # type: ignore[arg-type] 

243 

244 self.try_compute_counters() 

245 

246 def try_compute_counters(self) -> None: 

247 # we depend on either the TrainingManager init or the wrapped loops 

248 # getting the epochs_total and dataloader 

249 # everything else is computed dynamically 

250 

251 if any( 

252 x is None 

253 for x in [ 

254 self.epochs_total, 

255 self.batches_per_epoch, 

256 self.batch_size, 

257 self.samples_per_epoch, 

258 ] 

259 ): 

260 # if we don't have all the info we need, return early 

261 return 

262 

263 # we can safely ignore type check here since we just checked for `None` 

264 self.batches_total = self.batches_per_epoch * self.epochs_total # type: ignore[operator] 

265 self.samples_total = self.samples_per_epoch * self.epochs_total # type: ignore[operator] 

266 

267 # check if the dataloader has a finite nonzero length 

268 if self.samples_per_epoch == 0: 

269 raise TrainingManagerInitError( 

270 f"Dataloader has no samples. Please provide a dataloader with a non-zero length. {self.samples_per_epoch = }" 

271 ) 

272 

273 if self.batches_per_epoch == 0: 

274 raise TrainingManagerInitError( 

275 f"Dataloader has no batches. Please provide a dataloader with a non-zero length. {self.batches_per_epoch = }" 

276 ) 

277 

278 if self.batch_size == 0: 

279 raise TrainingManagerInitError( 

280 f"Dataloader has a batch size of 0. Please provide a dataloader with a non-zero batch size. {self.batch_size = }" 

281 ) 

282 

283 if self.batch_size is None: 

284 warnings.warn( 

285 "batch size is None. This is likely because the dataloader passed to `TrainingManager` does not have a `batch_size` attribute." 

286 + "\nthis should probably be an exception" 

287 ) 

288 

289 # normalize intervals for checkpoints, metrics printing, and evals 

290 _batch_info_kwargs: dict[str, int | None] = dict( 

291 batches_per_epoch=self.batches_per_epoch, 

292 batchsize=self.batch_size, 

293 epochs=self.epochs_total, 

294 ) 

295 

296 # TODO: no idea why we need `type: ignore[arg-type]` here 

297 self.checkpoint_interval = TrainingInterval.process_to_batches( 

298 interval=self._checkpoint_interval, 

299 **_batch_info_kwargs, # type: ignore[arg-type] 

300 ) 

301 self.print_metrics_interval = TrainingInterval.process_to_batches( 

302 interval=self._print_metrics_interval, 

303 **_batch_info_kwargs, # type: ignore[arg-type] 

304 ) 

305 

306 # list[tuple[int, EvalFunction]] 

307 self.evals = [ 

308 ( 

309 TrainingInterval.process_to_batches(interval, **_batch_info_kwargs), # type: ignore[arg-type] 

310 eval_fn, 

311 ) 

312 for interval, eval_fn in self._evals 

313 ] 

314 

315 # log this info 

316 self.init_complete = True 

317 self.logger.message( 

318 "initialized training manager", 

319 __training_manager_init__=True, 

320 epochs_total=self.epochs_total, 

321 batches_per_epoch=self.batches_per_epoch, 

322 batch_size=self.batch_size, 

323 samples_per_epoch=self.samples_per_epoch, 

324 samples_total=self.samples_total, 

325 checkpoint_interval_batches=self.checkpoint_interval, 

326 print_metrics_interval_batches=self.print_metrics_interval, 

327 model_save_path=self.model_save_path, 

328 model_save_path_special=self.model_save_path_special, 

329 **self.training_status(), 

330 ) 

331 

332 def __enter__(self): 

333 return self 

334 

335 def __exit__(self, exc_type: Type, exc_val: Exception, exc_tb: TracebackType): 

336 # if error 

337 if exc_type is not None: 

338 # add exception info 

339 self.logger.error( 

340 str(exc_val), 

341 exc_type=str(exc_type), 

342 exc_val=str(exc_val), 

343 exc_tb=str(exc_tb), 

344 ) 

345 # save the model 

346 self._save_checkpoint(alias="exception") 

347 

348 # close the logger 

349 self.logger.finish() 

350 else: 

351 # if no error, log and save the final model 

352 self.logger.message( 

353 "training complete", 

354 __complete__=True, 

355 ) 

356 self._save_checkpoint(alias="final") 

357 self.logger.finish() 

358 

359 def epoch_loop( 

360 self, 

361 epochs: Sequence[int], 

362 use_tqdm: bool = True, 

363 **tqdm_kwargs, 

364 ) -> Generator[int, None, None]: 

365 return wrapped_iterable( 

366 sequence=epochs, 

367 manager=self, 

368 is_epoch=True, 

369 use_tqdm=use_tqdm, 

370 tqdm_kwargs=tqdm_kwargs, 

371 ) 

372 

373 def batch_loop( 

374 self, 

375 batches: Sequence[int], 

376 use_tqdm: bool = False, 

377 **tqdm_kwargs, 

378 ) -> Generator[int, None, None]: 

379 return wrapped_iterable( 

380 sequence=batches, 

381 manager=self, 

382 is_epoch=False, 

383 use_tqdm=use_tqdm, 

384 tqdm_kwargs=tqdm_kwargs, 

385 ) 

386 

387 def check_is_initialized(self): 

388 if not self.init_complete: 

389 raise TrainingManagerInitError( 

390 "TrainingManager not correctly initialized. ", 

391 "This is likely due to failing to specify the epoch count, or failing to specify batch size/count. " 

392 "you must either wrap your epoch loop with `TrainingManager.epoch_loop` or specify `epochs_total`", 

393 "AND you must either wrap your batch loop with `TrainingManager.batch_loop` or pass a `torch.utils.data.DataLoader` to the TrainingManager constructor.", 

394 "please note, if not wrapping the epoch loop, you must also call `TrainingManager.epoch_update` at the end of each epoch.", 

395 ) 

396 

397 def get_elapsed_time(self) -> float: 

398 """return the elapsed time in seconds since the start of training""" 

399 return time.time() - self.start_time 

400 

401 def training_status(self) -> dict[str, int | float]: 

402 """status of elapsed time, samples, batches, epochs, and checkpoints""" 

403 return dict( 

404 # timestamp handled in logger 

405 elapsed_time=self.get_elapsed_time(), 

406 samples=self.samples, 

407 batches=self.batches, 

408 epochs=self.epochs, 

409 latest_checkpoint=self.checkpoints, 

410 ) 

411 

412 def _get_format_kwargs(self) -> dict[str, str | int | float]: 

413 """keyword args for formatting model save paths. calls `TrainingManager.training_status` 

414 

415 # Provides: 

416 - `run_path : str` - path where the run is being logged and artifacts are being saved 

417 - `elapsed_time : float` - the elapsed time in seconds since the start of training 

418 - `samples : int` - samples seen so far 

419 - `batches : int` - batches seen so far 

420 - `epochs : int` - the latest epoch number 

421 - `latest_checkpoint : int` - the latest checkpoint number 

422 

423 """ 

424 return { 

425 "run_path": ( 

426 self.logger.run_path.as_posix() 

427 if isinstance(self.logger.run_path, Path) 

428 # HACK: just return the first one 

429 else self.logger.run_path[0].as_posix() 

430 ), 

431 **self.training_status(), 

432 } 

433 

434 def batch_update(self, samples: int | None, metrics: dict | None = None, **kwargs): 

435 """call this at the end of every batch. Pass `samples` or it will be inferred from the batch size, and any other metrics as kwargs 

436 

437 This function will: 

438 - update internal counters 

439 - run evals as needed (based on the intervals passed) 

440 - log all metrics and training status 

441 - save a checkpoint as needed (based on the checkpoint interval) 

442 """ 

443 # check init is finished 

444 if not self.init_complete: 

445 self.try_compute_counters() 

446 self.check_is_initialized() 

447 

448 # process metrics and kwargs 

449 if metrics is None: 

450 metrics = dict() 

451 

452 metrics.update(kwargs) 

453 

454 # update counters 

455 self.batches += 1 

456 if samples is not None: 

457 self.samples += samples 

458 else: 

459 # TODO: we warn if batch size is None, but don't except 

460 self.samples += self.batch_size # type: ignore[operator] 

461 

462 # run evals if needed 

463 for interval, eval_fn in self.evals: 

464 if (self.batches % interval == 0) or (self.batches == self.batches_total): 

465 metrics.update(eval_fn(self.model)) 

466 

467 # log metrics & training status 

468 self.logger.metrics({**metrics, **self.training_status()}) 

469 

470 # print metrics if needed 

471 

472 # save checkpoint if needed 

473 if self.batches % self.checkpoint_interval == 0: # type: ignore[operator] 

474 self._save_checkpoint() 

475 

476 def epoch_update(self): 

477 """call this at the end of every epoch. This function will log the completion of the epoch and update the epoch counter""" 

478 self.logger.debug(f"completed epoch {self.epochs + 1}/{self.epochs_total}") 

479 self.epochs += 1 

480 

481 def _save_checkpoint(self, alias: str | None = None): 

482 """wrapper for saving checkpoint as an artifact to the logger and incrementing the checkpoint counter""" 

483 # if no alias, then it's a regular checkpoint 

484 no_alias: bool = alias is None 

485 if no_alias: 

486 alias = f"checkpoint-{self.checkpoints}" 

487 

488 # TODO: store training hist with model? 

489 

490 # put together a path 

491 checkpoint_path: Path 

492 if no_alias: 

493 # format the model save path for a normal checkpoint 

494 checkpoint_path = Path( 

495 self.model_save_path.format( 

496 **self._get_format_kwargs(), 

497 alias=alias, 

498 ) 

499 ) 

500 else: 

501 # for a special checkpoint, use the special path 

502 checkpoint_path = Path( 

503 self.model_save_path_special.format( 

504 **self._get_format_kwargs(), 

505 alias=alias, 

506 ) 

507 ) 

508 

509 # make sure directory exists 

510 checkpoint_path.parent.mkdir(parents=True, exist_ok=True) 

511 

512 # save the model 

513 self.save_model(self.model, checkpoint_path) 

514 

515 # log the checkpoint as an artifact 

516 self.logger.artifact( 

517 checkpoint_path, 

518 "model", 

519 metadata=self.training_status(), 

520 aliases=[alias] if alias else None, 

521 ) 

522 

523 # increment checkpoint counter 

524 self.checkpoints += 1