docs for trnbl v0.1.1

Contents

PyPI docs Checks Coverage

PyPI - Downloads GitHub commit activity GitHub closed pull requests code size, bytes

trnblTraining Butler

If you train a lot of models, you might often find yourself being annoyed at swapping between different loggers and fiddling with a bunch of if batch_idx % some_number == 0 statements. This package aims to fix that problem.

Firstly, a universal interface to wandb, tensorboard, and a minimal local logging solution (live demo) is provided.

Secondly, a TrainingManager class is provided which handles logging, artifacts, checkpointing, evaluations, exceptions, and more, with flexibly customizable intervals.

Installation

pip install trnbl

Usage

also see the notebooks/ folder: - demo_minimal.py for a minimal example with dummy data - demo.ipynb for an example with all options on the iris dataset

import torch
from torch.utils.data import DataLoader
from trnbl.logging.local import LocalLogger
from trnbl.training_manager import TrainingManager

# set up your dataset, model, optimizer, etc as usual
dataloader: DataLoader = DataLoader(my_dataset, batch_size=32)
model: torch.nn.Module = MyModel()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# set up a logger -- swap seamlessly between wandb, tensorboard, and local logging
logger: LocalLogger = LocalLogger(
    project="iris-demo",
    metric_names=["train/loss", "train/acc", "val/loss", "val/acc"],
    train_config=dict(
        model=str(model), optimizer=str(optimizer), criterion=str(criterion)
    ),
)

with TrainingManager(
    # pass your model and logger
    model=model,
    logger=logger,
    evals={
        # pass evaluation functions which take a model, and return a dict of metrics
        "1k samples": my_evaluation_function,
        "0.5 epochs": lambda model: logger.get_mem_usage(),
        "100 batches": my_other_eval_function,
    }.items(),
    checkpoint_interval="1/10 run", # will save a checkpoint 10 times per run
) as tr:

    # wrap the loops, and length will be automatically calculated
    # and used to figure out when to run evals, checkpoint, etc
    for epoch in tr.epoch_loop(range(120)):
        for inputs, targets in tr.batch_loop(TRAIN_LOADER):
            # your normal training code
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            # compute whatever you want every batch
            accuracy = torch.sum(torch.argmax(outputs, dim=1) == targets).item() / len(targets)
            
            # log the metrics
            tr.batch_update(
                samples=len(targets),
                **{"train/loss": loss.item(), "train/acc": accuracy},
            )

    # a `model.final.pt` checkpoint will be saved at the end of the run,
    # or a `model.exception.pt` if something crashes inside the context

LocalLogger

Intended as a minimal logging solution for local runs, when you’re too lazy to set up a new wandb project for a quick test, and want to be able to easily read the logs. It logs everything as json or jsonl files, and provides a simple web interface for viewing the data. The web interface allows:

You can view a live demo of the web interface here.

TODOs:

Submodules

API Documentation

View Source on GitHub

trnbl

PyPI docs Checks Coverage

PyPI - Downloads GitHub commit activity GitHub closed pull requests code size, bytes

trnblTraining Butler

If you train a lot of models, you might often find yourself being annoyed at swapping between different loggers and fiddling with a bunch of if batch_idx % some_number == 0 statements. This package aims to fix that problem.

Firstly, a universal interface to wandb, tensorboard, and a minimal local logging solution (live demo) is provided.

Secondly, a TrainingManager class is provided which handles logging, artifacts, checkpointing, evaluations, exceptions, and more, with flexibly customizable intervals.

Installation

pip install trnbl

Usage

also see the notebooks/ folder: - demo_minimal.py for a minimal example with dummy data - demo.ipynb for an example with all options on the iris dataset

import torch
from torch.utils.data import DataLoader
from trnbl.logging.local import LocalLogger
from <a href="trnbl/training_manager.html">trnbl.training_manager</a> import TrainingManager

### set up your dataset, model, optimizer, etc as usual
dataloader: DataLoader = DataLoader(my_dataset, batch_size=32)
model: torch.nn.Module = MyModel()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

### set up a logger -- swap seamlessly between wandb, tensorboard, and local logging
logger: LocalLogger = LocalLogger(
    project="iris-demo",
    metric_names=["train/loss", "train/acc", "val/loss", "val/acc"],
    train_config=dict(
        model=str(model), optimizer=str(optimizer), criterion=str(criterion)
    ),
)

with TrainingManager(
    # pass your model and logger
    model=model,
    logger=logger,
    evals={
        # pass evaluation functions which take a model, and return a dict of metrics
        "1k samples": my_evaluation_function,
        "0.5 epochs": lambda model: logger.get_mem_usage(),
        "100 batches": my_other_eval_function,
    }.items(),
    checkpoint_interval="1/10 run", # will save a checkpoint 10 times per run
) as tr:

    # wrap the loops, and length will be automatically calculated
    # and used to figure out when to run evals, checkpoint, etc
    for epoch in tr.epoch_loop(range(120)):
        for inputs, targets in tr.batch_loop(TRAIN_LOADER):
            # your normal training code
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            # compute whatever you want every batch
            accuracy = torch.sum(torch.argmax(outputs, dim=1) == targets).item() / len(targets)
            
            # log the metrics
            tr.batch_update(
                samples=len(targets),
                **{"train/loss": loss.item(), "train/acc": accuracy},
            )

    # a `model.final.pt` checkpoint will be saved at the end of the run,
    # or a `model.exception.pt` if something crashes inside the context

LocalLogger

Intended as a minimal logging solution for local runs, when you’re too lazy to set up a new wandb project for a quick test, and want to be able to easily read the logs. It logs everything as json or jsonl files, and provides a simple web interface for viewing the data. The web interface allows:

You can view a live demo of the web interface here.

TODOs:

View Source on GitHub

class TrainingInterval:

View Source on GitHub

A training interval, which can be specified in a few different units.

Attributes:

Methods:

Provides methods for reading from a string or tuple, and normalizing to batches.

TrainingInterval

(
    quantity: int | float,
    unit: Literal['runs', 'epochs', 'batches', 'samples']
)

def as_batch_count

(
    self,
    batchsize: int,
    batches_per_epoch: int,
    epochs: int | None = None
) -> int

View Source on GitHub

given the batchsize, number of batches per epoch, and number of epochs, return the interval as a number of batches

Parameters:

Returns:

Raises:

def normalized

(
    self,
    batchsize: int,
    batches_per_epoch: int,
    epochs: int | None = None
) -> trnbl.training_interval.TrainingInterval

View Source on GitHub

convert the units of the interval to batches, by calling as_batch_count and setting the unit to “batches

def from_str

(cls, raw: str) -> trnbl.training_interval.TrainingInterval

View Source on GitHub

parse a string into a TrainingInterval object

Examples:

TrainingInterval.from_str(“5 epochs”) TrainingInterval(5, ‘epochs’) TrainingInterval.from_str(“100 batches”) TrainingInterval(100, ‘batches’) TrainingInterval.from_str(“0.1 runs”) TrainingInterval(0.1, ‘runs’) TrainingInterval.from_str(“1/5 runs”) TrainingInterval(0.2, ‘runs’)

def from_any

(cls, *args, **kwargs) -> trnbl.training_interval.TrainingInterval

View Source on GitHub

parse a string or tuple into a TrainingInterval object

def process_to_batches

(
    cls,
    interval: Union[str, tuple[Union[int, float, str], str], trnbl.training_interval.TrainingInterval],
    batchsize: int,
    batches_per_epoch: int,
    epochs: int | None = None
) -> int

View Source on GitHub

directly from any representation to a number of batches

class TrainingLoggerBase(abc.ABC):

View Source on GitHub

Base class for training loggers

def debug

(self, message: str, **kwargs) -> None

View Source on GitHub

log a debug message which will be saved, but not printed

def message

(self, message: str, **kwargs) -> None

View Source on GitHub

log a progress message, which will be printed to stdout

def warning

(self, message: str, **kwargs) -> None

View Source on GitHub

log a warning message, which will be printed to stderr

def error

(self, message: str, **kwargs) -> None

View Source on GitHub

log an error message

def metrics

(self, data: dict[str, typing.Any]) -> None

View Source on GitHub

Log a dictionary of metrics

def artifact

(
    self,
    path: pathlib.Path,
    type: str,
    aliases: list[str] | None = None,
    metadata: dict | None = None
) -> None

View Source on GitHub

log an artifact from a file

View Source on GitHub

Get the URL for the current logging run

View Source on GitHub

Get the path to the current logging run

def flush

(self) -> None

View Source on GitHub

Flush the logger

def finish

(self) -> None

View Source on GitHub

Finish logging

def get_mem_usage

(self) -> dict

View Source on GitHub

def spinner_task

(self, **kwargs) -> trnbl.loggers.base.LoggerSpinner

View Source on GitHub

Create a spinner task. kwargs are passed to Spinner.

class TrainingManager(typing.Generic[~TLogger]):

View Source on GitHub

context manager for training a model, with logging, evals, and checkpoints

Parameters:

Usage:

with TrainingManager(
    model=model, dataloader=TRAIN_LOADER, logger=logger, epochs=500,
    evals={
        "1 epochs": eval_func,
        "0.1 runs": lambda model: logger.get_mem_usage(),
    }.items(),
    checkpoint_interval="50 epochs",
) as tp:

    # Training loop
    model.train()
    for epoch in range(epochs):
        for inputs, targets in TRAIN_LOADER:
            # the usual
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            # compute accuracy
            accuracy = torch.sum(torch.argmax(outputs, dim=1) == targets).item() / len(targets)

            # log metrics
            tp.batch_update(
                # pass in number of samples in your batch (or it will be inferred from the batch size)
                samples=len(targets),
                # any other metrics you compute every loop
                **{"train/loss": loss.item(), "train/acc": accuracy},
            )
            # batch_update will automatically run evals and save checkpoints as needed

        tp.epoch_update()

TrainingManager

(
    model: torch.nn.modules.module.Module,
    logger: ~TLogger,
    dataloader: torch.utils.data.dataloader.DataLoader | None = None,
    epochs_total: int | None = None,
    save_model: Callable[[torch.nn.modules.module.Module, pathlib.Path], NoneType] = <function save>,
    evals: Optional[Iterable[tuple[Union[str, tuple[Union[int, float, str], str], trnbl.training_interval.TrainingInterval], Callable[[torch.nn.modules.module.Module], dict]]]] = None,
    checkpoint_interval: Union[str, tuple[Union[int, float, str], str], trnbl.training_interval.TrainingInterval] = TrainingInterval(quantity=1, unit='epochs'),
    print_metrics_interval: Union[str, tuple[Union[int, float, str], str], trnbl.training_interval.TrainingInterval] = TrainingInterval(quantity=0.1, unit='runs'),
    model_save_path: str = '{run_path}/checkpoints/model.checkpoint-{latest_checkpoint}.pt',
    model_save_path_special: str = '{run_path}/model.{alias}.pt'
)

View Source on GitHub

def try_compute_counters

(self) -> None

View Source on GitHub

def epoch_loop

(
    self,
    epochs: Sequence[int],
    use_tqdm: bool = True,
    **tqdm_kwargs
) -> Generator[int, NoneType, NoneType]

View Source on GitHub

def batch_loop

(
    self,
    batches: Sequence[int],
    use_tqdm: bool = False,
    **tqdm_kwargs
) -> Generator[int, NoneType, NoneType]

View Source on GitHub

def check_is_initialized

(self)

View Source on GitHub

def get_elapsed_time

(self) -> float

View Source on GitHub

return the elapsed time in seconds since the start of training

def training_status

(self) -> dict[str, int | float]

View Source on GitHub

status of elapsed time, samples, batches, epochs, and checkpoints

def batch_update

(self, samples: int | None, metrics: dict | None = None, **kwargs)

View Source on GitHub

call this at the end of every batch. Pass samples or it will be inferred from the batch size, and any other metrics as kwargs

This function will: - update internal counters - run evals as needed (based on the intervals passed) - log all metrics and training status - save a checkpoint as needed (based on the checkpoint interval)

def epoch_update

(self)

View Source on GitHub

call this at the end of every epoch. This function will log the completion of the epoch and update the epoch counter

docs for trnbl v0.1.1

Submodules

View Source on GitHub

trnbl.loggers

View Source on GitHub

docs for trnbl v0.1.1

API Documentation

View Source on GitHub

trnbl.loggers.base

View Source on GitHub

def rand_syllabic_string

(length: int = 6) -> str

View Source on GitHub

Generate a random string of alternating consonants and vowels to use as a unique identifier

for a length of 2n, there are about 10^{2n} possible strings

default is 6 characters, which gives 10^6 possible strings

class LoggerSpinner(muutils.spinner.Spinner):

View Source on GitHub

see Spinner for parameters. catches update_value and passes it to the LocalLogger

LoggerSpinner

(*args, logger: trnbl.loggers.base.TrainingLoggerBase, **kwargs)

View Source on GitHub

def update_value

(self, value: Any) -> None

View Source on GitHub

update the value of the spinner and log it

Inherited Members

class TrainingLoggerBase(abc.ABC):

View Source on GitHub

Base class for training loggers

def debug

(self, message: str, **kwargs) -> None

View Source on GitHub

log a debug message which will be saved, but not printed

def message

(self, message: str, **kwargs) -> None

View Source on GitHub

log a progress message, which will be printed to stdout

def warning

(self, message: str, **kwargs) -> None

View Source on GitHub

log a warning message, which will be printed to stderr

def error

(self, message: str, **kwargs) -> None

View Source on GitHub

log an error message

def metrics

(self, data: dict[str, typing.Any]) -> None

View Source on GitHub

Log a dictionary of metrics

def artifact

(
    self,
    path: pathlib.Path,
    type: str,
    aliases: list[str] | None = None,
    metadata: dict | None = None
) -> None

View Source on GitHub

log an artifact from a file

View Source on GitHub

Get the URL for the current logging run

View Source on GitHub

Get the path to the current logging run

def flush

(self) -> None

View Source on GitHub

Flush the logger

def finish

(self) -> None

View Source on GitHub

Finish logging

def get_mem_usage

(self) -> dict

View Source on GitHub

def spinner_task

(self, **kwargs) -> trnbl.loggers.base.LoggerSpinner

View Source on GitHub

Create a spinner task. kwargs are passed to Spinner.

docs for trnbl v0.1.1

Submodules

API Documentation

View Source on GitHub

trnbl.loggers.local

View Source on GitHub

class FilePaths:

View Source on GitHub

class LocalLogger(trnbl.loggers.base.TrainingLoggerBase):

View Source on GitHub

Base class for training loggers

LocalLogger

(
    project: str,
    metric_names: list[str],
    train_config: dict,
    group: str = '',
    base_path: str | pathlib.Path = WindowsPath('trnbl-logs'),
    memusage_as_metrics: bool = True,
    console_msg_prefix: str = '# '
)

View Source on GitHub

View Source on GitHub

def get_timestamp

(self) -> str

View Source on GitHub

def debug

(self, message: str, **kwargs) -> None

View Source on GitHub

log a debug message

def message

(self, message: str, **kwargs) -> None

View Source on GitHub

log a progress message

def warning

(self, message: str, **kwargs) -> None

View Source on GitHub

log a warning message

def error

(self, message: str, **kwargs) -> None

View Source on GitHub

log an error message

def metrics

(self, data: dict[str, typing.Any]) -> None

View Source on GitHub

log a dictionary of metrics

def artifact

(
    self,
    path: pathlib.Path,
    type: str,
    aliases: list[str] | None = None,
    metadata: dict | None = None
) -> None

View Source on GitHub

log an artifact from a file

View Source on GitHub

Get the URL for the current logging run

View Source on GitHub

Get the path to the current logging run

def flush

(self) -> None

View Source on GitHub

Flush the logger

def finish

(self) -> None

View Source on GitHub

Finish logging

Inherited Members

docs for trnbl v0.1.1

API Documentation

View Source on GitHub

trnbl.loggers.local.build_dist

View Source on GitHub

def get_remote

(
    path_or_url: str,
    download_remote: bool = False,
    get_bytes: bool = False,
    allow_remote_fail: bool = True
) -> str | bytes | None

View Source on GitHub

gets a resource from a path or url

Parameters:

Raises:

Returns:

def build_dist

(
    path: pathlib.Path,
    minify: bool = True,
    download_remote: bool = True
) -> str

View Source on GitHub

Build a single file html from a folder

partially from https://stackoverflow.com/questions/44646481/merging-js-css-html-into-single-html

def main

() -> None

View Source on GitHub

docs for trnbl v0.1.1

API Documentation

View Source on GitHub

trnbl.loggers.local.html_frontend

View Source on GitHub

def get_html_frontend

() -> str

View Source on GitHub

docs for trnbl v0.1.1

API Documentation

View Source on GitHub

trnbl.loggers.local.locallogger

View Source on GitHub

class FilePaths:

View Source on GitHub

class LocalLogger(trnbl.loggers.base.TrainingLoggerBase):

View Source on GitHub

Base class for training loggers

LocalLogger

(
    project: str,
    metric_names: list[str],
    train_config: dict,
    group: str = '',
    base_path: str | pathlib.Path = WindowsPath('trnbl-logs'),
    memusage_as_metrics: bool = True,
    console_msg_prefix: str = '# '
)

View Source on GitHub

View Source on GitHub

def get_timestamp

(self) -> str

View Source on GitHub

def debug

(self, message: str, **kwargs) -> None

View Source on GitHub

log a debug message

def message

(self, message: str, **kwargs) -> None

View Source on GitHub

log a progress message

def warning

(self, message: str, **kwargs) -> None

View Source on GitHub

log a warning message

def error

(self, message: str, **kwargs) -> None

View Source on GitHub

log an error message

def metrics

(self, data: dict[str, typing.Any]) -> None

View Source on GitHub

log a dictionary of metrics

def artifact

(
    self,
    path: pathlib.Path,
    type: str,
    aliases: list[str] | None = None,
    metadata: dict | None = None
) -> None

View Source on GitHub

log an artifact from a file

View Source on GitHub

Get the URL for the current logging run

View Source on GitHub

Get the path to the current logging run

def flush

(self) -> None

View Source on GitHub

Flush the logger

def finish

(self) -> None

View Source on GitHub

Finish logging

Inherited Members

docs for trnbl v0.1.1

Contents

Usage: python start_server.py path/to/directory [port]

API Documentation

View Source on GitHub

trnbl.loggers.local.start_server

Usage: python start_server.py path/to/directory [port]

View Source on GitHub

def start_server

(path: str, port: int = 8000) -> None

View Source on GitHub

Starts a server to serve the files in the given path.

docs for trnbl v0.1.1

API Documentation

View Source on GitHub

trnbl.loggers.multi

View Source on GitHub

def maybe_flatten

(lst: list[typing.Union[~T, list[~T]]]) -> list[~T]

View Source on GitHub

flatten a list if it is nested

class MultiLogger(trnbl.loggers.base.TrainingLoggerBase):

View Source on GitHub

use multiple loggers at once

MultiLogger

(loggers: list[trnbl.loggers.base.TrainingLoggerBase])

View Source on GitHub

def debug

(self, message: str, **kwargs) -> None

View Source on GitHub

log a debug message which will be saved, but not printed

def message

(self, message: str, **kwargs) -> None

View Source on GitHub

log a progress message

def metrics

(self, data: dict[str, typing.Any]) -> None

View Source on GitHub

Log a dictionary of metrics

def artifact

(
    self,
    path: pathlib.Path,
    type: str,
    aliases: list[str] | None = None,
    metadata: dict | None = None
) -> None

View Source on GitHub

log an artifact from a file

View Source on GitHub

Get the URL for the current logging run

View Source on GitHub

Get the paths to the current logging run

def flush

(self) -> None

View Source on GitHub

Flush the logger

def finish

(self) -> None

View Source on GitHub

Finish logging

Inherited Members

docs for trnbl v0.1.1

API Documentation

View Source on GitHub

trnbl.loggers.tensorboard

View Source on GitHub

class TensorBoardLogger(trnbl.loggers.base.TrainingLoggerBase):

View Source on GitHub

Base class for training loggers

TensorBoardLogger

(
    log_dir: str | pathlib.Path,
    train_config: dict | None = None,
    name: str | None = None,
    **kwargs
)

View Source on GitHub

def debug

(self, message: str, **kwargs) -> None

View Source on GitHub

log a debug message which will be saved, but not printed

def message

(self, message: str, **kwargs) -> None

View Source on GitHub

log a progress message, which will be printed to stdout

def metrics

(self, data: dict[str, typing.Any]) -> None

View Source on GitHub

Log a dictionary of metrics

def artifact

(
    self,
    path: pathlib.Path,
    type: str,
    aliases: list[str] | None = None,
    metadata: dict | None = None
) -> None

View Source on GitHub

log an artifact from a file

View Source on GitHub

Get the URL for the current logging run

View Source on GitHub

Get the path to the current logging run

def flush

(self) -> None

View Source on GitHub

Flush the logger

def finish

(self) -> None

View Source on GitHub

Finish logging

Inherited Members

docs for trnbl v0.1.1

API Documentation

View Source on GitHub

trnbl.loggers.wandb

View Source on GitHub

class WandbLogger(trnbl.loggers.base.TrainingLoggerBase):

View Source on GitHub

wrapper around wandb logging for TrainingLoggerBase. create using <a href="#WandbLogger.create">WandbLogger.create</a>(config, project, job_type)

WandbLogger

(run: wandb.sdk.wandb_run.Run)

View Source on GitHub

def create

(
    cls,
    config: dict,
    project: str | None = None,
    job_type: str | None = None,
    logging_fmt: str = '%(asctime)s %(levelname)s %(message)s',
    logging_datefmt: str = '%Y-%m-%d %H:%M:%S',
    wandb_kwargs: dict | None = None
) -> trnbl.loggers.wandb.WandbLogger

View Source on GitHub

def debug

(self, message: str, **kwargs) -> None

View Source on GitHub

log a debug message which will be saved, but not printed

def message

(self, message: str, **kwargs) -> None

View Source on GitHub

log a progress message, which will be printed to stdout

def metrics

(self, data: dict[str, typing.Any]) -> None

View Source on GitHub

Log a dictionary of metrics

def artifact

(
    self,
    path: pathlib.Path,
    type: str,
    aliases: list[str] | None = None,
    metadata: dict | None = None
) -> None

View Source on GitHub

log an artifact from a file

View Source on GitHub

Get the URL for the current logging run

View Source on GitHub

Get the path to the current logging run

def flush

(self) -> None

View Source on GitHub

Flush the logger

def finish

(self) -> None

View Source on GitHub

Finish logging

Inherited Members

docs for trnbl v0.1.1

API Documentation

View Source on GitHub

trnbl.training_interval

View Source on GitHub

class IntervalValueError(builtins.UserWarning):

View Source on GitHub

Error for when the interval is less than 1 batch

Inherited Members

class TrainingInterval:

View Source on GitHub

A training interval, which can be specified in a few different units.

Attributes:

Methods:

Provides methods for reading from a string or tuple, and normalizing to batches.

TrainingInterval

(
    quantity: int | float,
    unit: Literal['runs', 'epochs', 'batches', 'samples']
)

def as_batch_count

(
    self,
    batchsize: int,
    batches_per_epoch: int,
    epochs: int | None = None
) -> int

View Source on GitHub

given the batchsize, number of batches per epoch, and number of epochs, return the interval as a number of batches

Parameters:

Returns:

Raises:

def normalized

(
    self,
    batchsize: int,
    batches_per_epoch: int,
    epochs: int | None = None
) -> trnbl.training_interval.TrainingInterval

View Source on GitHub

convert the units of the interval to batches, by calling as_batch_count and setting the unit to “batches

def from_str

(cls, raw: str) -> trnbl.training_interval.TrainingInterval

View Source on GitHub

parse a string into a TrainingInterval object

Examples:

TrainingInterval.from_str(“5 epochs”) TrainingInterval(5, ‘epochs’) TrainingInterval.from_str(“100 batches”) TrainingInterval(100, ‘batches’) TrainingInterval.from_str(“0.1 runs”) TrainingInterval(0.1, ‘runs’) TrainingInterval.from_str(“1/5 runs”) TrainingInterval(0.2, ‘runs’)

def from_any

(cls, *args, **kwargs) -> trnbl.training_interval.TrainingInterval

View Source on GitHub

parse a string or tuple into a TrainingInterval object

def process_to_batches

(
    cls,
    interval: Union[str, tuple[Union[int, float, str], str], trnbl.training_interval.TrainingInterval],
    batchsize: int,
    batches_per_epoch: int,
    epochs: int | None = None
) -> int

View Source on GitHub

directly from any representation to a number of batches

docs for trnbl v0.1.1

API Documentation

View Source on GitHub

trnbl.training_manager

View Source on GitHub

class TrainingManagerInitError(builtins.Exception):

View Source on GitHub

Common base class for all non-exit exceptions.

Inherited Members

def wrapped_iterable

(
    sequence: Sequence[~T],
    manager: trnbl.training_manager.TrainingManager,
    is_epoch: bool = False,
    use_tqdm: bool | None = None,
    tqdm_kwargs: dict[str, typing.Any] | None = None
) -> Generator[~T, NoneType, NoneType]

View Source on GitHub

class TrainingManager(typing.Generic[~TLogger]):

View Source on GitHub

context manager for training a model, with logging, evals, and checkpoints

Parameters:

Usage:

with TrainingManager(
    model=model, dataloader=TRAIN_LOADER, logger=logger, epochs=500,
    evals={
        "1 epochs": eval_func,
        "0.1 runs": lambda model: logger.get_mem_usage(),
    }.items(),
    checkpoint_interval="50 epochs",
) as tp:

    # Training loop
    model.train()
    for epoch in range(epochs):
        for inputs, targets in TRAIN_LOADER:
            # the usual
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            # compute accuracy
            accuracy = torch.sum(torch.argmax(outputs, dim=1) == targets).item() / len(targets)

            # log metrics
            tp.batch_update(
                # pass in number of samples in your batch (or it will be inferred from the batch size)
                samples=len(targets),
                # any other metrics you compute every loop
                **{"train/loss": loss.item(), "train/acc": accuracy},
            )
            # batch_update will automatically run evals and save checkpoints as needed

        tp.epoch_update()

TrainingManager

(
    model: torch.nn.modules.module.Module,
    logger: ~TLogger,
    dataloader: torch.utils.data.dataloader.DataLoader | None = None,
    epochs_total: int | None = None,
    save_model: Callable[[torch.nn.modules.module.Module, pathlib.Path], NoneType] = <function save>,
    evals: Optional[Iterable[tuple[Union[str, tuple[Union[int, float, str], str], trnbl.training_interval.TrainingInterval], Callable[[torch.nn.modules.module.Module], dict]]]] = None,
    checkpoint_interval: Union[str, tuple[Union[int, float, str], str], trnbl.training_interval.TrainingInterval] = TrainingInterval(quantity=1, unit='epochs'),
    print_metrics_interval: Union[str, tuple[Union[int, float, str], str], trnbl.training_interval.TrainingInterval] = TrainingInterval(quantity=0.1, unit='runs'),
    model_save_path: str = '{run_path}/checkpoints/model.checkpoint-{latest_checkpoint}.pt',
    model_save_path_special: str = '{run_path}/model.{alias}.pt'
)

View Source on GitHub

def try_compute_counters

(self) -> None

View Source on GitHub

def epoch_loop

(
    self,
    epochs: Sequence[int],
    use_tqdm: bool = True,
    **tqdm_kwargs
) -> Generator[int, NoneType, NoneType]

View Source on GitHub

def batch_loop

(
    self,
    batches: Sequence[int],
    use_tqdm: bool = False,
    **tqdm_kwargs
) -> Generator[int, NoneType, NoneType]

View Source on GitHub

def check_is_initialized

(self)

View Source on GitHub

def get_elapsed_time

(self) -> float

View Source on GitHub

return the elapsed time in seconds since the start of training

def training_status

(self) -> dict[str, int | float]

View Source on GitHub

status of elapsed time, samples, batches, epochs, and checkpoints

def batch_update

(self, samples: int | None, metrics: dict | None = None, **kwargs)

View Source on GitHub

call this at the end of every batch. Pass samples or it will be inferred from the batch size, and any other metrics as kwargs

This function will: - update internal counters - run evals as needed (based on the intervals passed) - log all metrics and training status - save a checkpoint as needed (based on the checkpoint interval)

def epoch_update

(self)

View Source on GitHub

call this at the end of every epoch. This function will log the completion of the epoch and update the epoch counter