Coverage for tests\unit\test_tm_integration.py: 100%
86 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-17 02:23 -0700
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-17 02:23 -0700
1import torch
2import torch.nn as nn
3import torch.optim as optim
4from torch.utils.data import DataLoader
6from trnbl.training_manager import TrainingManager
7from trnbl.loggers.local import LocalLogger
9DATASET_LEN: int = 50
10BATCH_SIZE: int = 10
11N_EPOCHS: int = 5
14class Model(nn.Module):
15 def __init__(self) -> None:
16 super(Model, self).__init__()
17 self.fc: nn.Linear = nn.Linear(1, 1)
19 def forward(self, x: torch.Tensor) -> torch.Tensor:
20 return self.fc(x)
23class MockedDataset(torch.utils.data.Dataset):
24 def __init__(
25 self,
26 length: int,
27 channels: int = 2,
28 ) -> None:
29 self.dataset = torch.randn(length, channels, 1)
31 def __getitem__(self, idx: int):
32 return self.dataset[idx][0], self.dataset[idx][1]
34 def __len__(self):
35 return len(self.dataset)
38def test_tm_integration_epoch_wrapped_batch_wrapped():
39 model = Model()
40 optimizer = optim.SGD(model.parameters(), lr=0.1)
41 criterion = nn.MSELoss()
43 logger = LocalLogger(
44 project="integration-tests",
45 metric_names=["train/loss", "train/acc", "val/loss", "val/acc"],
46 train_config=dict(
47 model=str(model),
48 dataset="dummy",
49 optimizer=str(optimizer),
50 criterion=str(criterion),
51 ),
52 base_path="tests/_temp",
53 )
55 train_loader: DataLoader = DataLoader(
56 MockedDataset(DATASET_LEN), batch_size=BATCH_SIZE
57 )
59 with TrainingManager(
60 model=model,
61 logger=logger,
62 evals={
63 "1 epochs": lambda model: {"wgt_mean": torch.mean(model.fc.weight).item()},
64 "1/2 epochs": lambda model: logger.get_mem_usage(),
65 }.items(),
66 checkpoint_interval="2 epochs",
67 ) as tr:
68 # Training loop
69 for epoch in tr.epoch_loop(range(N_EPOCHS)):
70 for inputs, targets in tr.batch_loop(train_loader):
71 optimizer.zero_grad()
72 outputs = model(inputs)
73 loss = criterion(outputs, targets)
74 loss.backward()
75 optimizer.step()
77 accuracy = torch.sum(
78 torch.argmax(outputs, dim=1) == targets
79 ).item() / len(targets)
81 tr.batch_update(
82 samples=len(targets),
83 **{"train/loss": loss.item(), "train/acc": accuracy},
84 )
87def test_tm_integration_epoch_wrapped_batch_explicit():
88 model = Model()
89 optimizer = optim.SGD(model.parameters(), lr=0.1)
90 criterion = nn.MSELoss()
92 logger = LocalLogger(
93 project="integration-tests",
94 metric_names=["train/loss", "train/acc", "val/loss", "val/acc"],
95 train_config=dict(
96 model=str(model),
97 dataset="dummy",
98 optimizer=str(optimizer),
99 criterion=str(criterion),
100 ),
101 base_path="tests/_temp",
102 )
104 train_loader: DataLoader = DataLoader(
105 MockedDataset(DATASET_LEN), batch_size=BATCH_SIZE
106 )
108 with TrainingManager(
109 model=model,
110 dataloader=train_loader,
111 logger=logger,
112 evals={
113 "1 epochs": lambda model: {"wgt_mean": torch.mean(model.fc.weight).item()},
114 "1/2 epochs": lambda model: logger.get_mem_usage(),
115 }.items(),
116 checkpoint_interval="2 epochs",
117 ) as tr:
118 # Training loop
119 for epoch in tr.epoch_loop(range(N_EPOCHS)):
120 for inputs, targets in train_loader:
121 optimizer.zero_grad()
122 outputs = model(inputs)
123 loss = criterion(outputs, targets)
124 loss.backward()
125 optimizer.step()
127 accuracy = torch.sum(
128 torch.argmax(outputs, dim=1) == targets
129 ).item() / len(targets)
131 tr.batch_update(
132 samples=len(targets),
133 **{"train/loss": loss.item(), "train/acc": accuracy},
134 )
137def test_tm_integration_epoch_explicit_batch_wrapped():
138 model = Model()
139 optimizer = optim.SGD(model.parameters(), lr=0.1)
140 criterion = nn.MSELoss()
142 logger = LocalLogger(
143 project="integration-tests",
144 metric_names=["train/loss", "train/acc", "val/loss", "val/acc"],
145 train_config=dict(
146 model=str(model),
147 dataset="dummy",
148 optimizer=str(optimizer),
149 criterion=str(criterion),
150 ),
151 base_path="tests/_temp",
152 )
154 train_loader: DataLoader = DataLoader(
155 MockedDataset(DATASET_LEN), batch_size=BATCH_SIZE
156 )
158 with TrainingManager(
159 model=model,
160 epochs_total=N_EPOCHS,
161 logger=logger,
162 evals={
163 "1 epochs": lambda model: {"wgt_mean": torch.mean(model.fc.weight).item()},
164 "1/2 epochs": lambda model: logger.get_mem_usage(),
165 }.items(),
166 checkpoint_interval="2 epochs",
167 ) as tr:
168 # Training loop
169 for epoch in range(N_EPOCHS):
170 for inputs, targets in tr.batch_loop(train_loader):
171 optimizer.zero_grad()
172 outputs = model(inputs)
173 loss = criterion(outputs, targets)
174 loss.backward()
175 optimizer.step()
177 accuracy = torch.sum(
178 torch.argmax(outputs, dim=1) == targets
179 ).item() / len(targets)
181 tr.batch_update(
182 samples=len(targets),
183 **{"train/loss": loss.item(), "train/acc": accuracy},
184 )
187def test_tm_integration_epoch_explicit_batch_explicit():
188 model = Model()
189 optimizer = optim.SGD(model.parameters(), lr=0.1)
190 criterion = nn.MSELoss()
192 logger = LocalLogger(
193 project="integration-tests",
194 metric_names=["train/loss", "train/acc", "val/loss", "val/acc"],
195 train_config=dict(
196 model=str(model),
197 dataset="dummy",
198 optimizer=str(optimizer),
199 criterion=str(criterion),
200 ),
201 base_path="tests/_temp",
202 )
204 train_loader: DataLoader = DataLoader(
205 MockedDataset(DATASET_LEN), batch_size=BATCH_SIZE
206 )
208 with TrainingManager(
209 model=model,
210 dataloader=train_loader,
211 epochs_total=N_EPOCHS,
212 logger=logger,
213 evals={
214 "1 epochs": lambda model: {"wgt_mean": torch.mean(model.fc.weight).item()},
215 "1/2 epochs": lambda model: logger.get_mem_usage(),
216 }.items(),
217 checkpoint_interval="2 epochs",
218 ) as tr:
219 # Training loop
220 for epoch in range(N_EPOCHS):
221 for inputs, targets in train_loader:
222 optimizer.zero_grad()
223 outputs = model(inputs)
224 loss = criterion(outputs, targets)
225 loss.backward()
226 optimizer.step()
228 accuracy = torch.sum(
229 torch.argmax(outputs, dim=1) == targets
230 ).item() / len(targets)
232 tr.batch_update(
233 samples=len(targets),
234 **{"train/loss": loss.item(), "train/acc": accuracy},
235 )