Coverage for tests\unit\test_tm_integration.py: 100%

86 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-01-17 02:23 -0700

1import torch 

2import torch.nn as nn 

3import torch.optim as optim 

4from torch.utils.data import DataLoader 

5 

6from trnbl.training_manager import TrainingManager 

7from trnbl.loggers.local import LocalLogger 

8 

9DATASET_LEN: int = 50 

10BATCH_SIZE: int = 10 

11N_EPOCHS: int = 5 

12 

13 

14class Model(nn.Module): 

15 def __init__(self) -> None: 

16 super(Model, self).__init__() 

17 self.fc: nn.Linear = nn.Linear(1, 1) 

18 

19 def forward(self, x: torch.Tensor) -> torch.Tensor: 

20 return self.fc(x) 

21 

22 

23class MockedDataset(torch.utils.data.Dataset): 

24 def __init__( 

25 self, 

26 length: int, 

27 channels: int = 2, 

28 ) -> None: 

29 self.dataset = torch.randn(length, channels, 1) 

30 

31 def __getitem__(self, idx: int): 

32 return self.dataset[idx][0], self.dataset[idx][1] 

33 

34 def __len__(self): 

35 return len(self.dataset) 

36 

37 

38def test_tm_integration_epoch_wrapped_batch_wrapped(): 

39 model = Model() 

40 optimizer = optim.SGD(model.parameters(), lr=0.1) 

41 criterion = nn.MSELoss() 

42 

43 logger = LocalLogger( 

44 project="integration-tests", 

45 metric_names=["train/loss", "train/acc", "val/loss", "val/acc"], 

46 train_config=dict( 

47 model=str(model), 

48 dataset="dummy", 

49 optimizer=str(optimizer), 

50 criterion=str(criterion), 

51 ), 

52 base_path="tests/_temp", 

53 ) 

54 

55 train_loader: DataLoader = DataLoader( 

56 MockedDataset(DATASET_LEN), batch_size=BATCH_SIZE 

57 ) 

58 

59 with TrainingManager( 

60 model=model, 

61 logger=logger, 

62 evals={ 

63 "1 epochs": lambda model: {"wgt_mean": torch.mean(model.fc.weight).item()}, 

64 "1/2 epochs": lambda model: logger.get_mem_usage(), 

65 }.items(), 

66 checkpoint_interval="2 epochs", 

67 ) as tr: 

68 # Training loop 

69 for epoch in tr.epoch_loop(range(N_EPOCHS)): 

70 for inputs, targets in tr.batch_loop(train_loader): 

71 optimizer.zero_grad() 

72 outputs = model(inputs) 

73 loss = criterion(outputs, targets) 

74 loss.backward() 

75 optimizer.step() 

76 

77 accuracy = torch.sum( 

78 torch.argmax(outputs, dim=1) == targets 

79 ).item() / len(targets) 

80 

81 tr.batch_update( 

82 samples=len(targets), 

83 **{"train/loss": loss.item(), "train/acc": accuracy}, 

84 ) 

85 

86 

87def test_tm_integration_epoch_wrapped_batch_explicit(): 

88 model = Model() 

89 optimizer = optim.SGD(model.parameters(), lr=0.1) 

90 criterion = nn.MSELoss() 

91 

92 logger = LocalLogger( 

93 project="integration-tests", 

94 metric_names=["train/loss", "train/acc", "val/loss", "val/acc"], 

95 train_config=dict( 

96 model=str(model), 

97 dataset="dummy", 

98 optimizer=str(optimizer), 

99 criterion=str(criterion), 

100 ), 

101 base_path="tests/_temp", 

102 ) 

103 

104 train_loader: DataLoader = DataLoader( 

105 MockedDataset(DATASET_LEN), batch_size=BATCH_SIZE 

106 ) 

107 

108 with TrainingManager( 

109 model=model, 

110 dataloader=train_loader, 

111 logger=logger, 

112 evals={ 

113 "1 epochs": lambda model: {"wgt_mean": torch.mean(model.fc.weight).item()}, 

114 "1/2 epochs": lambda model: logger.get_mem_usage(), 

115 }.items(), 

116 checkpoint_interval="2 epochs", 

117 ) as tr: 

118 # Training loop 

119 for epoch in tr.epoch_loop(range(N_EPOCHS)): 

120 for inputs, targets in train_loader: 

121 optimizer.zero_grad() 

122 outputs = model(inputs) 

123 loss = criterion(outputs, targets) 

124 loss.backward() 

125 optimizer.step() 

126 

127 accuracy = torch.sum( 

128 torch.argmax(outputs, dim=1) == targets 

129 ).item() / len(targets) 

130 

131 tr.batch_update( 

132 samples=len(targets), 

133 **{"train/loss": loss.item(), "train/acc": accuracy}, 

134 ) 

135 

136 

137def test_tm_integration_epoch_explicit_batch_wrapped(): 

138 model = Model() 

139 optimizer = optim.SGD(model.parameters(), lr=0.1) 

140 criterion = nn.MSELoss() 

141 

142 logger = LocalLogger( 

143 project="integration-tests", 

144 metric_names=["train/loss", "train/acc", "val/loss", "val/acc"], 

145 train_config=dict( 

146 model=str(model), 

147 dataset="dummy", 

148 optimizer=str(optimizer), 

149 criterion=str(criterion), 

150 ), 

151 base_path="tests/_temp", 

152 ) 

153 

154 train_loader: DataLoader = DataLoader( 

155 MockedDataset(DATASET_LEN), batch_size=BATCH_SIZE 

156 ) 

157 

158 with TrainingManager( 

159 model=model, 

160 epochs_total=N_EPOCHS, 

161 logger=logger, 

162 evals={ 

163 "1 epochs": lambda model: {"wgt_mean": torch.mean(model.fc.weight).item()}, 

164 "1/2 epochs": lambda model: logger.get_mem_usage(), 

165 }.items(), 

166 checkpoint_interval="2 epochs", 

167 ) as tr: 

168 # Training loop 

169 for epoch in range(N_EPOCHS): 

170 for inputs, targets in tr.batch_loop(train_loader): 

171 optimizer.zero_grad() 

172 outputs = model(inputs) 

173 loss = criterion(outputs, targets) 

174 loss.backward() 

175 optimizer.step() 

176 

177 accuracy = torch.sum( 

178 torch.argmax(outputs, dim=1) == targets 

179 ).item() / len(targets) 

180 

181 tr.batch_update( 

182 samples=len(targets), 

183 **{"train/loss": loss.item(), "train/acc": accuracy}, 

184 ) 

185 

186 

187def test_tm_integration_epoch_explicit_batch_explicit(): 

188 model = Model() 

189 optimizer = optim.SGD(model.parameters(), lr=0.1) 

190 criterion = nn.MSELoss() 

191 

192 logger = LocalLogger( 

193 project="integration-tests", 

194 metric_names=["train/loss", "train/acc", "val/loss", "val/acc"], 

195 train_config=dict( 

196 model=str(model), 

197 dataset="dummy", 

198 optimizer=str(optimizer), 

199 criterion=str(criterion), 

200 ), 

201 base_path="tests/_temp", 

202 ) 

203 

204 train_loader: DataLoader = DataLoader( 

205 MockedDataset(DATASET_LEN), batch_size=BATCH_SIZE 

206 ) 

207 

208 with TrainingManager( 

209 model=model, 

210 dataloader=train_loader, 

211 epochs_total=N_EPOCHS, 

212 logger=logger, 

213 evals={ 

214 "1 epochs": lambda model: {"wgt_mean": torch.mean(model.fc.weight).item()}, 

215 "1/2 epochs": lambda model: logger.get_mem_usage(), 

216 }.items(), 

217 checkpoint_interval="2 epochs", 

218 ) as tr: 

219 # Training loop 

220 for epoch in range(N_EPOCHS): 

221 for inputs, targets in train_loader: 

222 optimizer.zero_grad() 

223 outputs = model(inputs) 

224 loss = criterion(outputs, targets) 

225 loss.backward() 

226 optimizer.step() 

227 

228 accuracy = torch.sum( 

229 torch.argmax(outputs, dim=1) == targets 

230 ).item() / len(targets) 

231 

232 tr.batch_update( 

233 samples=len(targets), 

234 **{"train/loss": loss.item(), "train/acc": accuracy}, 

235 )