docs for
trnbl
v0.1.1
trnbl
–
Training
ButlerIf you train a lot of models, you might often find yourself being
annoyed at swapping between different loggers and fiddling with a bunch
of if batch_idx % some_number == 0
statements. This package
aims to fix that problem.
Firstly, a universal interface to wandb
, tensorboard
,
and a minimal local logging solution (live demo) is
provided.
Secondly, a TrainingManager
class is provided which
handles logging, artifacts, checkpointing, evaluations, exceptions, and
more, with flexibly customizable intervals.
Rather than having to specify all intervals in batches and then change everything manually when you change the batch size, dataset size, or number of epochs, you specify an interval in samples, batches, epochs, or runs. This is computed into the correct number of batches or epochs based on the current dataset and batch size.
"1/10 runs"
– 10 times a run"2.5 epochs"
– every 2 & 1/2 epochs(100, "batches")
– every 100 batches"10k samples"
– every 10,000 samplesan evaluation function is passed in a tuple with an interval, takes the model as an argument, and returns the metrics as a dictionary
checkpointing is handled automatically, specifying an interval in the same way as evaluations
models are saved at the end of the run, or if an exception is
raised, a model.exception.pt
is saved
pip install trnbl
also see the notebooks/
folder: - demo_minimal.py
for a minimal example with dummy data - demo.ipynb
for an example with all options on the iris dataset
import torch
from torch.utils.data import DataLoader
from trnbl.logging.local import LocalLogger
from trnbl.training_manager import TrainingManager
# set up your dataset, model, optimizer, etc as usual
= DataLoader(my_dataset, batch_size=32)
dataloader: DataLoader = MyModel()
model: torch.nn.Module = torch.nn.CrossEntropyLoss()
criterion = torch.optim.Adam(model.parameters(), lr=1e-3)
optimizer
# set up a logger -- swap seamlessly between wandb, tensorboard, and local logging
= LocalLogger(
logger: LocalLogger ="iris-demo",
project=["train/loss", "train/acc", "val/loss", "val/acc"],
metric_names=dict(
train_config=str(model), optimizer=str(optimizer), criterion=str(criterion)
model
),
)
with TrainingManager(
# pass your model and logger
=model,
model=logger,
logger={
evals# pass evaluation functions which take a model, and return a dict of metrics
"1k samples": my_evaluation_function,
"0.5 epochs": lambda model: logger.get_mem_usage(),
"100 batches": my_other_eval_function,
}.items(),="1/10 run", # will save a checkpoint 10 times per run
checkpoint_intervalas tr:
)
# wrap the loops, and length will be automatically calculated
# and used to figure out when to run evals, checkpoint, etc
for epoch in tr.epoch_loop(range(120)):
for inputs, targets in tr.batch_loop(TRAIN_LOADER):
# your normal training code
optimizer.zero_grad()= model(inputs)
outputs = criterion(outputs, targets)
loss
loss.backward()
optimizer.step()
# compute whatever you want every batch
= torch.sum(torch.argmax(outputs, dim=1) == targets).item() / len(targets)
accuracy
# log the metrics
tr.batch_update(=len(targets),
samples**{"train/loss": loss.item(), "train/acc": accuracy},
)
# a `model.final.pt` checkpoint will be saved at the end of the run,
# or a `model.exception.pt` if something crashes inside the context
LocalLogger
Intended as a minimal logging solution for local runs, when you’re
too lazy to set up a new wandb
project for a quick test,
and want to be able to easily read the logs. It logs everything as json
or jsonl files, and provides a simple web interface for viewing the
data. The web interface allows:
You can view a live demo of the web interface here.
frontend:
deployment:
trnbl
trnbl
–
Training
ButlerIf you train a lot of models, you might often find yourself being
annoyed at swapping between different loggers and fiddling with a bunch
of if batch_idx % some_number == 0
statements. This package
aims to fix that problem.
Firstly, a universal interface to wandb
, tensorboard
,
and a minimal local logging solution (live demo) is
provided.
Secondly, a TrainingManager
class is provided which
handles logging, artifacts, checkpointing, evaluations, exceptions, and
more, with flexibly customizable intervals.
Rather than having to specify all intervals in batches and then change everything manually when you change the batch size, dataset size, or number of epochs, you specify an interval in samples, batches, epochs, or runs. This is computed into the correct number of batches or epochs based on the current dataset and batch size.
"1/10 runs"
– 10 times a run"2.5 epochs"
– every 2 & 1/2 epochs(100, "batches")
– every 100 batches"10k samples"
– every 10,000 samplesan evaluation function is passed in a tuple with an interval, takes the model as an argument, and returns the metrics as a dictionary
checkpointing is handled automatically, specifying an interval in the same way as evaluations
models are saved at the end of the run, or if an exception is
raised, a model.exception.pt
is saved
pip install trnbl
also see the notebooks/
folder: - demo_minimal.py
for a minimal example with dummy data - demo.ipynb
for an example with all options on the iris dataset
import torch
from torch.utils.data import DataLoader
from trnbl.logging.local import LocalLogger
from <a href="trnbl/training_manager.html">trnbl.training_manager</a> import TrainingManager
### set up your dataset, model, optimizer, etc as usual
= DataLoader(my_dataset, batch_size=32)
dataloader: DataLoader = MyModel()
model: torch.nn.Module = torch.nn.CrossEntropyLoss()
criterion = torch.optim.Adam(model.parameters(), lr=1e-3)
optimizer
### set up a logger -- swap seamlessly between wandb, tensorboard, and local logging
= LocalLogger(
logger: LocalLogger ="iris-demo",
project=["train/loss", "train/acc", "val/loss", "val/acc"],
metric_names=dict(
train_config=str(model), optimizer=str(optimizer), criterion=str(criterion)
model
),
)
with TrainingManager(
# pass your model and logger
=model,
model=logger,
logger={
evals# pass evaluation functions which take a model, and return a dict of metrics
"1k samples": my_evaluation_function,
"0.5 epochs": lambda model: logger.get_mem_usage(),
"100 batches": my_other_eval_function,
}.items(),="1/10 run", # will save a checkpoint 10 times per run
checkpoint_intervalas tr:
)
# wrap the loops, and length will be automatically calculated
# and used to figure out when to run evals, checkpoint, etc
for epoch in tr.epoch_loop(range(120)):
for inputs, targets in tr.batch_loop(TRAIN_LOADER):
# your normal training code
optimizer.zero_grad()= model(inputs)
outputs = criterion(outputs, targets)
loss
loss.backward()
optimizer.step()
# compute whatever you want every batch
= torch.sum(torch.argmax(outputs, dim=1) == targets).item() / len(targets)
accuracy
# log the metrics
tr.batch_update(=len(targets),
samples**{"train/loss": loss.item(), "train/acc": accuracy},
)
# a `model.final.pt` checkpoint will be saved at the end of the run,
# or a `model.exception.pt` if something crashes inside the context
LocalLogger
Intended as a minimal logging solution for local runs, when you’re
too lazy to set up a new wandb
project for a quick test,
and want to be able to easily read the logs. It logs everything as json
or jsonl files, and provides a simple web interface for viewing the
data. The web interface allows:
You can view a live demo of the web interface here.
frontend:
deployment:
class TrainingInterval:
A training interval, which can be specified in a few different units.
quantity: int|float
- the quantity of the intervalunit: TrainingIntervalUnit
- the unit of the interval,
one of “runs”, “epochs”, “batches”, or “samples”<a href="#TrainingInterval.from_str">TrainingInterval.from_str</a>(raw: str) -> TrainingInterval
- parse a string into a TrainingInterval object<a href="#TrainingInterval.as_batch_count">TrainingInterval.as_batch_count</a>(batchsize: int, batches_per_epoch: int, epochs: int|None) -> int
- convert the interval to a raw number of batches<a href="#TrainingInterval.process_to_batches">TrainingInterval.process_to_batches</a>(interval: str|TrainingInterval, batchsize: int, batches_per_epoch: int, epochs: int|None) -> int
- any representation to a number of batches<a href="#TrainingInterval.normalized">TrainingInterval.normalized</a>(batchsize: int, batches_per_epoch: int, epochs: int|None) -> None
- current interval, with units switched to batchesProvides methods for reading from a string or tuple, and normalizing to batches.
TrainingInterval
(int | float,
quantity: 'runs', 'epochs', 'batches', 'samples']
unit: Literal[ )
quantity: int | float
unit: Literal['runs', 'epochs', 'batches', 'samples']
def as_batch_count
(self,
int,
batchsize: int,
batches_per_epoch: int | None = None
epochs: -> int )
given the batchsize, number of batches per epoch, and number of epochs, return the interval as a number of batches
batchsize: int
the size of a batchbatches_per_epoch: int
the number of batches in an
epochepochs: int|None
the number of epochs to run (only
required if the interval is in “runs”)int
the interval as a number of batchesValueError
if the interval is less than 1 batch, and
the
<a href="trnbl/training_interval.html#WhenIntervalLessThanBatch">trnbl.training_interval.WhenIntervalLessThanBatch</a>
is set to muutils.errormode.ErrorMode.ERROR
otherwise, will
warn or ignore and set the interval to 1 batchValueError
if the unit is not one of “runs”, “epochs”,
“batches”, or “samples”def normalized
(self,
int,
batchsize: int,
batches_per_epoch: int | None = None
epochs: -> trnbl.training_interval.TrainingInterval )
convert the units of the interval to batches, by calling
as_batch_count
and setting the unit
to
“batches
def from_str
str) -> trnbl.training_interval.TrainingInterval (cls, raw:
parse a string into a TrainingInterval object
TrainingInterval.from_str(“5 epochs”) TrainingInterval(5, ‘epochs’) TrainingInterval.from_str(“100 batches”) TrainingInterval(100, ‘batches’) TrainingInterval.from_str(“0.1 runs”) TrainingInterval(0.1, ‘runs’) TrainingInterval.from_str(“1/5 runs”) TrainingInterval(0.2, ‘runs’)
def from_any
*args, **kwargs) -> trnbl.training_interval.TrainingInterval (cls,
parse a string or tuple into a TrainingInterval object
def process_to_batches
(
cls,str, tuple[Union[int, float, str], str], trnbl.training_interval.TrainingInterval],
interval: Union[int,
batchsize: int,
batches_per_epoch: int | None = None
epochs: -> int )
directly from any representation to a number of batches
TrainingIntervalUnit = typing.Literal['runs', 'epochs', 'batches', 'samples']
class TrainingLoggerBase(abc.ABC):
Base class for training loggers
def debug
self, message: str, **kwargs) -> None (
log a debug message which will be saved, but not printed
def message
self, message: str, **kwargs) -> None (
log a progress message, which will be printed to stdout
def warning
self, message: str, **kwargs) -> None (
log a warning message, which will be printed to stderr
def error
self, message: str, **kwargs) -> None (
log an error message
def metrics
self, data: dict[str, typing.Any]) -> None (
Log a dictionary of metrics
def artifact
(self,
path: pathlib.Path,type: str,
list[str] | None = None,
aliases: dict | None = None
metadata: -> None )
log an artifact from a file
url: str | list[str]
Get the URL for the current logging run
run_path: pathlib.Path | list[pathlib.Path]
Get the path to the current logging run
def flush
self) -> None (
Flush the logger
def finish
self) -> None (
Finish logging
def get_mem_usage
self) -> dict (
def spinner_task
self, **kwargs) -> trnbl.loggers.base.LoggerSpinner (
Create a spinner task. kwargs are passed to Spinner
.
class TrainingManager(typing.Generic[~TLogger]):
context manager for training a model, with logging, evals, and checkpoints
model : torch.nn.Module
ref to model being trained -
used for saving checkpointsdataloader : torch.utils.data.DataLoader
ref to
dataloader being used - used for calculating training progresslogger : TrainingLoggerBase
logger, which can be local
or interface with wandb.epochs : int
number of epochs to train for (defaults to
1
)evals : Iterable[tuple[TrainingInterval | str, EvalFunction]] | None
list of pairs of (interval, eval_fn) to run evals on the model. See
TrainingInterval
for interval options. (defaults to
None
)checkpoint_interval : TrainingInterval | str
interval
at which to save model checkpoints (defaults to
TrainingInterval(1, "epochs")
)print_metrics_interval : TrainingInterval | str
interval at which to print metrics (defaults to
TrainingInterval(0.1, "runs")
)save_model : Callable[[torch.nn.Module, Path], None]
function to save the model (defaults to torch.save
)
(defaults to torch.save
)model_save_path : str
format string for saving model
checkpoints. uses _get_format_kwargs
for formatting, along
with an alias
kwarg (defaults to
"{run_path}/checkpoints/model.checkpoint-{latest_checkpoint}.pt"
)model_save_path_special : str
format string for saving
special model checkpoints (final, exception, etc.). uses
_get_format_kwargs
for formatting, along with an
alias
kwarg (defaults to
"{run_path}/model.{alias}.pt"
)with TrainingManager(
=model, dataloader=TRAIN_LOADER, logger=logger, epochs=500,
model={
evals"1 epochs": eval_func,
"0.1 runs": lambda model: logger.get_mem_usage(),
}.items(),="50 epochs",
checkpoint_intervalas tp:
)
# Training loop
model.train()for epoch in range(epochs):
for inputs, targets in TRAIN_LOADER:
# the usual
optimizer.zero_grad()= model(inputs)
outputs = criterion(outputs, targets)
loss
loss.backward()
optimizer.step()
# compute accuracy
= torch.sum(torch.argmax(outputs, dim=1) == targets).item() / len(targets)
accuracy
# log metrics
tp.batch_update(# pass in number of samples in your batch (or it will be inferred from the batch size)
=len(targets),
samples# any other metrics you compute every loop
**{"train/loss": loss.item(), "train/acc": accuracy},
)# batch_update will automatically run evals and save checkpoints as needed
tp.epoch_update()
TrainingManager
(
model: torch.nn.modules.module.Module,~TLogger,
logger: | None = None,
dataloader: torch.utils.data.dataloader.DataLoader int | None = None,
epochs_total: = <function save>,
save_model: Callable[[torch.nn.modules.module.Module, pathlib.Path], NoneType] tuple[Union[str, tuple[Union[int, float, str], str], trnbl.training_interval.TrainingInterval], Callable[[torch.nn.modules.module.Module], dict]]]] = None,
evals: Optional[Iterable[str, tuple[Union[int, float, str], str], trnbl.training_interval.TrainingInterval] = TrainingInterval(quantity=1, unit='epochs'),
checkpoint_interval: Union[str, tuple[Union[int, float, str], str], trnbl.training_interval.TrainingInterval] = TrainingInterval(quantity=0.1, unit='runs'),
print_metrics_interval: Union[str = '{run_path}/checkpoints/model.checkpoint-{latest_checkpoint}.pt',
model_save_path: str = '{run_path}/model.{alias}.pt'
model_save_path_special: )
start_time: float
model: torch.nn.modules.module.Module
logger: ~TLogger
save_model: Callable[[torch.nn.modules.module.Module, pathlib.Path], NoneType]
model_save_path: str
model_save_path_special: str
evals: list[tuple[int, typing.Callable[[torch.nn.modules.module.Module], dict]]]
checkpoint_interval: int | None
print_metrics_interval: int | None
epochs: int
batches: int
samples: int
checkpoints: int
epochs_total: int | None
batches_per_epoch: int | None
batch_size: int | None
samples_per_epoch: int | None
batches_total: int | None
samples_total: int | None
init_complete: bool
def try_compute_counters
self) -> None (
def epoch_loop
(self,
int],
epochs: Sequence[bool = True,
use_tqdm: **tqdm_kwargs
-> Generator[int, NoneType, NoneType] )
def batch_loop
(self,
int],
batches: Sequence[bool = False,
use_tqdm: **tqdm_kwargs
-> Generator[int, NoneType, NoneType] )
def check_is_initialized
self) (
def get_elapsed_time
self) -> float (
return the elapsed time in seconds since the start of training
def training_status
self) -> dict[str, int | float] (
status of elapsed time, samples, batches, epochs, and checkpoints
def batch_update
self, samples: int | None, metrics: dict | None = None, **kwargs) (
call this at the end of every batch. Pass samples
or it
will be inferred from the batch size, and any other metrics as
kwargs
This function will: - update internal counters - run evals as needed (based on the intervals passed) - log all metrics and training status - save a checkpoint as needed (based on the checkpoint interval)
def epoch_update
self) (
call this at the end of every epoch. This function will log the completion of the epoch and update the epoch counter
docs for
trnbl
v0.1.1
trnbl.loggers
docs for
trnbl
v0.1.1
GPU_UTILS_AVAILABLE
PSUTIL_AVAILABLE
VOWELS
CONSONANTS
rand_syllabic_string
LoggerSpinner
TrainingLoggerBase
trnbl.loggers.base
GPU_UTILS_AVAILABLE: bool = True
PSUTIL_AVAILABLE: bool = True
VOWELS: str = 'aeiou'
CONSONANTS: str = 'bcdfghjklmnpqrstvwxyz'
def rand_syllabic_string
int = 6) -> str (length:
Generate a random string of alternating consonants and vowels to use as a unique identifier
for a length of 2n, there are about 10^{2n} possible strings
default is 6 characters, which gives 10^6 possible strings
class LoggerSpinner(muutils.spinner.Spinner):
see Spinner
for parameters. catches
update_value
and passes it to the
LocalLogger
LoggerSpinner
*args, logger: trnbl.loggers.base.TrainingLoggerBase, **kwargs) (
logger: trnbl.loggers.base.TrainingLoggerBase
def update_value
self, value: Any) -> None (
update the value of the spinner and log it
config
format_string_when_updated
update_interval
message
current_value
format_string
output_stream
start_time
stop_spinner
spinner_thread
value_changed
term_width
state
spin
start
stop
class TrainingLoggerBase(abc.ABC):
Base class for training loggers
def debug
self, message: str, **kwargs) -> None (
log a debug message which will be saved, but not printed
def message
self, message: str, **kwargs) -> None (
log a progress message, which will be printed to stdout
def warning
self, message: str, **kwargs) -> None (
log a warning message, which will be printed to stderr
def error
self, message: str, **kwargs) -> None (
log an error message
def metrics
self, data: dict[str, typing.Any]) -> None (
Log a dictionary of metrics
def artifact
(self,
path: pathlib.Path,type: str,
list[str] | None = None,
aliases: dict | None = None
metadata: -> None )
log an artifact from a file
url: str | list[str]
Get the URL for the current logging run
run_path: pathlib.Path | list[pathlib.Path]
Get the path to the current logging run
def flush
self) -> None (
Flush the logger
def finish
self) -> None (
Finish logging
def get_mem_usage
self) -> dict (
def spinner_task
self, **kwargs) -> trnbl.loggers.base.LoggerSpinner (
Create a spinner task. kwargs are passed to Spinner
.
docs for
trnbl
v0.1.1
trnbl.loggers.local
class FilePaths:
TRAIN_CONFIG: pathlib.Path = WindowsPath('config.json')
LOGGER_META: pathlib.Path = WindowsPath('meta.json')
TRAIN_CONFIG_YML: pathlib.Path = WindowsPath('config.yml')
LOGGER_META_YML: pathlib.Path = WindowsPath('meta.yml')
ARTIFACTS: pathlib.Path = WindowsPath('artifacts.jsonl')
METRICS: pathlib.Path = WindowsPath('metrics.jsonl')
LOG: pathlib.Path = WindowsPath('log.jsonl')
ERROR_FILE: pathlib.Path = WindowsPath('ERROR.txt')
RUNS_MANIFEST: pathlib.Path = WindowsPath('runs.jsonl')
RUNS_DIR: pathlib.Path = WindowsPath('runs')
HTML_INDEX: pathlib.Path = WindowsPath('index.html')
START_SERVER: pathlib.Path = WindowsPath('start_server.py')
class LocalLogger(trnbl.loggers.base.TrainingLoggerBase):
Base class for training loggers
LocalLogger
(str,
project: list[str],
metric_names: dict,
train_config: str = '',
group: str | pathlib.Path = WindowsPath('trnbl-logs'),
base_path: bool = True,
memusage_as_metrics: str = '# '
console_msg_prefix: )
log_list: list[dict]
metrics_list: list[dict]
artifacts_list: list[dict]
train_config: dict
project: str
group: str
group_str: str
base_path: pathlib.Path
console_msg_prefix: str
run_init_timestamp: datetime.datetime
run_id: str
project_path: pathlib.Path
log_file: _io.TextIOWrapper
metrics_file: _io.TextIOWrapper
artifacts_file: _io.TextIOWrapper
metric_names: list[str]
logger_meta: dict
syllabic_id: str
def get_timestamp
self) -> str (
def debug
self, message: str, **kwargs) -> None (
log a debug message
def message
self, message: str, **kwargs) -> None (
log a progress message
def warning
self, message: str, **kwargs) -> None (
log a warning message
def error
self, message: str, **kwargs) -> None (
log an error message
def metrics
self, data: dict[str, typing.Any]) -> None (
log a dictionary of metrics
def artifact
(self,
path: pathlib.Path,type: str,
list[str] | None = None,
aliases: dict | None = None
metadata: -> None )
log an artifact from a file
url: str
Get the URL for the current logging run
run_path: pathlib.Path
Get the path to the current logging run
def flush
self) -> None (
Flush the logger
def finish
self) -> None (
Finish logging
docs for
trnbl
v0.1.1
trnbl.loggers.local.build_dist
def get_remote
(str,
path_or_url: bool = False,
download_remote: bool = False,
get_bytes: bool = True
allow_remote_fail: -> str | bytes | None )
gets a resource from a path or url
get_bytes
is
True
None
if its from the web and
download_remote
is False
path_or_url : str
location of the resource. if it
starts with http
, it is considered a urldownload_remote : bool
whether to download the resource
if it is a url (defaults to False
)get_bytes : bool
whether to return the resource as
bytes (defaults to False
)allow_remote_fail : bool
if a remote resource fails to
download, return None
. if this is False
, raise
an exception (defaults to True
)requests.HTTPError
if the remote resource returns an
error, and allow_remote_fail
is False
str|bytes|None
def build_dist
(
path: pathlib.Path,bool = True,
minify: bool = True
download_remote: -> str )
Build a single file html from a folder
partially from https://stackoverflow.com/questions/44646481/merging-js-css-html-into-single-html
def main
-> None ()
docs for
trnbl
v0.1.1
trnbl.loggers.local.html_frontend
def get_html_frontend
-> str ()
docs for
trnbl
v0.1.1
trnbl.loggers.local.locallogger
class FilePaths:
TRAIN_CONFIG: pathlib.Path = WindowsPath('config.json')
LOGGER_META: pathlib.Path = WindowsPath('meta.json')
TRAIN_CONFIG_YML: pathlib.Path = WindowsPath('config.yml')
LOGGER_META_YML: pathlib.Path = WindowsPath('meta.yml')
ARTIFACTS: pathlib.Path = WindowsPath('artifacts.jsonl')
METRICS: pathlib.Path = WindowsPath('metrics.jsonl')
LOG: pathlib.Path = WindowsPath('log.jsonl')
ERROR_FILE: pathlib.Path = WindowsPath('ERROR.txt')
RUNS_MANIFEST: pathlib.Path = WindowsPath('runs.jsonl')
RUNS_DIR: pathlib.Path = WindowsPath('runs')
HTML_INDEX: pathlib.Path = WindowsPath('index.html')
START_SERVER: pathlib.Path = WindowsPath('start_server.py')
class LocalLogger(trnbl.loggers.base.TrainingLoggerBase):
Base class for training loggers
LocalLogger
(str,
project: list[str],
metric_names: dict,
train_config: str = '',
group: str | pathlib.Path = WindowsPath('trnbl-logs'),
base_path: bool = True,
memusage_as_metrics: str = '# '
console_msg_prefix: )
log_list: list[dict]
metrics_list: list[dict]
artifacts_list: list[dict]
train_config: dict
project: str
group: str
group_str: str
base_path: pathlib.Path
console_msg_prefix: str
run_init_timestamp: datetime.datetime
run_id: str
project_path: pathlib.Path
log_file: _io.TextIOWrapper
metrics_file: _io.TextIOWrapper
artifacts_file: _io.TextIOWrapper
metric_names: list[str]
logger_meta: dict
syllabic_id: str
def get_timestamp
self) -> str (
def debug
self, message: str, **kwargs) -> None (
log a debug message
def message
self, message: str, **kwargs) -> None (
log a progress message
def warning
self, message: str, **kwargs) -> None (
log a warning message
def error
self, message: str, **kwargs) -> None (
log an error message
def metrics
self, data: dict[str, typing.Any]) -> None (
log a dictionary of metrics
def artifact
(self,
path: pathlib.Path,type: str,
list[str] | None = None,
aliases: dict | None = None
metadata: -> None )
log an artifact from a file
url: str
Get the URL for the current logging run
run_path: pathlib.Path
Get the path to the current logging run
def flush
self) -> None (
Flush the logger
def finish
self) -> None (
Finish logging
docs for
trnbl
v0.1.1
Usage: python start_server.py path/to/directory [port]
trnbl.loggers.local.start_server
Usage: python start_server.py path/to/directory [port]
def start_server
str, port: int = 8000) -> None (path:
Starts a server to serve the files in the given path.
docs for
trnbl
v0.1.1
trnbl.loggers.multi
def maybe_flatten
list[typing.Union[~T, list[~T]]]) -> list[~T] (lst:
flatten a list if it is nested
class MultiLogger(trnbl.loggers.base.TrainingLoggerBase):
use multiple loggers at once
MultiLogger
list[trnbl.loggers.base.TrainingLoggerBase]) (loggers:
loggers: list[trnbl.loggers.base.TrainingLoggerBase]
def debug
self, message: str, **kwargs) -> None (
log a debug message which will be saved, but not printed
def message
self, message: str, **kwargs) -> None (
log a progress message
def metrics
self, data: dict[str, typing.Any]) -> None (
Log a dictionary of metrics
def artifact
(self,
path: pathlib.Path,type: str,
list[str] | None = None,
aliases: dict | None = None
metadata: -> None )
log an artifact from a file
url: list[str]
Get the URL for the current logging run
run_path: list[pathlib.Path]
Get the paths to the current logging run
def flush
self) -> None (
Flush the logger
def finish
self) -> None (
Finish logging
docs for
trnbl
v0.1.1
trnbl.loggers.tensorboard
class TensorBoardLogger(trnbl.loggers.base.TrainingLoggerBase):
Base class for training loggers
TensorBoardLogger
(str | pathlib.Path,
log_dir: dict | None = None,
train_config: str | None = None,
name: **kwargs
)
def debug
self, message: str, **kwargs) -> None (
log a debug message which will be saved, but not printed
def message
self, message: str, **kwargs) -> None (
log a progress message, which will be printed to stdout
def metrics
self, data: dict[str, typing.Any]) -> None (
Log a dictionary of metrics
def artifact
(self,
path: pathlib.Path,type: str,
list[str] | None = None,
aliases: dict | None = None
metadata: -> None )
log an artifact from a file
url: str
Get the URL for the current logging run
run_path: pathlib.Path
Get the path to the current logging run
def flush
self) -> None (
Flush the logger
def finish
self) -> None (
Finish logging
docs for
trnbl
v0.1.1
trnbl.loggers.wandb
class WandbLogger(trnbl.loggers.base.TrainingLoggerBase):
wrapper around wandb logging for TrainingLoggerBase
.
create using
<a href="#WandbLogger.create">WandbLogger.create</a>(config, project, job_type)
WandbLogger
(run: wandb.sdk.wandb_run.Run)
def create
(
cls,dict,
config: str | None = None,
project: str | None = None,
job_type: str = '%(asctime)s %(levelname)s %(message)s',
logging_fmt: str = '%Y-%m-%d %H:%M:%S',
logging_datefmt: dict | None = None
wandb_kwargs: -> trnbl.loggers.wandb.WandbLogger )
def debug
self, message: str, **kwargs) -> None (
log a debug message which will be saved, but not printed
def message
self, message: str, **kwargs) -> None (
log a progress message, which will be printed to stdout
def metrics
self, data: dict[str, typing.Any]) -> None (
Log a dictionary of metrics
def artifact
(self,
path: pathlib.Path,type: str,
list[str] | None = None,
aliases: dict | None = None
metadata: -> None )
log an artifact from a file
url: str
Get the URL for the current logging run
run_path: pathlib.Path
Get the path to the current logging run
def flush
self) -> None (
Flush the logger
def finish
self) -> None (
Finish logging
docs for
trnbl
v0.1.1
TrainingIntervalUnit
WhenIntervalLessThanBatch
IntervalValueError
TrainingInterval
CastableToTrainingInterval
trnbl.training_interval
TrainingIntervalUnit = typing.Literal['runs', 'epochs', 'batches', 'samples']
WhenIntervalLessThanBatch: muutils.errormode.ErrorMode = ErrorMode.Warn
class IntervalValueError(builtins.UserWarning):
Error for when the interval is less than 1 batch
class TrainingInterval:
A training interval, which can be specified in a few different units.
quantity: int|float
- the quantity of the intervalunit: TrainingIntervalUnit
- the unit of the interval,
one of “runs”, “epochs”, “batches”, or “samples”<a href="#TrainingInterval.from_str">TrainingInterval.from_str</a>(raw: str) -> TrainingInterval
- parse a string into a TrainingInterval object<a href="#TrainingInterval.as_batch_count">TrainingInterval.as_batch_count</a>(batchsize: int, batches_per_epoch: int, epochs: int|None) -> int
- convert the interval to a raw number of batches<a href="#TrainingInterval.process_to_batches">TrainingInterval.process_to_batches</a>(interval: str|TrainingInterval, batchsize: int, batches_per_epoch: int, epochs: int|None) -> int
- any representation to a number of batches<a href="#TrainingInterval.normalized">TrainingInterval.normalized</a>(batchsize: int, batches_per_epoch: int, epochs: int|None) -> None
- current interval, with units switched to batchesProvides methods for reading from a string or tuple, and normalizing to batches.
TrainingInterval
(int | float,
quantity: 'runs', 'epochs', 'batches', 'samples']
unit: Literal[ )
quantity: int | float
unit: Literal['runs', 'epochs', 'batches', 'samples']
def as_batch_count
(self,
int,
batchsize: int,
batches_per_epoch: int | None = None
epochs: -> int )
given the batchsize, number of batches per epoch, and number of epochs, return the interval as a number of batches
batchsize: int
the size of a batchbatches_per_epoch: int
the number of batches in an
epochepochs: int|None
the number of epochs to run (only
required if the interval is in “runs”)int
the interval as a number of batchesValueError
if the interval is less than 1 batch, and
the
<a href="#WhenIntervalLessThanBatch">WhenIntervalLessThanBatch</a>
is set to muutils.errormode.ErrorMode.ERROR
otherwise, will
warn or ignore and set the interval to 1 batchValueError
if the unit is not one of “runs”, “epochs”,
“batches”, or “samples”def normalized
(self,
int,
batchsize: int,
batches_per_epoch: int | None = None
epochs: -> trnbl.training_interval.TrainingInterval )
convert the units of the interval to batches, by calling
as_batch_count
and setting the unit
to
“batches
def from_str
str) -> trnbl.training_interval.TrainingInterval (cls, raw:
parse a string into a TrainingInterval object
TrainingInterval.from_str(“5 epochs”) TrainingInterval(5, ‘epochs’) TrainingInterval.from_str(“100 batches”) TrainingInterval(100, ‘batches’) TrainingInterval.from_str(“0.1 runs”) TrainingInterval(0.1, ‘runs’) TrainingInterval.from_str(“1/5 runs”) TrainingInterval(0.2, ‘runs’)
def from_any
*args, **kwargs) -> trnbl.training_interval.TrainingInterval (cls,
parse a string or tuple into a TrainingInterval object
def process_to_batches
(
cls,str, tuple[Union[int, float, str], str], trnbl.training_interval.TrainingInterval],
interval: Union[int,
batchsize: int,
batches_per_epoch: int | None = None
epochs: -> int )
directly from any representation to a number of batches
CastableToTrainingInterval = typing.Union[str, tuple[typing.Union[int, float, str], str], trnbl.training_interval.TrainingInterval]
docs for
trnbl
v0.1.1
trnbl.training_manager
EvalFunction = typing.Callable[[ForwardRef('torch.nn.Module')], dict]
class TrainingManagerInitError(builtins.Exception):
Common base class for all non-exit exceptions.
def wrapped_iterable
(~T],
sequence: Sequence[
manager: trnbl.training_manager.TrainingManager,bool = False,
is_epoch: bool | None = None,
use_tqdm: dict[str, typing.Any] | None = None
tqdm_kwargs: -> Generator[~T, NoneType, NoneType] )
class TrainingManager(typing.Generic[~TLogger]):
context manager for training a model, with logging, evals, and checkpoints
model : torch.nn.Module
ref to model being trained -
used for saving checkpointsdataloader : torch.utils.data.DataLoader
ref to
dataloader being used - used for calculating training progresslogger : TrainingLoggerBase
logger, which can be local
or interface with wandb.epochs : int
number of epochs to train for (defaults to
1
)evals : Iterable[tuple[TrainingInterval | str, EvalFunction]] | None
list of pairs of (interval, eval_fn) to run evals on the model. See
TrainingInterval
for interval options. (defaults to
None
)checkpoint_interval : TrainingInterval | str
interval
at which to save model checkpoints (defaults to
TrainingInterval(1, "epochs")
)print_metrics_interval : TrainingInterval | str
interval at which to print metrics (defaults to
TrainingInterval(0.1, "runs")
)save_model : Callable[[torch.nn.Module, Path], None]
function to save the model (defaults to torch.save
)
(defaults to torch.save
)model_save_path : str
format string for saving model
checkpoints. uses _get_format_kwargs
for formatting, along
with an alias
kwarg (defaults to
"{run_path}/checkpoints/model.checkpoint-{latest_checkpoint}.pt"
)model_save_path_special : str
format string for saving
special model checkpoints (final, exception, etc.). uses
_get_format_kwargs
for formatting, along with an
alias
kwarg (defaults to
"{run_path}/model.{alias}.pt"
)with TrainingManager(
=model, dataloader=TRAIN_LOADER, logger=logger, epochs=500,
model={
evals"1 epochs": eval_func,
"0.1 runs": lambda model: logger.get_mem_usage(),
}.items(),="50 epochs",
checkpoint_intervalas tp:
)
# Training loop
model.train()for epoch in range(epochs):
for inputs, targets in TRAIN_LOADER:
# the usual
optimizer.zero_grad()= model(inputs)
outputs = criterion(outputs, targets)
loss
loss.backward()
optimizer.step()
# compute accuracy
= torch.sum(torch.argmax(outputs, dim=1) == targets).item() / len(targets)
accuracy
# log metrics
tp.batch_update(# pass in number of samples in your batch (or it will be inferred from the batch size)
=len(targets),
samples# any other metrics you compute every loop
**{"train/loss": loss.item(), "train/acc": accuracy},
)# batch_update will automatically run evals and save checkpoints as needed
tp.epoch_update()
TrainingManager
(
model: torch.nn.modules.module.Module,~TLogger,
logger: | None = None,
dataloader: torch.utils.data.dataloader.DataLoader int | None = None,
epochs_total: = <function save>,
save_model: Callable[[torch.nn.modules.module.Module, pathlib.Path], NoneType] tuple[Union[str, tuple[Union[int, float, str], str], trnbl.training_interval.TrainingInterval], Callable[[torch.nn.modules.module.Module], dict]]]] = None,
evals: Optional[Iterable[str, tuple[Union[int, float, str], str], trnbl.training_interval.TrainingInterval] = TrainingInterval(quantity=1, unit='epochs'),
checkpoint_interval: Union[str, tuple[Union[int, float, str], str], trnbl.training_interval.TrainingInterval] = TrainingInterval(quantity=0.1, unit='runs'),
print_metrics_interval: Union[str = '{run_path}/checkpoints/model.checkpoint-{latest_checkpoint}.pt',
model_save_path: str = '{run_path}/model.{alias}.pt'
model_save_path_special: )
start_time: float
model: torch.nn.modules.module.Module
logger: ~TLogger
save_model: Callable[[torch.nn.modules.module.Module, pathlib.Path], NoneType]
model_save_path: str
model_save_path_special: str
evals: list[tuple[int, typing.Callable[[torch.nn.modules.module.Module], dict]]]
checkpoint_interval: int | None
print_metrics_interval: int | None
epochs: int
batches: int
samples: int
checkpoints: int
epochs_total: int | None
batches_per_epoch: int | None
batch_size: int | None
samples_per_epoch: int | None
batches_total: int | None
samples_total: int | None
init_complete: bool
def try_compute_counters
self) -> None (
def epoch_loop
(self,
int],
epochs: Sequence[bool = True,
use_tqdm: **tqdm_kwargs
-> Generator[int, NoneType, NoneType] )
def batch_loop
(self,
int],
batches: Sequence[bool = False,
use_tqdm: **tqdm_kwargs
-> Generator[int, NoneType, NoneType] )
def check_is_initialized
self) (
def get_elapsed_time
self) -> float (
return the elapsed time in seconds since the start of training
def training_status
self) -> dict[str, int | float] (
status of elapsed time, samples, batches, epochs, and checkpoints
def batch_update
self, samples: int | None, metrics: dict | None = None, **kwargs) (
call this at the end of every batch. Pass samples
or it
will be inferred from the batch size, and any other metrics as
kwargs
This function will: - update internal counters - run evals as needed (based on the intervals passed) - log all metrics and training status - save a checkpoint as needed (based on the checkpoint interval)
def epoch_update
self) (
call this at the end of every epoch. This function will log the completion of the epoch and update the epoch counter