Coverage for trnbl\training_interval.py: 94%
118 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-17 02:23 -0700
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-17 02:23 -0700
1from typing import Any, Generator, Literal, Callable, Union, Sequence
2from dataclasses import dataclass
4from muutils.misc import str_to_numeric
5from muutils.errormode import ErrorMode
6from muutils.interval import Interval
8_EPSILON: float = 1e-6
10# units of training intervals -- we convert this all to batches
11TrainingIntervalUnit = Literal["runs", "epochs", "batches", "samples"]
13_TRAINING_INTERVAL_UNITS_RANGES: dict[TrainingIntervalUnit, Interval] = {
14 # epochs and runs should not actually be closed, but we allow it
15 "runs": Interval(0, 1, is_closed=True),
16 "epochs": Interval(0, float("inf"), is_closed=True),
17 "batches": Interval(1, float("inf"), is_closed=True),
18 "samples": Interval(1, float("inf"), is_closed=True),
19}
21_TRAINING_INTERVAL_UNITS_CAST: dict[TrainingIntervalUnit, Callable] = {
22 "runs": lambda x: x,
23 "epochs": lambda x: x,
24 "batches": lambda x: int(round(x)),
25 "samples": lambda x: int(round(x)),
26}
28_TRAINING_INTERVAL_UNIT_ALIASES: dict[str, str] = {
29 "run": "runs",
30 "epoch": "epochs",
31 "batch": "batches",
32 "sample": "samples",
33}
35# what to do if interval is < 1 batch
36# if WARN or IGNORE, set it to 1 batch
37WhenIntervalLessThanBatch: ErrorMode = ErrorMode.WARN
40class IntervalValueError(UserWarning):
41 """Error for when the interval is less than 1 batch"""
43 pass
46@dataclass(frozen=True)
47class TrainingInterval:
48 """A training interval, which can be specified in a few different units.
50 # Attributes:
51 - `quantity: int|float` - the quantity of the interval
52 - `unit: TrainingIntervalUnit` - the unit of the interval, one of "runs", "epochs", "batches", or "samples"
54 # Methods:
55 - `TrainingInterval.from_str(raw: str) -> TrainingInterval` - parse a string into a TrainingInterval object
56 - `TrainingInterval.as_batch_count(batchsize: int, batches_per_epoch: int, epochs: int|None) -> int` - convert the interval to a raw number of batches
57 - `TrainingInterval.process_to_batches(interval: str|TrainingInterval, batchsize: int, batches_per_epoch: int, epochs: int|None) -> int` - any representation to a number of batches
58 - `TrainingInterval.normalized(batchsize: int, batches_per_epoch: int, epochs: int|None) -> None` - current interval, with units switched to batches
60 Provides methods for reading from a string or tuple, and normalizing to batches.
61 """
63 quantity: int | float
64 unit: TrainingIntervalUnit
66 def __iter__(self) -> Generator[int | float | TrainingIntervalUnit, None, None]:
67 yield self.quantity
68 yield self.unit
70 def __getitem__(self, index: int) -> int | float | TrainingIntervalUnit:
71 if index == 0:
72 return self.quantity
73 elif index == 1:
74 return self.unit
75 else:
76 raise IndexError(f"invalid index {index} for TrainingInterval")
78 def __post_init__(self) -> None:
79 try:
80 assert isinstance(self.quantity, (int, float)), (
81 "quantity should be an integer or float"
82 )
83 # TODO: Literal[...].__args__ is not defined??
84 if self.unit not in TrainingIntervalUnit.__args__: # type: ignore[attr-defined]
85 unit_dealised: str | None = _TRAINING_INTERVAL_UNIT_ALIASES.get(
86 self.unit.lower(), None
87 )
88 if isinstance(unit_dealised, str):
89 self.__dict__["unit"] = unit_dealised
90 else:
91 raise ValueError(f"invalid unit {self.unit = }")
93 assert self.unit in TrainingIntervalUnit.__args__, ( # type: ignore[attr-defined]
94 f"invalid unit {self.unit}"
95 )
96 except AssertionError as e:
97 raise AssertionError(
98 f"Error initializing TrainingInterval\n{self}\n{e}"
99 ) from e
101 # check values in proper ranges
102 expected_interval: Interval = _TRAINING_INTERVAL_UNITS_RANGES[self.unit]
103 if self.quantity not in expected_interval:
104 WhenIntervalLessThanBatch.process(
105 f"interval {self} has invalid quantity, expected in interval {expected_interval}, will set to closest bound if not erroring out",
106 except_cls=IntervalValueError,
107 warn_cls=IntervalValueError,
108 )
109 self.__dict__["quantity"] = expected_interval.clamp(self.quantity)
111 # cast if necessary
112 self.__dict__["quantity"] = _TRAINING_INTERVAL_UNITS_CAST[self.unit](
113 self.quantity
114 )
116 def __eq__(self, other: Any) -> bool:
117 if not isinstance(other, self.__class__):
118 raise TypeError(
119 f"invalid type {type(other)} for comparison with TrainingInterval"
120 )
121 return (
122 abs(self.quantity - other.quantity) < _EPSILON and self.unit == other.unit
123 )
125 def as_batch_count(
126 self,
127 batchsize: int,
128 batches_per_epoch: int,
129 epochs: int | None = None,
130 ) -> int:
131 """given the batchsize, number of batches per epoch, and number of epochs, return the interval as a number of batches
133 # Parameters:
134 - `batchsize: int`
135 the size of a batch
136 - `batches_per_epoch: int`
137 the number of batches in an epoch
138 - `epochs: int|None`
139 the number of epochs to run (only required if the interval is in "runs")
141 # Returns:
142 - `int`
143 the interval as a number of batches
145 # Raises:
146 - `ValueError`
147 if the interval is less than 1 batch, and the `trnbl.training_interval.WhenIntervalLessThanBatch` is set to `muutils.errormode.ErrorMode.ERROR`
148 otherwise, will warn or ignore and set the interval to 1 batch
149 - `ValueError`
150 if the unit is not one of "runs", "epochs", "batches", or "samples"
153 """
155 output: int | float
157 match self.unit:
158 case "runs":
159 assert epochs is not None, (
160 "epochs must be provided to convert runs to batches"
161 )
162 output = self.quantity * epochs * batches_per_epoch
163 case "epochs":
164 output = self.quantity * batches_per_epoch
165 case "batches":
166 output = self.quantity
167 case "samples":
168 output = self.quantity / batchsize
169 case _:
170 raise ValueError(f"invalid unit {self.unit}")
172 # check if interval is less than 1 batch
173 if output < 1:
174 WhenIntervalLessThanBatch.process(
175 f"interval {self} is less than 1 batch, will set to 1 batch if not erroring out",
176 except_cls=IntervalValueError,
177 warn_cls=IntervalValueError,
178 )
179 output = 1
181 return int(round(output))
183 def normalized(
184 self,
185 batchsize: int,
186 batches_per_epoch: int,
187 epochs: int | None = None,
188 ) -> "TrainingInterval":
189 """convert the units of the interval to batches, by calling `as_batch_count` and setting the `unit` to "batches"""
190 quantity: int | float = self.as_batch_count(
191 batches_per_epoch=batches_per_epoch,
192 batchsize=batchsize,
193 epochs=epochs,
194 )
195 unit: TrainingIntervalUnit = "batches"
196 return self.__class__(quantity, unit)
198 @classmethod
199 def from_str(cls, raw: str) -> "TrainingInterval":
200 """parse a string into a TrainingInterval object
202 # Examples:
204 >>> TrainingInterval.from_str("5 epochs")
205 TrainingInterval(5, 'epochs')
206 >>> TrainingInterval.from_str("100 batches")
207 TrainingInterval(100, 'batches')
208 >>> TrainingInterval.from_str("0.1 runs")
209 TrainingInterval(0.1, 'runs')
210 >>> TrainingInterval.from_str("1/5 runs")
211 TrainingInterval(0.2, 'runs')
213 """
214 try:
215 # remove prefix and suffix (optionally)
216 raw = raw.removeprefix("TrainingInterval(").removesuffix(")")
218 # process quantity
219 raw_split: list[str]
220 quantity_str: str
221 if "," in raw:
222 raw_split = raw.split(",")
223 quantity_str = ",".join(raw_split[:-1])
224 else:
225 raw_split = raw.split()
226 quantity_str = " ".join(raw_split[:-1])
228 quantity: int | float = str_to_numeric(quantity_str)
230 # process unit
231 unit: str = raw_split[-1]
232 unit.strip().strip("'\"").strip()
234 # unit should be one of the allowed units
235 unit_dealised: str | None
236 if unit.lower() in TrainingIntervalUnit.__args__: # type: ignore[attr-defined]
237 unit_dealised = unit.lower()
238 else:
239 unit_dealised = _TRAINING_INTERVAL_UNIT_ALIASES.get(unit.lower(), None)
240 if isinstance(unit_dealised, str):
241 unit = unit_dealised
242 else:
243 raise ValueError(f"invalid unit {unit}")
245 assert unit in TrainingIntervalUnit.__args__ # type: ignore[attr-defined]
246 except Exception as e:
247 raise ValueError(f"Error parsing {raw} as a TrainingInterval\n{e}") from e
249 return cls(quantity, unit) # type: ignore[arg-type]
251 @classmethod
252 def from_any(cls, *args, **kwargs) -> "TrainingInterval":
253 """parse a string or tuple into a TrainingInterval object"""
255 try:
256 # no kwargs allowed
257 assert len(kwargs) == 0, "no kwargs allowed for from_any"
259 # split up args
260 data: Any
261 match len(args):
262 case 1:
263 data = args[0]
264 case 2:
265 data = args
266 case _:
267 raise ValueError(
268 f"invalid number of args {len(args)} for from_any: {args = }"
269 )
271 if isinstance(data, cls):
272 return data
273 elif isinstance(data, str):
274 return cls.from_str(data)
275 elif isinstance(data, Sequence):
276 assert len(data) == 2, (
277 f"invalid length {len(data)} for TrainingInterval: {data}"
278 )
279 quantity, unit = data
280 if isinstance(quantity, str):
281 quantity = str_to_numeric(quantity)
282 return cls(quantity, unit)
283 else:
284 raise ValueError(f"invalid type {type(data)} for TrainingInterval")
286 except AssertionError as e:
287 raise ValueError(f"Error parsing {data} as a TrainingInterval\n{e}") from e
289 @classmethod
290 def process_to_batches(
291 cls,
292 interval: "CastableToTrainingInterval",
293 batchsize: int,
294 batches_per_epoch: int,
295 epochs: int | None = None,
296 ) -> int:
297 """directly from any representation to a number of batches"""
299 interval_ti: TrainingInterval = cls.from_any(interval)
301 return interval_ti.as_batch_count(
302 batches_per_epoch=batches_per_epoch,
303 batchsize=batchsize,
304 epochs=epochs,
305 )
308CastableToTrainingInterval = Union[
309 str, # parse as string "quantity unit"
310 tuple[Union[int, float, str], str], # parse as tuple (quantity, unit)
311 TrainingInterval, # already a TrainingInterval
312]