Coverage for trnbl\training_interval.py: 94%

118 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-01-17 02:23 -0700

1from typing import Any, Generator, Literal, Callable, Union, Sequence 

2from dataclasses import dataclass 

3 

4from muutils.misc import str_to_numeric 

5from muutils.errormode import ErrorMode 

6from muutils.interval import Interval 

7 

8_EPSILON: float = 1e-6 

9 

10# units of training intervals -- we convert this all to batches 

11TrainingIntervalUnit = Literal["runs", "epochs", "batches", "samples"] 

12 

13_TRAINING_INTERVAL_UNITS_RANGES: dict[TrainingIntervalUnit, Interval] = { 

14 # epochs and runs should not actually be closed, but we allow it 

15 "runs": Interval(0, 1, is_closed=True), 

16 "epochs": Interval(0, float("inf"), is_closed=True), 

17 "batches": Interval(1, float("inf"), is_closed=True), 

18 "samples": Interval(1, float("inf"), is_closed=True), 

19} 

20 

21_TRAINING_INTERVAL_UNITS_CAST: dict[TrainingIntervalUnit, Callable] = { 

22 "runs": lambda x: x, 

23 "epochs": lambda x: x, 

24 "batches": lambda x: int(round(x)), 

25 "samples": lambda x: int(round(x)), 

26} 

27 

28_TRAINING_INTERVAL_UNIT_ALIASES: dict[str, str] = { 

29 "run": "runs", 

30 "epoch": "epochs", 

31 "batch": "batches", 

32 "sample": "samples", 

33} 

34 

35# what to do if interval is < 1 batch 

36# if WARN or IGNORE, set it to 1 batch 

37WhenIntervalLessThanBatch: ErrorMode = ErrorMode.WARN 

38 

39 

40class IntervalValueError(UserWarning): 

41 """Error for when the interval is less than 1 batch""" 

42 

43 pass 

44 

45 

46@dataclass(frozen=True) 

47class TrainingInterval: 

48 """A training interval, which can be specified in a few different units. 

49 

50 # Attributes: 

51 - `quantity: int|float` - the quantity of the interval 

52 - `unit: TrainingIntervalUnit` - the unit of the interval, one of "runs", "epochs", "batches", or "samples" 

53 

54 # Methods: 

55 - `TrainingInterval.from_str(raw: str) -> TrainingInterval` - parse a string into a TrainingInterval object 

56 - `TrainingInterval.as_batch_count(batchsize: int, batches_per_epoch: int, epochs: int|None) -> int` - convert the interval to a raw number of batches 

57 - `TrainingInterval.process_to_batches(interval: str|TrainingInterval, batchsize: int, batches_per_epoch: int, epochs: int|None) -> int` - any representation to a number of batches 

58 - `TrainingInterval.normalized(batchsize: int, batches_per_epoch: int, epochs: int|None) -> None` - current interval, with units switched to batches 

59 

60 Provides methods for reading from a string or tuple, and normalizing to batches. 

61 """ 

62 

63 quantity: int | float 

64 unit: TrainingIntervalUnit 

65 

66 def __iter__(self) -> Generator[int | float | TrainingIntervalUnit, None, None]: 

67 yield self.quantity 

68 yield self.unit 

69 

70 def __getitem__(self, index: int) -> int | float | TrainingIntervalUnit: 

71 if index == 0: 

72 return self.quantity 

73 elif index == 1: 

74 return self.unit 

75 else: 

76 raise IndexError(f"invalid index {index} for TrainingInterval") 

77 

78 def __post_init__(self) -> None: 

79 try: 

80 assert isinstance(self.quantity, (int, float)), ( 

81 "quantity should be an integer or float" 

82 ) 

83 # TODO: Literal[...].__args__ is not defined?? 

84 if self.unit not in TrainingIntervalUnit.__args__: # type: ignore[attr-defined] 

85 unit_dealised: str | None = _TRAINING_INTERVAL_UNIT_ALIASES.get( 

86 self.unit.lower(), None 

87 ) 

88 if isinstance(unit_dealised, str): 

89 self.__dict__["unit"] = unit_dealised 

90 else: 

91 raise ValueError(f"invalid unit {self.unit = }") 

92 

93 assert self.unit in TrainingIntervalUnit.__args__, ( # type: ignore[attr-defined] 

94 f"invalid unit {self.unit}" 

95 ) 

96 except AssertionError as e: 

97 raise AssertionError( 

98 f"Error initializing TrainingInterval\n{self}\n{e}" 

99 ) from e 

100 

101 # check values in proper ranges 

102 expected_interval: Interval = _TRAINING_INTERVAL_UNITS_RANGES[self.unit] 

103 if self.quantity not in expected_interval: 

104 WhenIntervalLessThanBatch.process( 

105 f"interval {self} has invalid quantity, expected in interval {expected_interval}, will set to closest bound if not erroring out", 

106 except_cls=IntervalValueError, 

107 warn_cls=IntervalValueError, 

108 ) 

109 self.__dict__["quantity"] = expected_interval.clamp(self.quantity) 

110 

111 # cast if necessary 

112 self.__dict__["quantity"] = _TRAINING_INTERVAL_UNITS_CAST[self.unit]( 

113 self.quantity 

114 ) 

115 

116 def __eq__(self, other: Any) -> bool: 

117 if not isinstance(other, self.__class__): 

118 raise TypeError( 

119 f"invalid type {type(other)} for comparison with TrainingInterval" 

120 ) 

121 return ( 

122 abs(self.quantity - other.quantity) < _EPSILON and self.unit == other.unit 

123 ) 

124 

125 def as_batch_count( 

126 self, 

127 batchsize: int, 

128 batches_per_epoch: int, 

129 epochs: int | None = None, 

130 ) -> int: 

131 """given the batchsize, number of batches per epoch, and number of epochs, return the interval as a number of batches 

132 

133 # Parameters: 

134 - `batchsize: int` 

135 the size of a batch 

136 - `batches_per_epoch: int` 

137 the number of batches in an epoch 

138 - `epochs: int|None` 

139 the number of epochs to run (only required if the interval is in "runs") 

140 

141 # Returns: 

142 - `int` 

143 the interval as a number of batches 

144 

145 # Raises: 

146 - `ValueError` 

147 if the interval is less than 1 batch, and the `trnbl.training_interval.WhenIntervalLessThanBatch` is set to `muutils.errormode.ErrorMode.ERROR` 

148 otherwise, will warn or ignore and set the interval to 1 batch 

149 - `ValueError` 

150 if the unit is not one of "runs", "epochs", "batches", or "samples" 

151 

152 

153 """ 

154 

155 output: int | float 

156 

157 match self.unit: 

158 case "runs": 

159 assert epochs is not None, ( 

160 "epochs must be provided to convert runs to batches" 

161 ) 

162 output = self.quantity * epochs * batches_per_epoch 

163 case "epochs": 

164 output = self.quantity * batches_per_epoch 

165 case "batches": 

166 output = self.quantity 

167 case "samples": 

168 output = self.quantity / batchsize 

169 case _: 

170 raise ValueError(f"invalid unit {self.unit}") 

171 

172 # check if interval is less than 1 batch 

173 if output < 1: 

174 WhenIntervalLessThanBatch.process( 

175 f"interval {self} is less than 1 batch, will set to 1 batch if not erroring out", 

176 except_cls=IntervalValueError, 

177 warn_cls=IntervalValueError, 

178 ) 

179 output = 1 

180 

181 return int(round(output)) 

182 

183 def normalized( 

184 self, 

185 batchsize: int, 

186 batches_per_epoch: int, 

187 epochs: int | None = None, 

188 ) -> "TrainingInterval": 

189 """convert the units of the interval to batches, by calling `as_batch_count` and setting the `unit` to "batches""" 

190 quantity: int | float = self.as_batch_count( 

191 batches_per_epoch=batches_per_epoch, 

192 batchsize=batchsize, 

193 epochs=epochs, 

194 ) 

195 unit: TrainingIntervalUnit = "batches" 

196 return self.__class__(quantity, unit) 

197 

198 @classmethod 

199 def from_str(cls, raw: str) -> "TrainingInterval": 

200 """parse a string into a TrainingInterval object 

201 

202 # Examples: 

203 

204 >>> TrainingInterval.from_str("5 epochs") 

205 TrainingInterval(5, 'epochs') 

206 >>> TrainingInterval.from_str("100 batches") 

207 TrainingInterval(100, 'batches') 

208 >>> TrainingInterval.from_str("0.1 runs") 

209 TrainingInterval(0.1, 'runs') 

210 >>> TrainingInterval.from_str("1/5 runs") 

211 TrainingInterval(0.2, 'runs') 

212 

213 """ 

214 try: 

215 # remove prefix and suffix (optionally) 

216 raw = raw.removeprefix("TrainingInterval(").removesuffix(")") 

217 

218 # process quantity 

219 raw_split: list[str] 

220 quantity_str: str 

221 if "," in raw: 

222 raw_split = raw.split(",") 

223 quantity_str = ",".join(raw_split[:-1]) 

224 else: 

225 raw_split = raw.split() 

226 quantity_str = " ".join(raw_split[:-1]) 

227 

228 quantity: int | float = str_to_numeric(quantity_str) 

229 

230 # process unit 

231 unit: str = raw_split[-1] 

232 unit.strip().strip("'\"").strip() 

233 

234 # unit should be one of the allowed units 

235 unit_dealised: str | None 

236 if unit.lower() in TrainingIntervalUnit.__args__: # type: ignore[attr-defined] 

237 unit_dealised = unit.lower() 

238 else: 

239 unit_dealised = _TRAINING_INTERVAL_UNIT_ALIASES.get(unit.lower(), None) 

240 if isinstance(unit_dealised, str): 

241 unit = unit_dealised 

242 else: 

243 raise ValueError(f"invalid unit {unit}") 

244 

245 assert unit in TrainingIntervalUnit.__args__ # type: ignore[attr-defined] 

246 except Exception as e: 

247 raise ValueError(f"Error parsing {raw} as a TrainingInterval\n{e}") from e 

248 

249 return cls(quantity, unit) # type: ignore[arg-type] 

250 

251 @classmethod 

252 def from_any(cls, *args, **kwargs) -> "TrainingInterval": 

253 """parse a string or tuple into a TrainingInterval object""" 

254 

255 try: 

256 # no kwargs allowed 

257 assert len(kwargs) == 0, "no kwargs allowed for from_any" 

258 

259 # split up args 

260 data: Any 

261 match len(args): 

262 case 1: 

263 data = args[0] 

264 case 2: 

265 data = args 

266 case _: 

267 raise ValueError( 

268 f"invalid number of args {len(args)} for from_any: {args = }" 

269 ) 

270 

271 if isinstance(data, cls): 

272 return data 

273 elif isinstance(data, str): 

274 return cls.from_str(data) 

275 elif isinstance(data, Sequence): 

276 assert len(data) == 2, ( 

277 f"invalid length {len(data)} for TrainingInterval: {data}" 

278 ) 

279 quantity, unit = data 

280 if isinstance(quantity, str): 

281 quantity = str_to_numeric(quantity) 

282 return cls(quantity, unit) 

283 else: 

284 raise ValueError(f"invalid type {type(data)} for TrainingInterval") 

285 

286 except AssertionError as e: 

287 raise ValueError(f"Error parsing {data} as a TrainingInterval\n{e}") from e 

288 

289 @classmethod 

290 def process_to_batches( 

291 cls, 

292 interval: "CastableToTrainingInterval", 

293 batchsize: int, 

294 batches_per_epoch: int, 

295 epochs: int | None = None, 

296 ) -> int: 

297 """directly from any representation to a number of batches""" 

298 

299 interval_ti: TrainingInterval = cls.from_any(interval) 

300 

301 return interval_ti.as_batch_count( 

302 batches_per_epoch=batches_per_epoch, 

303 batchsize=batchsize, 

304 epochs=epochs, 

305 ) 

306 

307 

308CastableToTrainingInterval = Union[ 

309 str, # parse as string "quantity unit" 

310 tuple[Union[int, float, str], str], # parse as tuple (quantity, unit) 

311 TrainingInterval, # already a TrainingInterval 

312]