trnbl
trnbl
-- Training Butler
If you train a lot of models, you might often find yourself being annoyed at swapping between different loggers and fiddling with a bunch of if batch_idx % some_number == 0
statements. This package aims to fix that problem.
Firstly, a universal interface to wandb
, tensorboard
, and a minimal local logging solution (live demo) is provided.
- This interface handles logging, error messages, metrics, and artifacts.
- Swapping from one logger to another requires no modifications except initializing the new logger you want, and passing that instead.
- You can even log to multiple loggers at once!
Secondly, a TrainingManager
class is provided which handles logging, artifacts, checkpointing, evaluations, exceptions, and more, with flexibly customizable intervals.
Rather than having to specify all intervals in batches and then change everything manually when you change the batch size, dataset size, or number of epochs, you specify an interval in samples, batches, epochs, or runs. This is computed into the correct number of batches or epochs based on the current dataset and batch size.
"1/10 runs"
-- 10 times a run"2.5 epochs"
-- every 2 & 1/2 epochs(100, "batches")
-- every 100 batches"10k samples"
-- every 10,000 samples
an evaluation function is passed in a tuple with an interval, takes the model as an argument, and returns the metrics as a dictionary
checkpointing is handled automatically, specifying an interval in the same way as evaluations
models are saved at the end of the run, or if an exception is raised, a
model.exception.pt
is saved
Installation
pip install trnbl
Usage
also see the notebooks/
folder:
demo_minimal.py
for a minimal example with dummy datademo.ipynb
for an example with all options on the iris dataset
import torch
from torch.utils.data import DataLoader
from trnbl.logging.local import LocalLogger
from trnbl.training_manager import TrainingManager
# set up your dataset, model, optimizer, etc as usual
dataloader: DataLoader = DataLoader(my_dataset, batch_size=32)
model: torch.nn.Module = MyModel()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# set up a logger -- swap seamlessly between wandb, tensorboard, and local logging
logger: LocalLogger = LocalLogger(
project="iris-demo",
metric_names=["train/loss", "train/acc", "val/loss", "val/acc"],
train_config=dict(
model=str(model), optimizer=str(optimizer), criterion=str(criterion)
),
)
with TrainingManager(
# pass your model and logger
model=model,
logger=logger,
evals={
# pass evaluation functions which take a model, and return a dict of metrics
"1k samples": my_evaluation_function,
"0.5 epochs": lambda model: logger.get_mem_usage(),
"100 batches": my_other_eval_function,
}.items(),
checkpoint_interval="1/10 run", # will save a checkpoint 10 times per run
) as tr:
# wrap the loops, and length will be automatically calculated
# and used to figure out when to run evals, checkpoint, etc
for epoch in tr.epoch_loop(range(120)):
for inputs, targets in tr.batch_loop(TRAIN_LOADER):
# your normal training code
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
# compute whatever you want every batch
accuracy = torch.sum(torch.argmax(outputs, dim=1) == targets).item() / len(targets)
# log the metrics
tr.batch_update(
samples=len(targets),
**{"train/loss": loss.item(), "train/acc": accuracy},
)
# a `model.final.pt` checkpoint will be saved at the end of the run,
# or a `model.exception.pt` if something crashes inside the context
LocalLogger
Intended as a minimal logging solution for local runs, when you're too lazy to set up a new wandb
project for a quick test, and want to be able to easily read the logs. It logs everything as json or jsonl files, and provides a simple web interface for viewing the data. The web interface allows:
- enable or disable the visibility of individual runs
- filter and sort runs by various stats via an interactive table
- smooth the data and change axes scales
- move and resize all plots and tables
You can view a live demo of the web interface here.
[](https://miv.name/trnbl/iris-demo/index.html)
TODOs:
BUG: minifying the html/js code causes things to break?
frontend:
- batch/epoch size to table in config column group
- box to add aliases to runs
- customizable grid snap size?
- display the grid on the background?
deployment:
- demo website for local logger
- CI/CD for website, minification, tests, etc
- migrate to typescript
1""" 2.. include:: ../README.md 3""" 4 5from trnbl.training_interval import TrainingInterval, TrainingIntervalUnit 6from trnbl.loggers.base import TrainingLoggerBase 7from trnbl.training_manager import TrainingManager 8 9__all__ = [ 10 "TrainingInterval", 11 "TrainingIntervalUnit", 12 "TrainingLoggerBase", 13 "TrainingManager", 14 # submodules 15 "loggers", 16 "training_interval", 17 "training_manager", 18]
47@dataclass(frozen=True) 48class TrainingInterval: 49 """A training interval, which can be specified in a few different units. 50 51 # Attributes: 52 - `quantity: int|float` - the quantity of the interval 53 - `unit: TrainingIntervalUnit` - the unit of the interval, one of "runs", "epochs", "batches", or "samples" 54 55 # Methods: 56 - `TrainingInterval.from_str(raw: str) -> TrainingInterval` - parse a string into a TrainingInterval object 57 - `TrainingInterval.as_batch_count(batchsize: int, batches_per_epoch: int, epochs: int|None) -> int` - convert the interval to a raw number of batches 58 - `TrainingInterval.process_to_batches(interval: str|TrainingInterval, batchsize: int, batches_per_epoch: int, epochs: int|None) -> int` - any representation to a number of batches 59 - `TrainingInterval.normalized(batchsize: int, batches_per_epoch: int, epochs: int|None) -> None` - current interval, with units switched to batches 60 61 Provides methods for reading from a string or tuple, and normalizing to batches. 62 """ 63 64 quantity: int | float 65 unit: TrainingIntervalUnit 66 67 def __iter__(self) -> Generator[int | float | TrainingIntervalUnit, None, None]: 68 yield self.quantity 69 yield self.unit 70 71 def __getitem__(self, index: int) -> int | float | TrainingIntervalUnit: 72 if index == 0: 73 return self.quantity 74 elif index == 1: 75 return self.unit 76 else: 77 raise IndexError(f"invalid index {index} for TrainingInterval") 78 79 def __post_init__(self) -> None: 80 try: 81 assert isinstance(self.quantity, (int, float)), ( 82 "quantity should be an integer or float" 83 ) 84 # TODO: Literal[...].__args__ is not defined?? 85 if self.unit not in TrainingIntervalUnit.__args__: # type: ignore[attr-defined] 86 unit_dealised: str | None = _TRAINING_INTERVAL_UNIT_ALIASES.get( 87 self.unit.lower(), None 88 ) 89 if isinstance(unit_dealised, str): 90 self.__dict__["unit"] = unit_dealised 91 else: 92 raise ValueError(f"invalid unit {self.unit = }") 93 94 assert self.unit in TrainingIntervalUnit.__args__, ( # type: ignore[attr-defined] 95 f"invalid unit {self.unit}" 96 ) 97 except AssertionError as e: 98 raise AssertionError( 99 f"Error initializing TrainingInterval\n{self}\n{e}" 100 ) from e 101 102 # check values in proper ranges 103 expected_interval: Interval = _TRAINING_INTERVAL_UNITS_RANGES[self.unit] 104 if self.quantity not in expected_interval: 105 WhenIntervalLessThanBatch.process( 106 f"interval {self} has invalid quantity, expected in interval {expected_interval}, will set to closest bound if not erroring out", 107 except_cls=IntervalValueError, 108 warn_cls=IntervalValueError, 109 ) 110 self.__dict__["quantity"] = expected_interval.clamp(self.quantity) 111 112 # cast if necessary 113 self.__dict__["quantity"] = _TRAINING_INTERVAL_UNITS_CAST[self.unit]( 114 self.quantity 115 ) 116 117 def __eq__(self, other: Any) -> bool: 118 if not isinstance(other, self.__class__): 119 raise TypeError( 120 f"invalid type {type(other)} for comparison with TrainingInterval" 121 ) 122 return ( 123 abs(self.quantity - other.quantity) < _EPSILON and self.unit == other.unit 124 ) 125 126 def as_batch_count( 127 self, 128 batchsize: int, 129 batches_per_epoch: int, 130 epochs: int | None = None, 131 ) -> int: 132 """given the batchsize, number of batches per epoch, and number of epochs, return the interval as a number of batches 133 134 # Parameters: 135 - `batchsize: int` 136 the size of a batch 137 - `batches_per_epoch: int` 138 the number of batches in an epoch 139 - `epochs: int|None` 140 the number of epochs to run (only required if the interval is in "runs") 141 142 # Returns: 143 - `int` 144 the interval as a number of batches 145 146 # Raises: 147 - `ValueError` 148 if the interval is less than 1 batch, and the `trnbl.training_interval.WhenIntervalLessThanBatch` is set to `muutils.errormode.ErrorMode.ERROR` 149 otherwise, will warn or ignore and set the interval to 1 batch 150 - `ValueError` 151 if the unit is not one of "runs", "epochs", "batches", or "samples" 152 153 154 """ 155 156 output: int | float 157 158 match self.unit: 159 case "runs": 160 assert epochs is not None, ( 161 "epochs must be provided to convert runs to batches" 162 ) 163 output = self.quantity * epochs * batches_per_epoch 164 case "epochs": 165 output = self.quantity * batches_per_epoch 166 case "batches": 167 output = self.quantity 168 case "samples": 169 output = self.quantity / batchsize 170 case _: 171 raise ValueError(f"invalid unit {self.unit}") 172 173 # check if interval is less than 1 batch 174 if output < 1: 175 WhenIntervalLessThanBatch.process( 176 f"interval {self} is less than 1 batch, will set to 1 batch if not erroring out", 177 except_cls=IntervalValueError, 178 warn_cls=IntervalValueError, 179 ) 180 output = 1 181 182 return int(round(output)) 183 184 def normalized( 185 self, 186 batchsize: int, 187 batches_per_epoch: int, 188 epochs: int | None = None, 189 ) -> "TrainingInterval": 190 """convert the units of the interval to batches, by calling `as_batch_count` and setting the `unit` to "batches""" 191 quantity: int | float = self.as_batch_count( 192 batches_per_epoch=batches_per_epoch, 193 batchsize=batchsize, 194 epochs=epochs, 195 ) 196 unit: TrainingIntervalUnit = "batches" 197 return self.__class__(quantity, unit) 198 199 @classmethod 200 def from_str(cls, raw: str) -> "TrainingInterval": 201 """parse a string into a TrainingInterval object 202 203 # Examples: 204 205 >>> TrainingInterval.from_str("5 epochs") 206 TrainingInterval(5, 'epochs') 207 >>> TrainingInterval.from_str("100 batches") 208 TrainingInterval(100, 'batches') 209 >>> TrainingInterval.from_str("0.1 runs") 210 TrainingInterval(0.1, 'runs') 211 >>> TrainingInterval.from_str("1/5 runs") 212 TrainingInterval(0.2, 'runs') 213 214 """ 215 try: 216 # remove prefix and suffix (optionally) 217 raw = raw.removeprefix("TrainingInterval(").removesuffix(")") 218 219 # process quantity 220 raw_split: list[str] 221 quantity_str: str 222 if "," in raw: 223 raw_split = raw.split(",") 224 quantity_str = ",".join(raw_split[:-1]) 225 else: 226 raw_split = raw.split() 227 quantity_str = " ".join(raw_split[:-1]) 228 229 quantity: int | float = str_to_numeric(quantity_str) 230 231 # process unit 232 unit: str = raw_split[-1] 233 unit.strip().strip("'\"").strip() 234 235 # unit should be one of the allowed units 236 unit_dealised: str | None 237 if unit.lower() in TrainingIntervalUnit.__args__: # type: ignore[attr-defined] 238 unit_dealised = unit.lower() 239 else: 240 unit_dealised = _TRAINING_INTERVAL_UNIT_ALIASES.get(unit.lower(), None) 241 if isinstance(unit_dealised, str): 242 unit = unit_dealised 243 else: 244 raise ValueError(f"invalid unit {unit}") 245 246 assert unit in TrainingIntervalUnit.__args__ # type: ignore[attr-defined] 247 except Exception as e: 248 raise ValueError(f"Error parsing {raw} as a TrainingInterval\n{e}") from e 249 250 return cls(quantity, unit) # type: ignore[arg-type] 251 252 @classmethod 253 def from_any(cls, *args, **kwargs) -> "TrainingInterval": 254 """parse a string or tuple into a TrainingInterval object""" 255 256 try: 257 # no kwargs allowed 258 assert len(kwargs) == 0, "no kwargs allowed for from_any" 259 260 # split up args 261 data: Any 262 match len(args): 263 case 1: 264 data = args[0] 265 case 2: 266 data = args 267 case _: 268 raise ValueError( 269 f"invalid number of args {len(args)} for from_any: {args = }" 270 ) 271 272 if isinstance(data, cls): 273 return data 274 elif isinstance(data, str): 275 return cls.from_str(data) 276 elif isinstance(data, Sequence): 277 assert len(data) == 2, ( 278 f"invalid length {len(data)} for TrainingInterval: {data}" 279 ) 280 quantity, unit = data 281 if isinstance(quantity, str): 282 quantity = str_to_numeric(quantity) 283 return cls(quantity, unit) 284 else: 285 raise ValueError(f"invalid type {type(data)} for TrainingInterval") 286 287 except AssertionError as e: 288 raise ValueError(f"Error parsing {data} as a TrainingInterval\n{e}") from e 289 290 @classmethod 291 def process_to_batches( 292 cls, 293 interval: "CastableToTrainingInterval", 294 batchsize: int, 295 batches_per_epoch: int, 296 epochs: int | None = None, 297 ) -> int: 298 """directly from any representation to a number of batches""" 299 300 interval_ti: TrainingInterval = cls.from_any(interval) 301 302 return interval_ti.as_batch_count( 303 batches_per_epoch=batches_per_epoch, 304 batchsize=batchsize, 305 epochs=epochs, 306 )
A training interval, which can be specified in a few different units.
Attributes:
quantity: int|float
- the quantity of the intervalunit: TrainingIntervalUnit
- the unit of the interval, one of "runs", "epochs", "batches", or "samples"
Methods:
TrainingInterval.from_str(raw: str) -> TrainingInterval
- parse a string into a TrainingInterval objectTrainingInterval.as_batch_count(batchsize: int, batches_per_epoch: int, epochs: int|None) -> int
- convert the interval to a raw number of batchesTrainingInterval.process_to_batches(interval: str|TrainingInterval, batchsize: int, batches_per_epoch: int, epochs: int|None) -> int
- any representation to a number of batchesTrainingInterval.normalized(batchsize: int, batches_per_epoch: int, epochs: int|None) -> None
- current interval, with units switched to batches
Provides methods for reading from a string or tuple, and normalizing to batches.
126 def as_batch_count( 127 self, 128 batchsize: int, 129 batches_per_epoch: int, 130 epochs: int | None = None, 131 ) -> int: 132 """given the batchsize, number of batches per epoch, and number of epochs, return the interval as a number of batches 133 134 # Parameters: 135 - `batchsize: int` 136 the size of a batch 137 - `batches_per_epoch: int` 138 the number of batches in an epoch 139 - `epochs: int|None` 140 the number of epochs to run (only required if the interval is in "runs") 141 142 # Returns: 143 - `int` 144 the interval as a number of batches 145 146 # Raises: 147 - `ValueError` 148 if the interval is less than 1 batch, and the `trnbl.training_interval.WhenIntervalLessThanBatch` is set to `muutils.errormode.ErrorMode.ERROR` 149 otherwise, will warn or ignore and set the interval to 1 batch 150 - `ValueError` 151 if the unit is not one of "runs", "epochs", "batches", or "samples" 152 153 154 """ 155 156 output: int | float 157 158 match self.unit: 159 case "runs": 160 assert epochs is not None, ( 161 "epochs must be provided to convert runs to batches" 162 ) 163 output = self.quantity * epochs * batches_per_epoch 164 case "epochs": 165 output = self.quantity * batches_per_epoch 166 case "batches": 167 output = self.quantity 168 case "samples": 169 output = self.quantity / batchsize 170 case _: 171 raise ValueError(f"invalid unit {self.unit}") 172 173 # check if interval is less than 1 batch 174 if output < 1: 175 WhenIntervalLessThanBatch.process( 176 f"interval {self} is less than 1 batch, will set to 1 batch if not erroring out", 177 except_cls=IntervalValueError, 178 warn_cls=IntervalValueError, 179 ) 180 output = 1 181 182 return int(round(output))
given the batchsize, number of batches per epoch, and number of epochs, return the interval as a number of batches
Parameters:
batchsize: int
the size of a batchbatches_per_epoch: int
the number of batches in an epochepochs: int|None
the number of epochs to run (only required if the interval is in "runs")
Returns:
int
the interval as a number of batches
Raises:
ValueError
if the interval is less than 1 batch, and thetrnbl.training_interval.WhenIntervalLessThanBatch
is set tomuutils.errormode.ErrorMode.ERROR
otherwise, will warn or ignore and set the interval to 1 batchValueError
if the unit is not one of "runs", "epochs", "batches", or "samples"
184 def normalized( 185 self, 186 batchsize: int, 187 batches_per_epoch: int, 188 epochs: int | None = None, 189 ) -> "TrainingInterval": 190 """convert the units of the interval to batches, by calling `as_batch_count` and setting the `unit` to "batches""" 191 quantity: int | float = self.as_batch_count( 192 batches_per_epoch=batches_per_epoch, 193 batchsize=batchsize, 194 epochs=epochs, 195 ) 196 unit: TrainingIntervalUnit = "batches" 197 return self.__class__(quantity, unit)
convert the units of the interval to batches, by calling as_batch_count
and setting the unit
to "batches
199 @classmethod 200 def from_str(cls, raw: str) -> "TrainingInterval": 201 """parse a string into a TrainingInterval object 202 203 # Examples: 204 205 >>> TrainingInterval.from_str("5 epochs") 206 TrainingInterval(5, 'epochs') 207 >>> TrainingInterval.from_str("100 batches") 208 TrainingInterval(100, 'batches') 209 >>> TrainingInterval.from_str("0.1 runs") 210 TrainingInterval(0.1, 'runs') 211 >>> TrainingInterval.from_str("1/5 runs") 212 TrainingInterval(0.2, 'runs') 213 214 """ 215 try: 216 # remove prefix and suffix (optionally) 217 raw = raw.removeprefix("TrainingInterval(").removesuffix(")") 218 219 # process quantity 220 raw_split: list[str] 221 quantity_str: str 222 if "," in raw: 223 raw_split = raw.split(",") 224 quantity_str = ",".join(raw_split[:-1]) 225 else: 226 raw_split = raw.split() 227 quantity_str = " ".join(raw_split[:-1]) 228 229 quantity: int | float = str_to_numeric(quantity_str) 230 231 # process unit 232 unit: str = raw_split[-1] 233 unit.strip().strip("'\"").strip() 234 235 # unit should be one of the allowed units 236 unit_dealised: str | None 237 if unit.lower() in TrainingIntervalUnit.__args__: # type: ignore[attr-defined] 238 unit_dealised = unit.lower() 239 else: 240 unit_dealised = _TRAINING_INTERVAL_UNIT_ALIASES.get(unit.lower(), None) 241 if isinstance(unit_dealised, str): 242 unit = unit_dealised 243 else: 244 raise ValueError(f"invalid unit {unit}") 245 246 assert unit in TrainingIntervalUnit.__args__ # type: ignore[attr-defined] 247 except Exception as e: 248 raise ValueError(f"Error parsing {raw} as a TrainingInterval\n{e}") from e 249 250 return cls(quantity, unit) # type: ignore[arg-type]
parse a string into a TrainingInterval object
Examples:
>>> TrainingInterval.from_str("5 epochs")
TrainingInterval(5, 'epochs')
>>> TrainingInterval.from_str("100 batches")
TrainingInterval(100, 'batches')
>>> TrainingInterval.from_str("0.1 runs")
TrainingInterval(0.1, 'runs')
>>> TrainingInterval.from_str("1/5 runs")
TrainingInterval(0.2, 'runs')
252 @classmethod 253 def from_any(cls, *args, **kwargs) -> "TrainingInterval": 254 """parse a string or tuple into a TrainingInterval object""" 255 256 try: 257 # no kwargs allowed 258 assert len(kwargs) == 0, "no kwargs allowed for from_any" 259 260 # split up args 261 data: Any 262 match len(args): 263 case 1: 264 data = args[0] 265 case 2: 266 data = args 267 case _: 268 raise ValueError( 269 f"invalid number of args {len(args)} for from_any: {args = }" 270 ) 271 272 if isinstance(data, cls): 273 return data 274 elif isinstance(data, str): 275 return cls.from_str(data) 276 elif isinstance(data, Sequence): 277 assert len(data) == 2, ( 278 f"invalid length {len(data)} for TrainingInterval: {data}" 279 ) 280 quantity, unit = data 281 if isinstance(quantity, str): 282 quantity = str_to_numeric(quantity) 283 return cls(quantity, unit) 284 else: 285 raise ValueError(f"invalid type {type(data)} for TrainingInterval") 286 287 except AssertionError as e: 288 raise ValueError(f"Error parsing {data} as a TrainingInterval\n{e}") from e
parse a string or tuple into a TrainingInterval object
290 @classmethod 291 def process_to_batches( 292 cls, 293 interval: "CastableToTrainingInterval", 294 batchsize: int, 295 batches_per_epoch: int, 296 epochs: int | None = None, 297 ) -> int: 298 """directly from any representation to a number of batches""" 299 300 interval_ti: TrainingInterval = cls.from_any(interval) 301 302 return interval_ti.as_batch_count( 303 batches_per_epoch=batches_per_epoch, 304 batchsize=batchsize, 305 epochs=epochs, 306 )
directly from any representation to a number of batches
82class TrainingLoggerBase(ABC): 83 """Base class for training loggers""" 84 85 @abstractmethod 86 def debug(self, message: str, **kwargs) -> None: 87 """log a debug message which will be saved, but not printed""" 88 pass 89 90 @abstractmethod 91 def message(self, message: str, **kwargs) -> None: 92 """log a progress message, which will be printed to stdout""" 93 pass 94 95 def warning(self, message: str, **kwargs) -> None: 96 """log a warning message, which will be printed to stderr""" 97 self.message(f"WARNING: {message}", __warning__=True, **kwargs) 98 99 def error(self, message: str, **kwargs) -> None: 100 """log an error message""" 101 self.message(f"ERROR: {message}", __error__=True, **kwargs) 102 103 @abstractmethod 104 def metrics(self, data: dict[str, Any]) -> None: 105 """Log a dictionary of metrics""" 106 pass 107 108 @abstractmethod 109 def artifact( 110 self, 111 path: Path, 112 type: str, 113 aliases: list[str] | None = None, 114 metadata: dict | None = None, 115 ) -> None: 116 """log an artifact from a file""" 117 pass 118 119 @property 120 @abstractmethod 121 def url(self) -> str | list[str]: 122 """Get the URL for the current logging run""" 123 pass 124 125 @property 126 @abstractmethod 127 def run_path(self) -> Path | list[Path]: 128 """Get the path to the current logging run""" 129 pass 130 131 @abstractmethod 132 def flush(self) -> None: 133 """Flush the logger""" 134 pass 135 136 @abstractmethod 137 def finish(self) -> None: 138 """Finish logging""" 139 pass 140 141 def get_mem_usage(self) -> dict: 142 mem_usage: dict = {} 143 144 try: 145 # CPU/Memory usage (if available) 146 if PSUTIL_AVAILABLE: 147 cpu_percent = psutil.cpu_percent() 148 mem_usage["cpu/percent"] = cpu_percent 149 150 # Memory usage 151 virtual_mem = psutil.virtual_memory() 152 mem_usage["ram/used"] = virtual_mem.used 153 mem_usage["ram/percent"] = virtual_mem.percent 154 155 # GPU information (if available) 156 if GPU_UTILS_AVAILABLE: 157 gpus = GPUtil.getGPUs() 158 for gpu in gpus: 159 gpu_id = gpu.id 160 mem_usage[f"gpu:{gpu_id}/load"] = gpu.load 161 mem_usage[f"gpu:{gpu_id}/memory_used"] = gpu.memoryUsed 162 mem_usage[f"gpu:{gpu_id}/temperature"] = gpu.temperature 163 except Exception as e: 164 self.warning(f"Error getting memory usage: {e}") 165 166 return mem_usage 167 168 def spinner_task(self, **kwargs) -> LoggerSpinner: 169 "Create a spinner task. kwargs are passed to `Spinner`." 170 return LoggerSpinner(logger=self, **kwargs) 171 172 # def seq_task(self, **kwargs) -> LoggerSpinner: 173 # "Create a sequential task with progress bar. kwargs are passed to `tqdm`." 174 # return LoggerSpinner(message=message, logger=self, **kwargs)
Base class for training loggers
85 @abstractmethod 86 def debug(self, message: str, **kwargs) -> None: 87 """log a debug message which will be saved, but not printed""" 88 pass
log a debug message which will be saved, but not printed
90 @abstractmethod 91 def message(self, message: str, **kwargs) -> None: 92 """log a progress message, which will be printed to stdout""" 93 pass
log a progress message, which will be printed to stdout
95 def warning(self, message: str, **kwargs) -> None: 96 """log a warning message, which will be printed to stderr""" 97 self.message(f"WARNING: {message}", __warning__=True, **kwargs)
log a warning message, which will be printed to stderr
99 def error(self, message: str, **kwargs) -> None: 100 """log an error message""" 101 self.message(f"ERROR: {message}", __error__=True, **kwargs)
log an error message
103 @abstractmethod 104 def metrics(self, data: dict[str, Any]) -> None: 105 """Log a dictionary of metrics""" 106 pass
Log a dictionary of metrics
108 @abstractmethod 109 def artifact( 110 self, 111 path: Path, 112 type: str, 113 aliases: list[str] | None = None, 114 metadata: dict | None = None, 115 ) -> None: 116 """log an artifact from a file""" 117 pass
log an artifact from a file
119 @property 120 @abstractmethod 121 def url(self) -> str | list[str]: 122 """Get the URL for the current logging run""" 123 pass
Get the URL for the current logging run
125 @property 126 @abstractmethod 127 def run_path(self) -> Path | list[Path]: 128 """Get the path to the current logging run""" 129 pass
Get the path to the current logging run
141 def get_mem_usage(self) -> dict: 142 mem_usage: dict = {} 143 144 try: 145 # CPU/Memory usage (if available) 146 if PSUTIL_AVAILABLE: 147 cpu_percent = psutil.cpu_percent() 148 mem_usage["cpu/percent"] = cpu_percent 149 150 # Memory usage 151 virtual_mem = psutil.virtual_memory() 152 mem_usage["ram/used"] = virtual_mem.used 153 mem_usage["ram/percent"] = virtual_mem.percent 154 155 # GPU information (if available) 156 if GPU_UTILS_AVAILABLE: 157 gpus = GPUtil.getGPUs() 158 for gpu in gpus: 159 gpu_id = gpu.id 160 mem_usage[f"gpu:{gpu_id}/load"] = gpu.load 161 mem_usage[f"gpu:{gpu_id}/memory_used"] = gpu.memoryUsed 162 mem_usage[f"gpu:{gpu_id}/temperature"] = gpu.temperature 163 except Exception as e: 164 self.warning(f"Error getting memory usage: {e}") 165 166 return mem_usage
97class TrainingManager(Generic[TLogger]): 98 """context manager for training a model, with logging, evals, and checkpoints 99 100 # Parameters: 101 - `model : torch.nn.Module` 102 ref to model being trained - used for saving checkpoints 103 - `dataloader : torch.utils.data.DataLoader` 104 ref to dataloader being used - used for calculating training progress 105 - `logger : TrainingLoggerBase` 106 logger, which can be local or interface with wandb. 107 - `epochs : int` 108 number of epochs to train for 109 (defaults to `1`) 110 - `evals : Iterable[tuple[TrainingInterval | str, EvalFunction]] | None` 111 list of pairs of (interval, eval_fn) to run evals on the model. See `TrainingInterval` for interval options. 112 (defaults to `None`) 113 - `checkpoint_interval : TrainingInterval | str` 114 interval at which to save model checkpoints 115 (defaults to `TrainingInterval(1, "epochs")`) 116 - `print_metrics_interval : TrainingInterval | str` 117 interval at which to print metrics 118 (defaults to `TrainingInterval(0.1, "runs")`) 119 - `save_model : Callable[[torch.nn.Module, Path], None]` 120 function to save the model (defaults to `torch.save`) 121 (defaults to `torch.save`) 122 - `model_save_path : str` 123 format string for saving model checkpoints. uses `_get_format_kwargs` for formatting, along with an `alias` kwarg 124 (defaults to `"{run_path}/checkpoints/model.checkpoint-{latest_checkpoint}.pt"`) 125 - `model_save_path_special : str` 126 format string for saving special model checkpoints (final, exception, etc.). uses `_get_format_kwargs` for formatting, along with an `alias` kwarg 127 (defaults to `"{run_path}/model.{alias}.pt"`) 128 129 # Usage: 130 ```python 131 with TrainingManager( 132 model=model, dataloader=TRAIN_LOADER, logger=logger, epochs=500, 133 evals={ 134 "1 epochs": eval_func, 135 "0.1 runs": lambda model: logger.get_mem_usage(), 136 }.items(), 137 checkpoint_interval="50 epochs", 138 ) as tp: 139 140 # Training loop 141 model.train() 142 for epoch in range(epochs): 143 for inputs, targets in TRAIN_LOADER: 144 # the usual 145 optimizer.zero_grad() 146 outputs = model(inputs) 147 loss = criterion(outputs, targets) 148 loss.backward() 149 optimizer.step() 150 151 # compute accuracy 152 accuracy = torch.sum(torch.argmax(outputs, dim=1) == targets).item() / len(targets) 153 154 # log metrics 155 tp.batch_update( 156 # pass in number of samples in your batch (or it will be inferred from the batch size) 157 samples=len(targets), 158 # any other metrics you compute every loop 159 **{"train/loss": loss.item(), "train/acc": accuracy}, 160 ) 161 # batch_update will automatically run evals and save checkpoints as needed 162 163 tp.epoch_update() 164 ``` 165 166 """ 167 168 def __init__( 169 self, 170 model: "torch.nn.Module", 171 logger: TLogger, 172 # required if you don't wrap the loops 173 dataloader: "torch.utils.data.DataLoader|None" = None, 174 epochs_total: int | None = None, 175 save_model: Callable[["torch.nn.Module", Path], None] = torch.save, 176 # everything with intervals 177 evals: Iterable[tuple[CastableToTrainingInterval, EvalFunction]] | None = None, 178 checkpoint_interval: CastableToTrainingInterval = TrainingInterval(1, "epochs"), 179 print_metrics_interval: CastableToTrainingInterval = TrainingInterval( 180 0.1, "runs" 181 ), 182 # everything with paths 183 model_save_path: str = "{run_path}/checkpoints/model.checkpoint-{latest_checkpoint}.pt", 184 model_save_path_special: str = "{run_path}/model.{alias}.pt", 185 ): 186 # save start time 187 self.start_time: float = time.time() 188 # non path and non-interval args get copied over directly 189 self.model: "torch.nn.Module" = model 190 self.logger: TLogger = logger 191 self.save_model: Callable[["torch.nn.Module", Path], None] = save_model 192 193 self.logger.message("starting training manager initialization") 194 195 # model save paths 196 self.model_save_path: str = model_save_path 197 self.model_save_path_special: str = model_save_path_special 198 199 # temp intervals for processing later in `try_compute_counters` 200 self._evals: Iterable[tuple[TrainingInterval, EvalFunction]] 201 if evals is None: 202 self._evals = [] 203 else: 204 self._evals = [ 205 (TrainingInterval.from_any(interval), eval_fn) 206 for interval, eval_fn in evals 207 ] 208 self._checkpoint_interval: TrainingInterval = TrainingInterval.from_any( 209 checkpoint_interval 210 ) 211 self._print_metrics_interval: TrainingInterval = TrainingInterval.from_any( 212 print_metrics_interval 213 ) 214 215 self.evals: list[tuple[int, EvalFunction]] = list() 216 self.checkpoint_interval: int | None = None 217 self.print_metrics_interval: int | None = None 218 219 # counters for epochs, batches, samples, and checkpoints 220 self.epochs: int = 0 221 self.batches: int = 0 222 self.samples: int = 0 223 self.checkpoints: int = 0 224 225 # total numbers of epochs, batches, and samples 226 # pass via init kwarg or wrapped epochs loop 227 self.epochs_total: int | None = epochs_total 228 # from dataloader or dataloader in wrapped loop 229 self.batches_per_epoch: int | None = None 230 self.batch_size: int | None = None 231 self.samples_per_epoch: int | None = None 232 # computed dynamically from the above 233 self.batches_total: int | None = None 234 self.samples_total: int | None = None 235 236 # whether the init is finished 237 self.init_complete: bool = False 238 239 # if we have a dataloader, we can compute some of the above 240 if dataloader is not None: 241 self.batches_per_epoch = len(dataloader) 242 self.batch_size = dataloader.batch_size 243 self.samples_per_epoch = len(dataloader.dataset) # type: ignore[arg-type] 244 245 self.try_compute_counters() 246 247 def try_compute_counters(self) -> None: 248 # we depend on either the TrainingManager init or the wrapped loops 249 # getting the epochs_total and dataloader 250 # everything else is computed dynamically 251 252 if any( 253 x is None 254 for x in [ 255 self.epochs_total, 256 self.batches_per_epoch, 257 self.batch_size, 258 self.samples_per_epoch, 259 ] 260 ): 261 # if we don't have all the info we need, return early 262 return 263 264 # we can safely ignore type check here since we just checked for `None` 265 self.batches_total = self.batches_per_epoch * self.epochs_total # type: ignore[operator] 266 self.samples_total = self.samples_per_epoch * self.epochs_total # type: ignore[operator] 267 268 # check if the dataloader has a finite nonzero length 269 if self.samples_per_epoch == 0: 270 raise TrainingManagerInitError( 271 f"Dataloader has no samples. Please provide a dataloader with a non-zero length. {self.samples_per_epoch = }" 272 ) 273 274 if self.batches_per_epoch == 0: 275 raise TrainingManagerInitError( 276 f"Dataloader has no batches. Please provide a dataloader with a non-zero length. {self.batches_per_epoch = }" 277 ) 278 279 if self.batch_size == 0: 280 raise TrainingManagerInitError( 281 f"Dataloader has a batch size of 0. Please provide a dataloader with a non-zero batch size. {self.batch_size = }" 282 ) 283 284 if self.batch_size is None: 285 warnings.warn( 286 "batch size is None. This is likely because the dataloader passed to `TrainingManager` does not have a `batch_size` attribute." 287 + "\nthis should probably be an exception" 288 ) 289 290 # normalize intervals for checkpoints, metrics printing, and evals 291 _batch_info_kwargs: dict[str, int | None] = dict( 292 batches_per_epoch=self.batches_per_epoch, 293 batchsize=self.batch_size, 294 epochs=self.epochs_total, 295 ) 296 297 # TODO: no idea why we need `type: ignore[arg-type]` here 298 self.checkpoint_interval = TrainingInterval.process_to_batches( 299 interval=self._checkpoint_interval, 300 **_batch_info_kwargs, # type: ignore[arg-type] 301 ) 302 self.print_metrics_interval = TrainingInterval.process_to_batches( 303 interval=self._print_metrics_interval, 304 **_batch_info_kwargs, # type: ignore[arg-type] 305 ) 306 307 # list[tuple[int, EvalFunction]] 308 self.evals = [ 309 ( 310 TrainingInterval.process_to_batches(interval, **_batch_info_kwargs), # type: ignore[arg-type] 311 eval_fn, 312 ) 313 for interval, eval_fn in self._evals 314 ] 315 316 # log this info 317 self.init_complete = True 318 self.logger.message( 319 "initialized training manager", 320 __training_manager_init__=True, 321 epochs_total=self.epochs_total, 322 batches_per_epoch=self.batches_per_epoch, 323 batch_size=self.batch_size, 324 samples_per_epoch=self.samples_per_epoch, 325 samples_total=self.samples_total, 326 checkpoint_interval_batches=self.checkpoint_interval, 327 print_metrics_interval_batches=self.print_metrics_interval, 328 model_save_path=self.model_save_path, 329 model_save_path_special=self.model_save_path_special, 330 **self.training_status(), 331 ) 332 333 def __enter__(self): 334 return self 335 336 def __exit__(self, exc_type: Type, exc_val: Exception, exc_tb: TracebackType): 337 # if error 338 if exc_type is not None: 339 # add exception info 340 self.logger.error( 341 str(exc_val), 342 exc_type=str(exc_type), 343 exc_val=str(exc_val), 344 exc_tb=str(exc_tb), 345 ) 346 # save the model 347 self._save_checkpoint(alias="exception") 348 349 # close the logger 350 self.logger.finish() 351 else: 352 # if no error, log and save the final model 353 self.logger.message( 354 "training complete", 355 __complete__=True, 356 ) 357 self._save_checkpoint(alias="final") 358 self.logger.finish() 359 360 def epoch_loop( 361 self, 362 epochs: Sequence[int], 363 use_tqdm: bool = True, 364 **tqdm_kwargs, 365 ) -> Generator[int, None, None]: 366 return wrapped_iterable( 367 sequence=epochs, 368 manager=self, 369 is_epoch=True, 370 use_tqdm=use_tqdm, 371 tqdm_kwargs=tqdm_kwargs, 372 ) 373 374 def batch_loop( 375 self, 376 batches: Sequence[int], 377 use_tqdm: bool = False, 378 **tqdm_kwargs, 379 ) -> Generator[int, None, None]: 380 return wrapped_iterable( 381 sequence=batches, 382 manager=self, 383 is_epoch=False, 384 use_tqdm=use_tqdm, 385 tqdm_kwargs=tqdm_kwargs, 386 ) 387 388 def check_is_initialized(self): 389 if not self.init_complete: 390 raise TrainingManagerInitError( 391 "TrainingManager not correctly initialized. ", 392 "This is likely due to failing to specify the epoch count, or failing to specify batch size/count. " 393 "you must either wrap your epoch loop with `TrainingManager.epoch_loop` or specify `epochs_total`", 394 "AND you must either wrap your batch loop with `TrainingManager.batch_loop` or pass a `torch.utils.data.DataLoader` to the TrainingManager constructor.", 395 "please note, if not wrapping the epoch loop, you must also call `TrainingManager.epoch_update` at the end of each epoch.", 396 ) 397 398 def get_elapsed_time(self) -> float: 399 """return the elapsed time in seconds since the start of training""" 400 return time.time() - self.start_time 401 402 def training_status(self) -> dict[str, int | float]: 403 """status of elapsed time, samples, batches, epochs, and checkpoints""" 404 return dict( 405 # timestamp handled in logger 406 elapsed_time=self.get_elapsed_time(), 407 samples=self.samples, 408 batches=self.batches, 409 epochs=self.epochs, 410 latest_checkpoint=self.checkpoints, 411 ) 412 413 def _get_format_kwargs(self) -> dict[str, str | int | float]: 414 """keyword args for formatting model save paths. calls `TrainingManager.training_status` 415 416 # Provides: 417 - `run_path : str` - path where the run is being logged and artifacts are being saved 418 - `elapsed_time : float` - the elapsed time in seconds since the start of training 419 - `samples : int` - samples seen so far 420 - `batches : int` - batches seen so far 421 - `epochs : int` - the latest epoch number 422 - `latest_checkpoint : int` - the latest checkpoint number 423 424 """ 425 return { 426 "run_path": ( 427 self.logger.run_path.as_posix() 428 if isinstance(self.logger.run_path, Path) 429 # HACK: just return the first one 430 else self.logger.run_path[0].as_posix() 431 ), 432 **self.training_status(), 433 } 434 435 def batch_update(self, samples: int | None, metrics: dict | None = None, **kwargs): 436 """call this at the end of every batch. Pass `samples` or it will be inferred from the batch size, and any other metrics as kwargs 437 438 This function will: 439 - update internal counters 440 - run evals as needed (based on the intervals passed) 441 - log all metrics and training status 442 - save a checkpoint as needed (based on the checkpoint interval) 443 """ 444 # check init is finished 445 if not self.init_complete: 446 self.try_compute_counters() 447 self.check_is_initialized() 448 449 # process metrics and kwargs 450 if metrics is None: 451 metrics = dict() 452 453 metrics.update(kwargs) 454 455 # update counters 456 self.batches += 1 457 if samples is not None: 458 self.samples += samples 459 else: 460 # TODO: we warn if batch size is None, but don't except 461 self.samples += self.batch_size # type: ignore[operator] 462 463 # run evals if needed 464 for interval, eval_fn in self.evals: 465 if (self.batches % interval == 0) or (self.batches == self.batches_total): 466 metrics.update(eval_fn(self.model)) 467 468 # log metrics & training status 469 self.logger.metrics({**metrics, **self.training_status()}) 470 471 # print metrics if needed 472 473 # save checkpoint if needed 474 if self.batches % self.checkpoint_interval == 0: # type: ignore[operator] 475 self._save_checkpoint() 476 477 def epoch_update(self): 478 """call this at the end of every epoch. This function will log the completion of the epoch and update the epoch counter""" 479 self.logger.debug(f"completed epoch {self.epochs + 1}/{self.epochs_total}") 480 self.epochs += 1 481 482 def _save_checkpoint(self, alias: str | None = None): 483 """wrapper for saving checkpoint as an artifact to the logger and incrementing the checkpoint counter""" 484 # if no alias, then it's a regular checkpoint 485 no_alias: bool = alias is None 486 if no_alias: 487 alias = f"checkpoint-{self.checkpoints}" 488 489 # TODO: store training hist with model? 490 491 # put together a path 492 checkpoint_path: Path 493 if no_alias: 494 # format the model save path for a normal checkpoint 495 checkpoint_path = Path( 496 self.model_save_path.format( 497 **self._get_format_kwargs(), 498 alias=alias, 499 ) 500 ) 501 else: 502 # for a special checkpoint, use the special path 503 checkpoint_path = Path( 504 self.model_save_path_special.format( 505 **self._get_format_kwargs(), 506 alias=alias, 507 ) 508 ) 509 510 # make sure directory exists 511 checkpoint_path.parent.mkdir(parents=True, exist_ok=True) 512 513 # save the model 514 self.save_model(self.model, checkpoint_path) 515 516 # log the checkpoint as an artifact 517 self.logger.artifact( 518 checkpoint_path, 519 "model", 520 metadata=self.training_status(), 521 aliases=[alias] if alias else None, 522 ) 523 524 # increment checkpoint counter 525 self.checkpoints += 1
context manager for training a model, with logging, evals, and checkpoints
Parameters:
model : torch.nn.Module
ref to model being trained - used for saving checkpointsdataloader : torch.utils.data.DataLoader
ref to dataloader being used - used for calculating training progresslogger : TrainingLoggerBase
logger, which can be local or interface with wandb.epochs : int
number of epochs to train for (defaults to1
)evals : Iterable[tuple[TrainingInterval | str, EvalFunction]] | None
list of pairs of (interval, eval_fn) to run evals on the model. SeeTrainingInterval
for interval options. (defaults toNone
)checkpoint_interval : TrainingInterval | str
interval at which to save model checkpoints (defaults toTrainingInterval(1, "epochs")
)print_metrics_interval : TrainingInterval | str
interval at which to print metrics (defaults toTrainingInterval(0.1, "runs")
)save_model : Callable[[torch.nn.Module, Path], None]
function to save the model (defaults totorch.save
) (defaults totorch.save
)model_save_path : str
format string for saving model checkpoints. uses_get_format_kwargs
for formatting, along with analias
kwarg (defaults to"{run_path}/checkpoints/model.checkpoint-{latest_checkpoint}.pt"
)model_save_path_special : str
format string for saving special model checkpoints (final, exception, etc.). uses_get_format_kwargs
for formatting, along with analias
kwarg (defaults to"{run_path}/model.{alias}.pt"
)
Usage:
with TrainingManager(
model=model, dataloader=TRAIN_LOADER, logger=logger, epochs=500,
evals={
"1 epochs": eval_func,
"0.1 runs": lambda model: logger.get_mem_usage(),
}.items(),
checkpoint_interval="50 epochs",
) as tp:
# Training loop
model.train()
for epoch in range(epochs):
for inputs, targets in TRAIN_LOADER:
# the usual
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
# compute accuracy
accuracy = torch.sum(torch.argmax(outputs, dim=1) == targets).item() / len(targets)
# log metrics
tp.batch_update(
# pass in number of samples in your batch (or it will be inferred from the batch size)
samples=len(targets),
# any other metrics you compute every loop
**{"train/loss": loss.item(), "train/acc": accuracy},
)
# batch_update will automatically run evals and save checkpoints as needed
tp.epoch_update()
168 def __init__( 169 self, 170 model: "torch.nn.Module", 171 logger: TLogger, 172 # required if you don't wrap the loops 173 dataloader: "torch.utils.data.DataLoader|None" = None, 174 epochs_total: int | None = None, 175 save_model: Callable[["torch.nn.Module", Path], None] = torch.save, 176 # everything with intervals 177 evals: Iterable[tuple[CastableToTrainingInterval, EvalFunction]] | None = None, 178 checkpoint_interval: CastableToTrainingInterval = TrainingInterval(1, "epochs"), 179 print_metrics_interval: CastableToTrainingInterval = TrainingInterval( 180 0.1, "runs" 181 ), 182 # everything with paths 183 model_save_path: str = "{run_path}/checkpoints/model.checkpoint-{latest_checkpoint}.pt", 184 model_save_path_special: str = "{run_path}/model.{alias}.pt", 185 ): 186 # save start time 187 self.start_time: float = time.time() 188 # non path and non-interval args get copied over directly 189 self.model: "torch.nn.Module" = model 190 self.logger: TLogger = logger 191 self.save_model: Callable[["torch.nn.Module", Path], None] = save_model 192 193 self.logger.message("starting training manager initialization") 194 195 # model save paths 196 self.model_save_path: str = model_save_path 197 self.model_save_path_special: str = model_save_path_special 198 199 # temp intervals for processing later in `try_compute_counters` 200 self._evals: Iterable[tuple[TrainingInterval, EvalFunction]] 201 if evals is None: 202 self._evals = [] 203 else: 204 self._evals = [ 205 (TrainingInterval.from_any(interval), eval_fn) 206 for interval, eval_fn in evals 207 ] 208 self._checkpoint_interval: TrainingInterval = TrainingInterval.from_any( 209 checkpoint_interval 210 ) 211 self._print_metrics_interval: TrainingInterval = TrainingInterval.from_any( 212 print_metrics_interval 213 ) 214 215 self.evals: list[tuple[int, EvalFunction]] = list() 216 self.checkpoint_interval: int | None = None 217 self.print_metrics_interval: int | None = None 218 219 # counters for epochs, batches, samples, and checkpoints 220 self.epochs: int = 0 221 self.batches: int = 0 222 self.samples: int = 0 223 self.checkpoints: int = 0 224 225 # total numbers of epochs, batches, and samples 226 # pass via init kwarg or wrapped epochs loop 227 self.epochs_total: int | None = epochs_total 228 # from dataloader or dataloader in wrapped loop 229 self.batches_per_epoch: int | None = None 230 self.batch_size: int | None = None 231 self.samples_per_epoch: int | None = None 232 # computed dynamically from the above 233 self.batches_total: int | None = None 234 self.samples_total: int | None = None 235 236 # whether the init is finished 237 self.init_complete: bool = False 238 239 # if we have a dataloader, we can compute some of the above 240 if dataloader is not None: 241 self.batches_per_epoch = len(dataloader) 242 self.batch_size = dataloader.batch_size 243 self.samples_per_epoch = len(dataloader.dataset) # type: ignore[arg-type] 244 245 self.try_compute_counters()
247 def try_compute_counters(self) -> None: 248 # we depend on either the TrainingManager init or the wrapped loops 249 # getting the epochs_total and dataloader 250 # everything else is computed dynamically 251 252 if any( 253 x is None 254 for x in [ 255 self.epochs_total, 256 self.batches_per_epoch, 257 self.batch_size, 258 self.samples_per_epoch, 259 ] 260 ): 261 # if we don't have all the info we need, return early 262 return 263 264 # we can safely ignore type check here since we just checked for `None` 265 self.batches_total = self.batches_per_epoch * self.epochs_total # type: ignore[operator] 266 self.samples_total = self.samples_per_epoch * self.epochs_total # type: ignore[operator] 267 268 # check if the dataloader has a finite nonzero length 269 if self.samples_per_epoch == 0: 270 raise TrainingManagerInitError( 271 f"Dataloader has no samples. Please provide a dataloader with a non-zero length. {self.samples_per_epoch = }" 272 ) 273 274 if self.batches_per_epoch == 0: 275 raise TrainingManagerInitError( 276 f"Dataloader has no batches. Please provide a dataloader with a non-zero length. {self.batches_per_epoch = }" 277 ) 278 279 if self.batch_size == 0: 280 raise TrainingManagerInitError( 281 f"Dataloader has a batch size of 0. Please provide a dataloader with a non-zero batch size. {self.batch_size = }" 282 ) 283 284 if self.batch_size is None: 285 warnings.warn( 286 "batch size is None. This is likely because the dataloader passed to `TrainingManager` does not have a `batch_size` attribute." 287 + "\nthis should probably be an exception" 288 ) 289 290 # normalize intervals for checkpoints, metrics printing, and evals 291 _batch_info_kwargs: dict[str, int | None] = dict( 292 batches_per_epoch=self.batches_per_epoch, 293 batchsize=self.batch_size, 294 epochs=self.epochs_total, 295 ) 296 297 # TODO: no idea why we need `type: ignore[arg-type]` here 298 self.checkpoint_interval = TrainingInterval.process_to_batches( 299 interval=self._checkpoint_interval, 300 **_batch_info_kwargs, # type: ignore[arg-type] 301 ) 302 self.print_metrics_interval = TrainingInterval.process_to_batches( 303 interval=self._print_metrics_interval, 304 **_batch_info_kwargs, # type: ignore[arg-type] 305 ) 306 307 # list[tuple[int, EvalFunction]] 308 self.evals = [ 309 ( 310 TrainingInterval.process_to_batches(interval, **_batch_info_kwargs), # type: ignore[arg-type] 311 eval_fn, 312 ) 313 for interval, eval_fn in self._evals 314 ] 315 316 # log this info 317 self.init_complete = True 318 self.logger.message( 319 "initialized training manager", 320 __training_manager_init__=True, 321 epochs_total=self.epochs_total, 322 batches_per_epoch=self.batches_per_epoch, 323 batch_size=self.batch_size, 324 samples_per_epoch=self.samples_per_epoch, 325 samples_total=self.samples_total, 326 checkpoint_interval_batches=self.checkpoint_interval, 327 print_metrics_interval_batches=self.print_metrics_interval, 328 model_save_path=self.model_save_path, 329 model_save_path_special=self.model_save_path_special, 330 **self.training_status(), 331 )
388 def check_is_initialized(self): 389 if not self.init_complete: 390 raise TrainingManagerInitError( 391 "TrainingManager not correctly initialized. ", 392 "This is likely due to failing to specify the epoch count, or failing to specify batch size/count. " 393 "you must either wrap your epoch loop with `TrainingManager.epoch_loop` or specify `epochs_total`", 394 "AND you must either wrap your batch loop with `TrainingManager.batch_loop` or pass a `torch.utils.data.DataLoader` to the TrainingManager constructor.", 395 "please note, if not wrapping the epoch loop, you must also call `TrainingManager.epoch_update` at the end of each epoch.", 396 )
398 def get_elapsed_time(self) -> float: 399 """return the elapsed time in seconds since the start of training""" 400 return time.time() - self.start_time
return the elapsed time in seconds since the start of training
402 def training_status(self) -> dict[str, int | float]: 403 """status of elapsed time, samples, batches, epochs, and checkpoints""" 404 return dict( 405 # timestamp handled in logger 406 elapsed_time=self.get_elapsed_time(), 407 samples=self.samples, 408 batches=self.batches, 409 epochs=self.epochs, 410 latest_checkpoint=self.checkpoints, 411 )
status of elapsed time, samples, batches, epochs, and checkpoints
435 def batch_update(self, samples: int | None, metrics: dict | None = None, **kwargs): 436 """call this at the end of every batch. Pass `samples` or it will be inferred from the batch size, and any other metrics as kwargs 437 438 This function will: 439 - update internal counters 440 - run evals as needed (based on the intervals passed) 441 - log all metrics and training status 442 - save a checkpoint as needed (based on the checkpoint interval) 443 """ 444 # check init is finished 445 if not self.init_complete: 446 self.try_compute_counters() 447 self.check_is_initialized() 448 449 # process metrics and kwargs 450 if metrics is None: 451 metrics = dict() 452 453 metrics.update(kwargs) 454 455 # update counters 456 self.batches += 1 457 if samples is not None: 458 self.samples += samples 459 else: 460 # TODO: we warn if batch size is None, but don't except 461 self.samples += self.batch_size # type: ignore[operator] 462 463 # run evals if needed 464 for interval, eval_fn in self.evals: 465 if (self.batches % interval == 0) or (self.batches == self.batches_total): 466 metrics.update(eval_fn(self.model)) 467 468 # log metrics & training status 469 self.logger.metrics({**metrics, **self.training_status()}) 470 471 # print metrics if needed 472 473 # save checkpoint if needed 474 if self.batches % self.checkpoint_interval == 0: # type: ignore[operator] 475 self._save_checkpoint()
call this at the end of every batch. Pass samples
or it will be inferred from the batch size, and any other metrics as kwargs
This function will:
- update internal counters
- run evals as needed (based on the intervals passed)
- log all metrics and training status
- save a checkpoint as needed (based on the checkpoint interval)
477 def epoch_update(self): 478 """call this at the end of every epoch. This function will log the completion of the epoch and update the epoch counter""" 479 self.logger.debug(f"completed epoch {self.epochs + 1}/{self.epochs_total}") 480 self.epochs += 1
call this at the end of every epoch. This function will log the completion of the epoch and update the epoch counter