Source code for reflectorch.runs.utils

from pathlib import Path
from typing import Tuple

import torch

from reflectorch import *
from reflectorch.runs.config import load_config

__all__ = [
    "train_from_config",
    "get_trainer_from_config",
    "get_paths_from_config",
    "get_callbacks_from_config",
    "get_trainer_by_name",
    "get_callbacks_by_name",
]


def init_from_conf(conf, **kwargs):
    """Initializes an object with class type, args and kwargs specified in the configuration

    Args:
        conf (dict): configuration dictionary

    Returns:
        Any: the initialized object
    """
    if not conf:
        return
    cls_name = conf['cls']
    if not cls_name:
        return
    cls = globals().get(cls_name)

    if not cls:
        raise ValueError(f'Unknown class {cls_name}')

    conf_args = conf.get('args', [])
    conf_kwargs = conf.get('kwargs', {})
    conf_kwargs.update(kwargs)
    return cls(*conf_args, **conf_kwargs)


[docs] def train_from_config(config: dict): """Train a model from a configuration dictionary Args: config (dict): configuration dictionary Returns: Trainer: the trainer object """ folder_paths = get_paths_from_config(config, mkdir=True) trainer = get_trainer_from_config(config, folder_paths) callbacks = get_callbacks_from_config(config, folder_paths) trainer.train( config['training']['num_iterations'], callbacks, disable_tqdm=False, update_tqdm_freq=config['training']['update_tqdm_freq'], grad_accumulation_steps=config['training'].get('grad_accumulation_steps', 1) ) torch.save({ 'paths': folder_paths, 'losses': trainer.losses, 'params': config, }, folder_paths['losses']) return trainer
[docs] def get_paths_from_config(config: dict, mkdir: bool = False): """Get the directory paths from a configuration dictionary Args: config (dict): configuration dictionary mkdir (bool, optional): option to create a new directory for the saved model weights and losses. Returns: dict: dictionary containing the folder paths """ root_dir = Path(config['general']['root_dir'] or ROOT_DIR) name = config['general']['name'] assert root_dir.is_dir() saved_models_dir = root_dir / 'saved_models' saved_losses_dir = root_dir / 'saved_losses' if mkdir: saved_models_dir.mkdir(exist_ok=True) saved_losses_dir.mkdir(exist_ok=True) model_path = str((saved_models_dir / f'model_{name}.pt').absolute()) losses_path = saved_losses_dir / f'{name}_losses.pt' return { 'name': name, 'model': model_path, 'losses': losses_path, 'root': root_dir, 'saved_models': saved_models_dir, 'saved_losses': saved_losses_dir, }
[docs] def get_callbacks_from_config(config: dict, folder_paths: dict = None) -> Tuple['TrainerCallback', ...]: """Initializes the training callbacks from a configuration dictionary Returns: tuple: tuple of callbacks """ callbacks = [] folder_paths = folder_paths or get_paths_from_config(config) train_conf = config['training'] callback_conf = dict(train_conf['callbacks']) save_conf = callback_conf.pop('save_best_model') if save_conf['enable']: save_model = SaveBestModel(folder_paths['model'], freq=save_conf['freq']) callbacks.append(save_model) for conf in callback_conf.values(): callback = init_from_conf(conf) if callback: callbacks.append(callback) if train_conf['logger']['use_neptune']: callbacks.append(LogLosses()) return tuple(callbacks)
[docs] def get_trainer_from_config(config: dict, folder_paths: dict = None): """Initializes a trainer from a configuration dictionary Args: config (dict): the configuration dictionary folder_paths (dict, optional): dictionary containing the folder paths Returns: Trainer: the trainer object """ dset = init_dset(config['dset']) folder_paths = folder_paths or get_paths_from_config(config) model = init_network(config['model']['network'], folder_paths['saved_models']) train_conf = config['training'] optim_cls = getattr(torch.optim, train_conf['optimizer']) logger = None train_with_q_input = train_conf.get('train_with_q_input', False) clip_grad_norm_max = train_conf.get('clip_grad_norm_max', None) trainer_cls = globals().get(train_conf['trainer_cls']) if 'trainer_cls' in train_conf else PointEstimatorTrainer trainer_kwargs = train_conf.get('trainer_kwargs', {}) trainer = trainer_cls( model, dset, train_conf['lr'], train_conf['batch_size'], clip_grad_norm_max=clip_grad_norm_max, logger=logger, optim_cls=optim_cls, train_with_q_input=train_with_q_input, **trainer_kwargs ) return trainer
[docs] def get_trainer_by_name(config_name, config_dir=None, model_path=None, load_weights: bool = True, inference_device = None): """Initializes a trainer object based on a configuration file (i.e. the model name) and optionally loads \ saved weights into the network Args: config_name (str): name of the configuration file config_dir (str): path of the configuration directory model_path (str, optional): path to the network weights. load_weights (bool, optional): if True the saved network weights are loaded into the network. Defaults to True. Returns: Trainer: the trainer object """ config = load_config(config_name, config_dir) config['model']['network']['pretrained_name'] = None config['training']['logger']['use_neptune'] = False if inference_device: config['model']['network']['device'] = inference_device config['dset']['prior_sampler']['kwargs']['device'] = inference_device config['dset']['q_generator']['kwargs']['device'] = inference_device trainer = get_trainer_from_config(config) num_params = sum(p.numel() for p in trainer.model.parameters()) print( f'Model {config_name} loaded. Number of parameters: {num_params / 10 ** 6:.2f} M', ) if not load_weights: return trainer if not model_path: model_name = f'model_{config_name}.pt' model_path = SAVED_MODELS_DIR / model_name try: state_dict = torch.load(model_path) except Exception as err: raise RuntimeError(f'Could not load model from {model_path}') from err if 'model' in state_dict: trainer.model.load_state_dict(state_dict['model']) else: trainer.model.load_state_dict(state_dict) return trainer
def get_callbacks_by_name(config_name, config_dir=None): """Initializes the trainer callbacks based on a configuration file Args: config_name (str): name of the configuration file config_dir (str): path of the configuration directory """ config = load_config(config_name, config_dir) callbacks = get_callbacks_from_config(config) return callbacks def init_network(config: dict, saved_models_dir: Path): """Initializes the network based on the configuration dictionary and optionally loades the weights from a pretrained model""" device = config.get('device', 'cuda') network = init_from_conf(config).to(device) network = load_pretrained(network, config.get('pretrained_name', None), saved_models_dir) return network def load_pretrained(model, model_name: str, saved_models_dir: Path): """Loads the saved weights into the network""" if not model_name: return model if '.' not in model_name: model_name = model_name + '.pt' model_path = saved_models_dir / model_name if not model_path.is_file(): model_path = saved_models_dir / f'model_{model_name}' if not model_path.is_file(): raise FileNotFoundError(f'File {str(model_path)} does not exist.') try: pretrained = torch.load(model_path) except Exception as err: raise RuntimeError(f'Could not load model from {str(model_path)}') from err if 'model' in pretrained: pretrained = pretrained['model'] try: model.load_state_dict(pretrained) except Exception as err: raise RuntimeError(f'Could not update state dict from {str(model_path)}') from err return model def init_dset(config: dict): """Initializes the dataset / dataloader object""" dset_cls = globals().get(config['cls']) if 'cls' in config else ReflectivityDataLoader prior_sampler = init_from_conf(config['prior_sampler']) intensity_noise = init_from_conf(config['intensity_noise']) q_generator = init_from_conf(config['q_generator']) curves_scaler = init_from_conf(config['curves_scaler']) if 'curves_scaler' in config else None smearing = init_from_conf(config['smearing']) if 'smearing' in config else None q_noise = init_from_conf(config['q_noise']) if 'q_noise' in config else None dset = dset_cls( q_generator=q_generator, prior_sampler=prior_sampler, intensity_noise=intensity_noise, curves_scaler=curves_scaler, smearing=smearing, q_noise=q_noise, ) return dset