docs for trnbl v0.1.1
View Source on GitHub

trnbl.training_interval


  1from typing import Any, Generator, Literal, Callable, Union, Sequence
  2from dataclasses import dataclass
  3
  4from muutils.misc import str_to_numeric
  5from muutils.errormode import ErrorMode
  6from muutils.interval import Interval
  7
  8_EPSILON: float = 1e-6
  9
 10# units of training intervals -- we convert this all to batches
 11TrainingIntervalUnit = Literal["runs", "epochs", "batches", "samples"]
 12
 13_TRAINING_INTERVAL_UNITS_RANGES: dict[TrainingIntervalUnit, Interval] = {
 14	# epochs and runs should not actually be closed, but we allow it
 15	"runs": Interval(0, 1, is_closed=True),
 16	"epochs": Interval(0, float("inf"), is_closed=True),
 17	"batches": Interval(1, float("inf"), is_closed=True),
 18	"samples": Interval(1, float("inf"), is_closed=True),
 19}
 20
 21_TRAINING_INTERVAL_UNITS_CAST: dict[TrainingIntervalUnit, Callable] = {
 22	"runs": lambda x: x,
 23	"epochs": lambda x: x,
 24	"batches": lambda x: int(round(x)),
 25	"samples": lambda x: int(round(x)),
 26}
 27
 28_TRAINING_INTERVAL_UNIT_ALIASES: dict[str, str] = {
 29	"run": "runs",
 30	"epoch": "epochs",
 31	"batch": "batches",
 32	"sample": "samples",
 33}
 34
 35# what to do if interval is < 1 batch
 36# if WARN or IGNORE, set it to 1 batch
 37WhenIntervalLessThanBatch: ErrorMode = ErrorMode.WARN
 38
 39
 40class IntervalValueError(UserWarning):
 41	"""Error for when the interval is less than 1 batch"""
 42
 43	pass
 44
 45
 46@dataclass(frozen=True)
 47class TrainingInterval:
 48	"""A training interval, which can be specified in a few different units.
 49
 50	# Attributes:
 51	- `quantity: int|float` - the quantity of the interval
 52	- `unit: TrainingIntervalUnit` - the unit of the interval, one of "runs", "epochs", "batches", or "samples"
 53
 54	# Methods:
 55	- `TrainingInterval.from_str(raw: str) -> TrainingInterval` - parse a string into a TrainingInterval object
 56	- `TrainingInterval.as_batch_count(batchsize: int, batches_per_epoch: int, epochs: int|None) -> int` - convert the interval to a raw number of batches
 57	- `TrainingInterval.process_to_batches(interval: str|TrainingInterval, batchsize: int, batches_per_epoch: int, epochs: int|None) -> int` - any representation to a number of batches
 58	- `TrainingInterval.normalized(batchsize: int, batches_per_epoch: int, epochs: int|None) -> None` - current interval, with units switched to batches
 59
 60	Provides methods for reading from a string or tuple, and normalizing to batches.
 61	"""
 62
 63	quantity: int | float
 64	unit: TrainingIntervalUnit
 65
 66	def __iter__(self) -> Generator[int | float | TrainingIntervalUnit, None, None]:
 67		yield self.quantity
 68		yield self.unit
 69
 70	def __getitem__(self, index: int) -> int | float | TrainingIntervalUnit:
 71		if index == 0:
 72			return self.quantity
 73		elif index == 1:
 74			return self.unit
 75		else:
 76			raise IndexError(f"invalid index {index} for TrainingInterval")
 77
 78	def __post_init__(self) -> None:
 79		try:
 80			assert isinstance(self.quantity, (int, float)), (
 81				"quantity should be an integer or float"
 82			)
 83			# TODO: Literal[...].__args__ is not defined??
 84			if self.unit not in TrainingIntervalUnit.__args__:  # type: ignore[attr-defined]
 85				unit_dealised: str | None = _TRAINING_INTERVAL_UNIT_ALIASES.get(
 86					self.unit.lower(), None
 87				)
 88				if isinstance(unit_dealised, str):
 89					self.__dict__["unit"] = unit_dealised
 90				else:
 91					raise ValueError(f"invalid unit {self.unit = }")
 92
 93			assert self.unit in TrainingIntervalUnit.__args__, (  # type: ignore[attr-defined]
 94				f"invalid unit {self.unit}"
 95			)
 96		except AssertionError as e:
 97			raise AssertionError(
 98				f"Error initializing TrainingInterval\n{self}\n{e}"
 99			) from e
100
101		# check values in proper ranges
102		expected_interval: Interval = _TRAINING_INTERVAL_UNITS_RANGES[self.unit]
103		if self.quantity not in expected_interval:
104			WhenIntervalLessThanBatch.process(
105				f"interval {self} has invalid quantity, expected in interval {expected_interval}, will set to closest bound if not erroring out",
106				except_cls=IntervalValueError,
107				warn_cls=IntervalValueError,
108			)
109			self.__dict__["quantity"] = expected_interval.clamp(self.quantity)
110
111		# cast if necessary
112		self.__dict__["quantity"] = _TRAINING_INTERVAL_UNITS_CAST[self.unit](
113			self.quantity
114		)
115
116	def __eq__(self, other: Any) -> bool:
117		if not isinstance(other, self.__class__):
118			raise TypeError(
119				f"invalid type {type(other)} for comparison with TrainingInterval"
120			)
121		return (
122			abs(self.quantity - other.quantity) < _EPSILON and self.unit == other.unit
123		)
124
125	def as_batch_count(
126		self,
127		batchsize: int,
128		batches_per_epoch: int,
129		epochs: int | None = None,
130	) -> int:
131		"""given the batchsize, number of batches per epoch, and number of epochs, return the interval as a number of batches
132
133		# Parameters:
134		 - `batchsize: int`
135		   the size of a batch
136		 - `batches_per_epoch: int`
137		   the number of batches in an epoch
138		 - `epochs: int|None`
139		   the number of epochs to run (only required if the interval is in "runs")
140
141		# Returns:
142		 - `int`
143		   the interval as a number of batches
144
145		# Raises:
146		 - `ValueError`
147		   if the interval is less than 1 batch, and the `trnbl.training_interval.WhenIntervalLessThanBatch` is set to `muutils.errormode.ErrorMode.ERROR`
148		   otherwise, will warn or ignore and set the interval to 1 batch
149		 - `ValueError`
150		   if the unit is not one of "runs", "epochs", "batches", or "samples"
151
152
153		"""
154
155		output: int | float
156
157		match self.unit:
158			case "runs":
159				assert epochs is not None, (
160					"epochs must be provided to convert runs to batches"
161				)
162				output = self.quantity * epochs * batches_per_epoch
163			case "epochs":
164				output = self.quantity * batches_per_epoch
165			case "batches":
166				output = self.quantity
167			case "samples":
168				output = self.quantity / batchsize
169			case _:
170				raise ValueError(f"invalid unit {self.unit}")
171
172		# check if interval is less than 1 batch
173		if output < 1:
174			WhenIntervalLessThanBatch.process(
175				f"interval {self} is less than 1 batch, will set to 1 batch if not erroring out",
176				except_cls=IntervalValueError,
177				warn_cls=IntervalValueError,
178			)
179			output = 1
180
181		return int(round(output))
182
183	def normalized(
184		self,
185		batchsize: int,
186		batches_per_epoch: int,
187		epochs: int | None = None,
188	) -> "TrainingInterval":
189		"""convert the units of the interval to batches, by calling `as_batch_count` and setting the `unit` to "batches"""
190		quantity: int | float = self.as_batch_count(
191			batches_per_epoch=batches_per_epoch,
192			batchsize=batchsize,
193			epochs=epochs,
194		)
195		unit: TrainingIntervalUnit = "batches"
196		return self.__class__(quantity, unit)
197
198	@classmethod
199	def from_str(cls, raw: str) -> "TrainingInterval":
200		"""parse a string into a TrainingInterval object
201
202		# Examples:
203
204		>>> TrainingInterval.from_str("5 epochs")
205		TrainingInterval(5, 'epochs')
206		>>> TrainingInterval.from_str("100 batches")
207		TrainingInterval(100, 'batches')
208		>>> TrainingInterval.from_str("0.1 runs")
209		TrainingInterval(0.1, 'runs')
210		>>> TrainingInterval.from_str("1/5 runs")
211		TrainingInterval(0.2, 'runs')
212
213		"""
214		try:
215			# remove prefix and suffix (optionally)
216			raw = raw.removeprefix("TrainingInterval(").removesuffix(")")
217
218			# process quantity
219			raw_split: list[str]
220			quantity_str: str
221			if "," in raw:
222				raw_split = raw.split(",")
223				quantity_str = ",".join(raw_split[:-1])
224			else:
225				raw_split = raw.split()
226				quantity_str = " ".join(raw_split[:-1])
227
228			quantity: int | float = str_to_numeric(quantity_str)
229
230			# process unit
231			unit: str = raw_split[-1]
232			unit.strip().strip("'\"").strip()
233
234			# unit should be one of the allowed units
235			unit_dealised: str | None
236			if unit.lower() in TrainingIntervalUnit.__args__:  # type: ignore[attr-defined]
237				unit_dealised = unit.lower()
238			else:
239				unit_dealised = _TRAINING_INTERVAL_UNIT_ALIASES.get(unit.lower(), None)
240			if isinstance(unit_dealised, str):
241				unit = unit_dealised
242			else:
243				raise ValueError(f"invalid unit {unit}")
244
245			assert unit in TrainingIntervalUnit.__args__  # type: ignore[attr-defined]
246		except Exception as e:
247			raise ValueError(f"Error parsing {raw} as a TrainingInterval\n{e}") from e
248
249		return cls(quantity, unit)  # type: ignore[arg-type]
250
251	@classmethod
252	def from_any(cls, *args, **kwargs) -> "TrainingInterval":
253		"""parse a string or tuple into a TrainingInterval object"""
254
255		try:
256			# no kwargs allowed
257			assert len(kwargs) == 0, "no kwargs allowed for from_any"
258
259			# split up args
260			data: Any
261			match len(args):
262				case 1:
263					data = args[0]
264				case 2:
265					data = args
266				case _:
267					raise ValueError(
268						f"invalid number of args {len(args)} for from_any: {args = }"
269					)
270
271			if isinstance(data, cls):
272				return data
273			elif isinstance(data, str):
274				return cls.from_str(data)
275			elif isinstance(data, Sequence):
276				assert len(data) == 2, (
277					f"invalid length {len(data)} for TrainingInterval: {data}"
278				)
279				quantity, unit = data
280				if isinstance(quantity, str):
281					quantity = str_to_numeric(quantity)
282				return cls(quantity, unit)
283			else:
284				raise ValueError(f"invalid type {type(data)} for TrainingInterval")
285
286		except AssertionError as e:
287			raise ValueError(f"Error parsing {data} as a TrainingInterval\n{e}") from e
288
289	@classmethod
290	def process_to_batches(
291		cls,
292		interval: "CastableToTrainingInterval",
293		batchsize: int,
294		batches_per_epoch: int,
295		epochs: int | None = None,
296	) -> int:
297		"""directly from any representation to a number of batches"""
298
299		interval_ti: TrainingInterval = cls.from_any(interval)
300
301		return interval_ti.as_batch_count(
302			batches_per_epoch=batches_per_epoch,
303			batchsize=batchsize,
304			epochs=epochs,
305		)
306
307
308CastableToTrainingInterval = Union[
309	str,  # parse as string "quantity unit"
310	tuple[Union[int, float, str], str],  # parse as tuple (quantity, unit)
311	TrainingInterval,  # already a TrainingInterval
312]

TrainingIntervalUnit = typing.Literal['runs', 'epochs', 'batches', 'samples']
WhenIntervalLessThanBatch: muutils.errormode.ErrorMode = ErrorMode.Warn
class IntervalValueError(builtins.UserWarning):
41class IntervalValueError(UserWarning):
42	"""Error for when the interval is less than 1 batch"""
43
44	pass

Error for when the interval is less than 1 batch

Inherited Members
builtins.UserWarning
UserWarning
builtins.BaseException
with_traceback
add_note
args
@dataclass(frozen=True)
class TrainingInterval:
 47@dataclass(frozen=True)
 48class TrainingInterval:
 49	"""A training interval, which can be specified in a few different units.
 50
 51	# Attributes:
 52	- `quantity: int|float` - the quantity of the interval
 53	- `unit: TrainingIntervalUnit` - the unit of the interval, one of "runs", "epochs", "batches", or "samples"
 54
 55	# Methods:
 56	- `TrainingInterval.from_str(raw: str) -> TrainingInterval` - parse a string into a TrainingInterval object
 57	- `TrainingInterval.as_batch_count(batchsize: int, batches_per_epoch: int, epochs: int|None) -> int` - convert the interval to a raw number of batches
 58	- `TrainingInterval.process_to_batches(interval: str|TrainingInterval, batchsize: int, batches_per_epoch: int, epochs: int|None) -> int` - any representation to a number of batches
 59	- `TrainingInterval.normalized(batchsize: int, batches_per_epoch: int, epochs: int|None) -> None` - current interval, with units switched to batches
 60
 61	Provides methods for reading from a string or tuple, and normalizing to batches.
 62	"""
 63
 64	quantity: int | float
 65	unit: TrainingIntervalUnit
 66
 67	def __iter__(self) -> Generator[int | float | TrainingIntervalUnit, None, None]:
 68		yield self.quantity
 69		yield self.unit
 70
 71	def __getitem__(self, index: int) -> int | float | TrainingIntervalUnit:
 72		if index == 0:
 73			return self.quantity
 74		elif index == 1:
 75			return self.unit
 76		else:
 77			raise IndexError(f"invalid index {index} for TrainingInterval")
 78
 79	def __post_init__(self) -> None:
 80		try:
 81			assert isinstance(self.quantity, (int, float)), (
 82				"quantity should be an integer or float"
 83			)
 84			# TODO: Literal[...].__args__ is not defined??
 85			if self.unit not in TrainingIntervalUnit.__args__:  # type: ignore[attr-defined]
 86				unit_dealised: str | None = _TRAINING_INTERVAL_UNIT_ALIASES.get(
 87					self.unit.lower(), None
 88				)
 89				if isinstance(unit_dealised, str):
 90					self.__dict__["unit"] = unit_dealised
 91				else:
 92					raise ValueError(f"invalid unit {self.unit = }")
 93
 94			assert self.unit in TrainingIntervalUnit.__args__, (  # type: ignore[attr-defined]
 95				f"invalid unit {self.unit}"
 96			)
 97		except AssertionError as e:
 98			raise AssertionError(
 99				f"Error initializing TrainingInterval\n{self}\n{e}"
100			) from e
101
102		# check values in proper ranges
103		expected_interval: Interval = _TRAINING_INTERVAL_UNITS_RANGES[self.unit]
104		if self.quantity not in expected_interval:
105			WhenIntervalLessThanBatch.process(
106				f"interval {self} has invalid quantity, expected in interval {expected_interval}, will set to closest bound if not erroring out",
107				except_cls=IntervalValueError,
108				warn_cls=IntervalValueError,
109			)
110			self.__dict__["quantity"] = expected_interval.clamp(self.quantity)
111
112		# cast if necessary
113		self.__dict__["quantity"] = _TRAINING_INTERVAL_UNITS_CAST[self.unit](
114			self.quantity
115		)
116
117	def __eq__(self, other: Any) -> bool:
118		if not isinstance(other, self.__class__):
119			raise TypeError(
120				f"invalid type {type(other)} for comparison with TrainingInterval"
121			)
122		return (
123			abs(self.quantity - other.quantity) < _EPSILON and self.unit == other.unit
124		)
125
126	def as_batch_count(
127		self,
128		batchsize: int,
129		batches_per_epoch: int,
130		epochs: int | None = None,
131	) -> int:
132		"""given the batchsize, number of batches per epoch, and number of epochs, return the interval as a number of batches
133
134		# Parameters:
135		 - `batchsize: int`
136		   the size of a batch
137		 - `batches_per_epoch: int`
138		   the number of batches in an epoch
139		 - `epochs: int|None`
140		   the number of epochs to run (only required if the interval is in "runs")
141
142		# Returns:
143		 - `int`
144		   the interval as a number of batches
145
146		# Raises:
147		 - `ValueError`
148		   if the interval is less than 1 batch, and the `trnbl.training_interval.WhenIntervalLessThanBatch` is set to `muutils.errormode.ErrorMode.ERROR`
149		   otherwise, will warn or ignore and set the interval to 1 batch
150		 - `ValueError`
151		   if the unit is not one of "runs", "epochs", "batches", or "samples"
152
153
154		"""
155
156		output: int | float
157
158		match self.unit:
159			case "runs":
160				assert epochs is not None, (
161					"epochs must be provided to convert runs to batches"
162				)
163				output = self.quantity * epochs * batches_per_epoch
164			case "epochs":
165				output = self.quantity * batches_per_epoch
166			case "batches":
167				output = self.quantity
168			case "samples":
169				output = self.quantity / batchsize
170			case _:
171				raise ValueError(f"invalid unit {self.unit}")
172
173		# check if interval is less than 1 batch
174		if output < 1:
175			WhenIntervalLessThanBatch.process(
176				f"interval {self} is less than 1 batch, will set to 1 batch if not erroring out",
177				except_cls=IntervalValueError,
178				warn_cls=IntervalValueError,
179			)
180			output = 1
181
182		return int(round(output))
183
184	def normalized(
185		self,
186		batchsize: int,
187		batches_per_epoch: int,
188		epochs: int | None = None,
189	) -> "TrainingInterval":
190		"""convert the units of the interval to batches, by calling `as_batch_count` and setting the `unit` to "batches"""
191		quantity: int | float = self.as_batch_count(
192			batches_per_epoch=batches_per_epoch,
193			batchsize=batchsize,
194			epochs=epochs,
195		)
196		unit: TrainingIntervalUnit = "batches"
197		return self.__class__(quantity, unit)
198
199	@classmethod
200	def from_str(cls, raw: str) -> "TrainingInterval":
201		"""parse a string into a TrainingInterval object
202
203		# Examples:
204
205		>>> TrainingInterval.from_str("5 epochs")
206		TrainingInterval(5, 'epochs')
207		>>> TrainingInterval.from_str("100 batches")
208		TrainingInterval(100, 'batches')
209		>>> TrainingInterval.from_str("0.1 runs")
210		TrainingInterval(0.1, 'runs')
211		>>> TrainingInterval.from_str("1/5 runs")
212		TrainingInterval(0.2, 'runs')
213
214		"""
215		try:
216			# remove prefix and suffix (optionally)
217			raw = raw.removeprefix("TrainingInterval(").removesuffix(")")
218
219			# process quantity
220			raw_split: list[str]
221			quantity_str: str
222			if "," in raw:
223				raw_split = raw.split(",")
224				quantity_str = ",".join(raw_split[:-1])
225			else:
226				raw_split = raw.split()
227				quantity_str = " ".join(raw_split[:-1])
228
229			quantity: int | float = str_to_numeric(quantity_str)
230
231			# process unit
232			unit: str = raw_split[-1]
233			unit.strip().strip("'\"").strip()
234
235			# unit should be one of the allowed units
236			unit_dealised: str | None
237			if unit.lower() in TrainingIntervalUnit.__args__:  # type: ignore[attr-defined]
238				unit_dealised = unit.lower()
239			else:
240				unit_dealised = _TRAINING_INTERVAL_UNIT_ALIASES.get(unit.lower(), None)
241			if isinstance(unit_dealised, str):
242				unit = unit_dealised
243			else:
244				raise ValueError(f"invalid unit {unit}")
245
246			assert unit in TrainingIntervalUnit.__args__  # type: ignore[attr-defined]
247		except Exception as e:
248			raise ValueError(f"Error parsing {raw} as a TrainingInterval\n{e}") from e
249
250		return cls(quantity, unit)  # type: ignore[arg-type]
251
252	@classmethod
253	def from_any(cls, *args, **kwargs) -> "TrainingInterval":
254		"""parse a string or tuple into a TrainingInterval object"""
255
256		try:
257			# no kwargs allowed
258			assert len(kwargs) == 0, "no kwargs allowed for from_any"
259
260			# split up args
261			data: Any
262			match len(args):
263				case 1:
264					data = args[0]
265				case 2:
266					data = args
267				case _:
268					raise ValueError(
269						f"invalid number of args {len(args)} for from_any: {args = }"
270					)
271
272			if isinstance(data, cls):
273				return data
274			elif isinstance(data, str):
275				return cls.from_str(data)
276			elif isinstance(data, Sequence):
277				assert len(data) == 2, (
278					f"invalid length {len(data)} for TrainingInterval: {data}"
279				)
280				quantity, unit = data
281				if isinstance(quantity, str):
282					quantity = str_to_numeric(quantity)
283				return cls(quantity, unit)
284			else:
285				raise ValueError(f"invalid type {type(data)} for TrainingInterval")
286
287		except AssertionError as e:
288			raise ValueError(f"Error parsing {data} as a TrainingInterval\n{e}") from e
289
290	@classmethod
291	def process_to_batches(
292		cls,
293		interval: "CastableToTrainingInterval",
294		batchsize: int,
295		batches_per_epoch: int,
296		epochs: int | None = None,
297	) -> int:
298		"""directly from any representation to a number of batches"""
299
300		interval_ti: TrainingInterval = cls.from_any(interval)
301
302		return interval_ti.as_batch_count(
303			batches_per_epoch=batches_per_epoch,
304			batchsize=batchsize,
305			epochs=epochs,
306		)

A training interval, which can be specified in a few different units.

Attributes:

  • quantity: int|float - the quantity of the interval
  • unit: TrainingIntervalUnit - the unit of the interval, one of "runs", "epochs", "batches", or "samples"

Methods:

Provides methods for reading from a string or tuple, and normalizing to batches.

TrainingInterval( quantity: int | float, unit: Literal['runs', 'epochs', 'batches', 'samples'])
quantity: int | float
unit: Literal['runs', 'epochs', 'batches', 'samples']
def as_batch_count( self, batchsize: int, batches_per_epoch: int, epochs: int | None = None) -> int:
126	def as_batch_count(
127		self,
128		batchsize: int,
129		batches_per_epoch: int,
130		epochs: int | None = None,
131	) -> int:
132		"""given the batchsize, number of batches per epoch, and number of epochs, return the interval as a number of batches
133
134		# Parameters:
135		 - `batchsize: int`
136		   the size of a batch
137		 - `batches_per_epoch: int`
138		   the number of batches in an epoch
139		 - `epochs: int|None`
140		   the number of epochs to run (only required if the interval is in "runs")
141
142		# Returns:
143		 - `int`
144		   the interval as a number of batches
145
146		# Raises:
147		 - `ValueError`
148		   if the interval is less than 1 batch, and the `trnbl.training_interval.WhenIntervalLessThanBatch` is set to `muutils.errormode.ErrorMode.ERROR`
149		   otherwise, will warn or ignore and set the interval to 1 batch
150		 - `ValueError`
151		   if the unit is not one of "runs", "epochs", "batches", or "samples"
152
153
154		"""
155
156		output: int | float
157
158		match self.unit:
159			case "runs":
160				assert epochs is not None, (
161					"epochs must be provided to convert runs to batches"
162				)
163				output = self.quantity * epochs * batches_per_epoch
164			case "epochs":
165				output = self.quantity * batches_per_epoch
166			case "batches":
167				output = self.quantity
168			case "samples":
169				output = self.quantity / batchsize
170			case _:
171				raise ValueError(f"invalid unit {self.unit}")
172
173		# check if interval is less than 1 batch
174		if output < 1:
175			WhenIntervalLessThanBatch.process(
176				f"interval {self} is less than 1 batch, will set to 1 batch if not erroring out",
177				except_cls=IntervalValueError,
178				warn_cls=IntervalValueError,
179			)
180			output = 1
181
182		return int(round(output))

given the batchsize, number of batches per epoch, and number of epochs, return the interval as a number of batches

Parameters:

  • batchsize: int the size of a batch
  • batches_per_epoch: int the number of batches in an epoch
  • epochs: int|None the number of epochs to run (only required if the interval is in "runs")

Returns:

  • int the interval as a number of batches

Raises:

  • ValueError if the interval is less than 1 batch, and the WhenIntervalLessThanBatch is set to muutils.errormode.ErrorMode.ERROR otherwise, will warn or ignore and set the interval to 1 batch
  • ValueError if the unit is not one of "runs", "epochs", "batches", or "samples"
def normalized( self, batchsize: int, batches_per_epoch: int, epochs: int | None = None) -> TrainingInterval:
184	def normalized(
185		self,
186		batchsize: int,
187		batches_per_epoch: int,
188		epochs: int | None = None,
189	) -> "TrainingInterval":
190		"""convert the units of the interval to batches, by calling `as_batch_count` and setting the `unit` to "batches"""
191		quantity: int | float = self.as_batch_count(
192			batches_per_epoch=batches_per_epoch,
193			batchsize=batchsize,
194			epochs=epochs,
195		)
196		unit: TrainingIntervalUnit = "batches"
197		return self.__class__(quantity, unit)

convert the units of the interval to batches, by calling as_batch_count and setting the unit to "batches

@classmethod
def from_str(cls, raw: str) -> TrainingInterval:
199	@classmethod
200	def from_str(cls, raw: str) -> "TrainingInterval":
201		"""parse a string into a TrainingInterval object
202
203		# Examples:
204
205		>>> TrainingInterval.from_str("5 epochs")
206		TrainingInterval(5, 'epochs')
207		>>> TrainingInterval.from_str("100 batches")
208		TrainingInterval(100, 'batches')
209		>>> TrainingInterval.from_str("0.1 runs")
210		TrainingInterval(0.1, 'runs')
211		>>> TrainingInterval.from_str("1/5 runs")
212		TrainingInterval(0.2, 'runs')
213
214		"""
215		try:
216			# remove prefix and suffix (optionally)
217			raw = raw.removeprefix("TrainingInterval(").removesuffix(")")
218
219			# process quantity
220			raw_split: list[str]
221			quantity_str: str
222			if "," in raw:
223				raw_split = raw.split(",")
224				quantity_str = ",".join(raw_split[:-1])
225			else:
226				raw_split = raw.split()
227				quantity_str = " ".join(raw_split[:-1])
228
229			quantity: int | float = str_to_numeric(quantity_str)
230
231			# process unit
232			unit: str = raw_split[-1]
233			unit.strip().strip("'\"").strip()
234
235			# unit should be one of the allowed units
236			unit_dealised: str | None
237			if unit.lower() in TrainingIntervalUnit.__args__:  # type: ignore[attr-defined]
238				unit_dealised = unit.lower()
239			else:
240				unit_dealised = _TRAINING_INTERVAL_UNIT_ALIASES.get(unit.lower(), None)
241			if isinstance(unit_dealised, str):
242				unit = unit_dealised
243			else:
244				raise ValueError(f"invalid unit {unit}")
245
246			assert unit in TrainingIntervalUnit.__args__  # type: ignore[attr-defined]
247		except Exception as e:
248			raise ValueError(f"Error parsing {raw} as a TrainingInterval\n{e}") from e
249
250		return cls(quantity, unit)  # type: ignore[arg-type]

parse a string into a TrainingInterval object

Examples:

>>> TrainingInterval.from_str("5 epochs")
TrainingInterval(5, 'epochs')
>>> TrainingInterval.from_str("100 batches")
TrainingInterval(100, 'batches')
>>> TrainingInterval.from_str("0.1 runs")
TrainingInterval(0.1, 'runs')
>>> TrainingInterval.from_str("1/5 runs")
TrainingInterval(0.2, 'runs')
@classmethod
def from_any(cls, *args, **kwargs) -> TrainingInterval:
252	@classmethod
253	def from_any(cls, *args, **kwargs) -> "TrainingInterval":
254		"""parse a string or tuple into a TrainingInterval object"""
255
256		try:
257			# no kwargs allowed
258			assert len(kwargs) == 0, "no kwargs allowed for from_any"
259
260			# split up args
261			data: Any
262			match len(args):
263				case 1:
264					data = args[0]
265				case 2:
266					data = args
267				case _:
268					raise ValueError(
269						f"invalid number of args {len(args)} for from_any: {args = }"
270					)
271
272			if isinstance(data, cls):
273				return data
274			elif isinstance(data, str):
275				return cls.from_str(data)
276			elif isinstance(data, Sequence):
277				assert len(data) == 2, (
278					f"invalid length {len(data)} for TrainingInterval: {data}"
279				)
280				quantity, unit = data
281				if isinstance(quantity, str):
282					quantity = str_to_numeric(quantity)
283				return cls(quantity, unit)
284			else:
285				raise ValueError(f"invalid type {type(data)} for TrainingInterval")
286
287		except AssertionError as e:
288			raise ValueError(f"Error parsing {data} as a TrainingInterval\n{e}") from e

parse a string or tuple into a TrainingInterval object

@classmethod
def process_to_batches( cls, interval: Union[str, tuple[Union[int, float, str], str], TrainingInterval], batchsize: int, batches_per_epoch: int, epochs: int | None = None) -> int:
290	@classmethod
291	def process_to_batches(
292		cls,
293		interval: "CastableToTrainingInterval",
294		batchsize: int,
295		batches_per_epoch: int,
296		epochs: int | None = None,
297	) -> int:
298		"""directly from any representation to a number of batches"""
299
300		interval_ti: TrainingInterval = cls.from_any(interval)
301
302		return interval_ti.as_batch_count(
303			batches_per_epoch=batches_per_epoch,
304			batchsize=batchsize,
305			epochs=epochs,
306		)

directly from any representation to a number of batches

CastableToTrainingInterval = typing.Union[str, tuple[typing.Union[int, float, str], str], TrainingInterval]