Module flashy.distrib

Expand source code
from contextlib import contextmanager
from functools import partial, wraps
import typing as tp

import torch
from torch import distributed
from torch.nn.parallel.distributed import DistributedDataParallel
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import DataLoader, Subset

from dora import distrib as dora_distrib


rank = 0
world_size = 1


def init(backend='nccl'):
    global rank, world_size
    dora_distrib.init(backend)
    rank = dora_distrib.rank()
    world_size = dora_distrib.world_size()


def rank_zero_only(fn: tp.Callable) -> tp.Callable:
    """Function that can be used as a decorator to enable a
    function/method being called only on rank 0."""

    @wraps(fn)
    def wrapped_fn(*args: tp.Any, **kwargs: tp.Any) -> tp.Optional[tp.Any]:
        if is_rank_zero():
            return fn(*args, **kwargs)
        return None

    return wrapped_fn


def is_rank_zero():
    return rank == 0


def is_distributed():
    return world_size > 1


def all_reduce(tensor: torch.Tensor, op=distributed.ReduceOp.SUM):
    if is_distributed():
        return distributed.all_reduce(tensor, op)


def average_metrics(metrics: tp.Dict[str, float], count=1.):
    if not is_distributed():
        return metrics
    keys, values = zip(*sorted(metrics.items()))
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    tensor = torch.tensor(list(values) + [1], device=device, dtype=torch.float32)
    tensor *= count
    all_reduce(tensor)
    averaged = (tensor[:-1] / tensor[-1]).cpu().tolist()
    return dict(zip(keys, averaged))


def wrap(model):
    if is_distributed():
        return DistributedDataParallel(
            model,
            device_ids=[torch.cuda.current_device()],
            output_device=torch.cuda.current_device())
    else:
        return model


def _check_number_of_params(params: tp.List[torch.Tensor]):
    # utility function to check that the number of params in all workers is the same,
    # and thus avoid a deadlock with distributed all reduce.
    if not is_distributed() or not params:
        return
    tensor = torch.tensor([len(params)], device=params[0].device, dtype=torch.long)
    all_reduce(tensor)
    if tensor.item() != len(params) * world_size:
        # If not all the workers have the same number, for at least one of them,
        # this inequality will be verified.
        raise RuntimeError(f"Mismatch in number of params: ours is {len(params)}, "
                           "at least one worker has a different one.")


def broadcast_weights(params: tp.Iterable[torch.Tensor], src: int = 0):
    """Broadcast the weights from the given parameters to all workers.
    This can be used to ensure that all workers have the same model to start with.
    """
    if not is_distributed():
        return
    params = list(params)
    _check_number_of_params(params)
    handles = []
    for param in params:
        handle = distributed.broadcast(param.data, src=src, async_op=True)
        handles.append(handle)
    for handle in handles:
        handle.wait()


def sync_grad(params: tp.Iterable[torch.Tensor]):
    """
    Simpler alternative to DistributedDataParallel, that doesn't rely
    on any black magic. For simple models it can also be as fast.
    Just call this on your model parameters after the call to backward.

    ..Warning:: This only synchronize the given params. When using this with `model.parameters()`,
        this will not synchronize buffers. For most cases, this should be alright. When using
        BatchNorm, this can lead to small differences in how the model is evaluated at valid time.
    """
    if not is_distributed():
        return
    params = [p for p in params if p.grad is not None]
    _check_number_of_params(params)
    handles = []
    for p in params:
        if p.grad is not None:
            handle = torch.distributed.all_reduce(
                p.grad.data, op=torch.distributed.ReduceOp.SUM, async_op=True)
            handles.append((p, handle))
    for p, handle in handles:
        handle.wait()
        p.grad.data /= world_size


@contextmanager
def eager_sync_grad(params: tp.Iterable[torch.Tensor]):
    """Similar to `sync_grad`, except this is a context manager that will start syncing
    gradient as soon as they become available. This can be faster, but requires backward to be
    called no more than once!

    ..Warning:: This only synchronize the given params. When using this with `model.parameters()`,
        this will not synchronize buffers. For most cases, this should be alright. When using
        BatchNorm, this can lead to small differences in how the model is evaluated at valid time.
    """
    if not is_distributed():
        yield
        return
    params = list([p for p in params if p.requires_grad])
    _check_number_of_params(params)
    hooks = []
    handles = []
    waiting_params = set(params)

    def _callback(param, grad):
        if param not in waiting_params:
            raise RuntimeError(f"We got a gradient twice for parameter {param}.")
        handle = torch.distributed.all_reduce(grad.data, op=torch.distributed.ReduceOp.SUM, async_op=True)
        handles.append((param, handle))
        waiting_params.remove(param)

    for param in params:
        hooks.append(param.register_hook(partial(_callback, param)))

    try:
        yield
    finally:
        for hook in hooks:
            hook.remove()
        _check_number_of_params(list(waiting_params))  # verify all workers have the same nb of remaining params.
        for param, handle in handles:
            handle.wait()
            assert param.grad is not None
            param.grad.data /= world_size


def loader(dataset, *args, shuffle=False, klass=DataLoader, **kwargs):
    """
    Create a dataloader properly in case of distributed training.
    If a gradient is going to be computed you must set `shuffle=True`.
    """
    if not is_distributed():
        return klass(dataset, *args, shuffle=shuffle, **kwargs)

    if shuffle:
        # train means we will compute backward, we use DistributedSampler
        sampler = DistributedSampler(dataset)
        # We ignore shuffle, DistributedSampler already shuffles
        return klass(dataset, *args, **kwargs, sampler=sampler)
    else:
        # We make a manual shard, as DistributedSampler otherwise replicate some examples
        dataset = Subset(dataset, list(range(rank, len(dataset), world_size)))
        return klass(dataset, *args, shuffle=shuffle, **kwargs)

Functions

def all_reduce(tensor: torch.Tensor, op=<ReduceOp.SUM: 0>)
Expand source code
def all_reduce(tensor: torch.Tensor, op=distributed.ReduceOp.SUM):
    if is_distributed():
        return distributed.all_reduce(tensor, op)
def average_metrics(metrics: Dict[str, float], count=1.0)
Expand source code
def average_metrics(metrics: tp.Dict[str, float], count=1.):
    if not is_distributed():
        return metrics
    keys, values = zip(*sorted(metrics.items()))
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    tensor = torch.tensor(list(values) + [1], device=device, dtype=torch.float32)
    tensor *= count
    all_reduce(tensor)
    averaged = (tensor[:-1] / tensor[-1]).cpu().tolist()
    return dict(zip(keys, averaged))
def broadcast_weights(params: Iterable[torch.Tensor], src: int = 0)

Broadcast the weights from the given parameters to all workers. This can be used to ensure that all workers have the same model to start with.

Expand source code
def broadcast_weights(params: tp.Iterable[torch.Tensor], src: int = 0):
    """Broadcast the weights from the given parameters to all workers.
    This can be used to ensure that all workers have the same model to start with.
    """
    if not is_distributed():
        return
    params = list(params)
    _check_number_of_params(params)
    handles = []
    for param in params:
        handle = distributed.broadcast(param.data, src=src, async_op=True)
        handles.append(handle)
    for handle in handles:
        handle.wait()
def eager_sync_grad(params: Iterable[torch.Tensor])

Similar to sync_grad(), except this is a context manager that will start syncing gradient as soon as they become available. This can be faster, but requires backward to be called no more than once!

Warning: This only synchronize the given params. When using this with model.parameters(),

this will not synchronize buffers. For most cases, this should be alright. When using BatchNorm, this can lead to small differences in how the model is evaluated at valid time.

Expand source code
@contextmanager
def eager_sync_grad(params: tp.Iterable[torch.Tensor]):
    """Similar to `sync_grad`, except this is a context manager that will start syncing
    gradient as soon as they become available. This can be faster, but requires backward to be
    called no more than once!

    ..Warning:: This only synchronize the given params. When using this with `model.parameters()`,
        this will not synchronize buffers. For most cases, this should be alright. When using
        BatchNorm, this can lead to small differences in how the model is evaluated at valid time.
    """
    if not is_distributed():
        yield
        return
    params = list([p for p in params if p.requires_grad])
    _check_number_of_params(params)
    hooks = []
    handles = []
    waiting_params = set(params)

    def _callback(param, grad):
        if param not in waiting_params:
            raise RuntimeError(f"We got a gradient twice for parameter {param}.")
        handle = torch.distributed.all_reduce(grad.data, op=torch.distributed.ReduceOp.SUM, async_op=True)
        handles.append((param, handle))
        waiting_params.remove(param)

    for param in params:
        hooks.append(param.register_hook(partial(_callback, param)))

    try:
        yield
    finally:
        for hook in hooks:
            hook.remove()
        _check_number_of_params(list(waiting_params))  # verify all workers have the same nb of remaining params.
        for param, handle in handles:
            handle.wait()
            assert param.grad is not None
            param.grad.data /= world_size
def init(backend='nccl')
Expand source code
def init(backend='nccl'):
    global rank, world_size
    dora_distrib.init(backend)
    rank = dora_distrib.rank()
    world_size = dora_distrib.world_size()
def is_distributed()
Expand source code
def is_distributed():
    return world_size > 1
def is_rank_zero()
Expand source code
def is_rank_zero():
    return rank == 0
def loader(dataset, *args, shuffle=False, klass=torch.utils.data.dataloader.DataLoader, **kwargs)

Create a dataloader properly in case of distributed training. If a gradient is going to be computed you must set shuffle=True.

Expand source code
def loader(dataset, *args, shuffle=False, klass=DataLoader, **kwargs):
    """
    Create a dataloader properly in case of distributed training.
    If a gradient is going to be computed you must set `shuffle=True`.
    """
    if not is_distributed():
        return klass(dataset, *args, shuffle=shuffle, **kwargs)

    if shuffle:
        # train means we will compute backward, we use DistributedSampler
        sampler = DistributedSampler(dataset)
        # We ignore shuffle, DistributedSampler already shuffles
        return klass(dataset, *args, **kwargs, sampler=sampler)
    else:
        # We make a manual shard, as DistributedSampler otherwise replicate some examples
        dataset = Subset(dataset, list(range(rank, len(dataset), world_size)))
        return klass(dataset, *args, shuffle=shuffle, **kwargs)
def rank_zero_only(fn: Callable) ‑> Callable

Function that can be used as a decorator to enable a function/method being called only on rank 0.

Expand source code
def rank_zero_only(fn: tp.Callable) -> tp.Callable:
    """Function that can be used as a decorator to enable a
    function/method being called only on rank 0."""

    @wraps(fn)
    def wrapped_fn(*args: tp.Any, **kwargs: tp.Any) -> tp.Optional[tp.Any]:
        if is_rank_zero():
            return fn(*args, **kwargs)
        return None

    return wrapped_fn
def sync_grad(params: Iterable[torch.Tensor])

Simpler alternative to DistributedDataParallel, that doesn't rely on any black magic. For simple models it can also be as fast. Just call this on your model parameters after the call to backward.

Warning: This only synchronize the given params. When using this with model.parameters(),

this will not synchronize buffers. For most cases, this should be alright. When using BatchNorm, this can lead to small differences in how the model is evaluated at valid time.

Expand source code
def sync_grad(params: tp.Iterable[torch.Tensor]):
    """
    Simpler alternative to DistributedDataParallel, that doesn't rely
    on any black magic. For simple models it can also be as fast.
    Just call this on your model parameters after the call to backward.

    ..Warning:: This only synchronize the given params. When using this with `model.parameters()`,
        this will not synchronize buffers. For most cases, this should be alright. When using
        BatchNorm, this can lead to small differences in how the model is evaluated at valid time.
    """
    if not is_distributed():
        return
    params = [p for p in params if p.grad is not None]
    _check_number_of_params(params)
    handles = []
    for p in params:
        if p.grad is not None:
            handle = torch.distributed.all_reduce(
                p.grad.data, op=torch.distributed.ReduceOp.SUM, async_op=True)
            handles.append((p, handle))
    for p, handle in handles:
        handle.wait()
        p.grad.data /= world_size
def wrap(model)
Expand source code
def wrap(model):
    if is_distributed():
        return DistributedDataParallel(
            model,
            device_ids=[torch.cuda.current_device()],
            output_device=torch.cuda.current_device())
    else:
        return model