trnbl.training_interval
1from typing import Any, Generator, Literal, Callable, Union, Sequence 2from dataclasses import dataclass 3 4from muutils.misc import str_to_numeric 5from muutils.errormode import ErrorMode 6from muutils.interval import Interval 7 8_EPSILON: float = 1e-6 9 10# units of training intervals -- we convert this all to batches 11TrainingIntervalUnit = Literal["runs", "epochs", "batches", "samples"] 12 13_TRAINING_INTERVAL_UNITS_RANGES: dict[TrainingIntervalUnit, Interval] = { 14 # epochs and runs should not actually be closed, but we allow it 15 "runs": Interval(0, 1, is_closed=True), 16 "epochs": Interval(0, float("inf"), is_closed=True), 17 "batches": Interval(1, float("inf"), is_closed=True), 18 "samples": Interval(1, float("inf"), is_closed=True), 19} 20 21_TRAINING_INTERVAL_UNITS_CAST: dict[TrainingIntervalUnit, Callable] = { 22 "runs": lambda x: x, 23 "epochs": lambda x: x, 24 "batches": lambda x: int(round(x)), 25 "samples": lambda x: int(round(x)), 26} 27 28_TRAINING_INTERVAL_UNIT_ALIASES: dict[str, str] = { 29 "run": "runs", 30 "epoch": "epochs", 31 "batch": "batches", 32 "sample": "samples", 33} 34 35# what to do if interval is < 1 batch 36# if WARN or IGNORE, set it to 1 batch 37WhenIntervalLessThanBatch: ErrorMode = ErrorMode.WARN 38 39 40class IntervalValueError(UserWarning): 41 """Error for when the interval is less than 1 batch""" 42 43 pass 44 45 46@dataclass(frozen=True) 47class TrainingInterval: 48 """A training interval, which can be specified in a few different units. 49 50 # Attributes: 51 - `quantity: int|float` - the quantity of the interval 52 - `unit: TrainingIntervalUnit` - the unit of the interval, one of "runs", "epochs", "batches", or "samples" 53 54 # Methods: 55 - `TrainingInterval.from_str(raw: str) -> TrainingInterval` - parse a string into a TrainingInterval object 56 - `TrainingInterval.as_batch_count(batchsize: int, batches_per_epoch: int, epochs: int|None) -> int` - convert the interval to a raw number of batches 57 - `TrainingInterval.process_to_batches(interval: str|TrainingInterval, batchsize: int, batches_per_epoch: int, epochs: int|None) -> int` - any representation to a number of batches 58 - `TrainingInterval.normalized(batchsize: int, batches_per_epoch: int, epochs: int|None) -> None` - current interval, with units switched to batches 59 60 Provides methods for reading from a string or tuple, and normalizing to batches. 61 """ 62 63 quantity: int | float 64 unit: TrainingIntervalUnit 65 66 def __iter__(self) -> Generator[int | float | TrainingIntervalUnit, None, None]: 67 yield self.quantity 68 yield self.unit 69 70 def __getitem__(self, index: int) -> int | float | TrainingIntervalUnit: 71 if index == 0: 72 return self.quantity 73 elif index == 1: 74 return self.unit 75 else: 76 raise IndexError(f"invalid index {index} for TrainingInterval") 77 78 def __post_init__(self) -> None: 79 try: 80 assert isinstance(self.quantity, (int, float)), ( 81 "quantity should be an integer or float" 82 ) 83 # TODO: Literal[...].__args__ is not defined?? 84 if self.unit not in TrainingIntervalUnit.__args__: # type: ignore[attr-defined] 85 unit_dealised: str | None = _TRAINING_INTERVAL_UNIT_ALIASES.get( 86 self.unit.lower(), None 87 ) 88 if isinstance(unit_dealised, str): 89 self.__dict__["unit"] = unit_dealised 90 else: 91 raise ValueError(f"invalid unit {self.unit = }") 92 93 assert self.unit in TrainingIntervalUnit.__args__, ( # type: ignore[attr-defined] 94 f"invalid unit {self.unit}" 95 ) 96 except AssertionError as e: 97 raise AssertionError( 98 f"Error initializing TrainingInterval\n{self}\n{e}" 99 ) from e 100 101 # check values in proper ranges 102 expected_interval: Interval = _TRAINING_INTERVAL_UNITS_RANGES[self.unit] 103 if self.quantity not in expected_interval: 104 WhenIntervalLessThanBatch.process( 105 f"interval {self} has invalid quantity, expected in interval {expected_interval}, will set to closest bound if not erroring out", 106 except_cls=IntervalValueError, 107 warn_cls=IntervalValueError, 108 ) 109 self.__dict__["quantity"] = expected_interval.clamp(self.quantity) 110 111 # cast if necessary 112 self.__dict__["quantity"] = _TRAINING_INTERVAL_UNITS_CAST[self.unit]( 113 self.quantity 114 ) 115 116 def __eq__(self, other: Any) -> bool: 117 if not isinstance(other, self.__class__): 118 raise TypeError( 119 f"invalid type {type(other)} for comparison with TrainingInterval" 120 ) 121 return ( 122 abs(self.quantity - other.quantity) < _EPSILON and self.unit == other.unit 123 ) 124 125 def as_batch_count( 126 self, 127 batchsize: int, 128 batches_per_epoch: int, 129 epochs: int | None = None, 130 ) -> int: 131 """given the batchsize, number of batches per epoch, and number of epochs, return the interval as a number of batches 132 133 # Parameters: 134 - `batchsize: int` 135 the size of a batch 136 - `batches_per_epoch: int` 137 the number of batches in an epoch 138 - `epochs: int|None` 139 the number of epochs to run (only required if the interval is in "runs") 140 141 # Returns: 142 - `int` 143 the interval as a number of batches 144 145 # Raises: 146 - `ValueError` 147 if the interval is less than 1 batch, and the `trnbl.training_interval.WhenIntervalLessThanBatch` is set to `muutils.errormode.ErrorMode.ERROR` 148 otherwise, will warn or ignore and set the interval to 1 batch 149 - `ValueError` 150 if the unit is not one of "runs", "epochs", "batches", or "samples" 151 152 153 """ 154 155 output: int | float 156 157 match self.unit: 158 case "runs": 159 assert epochs is not None, ( 160 "epochs must be provided to convert runs to batches" 161 ) 162 output = self.quantity * epochs * batches_per_epoch 163 case "epochs": 164 output = self.quantity * batches_per_epoch 165 case "batches": 166 output = self.quantity 167 case "samples": 168 output = self.quantity / batchsize 169 case _: 170 raise ValueError(f"invalid unit {self.unit}") 171 172 # check if interval is less than 1 batch 173 if output < 1: 174 WhenIntervalLessThanBatch.process( 175 f"interval {self} is less than 1 batch, will set to 1 batch if not erroring out", 176 except_cls=IntervalValueError, 177 warn_cls=IntervalValueError, 178 ) 179 output = 1 180 181 return int(round(output)) 182 183 def normalized( 184 self, 185 batchsize: int, 186 batches_per_epoch: int, 187 epochs: int | None = None, 188 ) -> "TrainingInterval": 189 """convert the units of the interval to batches, by calling `as_batch_count` and setting the `unit` to "batches""" 190 quantity: int | float = self.as_batch_count( 191 batches_per_epoch=batches_per_epoch, 192 batchsize=batchsize, 193 epochs=epochs, 194 ) 195 unit: TrainingIntervalUnit = "batches" 196 return self.__class__(quantity, unit) 197 198 @classmethod 199 def from_str(cls, raw: str) -> "TrainingInterval": 200 """parse a string into a TrainingInterval object 201 202 # Examples: 203 204 >>> TrainingInterval.from_str("5 epochs") 205 TrainingInterval(5, 'epochs') 206 >>> TrainingInterval.from_str("100 batches") 207 TrainingInterval(100, 'batches') 208 >>> TrainingInterval.from_str("0.1 runs") 209 TrainingInterval(0.1, 'runs') 210 >>> TrainingInterval.from_str("1/5 runs") 211 TrainingInterval(0.2, 'runs') 212 213 """ 214 try: 215 # remove prefix and suffix (optionally) 216 raw = raw.removeprefix("TrainingInterval(").removesuffix(")") 217 218 # process quantity 219 raw_split: list[str] 220 quantity_str: str 221 if "," in raw: 222 raw_split = raw.split(",") 223 quantity_str = ",".join(raw_split[:-1]) 224 else: 225 raw_split = raw.split() 226 quantity_str = " ".join(raw_split[:-1]) 227 228 quantity: int | float = str_to_numeric(quantity_str) 229 230 # process unit 231 unit: str = raw_split[-1] 232 unit.strip().strip("'\"").strip() 233 234 # unit should be one of the allowed units 235 unit_dealised: str | None 236 if unit.lower() in TrainingIntervalUnit.__args__: # type: ignore[attr-defined] 237 unit_dealised = unit.lower() 238 else: 239 unit_dealised = _TRAINING_INTERVAL_UNIT_ALIASES.get(unit.lower(), None) 240 if isinstance(unit_dealised, str): 241 unit = unit_dealised 242 else: 243 raise ValueError(f"invalid unit {unit}") 244 245 assert unit in TrainingIntervalUnit.__args__ # type: ignore[attr-defined] 246 except Exception as e: 247 raise ValueError(f"Error parsing {raw} as a TrainingInterval\n{e}") from e 248 249 return cls(quantity, unit) # type: ignore[arg-type] 250 251 @classmethod 252 def from_any(cls, *args, **kwargs) -> "TrainingInterval": 253 """parse a string or tuple into a TrainingInterval object""" 254 255 try: 256 # no kwargs allowed 257 assert len(kwargs) == 0, "no kwargs allowed for from_any" 258 259 # split up args 260 data: Any 261 match len(args): 262 case 1: 263 data = args[0] 264 case 2: 265 data = args 266 case _: 267 raise ValueError( 268 f"invalid number of args {len(args)} for from_any: {args = }" 269 ) 270 271 if isinstance(data, cls): 272 return data 273 elif isinstance(data, str): 274 return cls.from_str(data) 275 elif isinstance(data, Sequence): 276 assert len(data) == 2, ( 277 f"invalid length {len(data)} for TrainingInterval: {data}" 278 ) 279 quantity, unit = data 280 if isinstance(quantity, str): 281 quantity = str_to_numeric(quantity) 282 return cls(quantity, unit) 283 else: 284 raise ValueError(f"invalid type {type(data)} for TrainingInterval") 285 286 except AssertionError as e: 287 raise ValueError(f"Error parsing {data} as a TrainingInterval\n{e}") from e 288 289 @classmethod 290 def process_to_batches( 291 cls, 292 interval: "CastableToTrainingInterval", 293 batchsize: int, 294 batches_per_epoch: int, 295 epochs: int | None = None, 296 ) -> int: 297 """directly from any representation to a number of batches""" 298 299 interval_ti: TrainingInterval = cls.from_any(interval) 300 301 return interval_ti.as_batch_count( 302 batches_per_epoch=batches_per_epoch, 303 batchsize=batchsize, 304 epochs=epochs, 305 ) 306 307 308CastableToTrainingInterval = Union[ 309 str, # parse as string "quantity unit" 310 tuple[Union[int, float, str], str], # parse as tuple (quantity, unit) 311 TrainingInterval, # already a TrainingInterval 312]
TrainingIntervalUnit =
typing.Literal['runs', 'epochs', 'batches', 'samples']
WhenIntervalLessThanBatch: muutils.errormode.ErrorMode =
ErrorMode.Warn
class
IntervalValueError(builtins.UserWarning):
41class IntervalValueError(UserWarning): 42 """Error for when the interval is less than 1 batch""" 43 44 pass
Error for when the interval is less than 1 batch
Inherited Members
- builtins.UserWarning
- UserWarning
- builtins.BaseException
- with_traceback
- add_note
- args
@dataclass(frozen=True)
class
TrainingInterval:
47@dataclass(frozen=True) 48class TrainingInterval: 49 """A training interval, which can be specified in a few different units. 50 51 # Attributes: 52 - `quantity: int|float` - the quantity of the interval 53 - `unit: TrainingIntervalUnit` - the unit of the interval, one of "runs", "epochs", "batches", or "samples" 54 55 # Methods: 56 - `TrainingInterval.from_str(raw: str) -> TrainingInterval` - parse a string into a TrainingInterval object 57 - `TrainingInterval.as_batch_count(batchsize: int, batches_per_epoch: int, epochs: int|None) -> int` - convert the interval to a raw number of batches 58 - `TrainingInterval.process_to_batches(interval: str|TrainingInterval, batchsize: int, batches_per_epoch: int, epochs: int|None) -> int` - any representation to a number of batches 59 - `TrainingInterval.normalized(batchsize: int, batches_per_epoch: int, epochs: int|None) -> None` - current interval, with units switched to batches 60 61 Provides methods for reading from a string or tuple, and normalizing to batches. 62 """ 63 64 quantity: int | float 65 unit: TrainingIntervalUnit 66 67 def __iter__(self) -> Generator[int | float | TrainingIntervalUnit, None, None]: 68 yield self.quantity 69 yield self.unit 70 71 def __getitem__(self, index: int) -> int | float | TrainingIntervalUnit: 72 if index == 0: 73 return self.quantity 74 elif index == 1: 75 return self.unit 76 else: 77 raise IndexError(f"invalid index {index} for TrainingInterval") 78 79 def __post_init__(self) -> None: 80 try: 81 assert isinstance(self.quantity, (int, float)), ( 82 "quantity should be an integer or float" 83 ) 84 # TODO: Literal[...].__args__ is not defined?? 85 if self.unit not in TrainingIntervalUnit.__args__: # type: ignore[attr-defined] 86 unit_dealised: str | None = _TRAINING_INTERVAL_UNIT_ALIASES.get( 87 self.unit.lower(), None 88 ) 89 if isinstance(unit_dealised, str): 90 self.__dict__["unit"] = unit_dealised 91 else: 92 raise ValueError(f"invalid unit {self.unit = }") 93 94 assert self.unit in TrainingIntervalUnit.__args__, ( # type: ignore[attr-defined] 95 f"invalid unit {self.unit}" 96 ) 97 except AssertionError as e: 98 raise AssertionError( 99 f"Error initializing TrainingInterval\n{self}\n{e}" 100 ) from e 101 102 # check values in proper ranges 103 expected_interval: Interval = _TRAINING_INTERVAL_UNITS_RANGES[self.unit] 104 if self.quantity not in expected_interval: 105 WhenIntervalLessThanBatch.process( 106 f"interval {self} has invalid quantity, expected in interval {expected_interval}, will set to closest bound if not erroring out", 107 except_cls=IntervalValueError, 108 warn_cls=IntervalValueError, 109 ) 110 self.__dict__["quantity"] = expected_interval.clamp(self.quantity) 111 112 # cast if necessary 113 self.__dict__["quantity"] = _TRAINING_INTERVAL_UNITS_CAST[self.unit]( 114 self.quantity 115 ) 116 117 def __eq__(self, other: Any) -> bool: 118 if not isinstance(other, self.__class__): 119 raise TypeError( 120 f"invalid type {type(other)} for comparison with TrainingInterval" 121 ) 122 return ( 123 abs(self.quantity - other.quantity) < _EPSILON and self.unit == other.unit 124 ) 125 126 def as_batch_count( 127 self, 128 batchsize: int, 129 batches_per_epoch: int, 130 epochs: int | None = None, 131 ) -> int: 132 """given the batchsize, number of batches per epoch, and number of epochs, return the interval as a number of batches 133 134 # Parameters: 135 - `batchsize: int` 136 the size of a batch 137 - `batches_per_epoch: int` 138 the number of batches in an epoch 139 - `epochs: int|None` 140 the number of epochs to run (only required if the interval is in "runs") 141 142 # Returns: 143 - `int` 144 the interval as a number of batches 145 146 # Raises: 147 - `ValueError` 148 if the interval is less than 1 batch, and the `trnbl.training_interval.WhenIntervalLessThanBatch` is set to `muutils.errormode.ErrorMode.ERROR` 149 otherwise, will warn or ignore and set the interval to 1 batch 150 - `ValueError` 151 if the unit is not one of "runs", "epochs", "batches", or "samples" 152 153 154 """ 155 156 output: int | float 157 158 match self.unit: 159 case "runs": 160 assert epochs is not None, ( 161 "epochs must be provided to convert runs to batches" 162 ) 163 output = self.quantity * epochs * batches_per_epoch 164 case "epochs": 165 output = self.quantity * batches_per_epoch 166 case "batches": 167 output = self.quantity 168 case "samples": 169 output = self.quantity / batchsize 170 case _: 171 raise ValueError(f"invalid unit {self.unit}") 172 173 # check if interval is less than 1 batch 174 if output < 1: 175 WhenIntervalLessThanBatch.process( 176 f"interval {self} is less than 1 batch, will set to 1 batch if not erroring out", 177 except_cls=IntervalValueError, 178 warn_cls=IntervalValueError, 179 ) 180 output = 1 181 182 return int(round(output)) 183 184 def normalized( 185 self, 186 batchsize: int, 187 batches_per_epoch: int, 188 epochs: int | None = None, 189 ) -> "TrainingInterval": 190 """convert the units of the interval to batches, by calling `as_batch_count` and setting the `unit` to "batches""" 191 quantity: int | float = self.as_batch_count( 192 batches_per_epoch=batches_per_epoch, 193 batchsize=batchsize, 194 epochs=epochs, 195 ) 196 unit: TrainingIntervalUnit = "batches" 197 return self.__class__(quantity, unit) 198 199 @classmethod 200 def from_str(cls, raw: str) -> "TrainingInterval": 201 """parse a string into a TrainingInterval object 202 203 # Examples: 204 205 >>> TrainingInterval.from_str("5 epochs") 206 TrainingInterval(5, 'epochs') 207 >>> TrainingInterval.from_str("100 batches") 208 TrainingInterval(100, 'batches') 209 >>> TrainingInterval.from_str("0.1 runs") 210 TrainingInterval(0.1, 'runs') 211 >>> TrainingInterval.from_str("1/5 runs") 212 TrainingInterval(0.2, 'runs') 213 214 """ 215 try: 216 # remove prefix and suffix (optionally) 217 raw = raw.removeprefix("TrainingInterval(").removesuffix(")") 218 219 # process quantity 220 raw_split: list[str] 221 quantity_str: str 222 if "," in raw: 223 raw_split = raw.split(",") 224 quantity_str = ",".join(raw_split[:-1]) 225 else: 226 raw_split = raw.split() 227 quantity_str = " ".join(raw_split[:-1]) 228 229 quantity: int | float = str_to_numeric(quantity_str) 230 231 # process unit 232 unit: str = raw_split[-1] 233 unit.strip().strip("'\"").strip() 234 235 # unit should be one of the allowed units 236 unit_dealised: str | None 237 if unit.lower() in TrainingIntervalUnit.__args__: # type: ignore[attr-defined] 238 unit_dealised = unit.lower() 239 else: 240 unit_dealised = _TRAINING_INTERVAL_UNIT_ALIASES.get(unit.lower(), None) 241 if isinstance(unit_dealised, str): 242 unit = unit_dealised 243 else: 244 raise ValueError(f"invalid unit {unit}") 245 246 assert unit in TrainingIntervalUnit.__args__ # type: ignore[attr-defined] 247 except Exception as e: 248 raise ValueError(f"Error parsing {raw} as a TrainingInterval\n{e}") from e 249 250 return cls(quantity, unit) # type: ignore[arg-type] 251 252 @classmethod 253 def from_any(cls, *args, **kwargs) -> "TrainingInterval": 254 """parse a string or tuple into a TrainingInterval object""" 255 256 try: 257 # no kwargs allowed 258 assert len(kwargs) == 0, "no kwargs allowed for from_any" 259 260 # split up args 261 data: Any 262 match len(args): 263 case 1: 264 data = args[0] 265 case 2: 266 data = args 267 case _: 268 raise ValueError( 269 f"invalid number of args {len(args)} for from_any: {args = }" 270 ) 271 272 if isinstance(data, cls): 273 return data 274 elif isinstance(data, str): 275 return cls.from_str(data) 276 elif isinstance(data, Sequence): 277 assert len(data) == 2, ( 278 f"invalid length {len(data)} for TrainingInterval: {data}" 279 ) 280 quantity, unit = data 281 if isinstance(quantity, str): 282 quantity = str_to_numeric(quantity) 283 return cls(quantity, unit) 284 else: 285 raise ValueError(f"invalid type {type(data)} for TrainingInterval") 286 287 except AssertionError as e: 288 raise ValueError(f"Error parsing {data} as a TrainingInterval\n{e}") from e 289 290 @classmethod 291 def process_to_batches( 292 cls, 293 interval: "CastableToTrainingInterval", 294 batchsize: int, 295 batches_per_epoch: int, 296 epochs: int | None = None, 297 ) -> int: 298 """directly from any representation to a number of batches""" 299 300 interval_ti: TrainingInterval = cls.from_any(interval) 301 302 return interval_ti.as_batch_count( 303 batches_per_epoch=batches_per_epoch, 304 batchsize=batchsize, 305 epochs=epochs, 306 )
A training interval, which can be specified in a few different units.
Attributes:
quantity: int|float
- the quantity of the intervalunit: TrainingIntervalUnit
- the unit of the interval, one of "runs", "epochs", "batches", or "samples"
Methods:
TrainingInterval.from_str(raw: str) -> TrainingInterval
- parse a string into a TrainingInterval objectTrainingInterval.as_batch_count(batchsize: int, batches_per_epoch: int, epochs: int|None) -> int
- convert the interval to a raw number of batchesTrainingInterval.process_to_batches(interval: str|TrainingInterval, batchsize: int, batches_per_epoch: int, epochs: int|None) -> int
- any representation to a number of batchesTrainingInterval.normalized(batchsize: int, batches_per_epoch: int, epochs: int|None) -> None
- current interval, with units switched to batches
Provides methods for reading from a string or tuple, and normalizing to batches.
def
as_batch_count( self, batchsize: int, batches_per_epoch: int, epochs: int | None = None) -> int:
126 def as_batch_count( 127 self, 128 batchsize: int, 129 batches_per_epoch: int, 130 epochs: int | None = None, 131 ) -> int: 132 """given the batchsize, number of batches per epoch, and number of epochs, return the interval as a number of batches 133 134 # Parameters: 135 - `batchsize: int` 136 the size of a batch 137 - `batches_per_epoch: int` 138 the number of batches in an epoch 139 - `epochs: int|None` 140 the number of epochs to run (only required if the interval is in "runs") 141 142 # Returns: 143 - `int` 144 the interval as a number of batches 145 146 # Raises: 147 - `ValueError` 148 if the interval is less than 1 batch, and the `trnbl.training_interval.WhenIntervalLessThanBatch` is set to `muutils.errormode.ErrorMode.ERROR` 149 otherwise, will warn or ignore and set the interval to 1 batch 150 - `ValueError` 151 if the unit is not one of "runs", "epochs", "batches", or "samples" 152 153 154 """ 155 156 output: int | float 157 158 match self.unit: 159 case "runs": 160 assert epochs is not None, ( 161 "epochs must be provided to convert runs to batches" 162 ) 163 output = self.quantity * epochs * batches_per_epoch 164 case "epochs": 165 output = self.quantity * batches_per_epoch 166 case "batches": 167 output = self.quantity 168 case "samples": 169 output = self.quantity / batchsize 170 case _: 171 raise ValueError(f"invalid unit {self.unit}") 172 173 # check if interval is less than 1 batch 174 if output < 1: 175 WhenIntervalLessThanBatch.process( 176 f"interval {self} is less than 1 batch, will set to 1 batch if not erroring out", 177 except_cls=IntervalValueError, 178 warn_cls=IntervalValueError, 179 ) 180 output = 1 181 182 return int(round(output))
given the batchsize, number of batches per epoch, and number of epochs, return the interval as a number of batches
Parameters:
batchsize: int
the size of a batchbatches_per_epoch: int
the number of batches in an epochepochs: int|None
the number of epochs to run (only required if the interval is in "runs")
Returns:
int
the interval as a number of batches
Raises:
ValueError
if the interval is less than 1 batch, and theWhenIntervalLessThanBatch
is set tomuutils.errormode.ErrorMode.ERROR
otherwise, will warn or ignore and set the interval to 1 batchValueError
if the unit is not one of "runs", "epochs", "batches", or "samples"
def
normalized( self, batchsize: int, batches_per_epoch: int, epochs: int | None = None) -> TrainingInterval:
184 def normalized( 185 self, 186 batchsize: int, 187 batches_per_epoch: int, 188 epochs: int | None = None, 189 ) -> "TrainingInterval": 190 """convert the units of the interval to batches, by calling `as_batch_count` and setting the `unit` to "batches""" 191 quantity: int | float = self.as_batch_count( 192 batches_per_epoch=batches_per_epoch, 193 batchsize=batchsize, 194 epochs=epochs, 195 ) 196 unit: TrainingIntervalUnit = "batches" 197 return self.__class__(quantity, unit)
convert the units of the interval to batches, by calling as_batch_count
and setting the unit
to "batches
199 @classmethod 200 def from_str(cls, raw: str) -> "TrainingInterval": 201 """parse a string into a TrainingInterval object 202 203 # Examples: 204 205 >>> TrainingInterval.from_str("5 epochs") 206 TrainingInterval(5, 'epochs') 207 >>> TrainingInterval.from_str("100 batches") 208 TrainingInterval(100, 'batches') 209 >>> TrainingInterval.from_str("0.1 runs") 210 TrainingInterval(0.1, 'runs') 211 >>> TrainingInterval.from_str("1/5 runs") 212 TrainingInterval(0.2, 'runs') 213 214 """ 215 try: 216 # remove prefix and suffix (optionally) 217 raw = raw.removeprefix("TrainingInterval(").removesuffix(")") 218 219 # process quantity 220 raw_split: list[str] 221 quantity_str: str 222 if "," in raw: 223 raw_split = raw.split(",") 224 quantity_str = ",".join(raw_split[:-1]) 225 else: 226 raw_split = raw.split() 227 quantity_str = " ".join(raw_split[:-1]) 228 229 quantity: int | float = str_to_numeric(quantity_str) 230 231 # process unit 232 unit: str = raw_split[-1] 233 unit.strip().strip("'\"").strip() 234 235 # unit should be one of the allowed units 236 unit_dealised: str | None 237 if unit.lower() in TrainingIntervalUnit.__args__: # type: ignore[attr-defined] 238 unit_dealised = unit.lower() 239 else: 240 unit_dealised = _TRAINING_INTERVAL_UNIT_ALIASES.get(unit.lower(), None) 241 if isinstance(unit_dealised, str): 242 unit = unit_dealised 243 else: 244 raise ValueError(f"invalid unit {unit}") 245 246 assert unit in TrainingIntervalUnit.__args__ # type: ignore[attr-defined] 247 except Exception as e: 248 raise ValueError(f"Error parsing {raw} as a TrainingInterval\n{e}") from e 249 250 return cls(quantity, unit) # type: ignore[arg-type]
parse a string into a TrainingInterval object
Examples:
>>> TrainingInterval.from_str("5 epochs")
TrainingInterval(5, 'epochs')
>>> TrainingInterval.from_str("100 batches")
TrainingInterval(100, 'batches')
>>> TrainingInterval.from_str("0.1 runs")
TrainingInterval(0.1, 'runs')
>>> TrainingInterval.from_str("1/5 runs")
TrainingInterval(0.2, 'runs')
252 @classmethod 253 def from_any(cls, *args, **kwargs) -> "TrainingInterval": 254 """parse a string or tuple into a TrainingInterval object""" 255 256 try: 257 # no kwargs allowed 258 assert len(kwargs) == 0, "no kwargs allowed for from_any" 259 260 # split up args 261 data: Any 262 match len(args): 263 case 1: 264 data = args[0] 265 case 2: 266 data = args 267 case _: 268 raise ValueError( 269 f"invalid number of args {len(args)} for from_any: {args = }" 270 ) 271 272 if isinstance(data, cls): 273 return data 274 elif isinstance(data, str): 275 return cls.from_str(data) 276 elif isinstance(data, Sequence): 277 assert len(data) == 2, ( 278 f"invalid length {len(data)} for TrainingInterval: {data}" 279 ) 280 quantity, unit = data 281 if isinstance(quantity, str): 282 quantity = str_to_numeric(quantity) 283 return cls(quantity, unit) 284 else: 285 raise ValueError(f"invalid type {type(data)} for TrainingInterval") 286 287 except AssertionError as e: 288 raise ValueError(f"Error parsing {data} as a TrainingInterval\n{e}") from e
parse a string or tuple into a TrainingInterval object
@classmethod
def
process_to_batches( cls, interval: Union[str, tuple[Union[int, float, str], str], TrainingInterval], batchsize: int, batches_per_epoch: int, epochs: int | None = None) -> int:
290 @classmethod 291 def process_to_batches( 292 cls, 293 interval: "CastableToTrainingInterval", 294 batchsize: int, 295 batches_per_epoch: int, 296 epochs: int | None = None, 297 ) -> int: 298 """directly from any representation to a number of batches""" 299 300 interval_ti: TrainingInterval = cls.from_any(interval) 301 302 return interval_ti.as_batch_count( 303 batches_per_epoch=batches_per_epoch, 304 batchsize=batchsize, 305 epochs=epochs, 306 )
directly from any representation to a number of batches
CastableToTrainingInterval =
typing.Union[str, tuple[typing.Union[int, float, str], str], TrainingInterval]