Coverage for tests\unit\test_training_interval.py: 100%

152 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-01-17 02:23 -0700

1from typing import Union, Any 

2import pytest 

3from trnbl.training_interval import ( 

4 TrainingInterval, 

5 IntervalValueError, 

6 TrainingIntervalUnit, 

7) 

8 

9 

10def test_as_batch_count(): 

11 assert ( 

12 TrainingInterval(1, "runs").as_batch_count( 

13 batchsize=32, batches_per_epoch=100, epochs=10 

14 ) 

15 == 1000 

16 ) 

17 assert ( 

18 TrainingInterval(5, "epochs").as_batch_count( 

19 batchsize=32, batches_per_epoch=100 

20 ) 

21 == 500 

22 ) 

23 assert ( 

24 TrainingInterval(200, "batches").as_batch_count( 

25 batchsize=32, batches_per_epoch=100 

26 ) 

27 == 200 

28 ) 

29 assert ( 

30 TrainingInterval(6400, "samples").as_batch_count( 

31 batchsize=32, batches_per_epoch=100 

32 ) 

33 == 200 

34 ) 

35 

36 

37def test_normalized(): 

38 interval = TrainingInterval(1, "runs") 

39 normalized_interval = interval.normalized( 

40 batchsize=32, batches_per_epoch=100, epochs=10 

41 ) 

42 assert ( 

43 normalized_interval.quantity == 1000 and normalized_interval.unit == "batches" 

44 ) 

45 

46 interval = TrainingInterval(5, "epochs") 

47 normalized_interval = interval.normalized(batchsize=32, batches_per_epoch=100) 

48 assert normalized_interval.quantity == 500 and normalized_interval.unit == "batches" 

49 

50 

51def test_from_str(): 

52 assert TrainingInterval.from_str("5 epochs") == TrainingInterval(5, "epochs") 

53 assert TrainingInterval.from_str("100 batches") == TrainingInterval(100, "batches") 

54 assert TrainingInterval.from_str("0.1 runs") == TrainingInterval(0.1, "runs") 

55 assert TrainingInterval.from_str("1/5 runs") == TrainingInterval(0.2, "runs") 

56 

57 

58def test_from_any(): 

59 assert TrainingInterval.from_any("5 epochs") == TrainingInterval(5, "epochs") 

60 assert TrainingInterval.from_any("5", "epochs") == TrainingInterval(5, "epochs") 

61 assert TrainingInterval.from_any(("5", "epochs")) == TrainingInterval(5, "epochs") 

62 assert TrainingInterval.from_any(["5", "epochs"]) == TrainingInterval(5, "epochs") 

63 assert TrainingInterval.from_any(TrainingInterval(5, "epochs")) == TrainingInterval( 

64 5, "epochs" 

65 ) 

66 

67 assert TrainingInterval.from_any("100 batches") == TrainingInterval(100, "batches") 

68 assert TrainingInterval.from_any("100", "batches") == TrainingInterval( 

69 100, "batches" 

70 ) 

71 assert TrainingInterval.from_any(("100", "batches")) == TrainingInterval( 

72 100, "batches" 

73 ) 

74 assert TrainingInterval.from_any(["100", "batches"]) == TrainingInterval( 

75 100, "batches" 

76 ) 

77 assert TrainingInterval.from_any( 

78 TrainingInterval(100, "batches") 

79 ) == TrainingInterval(100, "batches") 

80 

81 assert TrainingInterval.from_any("0.1 runs") == TrainingInterval(0.1, "runs") 

82 assert TrainingInterval.from_any("0.1", "runs") == TrainingInterval(0.1, "runs") 

83 assert TrainingInterval.from_any(("0.1", "runs")) == TrainingInterval(0.1, "runs") 

84 assert TrainingInterval.from_any(["0.1", "runs"]) == TrainingInterval(0.1, "runs") 

85 assert TrainingInterval.from_any(TrainingInterval(0.1, "runs")) == TrainingInterval( 

86 0.1, "runs" 

87 ) 

88 

89 assert TrainingInterval.from_any("1/5 runs") == TrainingInterval(0.2, "runs") 

90 assert TrainingInterval.from_any("1/5", "runs") == TrainingInterval(0.2, "runs") 

91 assert TrainingInterval.from_any(("1/5", "runs")) == TrainingInterval(0.2, "runs") 

92 assert TrainingInterval.from_any(["1/5", "runs"]) == TrainingInterval(0.2, "runs") 

93 assert TrainingInterval.from_any( 

94 TrainingInterval(1 / 5, "runs") 

95 ) == TrainingInterval(0.2, "runs") 

96 

97 

98def test_process_to_batches(): 

99 assert ( 

100 TrainingInterval.process_to_batches( 

101 "5 epochs", batchsize=32, batches_per_epoch=100 

102 ) 

103 == 500 

104 ) 

105 assert ( 

106 TrainingInterval.process_to_batches( 

107 ("100", "batches"), batchsize=32, batches_per_epoch=100 

108 ) 

109 == 100 

110 ) 

111 assert ( 

112 TrainingInterval.process_to_batches( 

113 TrainingInterval(0.1, "runs"), 

114 batchsize=32, 

115 batches_per_epoch=100, 

116 epochs=10, 

117 ) 

118 == 100 

119 ) 

120 assert ( 

121 TrainingInterval.process_to_batches( 

122 ("1/5", "runs"), batchsize=32, batches_per_epoch=100, epochs=10 

123 ) 

124 == 200 

125 ) 

126 

127 

128def test_edge_cases(): 

129 with pytest.warns(IntervalValueError): 

130 assert ( 

131 TrainingInterval(0, "runs").as_batch_count( 

132 batchsize=32, batches_per_epoch=100, epochs=10 

133 ) 

134 == 1 

135 ) 

136 assert TrainingInterval(1e6, "batches").as_batch_count( 

137 batchsize=32, batches_per_epoch=100, epochs=10 

138 ) == int(1e6) 

139 assert ( 

140 TrainingInterval(14, "samples").as_batch_count( 

141 batchsize=10, batches_per_epoch=100, epochs=10 

142 ) 

143 == 1 

144 ) 

145 

146 

147def test_invalid_inputs(): 

148 with pytest.raises(ValueError): 

149 TrainingInterval.from_str("5 decades") 

150 

151 with pytest.raises(ValueError): 

152 TrainingInterval.from_any((100,)) 

153 

154 with pytest.raises(ValueError): 

155 TrainingInterval.from_any(123) 

156 

157 with pytest.raises(ValueError): 

158 TrainingInterval.from_any(("5", "epochs", "lol")) 

159 

160 with pytest.raises(ValueError): 

161 TrainingInterval.from_any("5", "epochs", "lol") 

162 

163 

164def test_boundary_cases(): 

165 assert ( 

166 TrainingInterval(1, "runs").as_batch_count( 

167 batchsize=1, batches_per_epoch=100, epochs=10 

168 ) 

169 == 1000 

170 ) 

171 assert ( 

172 TrainingInterval(0.9, "runs").as_batch_count( 

173 batchsize=1, batches_per_epoch=100, epochs=10 

174 ) 

175 == 900 

176 ) 

177 assert ( 

178 TrainingInterval(1, "epochs").as_batch_count(batchsize=32, batches_per_epoch=1) 

179 == 1 

180 ) 

181 assert ( 

182 TrainingInterval(1, "runs").as_batch_count( 

183 batchsize=32, batches_per_epoch=100, epochs=1 

184 ) 

185 == 100 

186 ) 

187 

188 

189def test_unpacking(): 

190 quantity, unit = TrainingInterval(5, "epochs") 

191 assert quantity == 5 and unit == "epochs" 

192 quantity, unit = TrainingInterval(100, "batches") 

193 assert quantity == 100 and unit == "batches" 

194 quantity, unit = TrainingInterval(0.1, "runs") 

195 assert quantity == 0.1 and unit == "runs" 

196 quantity, unit = TrainingInterval(1 / 12, "runs") 

197 assert quantity == 1 / 12 and unit == "runs" 

198 

199 

200@pytest.mark.parametrize( 

201 "quantity, unit", 

202 [ 

203 (0.1, "runs"), 

204 (0.1, "epochs"), 

205 (0.0001, "runs"), 

206 (1e-10, "epochs"), 

207 ], 

208) 

209def test_very_small_values( 

210 quantity: Union[int, float], unit: TrainingIntervalUnit 

211) -> None: 

212 interval = TrainingInterval(quantity, unit) 

213 assert interval.quantity == quantity 

214 assert interval.unit == unit 

215 

216 

217def test_zero_samples() -> None: 

218 with pytest.warns(IntervalValueError): 

219 TrainingInterval(0, "samples") 

220 

221 

222@pytest.mark.parametrize("quantity", [0.51, 0.9, 1.1, 1.49]) 

223def test_samples_rounding(quantity: float) -> None: 

224 if quantity < 1: 

225 with pytest.warns(IntervalValueError): 

226 interval = TrainingInterval(quantity, "samples") 

227 else: 

228 interval = TrainingInterval(quantity, "samples") 

229 assert interval.quantity == 1 

230 assert interval.unit == "samples" 

231 

232 

233@pytest.mark.parametrize( 

234 "quantity, unit, batchsize, batches_per_epoch, epochs, expected", 

235 [ 

236 (1, "samples", 32, 100, 10, 1), 

237 (0.000001, "runs", 32, 100, 10, 1), 

238 (0.0001, "epochs", 32, 100, 10, 1), 

239 (1e-10, "runs", 32, 100, 10, 1), 

240 (1e-10, "epochs", 32, 100, 10, 1), 

241 ], 

242) 

243def test_as_batch_count_edge_cases( 

244 quantity: Union[int, float], 

245 unit: TrainingIntervalUnit, 

246 batchsize: int, 

247 batches_per_epoch: int, 

248 epochs: int, 

249 expected: int, 

250) -> None: 

251 interval = TrainingInterval(quantity, unit) 

252 with pytest.warns(IntervalValueError): 

253 result = interval.as_batch_count(batchsize, batches_per_epoch, epochs) 

254 assert result == expected, f"Expected {expected}, but got {result} for {interval}" 

255 

256 

257def test_as_batch_count_without_epochs() -> None: 

258 interval = TrainingInterval(0.1, "runs") 

259 with pytest.raises(AssertionError): 

260 interval.as_batch_count(32, 100) 

261 

262 

263@pytest.mark.parametrize( 

264 "input_data, expected", 

265 [ 

266 ("0.1 runs", (0.1, "runs")), 

267 ("0.1 epochs", (0.1, "epochs")), 

268 ("1 batches", (1, "batches")), 

269 ("0.1 runs", (0.1, "runs")), 

270 ("1/1000 epochs", (0.001, "epochs")), 

271 ], 

272) 

273def test_from_str_edge_cases( 

274 input_data: str, expected: tuple[float | int, TrainingIntervalUnit] 

275) -> None: 

276 result = TrainingInterval.from_str(input_data) 

277 assert result == TrainingInterval(*expected), ( 

278 f"Expected {expected}, but got {result} for input '{input_data}'" 

279 ) 

280 

281 

282@pytest.mark.parametrize( 

283 "input_data", 

284 [ 

285 "invalid unit", 

286 "1.5.5 epochs", 

287 "123", 

288 "1/2/3 batches", 

289 "0.0.0 batches", 

290 "ten samples", 

291 "1/2/3 samples", 

292 ], 

293) 

294def test_from_str_invalid_inputs(input_data: str) -> None: 

295 with pytest.raises(ValueError): 

296 TrainingInterval.from_str(input_data) 

297 

298 

299@pytest.mark.parametrize( 

300 "input_data, expected", 

301 [ 

302 ((0.1, "runs"), (0.1, "runs")), 

303 (["0.1", "epochs"], (0.1, "epochs")), 

304 ("0.1 runs", (0.1, "runs")), 

305 (("1/1000", "epochs"), (0.001, "epochs")), 

306 ], 

307) 

308def test_from_any_edge_cases_nowarn( 

309 input_data: Any, expected: tuple[float | int, TrainingIntervalUnit] 

310) -> None: 

311 "no warnings because batchsize is unknown" 

312 result = TrainingInterval.from_any(input_data) 

313 assert result == TrainingInterval(*expected), ( 

314 f"Expected {expected}, but got {result} for input {input_data}" 

315 ) 

316 

317 

318@pytest.mark.parametrize( 

319 "input_data, expected", 

320 [ 

321 ((1e-10, "batches"), (1, "batches")), 

322 ((1e-10, "batches"), (1, "batches")), 

323 ((0, "batches"), (1, "batches")), 

324 ((0, "batches"), (1, "batches")), 

325 (("1/2 batches"), (1, "batches")), 

326 ("0.0 batches", (1, "batches")), 

327 ((0, "samples"), (1, "samples")), 

328 ], 

329) 

330def test_from_any_edge_cases_warn( 

331 input_data: Any, expected: tuple[float | int, TrainingIntervalUnit] 

332) -> None: 

333 "no warnings because batchsize is unknown" 

334 with pytest.warns(IntervalValueError): 

335 result = TrainingInterval.from_any(input_data) 

336 assert result == TrainingInterval(*expected), ( 

337 f"Expected {expected}, but got {result} for input {input_data}" 

338 ) 

339 

340 

341@pytest.mark.parametrize( 

342 "input_data", 

343 [ 

344 (0, "potatoes"), 

345 "invalid unit", 

346 (1.5, 5, "epochs"), 

347 123, 

348 ("1", "batches", "lol"), 

349 ], 

350) 

351def test_from_any_invalid_inputs(input_data: Any) -> None: 

352 with pytest.raises(ValueError): 

353 TrainingInterval.from_any(input_data) 

354 

355 

356@pytest.mark.parametrize( 

357 "interval, batchsize, batches_per_epoch, epochs, expected", 

358 [ 

359 ("0 runs", 32, 100, 10, 1), 

360 ("1e-10 epochs", 32, 100, 10, 1), 

361 ("0.1 batches", 32, 100, 10, 1), 

362 ("1 samples", 32, 100, 10, 1), 

363 ], 

364) 

365def test_process_to_batches_edge_cases( 

366 interval: Union[str, tuple, TrainingInterval], 

367 batchsize: int, 

368 batches_per_epoch: int, 

369 epochs: int, 

370 expected: int, 

371) -> None: 

372 with pytest.warns(IntervalValueError): 

373 result = TrainingInterval.process_to_batches( 

374 interval, batchsize, batches_per_epoch, epochs 

375 ) 

376 assert result == expected, f"Expected {expected}, but got {result} for {interval}" 

377 

378 

379def test_normalization_edge_cases() -> None: 

380 interval = TrainingInterval(0.1, "runs") 

381 normalized = interval.normalized(batchsize=32, batches_per_epoch=100, epochs=10) 

382 assert normalized.quantity == 100 

383 assert normalized.unit == "batches" 

384 

385 interval = TrainingInterval(1e-10, "epochs") 

386 with pytest.warns(IntervalValueError): 

387 normalized = interval.normalized(batchsize=32, batches_per_epoch=100) 

388 assert normalized.quantity == 1 

389 assert normalized.unit == "batches" 

390 

391 

392def test_equality_edge_cases() -> None: 

393 assert TrainingInterval(0.1, "runs") == TrainingInterval(0.1, "runs") 

394 assert TrainingInterval(0.1, "runs") != TrainingInterval(0.1, "epochs") 

395 

396 with pytest.warns(IntervalValueError): 

397 assert TrainingInterval(1e-10, "batches") == TrainingInterval(1, "batches") 

398 

399 

400def test_iteration_and_indexing() -> None: 

401 interval = TrainingInterval(0.1, "runs") 

402 quantity, unit = interval 

403 assert quantity == 0.1 

404 assert unit == "runs" 

405 

406 assert interval[0] == 0.1 

407 assert interval[1] == "runs" 

408 

409 with pytest.raises(IndexError): 

410 _ = interval[2]