Coverage for tests\unit\test_training_interval.py: 100%
152 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-17 02:23 -0700
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-17 02:23 -0700
1from typing import Union, Any
2import pytest
3from trnbl.training_interval import (
4 TrainingInterval,
5 IntervalValueError,
6 TrainingIntervalUnit,
7)
10def test_as_batch_count():
11 assert (
12 TrainingInterval(1, "runs").as_batch_count(
13 batchsize=32, batches_per_epoch=100, epochs=10
14 )
15 == 1000
16 )
17 assert (
18 TrainingInterval(5, "epochs").as_batch_count(
19 batchsize=32, batches_per_epoch=100
20 )
21 == 500
22 )
23 assert (
24 TrainingInterval(200, "batches").as_batch_count(
25 batchsize=32, batches_per_epoch=100
26 )
27 == 200
28 )
29 assert (
30 TrainingInterval(6400, "samples").as_batch_count(
31 batchsize=32, batches_per_epoch=100
32 )
33 == 200
34 )
37def test_normalized():
38 interval = TrainingInterval(1, "runs")
39 normalized_interval = interval.normalized(
40 batchsize=32, batches_per_epoch=100, epochs=10
41 )
42 assert (
43 normalized_interval.quantity == 1000 and normalized_interval.unit == "batches"
44 )
46 interval = TrainingInterval(5, "epochs")
47 normalized_interval = interval.normalized(batchsize=32, batches_per_epoch=100)
48 assert normalized_interval.quantity == 500 and normalized_interval.unit == "batches"
51def test_from_str():
52 assert TrainingInterval.from_str("5 epochs") == TrainingInterval(5, "epochs")
53 assert TrainingInterval.from_str("100 batches") == TrainingInterval(100, "batches")
54 assert TrainingInterval.from_str("0.1 runs") == TrainingInterval(0.1, "runs")
55 assert TrainingInterval.from_str("1/5 runs") == TrainingInterval(0.2, "runs")
58def test_from_any():
59 assert TrainingInterval.from_any("5 epochs") == TrainingInterval(5, "epochs")
60 assert TrainingInterval.from_any("5", "epochs") == TrainingInterval(5, "epochs")
61 assert TrainingInterval.from_any(("5", "epochs")) == TrainingInterval(5, "epochs")
62 assert TrainingInterval.from_any(["5", "epochs"]) == TrainingInterval(5, "epochs")
63 assert TrainingInterval.from_any(TrainingInterval(5, "epochs")) == TrainingInterval(
64 5, "epochs"
65 )
67 assert TrainingInterval.from_any("100 batches") == TrainingInterval(100, "batches")
68 assert TrainingInterval.from_any("100", "batches") == TrainingInterval(
69 100, "batches"
70 )
71 assert TrainingInterval.from_any(("100", "batches")) == TrainingInterval(
72 100, "batches"
73 )
74 assert TrainingInterval.from_any(["100", "batches"]) == TrainingInterval(
75 100, "batches"
76 )
77 assert TrainingInterval.from_any(
78 TrainingInterval(100, "batches")
79 ) == TrainingInterval(100, "batches")
81 assert TrainingInterval.from_any("0.1 runs") == TrainingInterval(0.1, "runs")
82 assert TrainingInterval.from_any("0.1", "runs") == TrainingInterval(0.1, "runs")
83 assert TrainingInterval.from_any(("0.1", "runs")) == TrainingInterval(0.1, "runs")
84 assert TrainingInterval.from_any(["0.1", "runs"]) == TrainingInterval(0.1, "runs")
85 assert TrainingInterval.from_any(TrainingInterval(0.1, "runs")) == TrainingInterval(
86 0.1, "runs"
87 )
89 assert TrainingInterval.from_any("1/5 runs") == TrainingInterval(0.2, "runs")
90 assert TrainingInterval.from_any("1/5", "runs") == TrainingInterval(0.2, "runs")
91 assert TrainingInterval.from_any(("1/5", "runs")) == TrainingInterval(0.2, "runs")
92 assert TrainingInterval.from_any(["1/5", "runs"]) == TrainingInterval(0.2, "runs")
93 assert TrainingInterval.from_any(
94 TrainingInterval(1 / 5, "runs")
95 ) == TrainingInterval(0.2, "runs")
98def test_process_to_batches():
99 assert (
100 TrainingInterval.process_to_batches(
101 "5 epochs", batchsize=32, batches_per_epoch=100
102 )
103 == 500
104 )
105 assert (
106 TrainingInterval.process_to_batches(
107 ("100", "batches"), batchsize=32, batches_per_epoch=100
108 )
109 == 100
110 )
111 assert (
112 TrainingInterval.process_to_batches(
113 TrainingInterval(0.1, "runs"),
114 batchsize=32,
115 batches_per_epoch=100,
116 epochs=10,
117 )
118 == 100
119 )
120 assert (
121 TrainingInterval.process_to_batches(
122 ("1/5", "runs"), batchsize=32, batches_per_epoch=100, epochs=10
123 )
124 == 200
125 )
128def test_edge_cases():
129 with pytest.warns(IntervalValueError):
130 assert (
131 TrainingInterval(0, "runs").as_batch_count(
132 batchsize=32, batches_per_epoch=100, epochs=10
133 )
134 == 1
135 )
136 assert TrainingInterval(1e6, "batches").as_batch_count(
137 batchsize=32, batches_per_epoch=100, epochs=10
138 ) == int(1e6)
139 assert (
140 TrainingInterval(14, "samples").as_batch_count(
141 batchsize=10, batches_per_epoch=100, epochs=10
142 )
143 == 1
144 )
147def test_invalid_inputs():
148 with pytest.raises(ValueError):
149 TrainingInterval.from_str("5 decades")
151 with pytest.raises(ValueError):
152 TrainingInterval.from_any((100,))
154 with pytest.raises(ValueError):
155 TrainingInterval.from_any(123)
157 with pytest.raises(ValueError):
158 TrainingInterval.from_any(("5", "epochs", "lol"))
160 with pytest.raises(ValueError):
161 TrainingInterval.from_any("5", "epochs", "lol")
164def test_boundary_cases():
165 assert (
166 TrainingInterval(1, "runs").as_batch_count(
167 batchsize=1, batches_per_epoch=100, epochs=10
168 )
169 == 1000
170 )
171 assert (
172 TrainingInterval(0.9, "runs").as_batch_count(
173 batchsize=1, batches_per_epoch=100, epochs=10
174 )
175 == 900
176 )
177 assert (
178 TrainingInterval(1, "epochs").as_batch_count(batchsize=32, batches_per_epoch=1)
179 == 1
180 )
181 assert (
182 TrainingInterval(1, "runs").as_batch_count(
183 batchsize=32, batches_per_epoch=100, epochs=1
184 )
185 == 100
186 )
189def test_unpacking():
190 quantity, unit = TrainingInterval(5, "epochs")
191 assert quantity == 5 and unit == "epochs"
192 quantity, unit = TrainingInterval(100, "batches")
193 assert quantity == 100 and unit == "batches"
194 quantity, unit = TrainingInterval(0.1, "runs")
195 assert quantity == 0.1 and unit == "runs"
196 quantity, unit = TrainingInterval(1 / 12, "runs")
197 assert quantity == 1 / 12 and unit == "runs"
200@pytest.mark.parametrize(
201 "quantity, unit",
202 [
203 (0.1, "runs"),
204 (0.1, "epochs"),
205 (0.0001, "runs"),
206 (1e-10, "epochs"),
207 ],
208)
209def test_very_small_values(
210 quantity: Union[int, float], unit: TrainingIntervalUnit
211) -> None:
212 interval = TrainingInterval(quantity, unit)
213 assert interval.quantity == quantity
214 assert interval.unit == unit
217def test_zero_samples() -> None:
218 with pytest.warns(IntervalValueError):
219 TrainingInterval(0, "samples")
222@pytest.mark.parametrize("quantity", [0.51, 0.9, 1.1, 1.49])
223def test_samples_rounding(quantity: float) -> None:
224 if quantity < 1:
225 with pytest.warns(IntervalValueError):
226 interval = TrainingInterval(quantity, "samples")
227 else:
228 interval = TrainingInterval(quantity, "samples")
229 assert interval.quantity == 1
230 assert interval.unit == "samples"
233@pytest.mark.parametrize(
234 "quantity, unit, batchsize, batches_per_epoch, epochs, expected",
235 [
236 (1, "samples", 32, 100, 10, 1),
237 (0.000001, "runs", 32, 100, 10, 1),
238 (0.0001, "epochs", 32, 100, 10, 1),
239 (1e-10, "runs", 32, 100, 10, 1),
240 (1e-10, "epochs", 32, 100, 10, 1),
241 ],
242)
243def test_as_batch_count_edge_cases(
244 quantity: Union[int, float],
245 unit: TrainingIntervalUnit,
246 batchsize: int,
247 batches_per_epoch: int,
248 epochs: int,
249 expected: int,
250) -> None:
251 interval = TrainingInterval(quantity, unit)
252 with pytest.warns(IntervalValueError):
253 result = interval.as_batch_count(batchsize, batches_per_epoch, epochs)
254 assert result == expected, f"Expected {expected}, but got {result} for {interval}"
257def test_as_batch_count_without_epochs() -> None:
258 interval = TrainingInterval(0.1, "runs")
259 with pytest.raises(AssertionError):
260 interval.as_batch_count(32, 100)
263@pytest.mark.parametrize(
264 "input_data, expected",
265 [
266 ("0.1 runs", (0.1, "runs")),
267 ("0.1 epochs", (0.1, "epochs")),
268 ("1 batches", (1, "batches")),
269 ("0.1 runs", (0.1, "runs")),
270 ("1/1000 epochs", (0.001, "epochs")),
271 ],
272)
273def test_from_str_edge_cases(
274 input_data: str, expected: tuple[float | int, TrainingIntervalUnit]
275) -> None:
276 result = TrainingInterval.from_str(input_data)
277 assert result == TrainingInterval(*expected), (
278 f"Expected {expected}, but got {result} for input '{input_data}'"
279 )
282@pytest.mark.parametrize(
283 "input_data",
284 [
285 "invalid unit",
286 "1.5.5 epochs",
287 "123",
288 "1/2/3 batches",
289 "0.0.0 batches",
290 "ten samples",
291 "1/2/3 samples",
292 ],
293)
294def test_from_str_invalid_inputs(input_data: str) -> None:
295 with pytest.raises(ValueError):
296 TrainingInterval.from_str(input_data)
299@pytest.mark.parametrize(
300 "input_data, expected",
301 [
302 ((0.1, "runs"), (0.1, "runs")),
303 (["0.1", "epochs"], (0.1, "epochs")),
304 ("0.1 runs", (0.1, "runs")),
305 (("1/1000", "epochs"), (0.001, "epochs")),
306 ],
307)
308def test_from_any_edge_cases_nowarn(
309 input_data: Any, expected: tuple[float | int, TrainingIntervalUnit]
310) -> None:
311 "no warnings because batchsize is unknown"
312 result = TrainingInterval.from_any(input_data)
313 assert result == TrainingInterval(*expected), (
314 f"Expected {expected}, but got {result} for input {input_data}"
315 )
318@pytest.mark.parametrize(
319 "input_data, expected",
320 [
321 ((1e-10, "batches"), (1, "batches")),
322 ((1e-10, "batches"), (1, "batches")),
323 ((0, "batches"), (1, "batches")),
324 ((0, "batches"), (1, "batches")),
325 (("1/2 batches"), (1, "batches")),
326 ("0.0 batches", (1, "batches")),
327 ((0, "samples"), (1, "samples")),
328 ],
329)
330def test_from_any_edge_cases_warn(
331 input_data: Any, expected: tuple[float | int, TrainingIntervalUnit]
332) -> None:
333 "no warnings because batchsize is unknown"
334 with pytest.warns(IntervalValueError):
335 result = TrainingInterval.from_any(input_data)
336 assert result == TrainingInterval(*expected), (
337 f"Expected {expected}, but got {result} for input {input_data}"
338 )
341@pytest.mark.parametrize(
342 "input_data",
343 [
344 (0, "potatoes"),
345 "invalid unit",
346 (1.5, 5, "epochs"),
347 123,
348 ("1", "batches", "lol"),
349 ],
350)
351def test_from_any_invalid_inputs(input_data: Any) -> None:
352 with pytest.raises(ValueError):
353 TrainingInterval.from_any(input_data)
356@pytest.mark.parametrize(
357 "interval, batchsize, batches_per_epoch, epochs, expected",
358 [
359 ("0 runs", 32, 100, 10, 1),
360 ("1e-10 epochs", 32, 100, 10, 1),
361 ("0.1 batches", 32, 100, 10, 1),
362 ("1 samples", 32, 100, 10, 1),
363 ],
364)
365def test_process_to_batches_edge_cases(
366 interval: Union[str, tuple, TrainingInterval],
367 batchsize: int,
368 batches_per_epoch: int,
369 epochs: int,
370 expected: int,
371) -> None:
372 with pytest.warns(IntervalValueError):
373 result = TrainingInterval.process_to_batches(
374 interval, batchsize, batches_per_epoch, epochs
375 )
376 assert result == expected, f"Expected {expected}, but got {result} for {interval}"
379def test_normalization_edge_cases() -> None:
380 interval = TrainingInterval(0.1, "runs")
381 normalized = interval.normalized(batchsize=32, batches_per_epoch=100, epochs=10)
382 assert normalized.quantity == 100
383 assert normalized.unit == "batches"
385 interval = TrainingInterval(1e-10, "epochs")
386 with pytest.warns(IntervalValueError):
387 normalized = interval.normalized(batchsize=32, batches_per_epoch=100)
388 assert normalized.quantity == 1
389 assert normalized.unit == "batches"
392def test_equality_edge_cases() -> None:
393 assert TrainingInterval(0.1, "runs") == TrainingInterval(0.1, "runs")
394 assert TrainingInterval(0.1, "runs") != TrainingInterval(0.1, "epochs")
396 with pytest.warns(IntervalValueError):
397 assert TrainingInterval(1e-10, "batches") == TrainingInterval(1, "batches")
400def test_iteration_and_indexing() -> None:
401 interval = TrainingInterval(0.1, "runs")
402 quantity, unit = interval
403 assert quantity == 0.1
404 assert unit == "runs"
406 assert interval[0] == 0.1
407 assert interval[1] == "runs"
409 with pytest.raises(IndexError):
410 _ = interval[2]