Coverage for tests\unit\test_training_manager.py: 98%
216 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-17 02:23 -0700
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-17 02:23 -0700
1from typing import Dict, Any, Union, Callable
2import time
3from unittest.mock import MagicMock
4from pathlib import Path
6import torch
7import torch.nn as nn
8import torch.optim as optim
9from torch.utils.data import DataLoader, TensorDataset
10import pytest
12from trnbl.training_interval import IntervalValueError
13from trnbl.training_manager import TrainingManager, TrainingManagerInitError
14from trnbl.loggers.local import LocalLogger
15from trnbl.loggers.base import TrainingLoggerBase
17# Temporary directory for testing
18TEMP_PATH: Path = Path("tests/_temp")
19TEMP_PATH.mkdir(parents=True, exist_ok=True)
22# Define a simple model
23class SimpleModel(nn.Module):
24 def __init__(self) -> None:
25 super(SimpleModel, self).__init__()
26 self.fc: nn.Linear = nn.Linear(10, 1)
28 def forward(self, x: torch.Tensor) -> torch.Tensor:
29 return self.fc(x)
32# Dummy dataset and dataloader
33inputs: torch.Tensor = torch.randn(100, 10)
34targets: torch.Tensor = torch.randn(100, 1)
35dataset: TensorDataset = TensorDataset(inputs, targets)
36dataloader: DataLoader = DataLoader(dataset, batch_size=10)
38# Logger configuration
39logger_config: Dict[str, Any] = {
40 "project": "test_project",
41 "metric_names": ["loss"],
42 "group": "test_run",
43 "train_config": {"batch_size": 10, "learning_rate": 0.001, "epochs": 10},
44 "base_path": TEMP_PATH,
45}
47# Initialize the model, criterion, and optimizer
48model: nn.Module = SimpleModel()
49criterion: nn.Module = nn.MSELoss()
50optimizer: optim.Optimizer = optim.SGD(model.parameters(), lr=0.001)
53def test_training_manager_initialization():
54 logger = LocalLogger(**logger_config)
55 training_manager = TrainingManager(
56 model=model, dataloader=dataloader, logger=logger, epochs_total=10
57 )
58 assert training_manager.model == model
59 assert training_manager.logger == logger
60 assert training_manager.epochs_total == 10
61 training_manager.logger.finish()
64def test_training_manager_batch_update():
65 logger = LocalLogger(**logger_config)
66 training_manager = TrainingManager(
67 model=model, dataloader=dataloader, logger=logger, epochs_total=1
68 )
69 training_manager._save_checkpoint = MagicMock()
71 # Simulate a training batch update
72 inputs, targets = next(iter(dataloader))
73 outputs = model(inputs)
74 loss = criterion(outputs, targets)
75 training_manager.batch_update(samples=len(targets), train_loss=loss.item())
77 # Check if metrics were logged
78 assert len(training_manager.logger.metrics_list) > 0
79 assert training_manager.logger.metrics_list[-1]["train_loss"] == loss.item()
81 # Check if a checkpoint was saved (based on interval)
82 if training_manager.batches % training_manager.checkpoint_interval == 0:
83 training_manager._save_checkpoint.assert_called()
85 training_manager.logger.finish()
88def test_training_manager_epoch_update():
89 logger = LocalLogger(**logger_config)
90 training_manager = TrainingManager(
91 model=model, dataloader=dataloader, logger=logger, epochs_total=1
92 )
94 # Simulate an epoch update
95 training_manager.epoch_update()
96 assert training_manager.epochs == 1
97 assert len(training_manager.logger.log_list) > 0
98 assert training_manager.logger.log_list[-1]["message"] == "completed epoch 1/1"
99 training_manager.logger.finish()
102def test_training_manager_checkpoint_saving():
103 logger = LocalLogger(**logger_config)
104 training_manager = TrainingManager(
105 model=model, dataloader=dataloader, logger=logger, epochs_total=1
106 )
107 training_manager._save_checkpoint(alias="test_checkpoint")
109 # Check if the checkpoint artifact was logged
110 assert len(training_manager.logger.artifacts_list) > 0
111 assert "test_checkpoint" in training_manager.logger.artifacts_list[-1]["aliases"]
112 training_manager.logger.finish()
115@pytest.fixture
116def training_manager() -> TrainingManager:
117 logger: LocalLogger = LocalLogger(**logger_config)
118 return TrainingManager(
119 model=model, dataloader=dataloader, logger=logger, epochs_total=1
120 )
123def test_training_manager_initialization_comprehensive(
124 training_manager: TrainingManager,
125) -> None:
126 assert isinstance(training_manager.model, nn.Module)
127 assert isinstance(training_manager.logger, TrainingLoggerBase)
128 assert training_manager.epochs_total == 1
129 assert training_manager.epochs == 0
130 assert training_manager.batches_per_epoch == len(dataloader)
131 assert training_manager.batch_size == dataloader.batch_size
132 assert training_manager.batches_total == len(dataloader)
133 assert training_manager.batches == 0
134 # TODO: why does it think a `Dataset[Any]` has no length?
135 assert training_manager.samples_per_epoch == len(dataloader.dataset) # type: ignore[arg-type]
136 assert training_manager.samples_total == len(dataloader.dataset) # type: ignore[arg-type]
137 assert training_manager.samples == 0
138 assert training_manager.checkpoints == 0
141def test_training_manager_enter(training_manager: TrainingManager) -> None:
142 with training_manager as tm:
143 assert tm == training_manager
146def test_training_manager_exit_normal(training_manager: TrainingManager) -> None:
147 training_manager._save_checkpoint = MagicMock() # type: ignore[method-assign]
148 training_manager.logger.finish = MagicMock() # type: ignore[method-assign]
150 with training_manager:
151 pass
153 training_manager._save_checkpoint.assert_called_with(alias="final")
154 training_manager.logger.finish.assert_called_once()
157def test_training_manager_exit_exception(training_manager: TrainingManager) -> None:
158 training_manager._save_checkpoint = MagicMock() # type: ignore[method-assign]
159 training_manager.logger.error = MagicMock() # type: ignore[method-assign]
160 training_manager.logger.finish = MagicMock() # type: ignore[method-assign]
162 with pytest.raises(ValueError):
163 with training_manager:
164 raise ValueError("Test exception")
166 training_manager._save_checkpoint.assert_called_with(alias="exception")
167 training_manager.logger.error.assert_called_once()
168 training_manager.logger.finish.assert_called_once()
171def test_training_manager_get_elapsed_time(training_manager: TrainingManager) -> None:
172 start_time: float = time.time()
173 training_manager.start_time = start_time
174 training_manager.get_elapsed_time()
177def test_training_manager_training_status(training_manager: TrainingManager) -> None:
178 status: Dict[str, Union[int, float]] = training_manager.training_status()
179 assert all(
180 key in status
181 for key in [
182 "elapsed_time",
183 "samples",
184 "batches",
185 "epochs",
186 "latest_checkpoint",
187 ]
188 )
189 assert all(isinstance(value, (int, float)) for value in status.values())
192def test_training_manager_get_format_kwargs(training_manager: TrainingManager) -> None:
193 kwargs: Dict[str, Union[str, int, float]] = training_manager._get_format_kwargs()
194 assert all(
195 key in kwargs
196 for key in [
197 "run_path",
198 "elapsed_time",
199 "samples",
200 "batches",
201 "epochs",
202 "latest_checkpoint",
203 ]
204 )
205 assert isinstance(kwargs["run_path"], str)
206 assert all(
207 isinstance(value, (int, float))
208 for key, value in kwargs.items()
209 if key != "run_path"
210 )
213def test_training_manager_batch_update_new(training_manager: TrainingManager) -> None:
214 training_manager._save_checkpoint = MagicMock() # type: ignore[method-assign]
216 initial_samples: int = training_manager.samples
217 initial_batches: int = training_manager.batches
219 training_manager.batch_update(samples=10, loss=0.5)
221 assert training_manager.samples == initial_samples + 10
222 assert training_manager.batches == initial_batches + 1
223 assert len(training_manager.logger.metrics_list) > 0
224 assert training_manager.logger.metrics_list[-1]["loss"] == 0.5
226 # ok to ignore because the test will just fail?
227 if training_manager.batches % training_manager.checkpoint_interval == 0: # type: ignore[operator]
228 training_manager._save_checkpoint.assert_called_once()
231def test_training_manager_epoch_update_new(training_manager: TrainingManager) -> None:
232 initial_epochs: int = training_manager.epochs
234 training_manager.epoch_update()
236 assert training_manager.epochs == initial_epochs + 1
237 assert len(training_manager.logger.log_list) > 0
238 assert (
239 f"completed epoch {initial_epochs + 1}/{training_manager.epochs_total}"
240 in training_manager.logger.log_list[-1]["message"]
241 )
244def test_training_manager_save_checkpoint(training_manager: TrainingManager) -> None:
245 initial_checkpoints: int = training_manager.checkpoints
247 training_manager._save_checkpoint(alias="test_checkpoint")
249 assert training_manager.checkpoints == initial_checkpoints + 1
250 assert len(training_manager.logger.artifacts_list) > 0
251 assert "test_checkpoint" in training_manager.logger.artifacts_list[-1]["aliases"]
254def test_training_manager_full_training_loop() -> None:
255 logger: LocalLogger = LocalLogger(**logger_config)
256 training_manager: TrainingManager = TrainingManager(
257 model=model,
258 dataloader=dataloader,
259 logger=logger,
260 epochs_total=2,
261 checkpoint_interval="1 epochs",
262 evals=[
263 ("1 epochs", lambda m: {"eval_loss": criterion(m(inputs), targets).item()})
264 ],
265 )
267 with training_manager as tm:
268 for epoch in range(2):
269 for batch_inputs, batch_targets in dataloader:
270 outputs: torch.Tensor = model(batch_inputs)
271 loss: torch.Tensor = criterion(outputs, batch_targets)
272 optimizer.zero_grad()
273 loss.backward()
274 optimizer.step()
276 tm.batch_update(samples=len(batch_targets), train_loss=loss.item())
278 tm.epoch_update()
280 assert tm.epochs_total == 2
281 assert tm.batches == 20 # 2 epochs * 10 batches per epoch
282 assert tm.samples == 200 # 2 epochs * 100 samples per epoch
283 assert tm.checkpoints == 3 # 1 checkpoint per epoch
284 assert len(tm.logger.metrics_list) == 20 # 1 metric log per batch
285 assert (
286 len(tm.logger.log_list) >= 4
287 ) # At least 1 log for init, 2 for epoch completions, 1 for training complete
288 assert (
289 len(tm.logger.artifacts_list) == 3
290 ) # 2 epoch checkpoints + 1 final checkpoint
293def test_training_manager_zero_epochs() -> None:
294 logger: LocalLogger = LocalLogger(**logger_config)
295 with pytest.warns(IntervalValueError):
296 TrainingManager(
297 model=model, dataloader=dataloader, logger=logger, epochs_total=0
298 )
301def test_training_manager_negative_epochs() -> None:
302 logger: LocalLogger = LocalLogger(**logger_config)
303 with pytest.warns(IntervalValueError):
304 TrainingManager(
305 model=model, dataloader=dataloader, logger=logger, epochs_total=-1
306 )
309def test_training_manager_custom_save_model() -> None:
310 logger: LocalLogger = LocalLogger(**logger_config)
311 custom_save_model: Callable[[nn.Module, Path], None] = MagicMock()
312 training_manager: TrainingManager = TrainingManager(
313 model=model,
314 dataloader=dataloader,
315 logger=logger,
316 epochs_total=1,
317 save_model=custom_save_model,
318 )
319 training_manager._save_checkpoint()
320 # ignoring here because custom_save_model is a mock
321 custom_save_model.assert_called_once() # type: ignore[attr-defined]
324def test_training_manager_custom_intervals() -> None:
325 logger: LocalLogger = LocalLogger(**logger_config)
326 training_manager: TrainingManager = TrainingManager(
327 model=model,
328 dataloader=dataloader,
329 logger=logger,
330 epochs_total=1,
331 checkpoint_interval="0.5 epochs",
332 print_metrics_interval="0.25 epochs",
333 evals=[("0.1 epochs", lambda m: {"eval_loss": 0.5})],
334 )
335 assert training_manager.checkpoint_interval == len(dataloader) // 2
336 assert training_manager.print_metrics_interval == len(dataloader) // 4
337 assert training_manager.evals[0][0] == len(dataloader) // 10
340def test_training_manager_custom_model_save_paths() -> None:
341 logger: LocalLogger = LocalLogger(**logger_config)
342 custom_path: str = "{run_path}/custom_checkpoints/model-{latest_checkpoint}.pt"
343 custom_special_path: str = "{run_path}/custom_special/model-{alias}.pt"
344 training_manager: TrainingManager = TrainingManager(
345 model=model,
346 dataloader=dataloader,
347 logger=logger,
348 epochs_total=1,
349 model_save_path=custom_path,
350 model_save_path_special=custom_special_path,
351 )
352 assert training_manager.model_save_path == custom_path
353 assert training_manager.model_save_path_special == custom_special_path
356def test_training_manager_batch_update_no_samples() -> None:
357 logger: LocalLogger = LocalLogger(**logger_config)
358 training_manager: TrainingManager = TrainingManager(
359 model=model, dataloader=dataloader, logger=logger, epochs_total=1
360 )
361 initial_samples: int = training_manager.samples
362 training_manager.batch_update(samples=None, loss=0.5)
363 # ok to ignore because the test will just fail?
364 assert training_manager.samples == initial_samples + training_manager.batch_size # type: ignore[operator]
367def test_training_manager_multiple_evals() -> None:
368 logger: LocalLogger = LocalLogger(**logger_config)
369 eval1: Callable[[nn.Module], Dict[str, float]] = lambda m: {"eval1": 0.5} # noqa: E731
370 eval2: Callable[[nn.Module], Dict[str, float]] = lambda m: {"eval2": 0.7} # noqa: E731
371 training_manager: TrainingManager = TrainingManager(
372 model=model,
373 dataloader=dataloader,
374 logger=logger,
375 epochs_total=1,
376 evals=[("1 batch", eval1), ("2 batches", eval2)],
377 )
378 training_manager.batch_update(samples=10, loss=0.3)
379 assert "eval1" in training_manager.logger.metrics_list[-1]
380 assert "eval2" not in training_manager.logger.metrics_list[-1]
381 training_manager.batch_update(samples=10, loss=0.3)
382 assert "eval1" in training_manager.logger.metrics_list[-1]
383 assert "eval2" in training_manager.logger.metrics_list[-1]
386@pytest.mark.parametrize(
387 "interval, expected",
388 [
389 ("1 epochs", len(dataloader)),
390 ("0.5 epochs", len(dataloader) // 2),
391 ("10 batches", 10),
392 ("10 samples", 1),
393 ],
394)
395def test_training_manager_interval_processing(interval: str, expected: int) -> None:
396 logger: LocalLogger = LocalLogger(**logger_config)
397 training_manager: TrainingManager = TrainingManager(
398 model=model,
399 dataloader=dataloader,
400 logger=logger,
401 epochs_total=1,
402 checkpoint_interval=interval,
403 )
404 assert training_manager.checkpoint_interval == expected
407def test_training_manager_empty_dataloader() -> None:
408 empty_dataloader: DataLoader = DataLoader(
409 TensorDataset(torch.Tensor([]), torch.Tensor([])), batch_size=1
410 )
411 logger: LocalLogger = LocalLogger(**logger_config)
412 with pytest.raises(TrainingManagerInitError):
413 TrainingManager(
414 model=model, dataloader=empty_dataloader, logger=logger, epochs_total=1
415 )
418def test_training_manager_0_batchsize() -> None:
419 with pytest.raises(ValueError):
420 empty_dataloader: DataLoader = DataLoader(
421 TensorDataset(torch.Tensor([]), torch.Tensor([])), batch_size=0
422 )
423 logger: LocalLogger = LocalLogger(**logger_config)
424 TrainingManager(
425 model=model, dataloader=empty_dataloader, logger=logger, epochs_total=1
426 )