Module flashy.distrib
Expand source code
from contextlib import contextmanager
from functools import partial, wraps
import typing as tp
import torch
from torch import distributed
from torch.nn.parallel.distributed import DistributedDataParallel
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import DataLoader, Subset
from dora import distrib as dora_distrib
rank = 0
world_size = 1
def init(backend='nccl'):
global rank, world_size
dora_distrib.init(backend)
rank = dora_distrib.rank()
world_size = dora_distrib.world_size()
def rank_zero_only(fn: tp.Callable) -> tp.Callable:
"""Function that can be used as a decorator to enable a
function/method being called only on rank 0."""
@wraps(fn)
def wrapped_fn(*args: tp.Any, **kwargs: tp.Any) -> tp.Optional[tp.Any]:
if is_rank_zero():
return fn(*args, **kwargs)
return None
return wrapped_fn
def is_rank_zero():
return rank == 0
def is_distributed():
return world_size > 1
def all_reduce(tensor: torch.Tensor, op=distributed.ReduceOp.SUM):
if is_distributed():
return distributed.all_reduce(tensor, op)
def average_metrics(metrics: tp.Dict[str, float], count=1.):
if not is_distributed():
return metrics
keys, values = zip(*sorted(metrics.items()))
device = 'cuda' if torch.cuda.is_available() else 'cpu'
tensor = torch.tensor(list(values) + [1], device=device, dtype=torch.float32)
tensor *= count
all_reduce(tensor)
averaged = (tensor[:-1] / tensor[-1]).cpu().tolist()
return dict(zip(keys, averaged))
def wrap(model):
if is_distributed():
return DistributedDataParallel(
model,
device_ids=[torch.cuda.current_device()],
output_device=torch.cuda.current_device())
else:
return model
def _check_number_of_params(params: tp.List[torch.Tensor]):
# utility function to check that the number of params in all workers is the same,
# and thus avoid a deadlock with distributed all reduce.
if not is_distributed() or not params:
return
tensor = torch.tensor([len(params)], device=params[0].device, dtype=torch.long)
all_reduce(tensor)
if tensor.item() != len(params) * world_size:
# If not all the workers have the same number, for at least one of them,
# this inequality will be verified.
raise RuntimeError(f"Mismatch in number of params: ours is {len(params)}, "
"at least one worker has a different one.")
def broadcast_weights(params: tp.Iterable[torch.Tensor], src: int = 0):
"""Broadcast the weights from the given parameters to all workers.
This can be used to ensure that all workers have the same model to start with.
"""
if not is_distributed():
return
params = list(params)
_check_number_of_params(params)
handles = []
for param in params:
handle = distributed.broadcast(param.data, src=src, async_op=True)
handles.append(handle)
for handle in handles:
handle.wait()
def sync_grad(params: tp.Iterable[torch.Tensor]):
"""
Simpler alternative to DistributedDataParallel, that doesn't rely
on any black magic. For simple models it can also be as fast.
Just call this on your model parameters after the call to backward.
..Warning:: This only synchronize the given params. When using this with `model.parameters()`,
this will not synchronize buffers. For most cases, this should be alright. When using
BatchNorm, this can lead to small differences in how the model is evaluated at valid time.
"""
if not is_distributed():
return
params = [p for p in params if p.grad is not None]
_check_number_of_params(params)
handles = []
for p in params:
if p.grad is not None:
handle = torch.distributed.all_reduce(
p.grad.data, op=torch.distributed.ReduceOp.SUM, async_op=True)
handles.append((p, handle))
for p, handle in handles:
handle.wait()
p.grad.data /= world_size
@contextmanager
def eager_sync_grad(params: tp.Iterable[torch.Tensor]):
"""Similar to `sync_grad`, except this is a context manager that will start syncing
gradient as soon as they become available. This can be faster, but requires backward to be
called no more than once!
..Warning:: This only synchronize the given params. When using this with `model.parameters()`,
this will not synchronize buffers. For most cases, this should be alright. When using
BatchNorm, this can lead to small differences in how the model is evaluated at valid time.
"""
if not is_distributed():
yield
return
params = list([p for p in params if p.requires_grad])
_check_number_of_params(params)
hooks = []
handles = []
waiting_params = set(params)
def _callback(param, grad):
if param not in waiting_params:
raise RuntimeError(f"We got a gradient twice for parameter {param}.")
handle = torch.distributed.all_reduce(grad.data, op=torch.distributed.ReduceOp.SUM, async_op=True)
handles.append((param, handle))
waiting_params.remove(param)
for param in params:
hooks.append(param.register_hook(partial(_callback, param)))
try:
yield
finally:
for hook in hooks:
hook.remove()
_check_number_of_params(list(waiting_params)) # verify all workers have the same nb of remaining params.
for param, handle in handles:
handle.wait()
assert param.grad is not None
param.grad.data /= world_size
def loader(dataset, *args, shuffle=False, klass=DataLoader, **kwargs):
"""
Create a dataloader properly in case of distributed training.
If a gradient is going to be computed you must set `shuffle=True`.
"""
if not is_distributed():
return klass(dataset, *args, shuffle=shuffle, **kwargs)
if shuffle:
# train means we will compute backward, we use DistributedSampler
sampler = DistributedSampler(dataset)
# We ignore shuffle, DistributedSampler already shuffles
return klass(dataset, *args, **kwargs, sampler=sampler)
else:
# We make a manual shard, as DistributedSampler otherwise replicate some examples
dataset = Subset(dataset, list(range(rank, len(dataset), world_size)))
return klass(dataset, *args, shuffle=shuffle, **kwargs)
Functions
def all_reduce(tensor: torch.Tensor, op=<ReduceOp.SUM: 0>)
-
Expand source code
def all_reduce(tensor: torch.Tensor, op=distributed.ReduceOp.SUM): if is_distributed(): return distributed.all_reduce(tensor, op)
def average_metrics(metrics: Dict[str, float], count=1.0)
-
Expand source code
def average_metrics(metrics: tp.Dict[str, float], count=1.): if not is_distributed(): return metrics keys, values = zip(*sorted(metrics.items())) device = 'cuda' if torch.cuda.is_available() else 'cpu' tensor = torch.tensor(list(values) + [1], device=device, dtype=torch.float32) tensor *= count all_reduce(tensor) averaged = (tensor[:-1] / tensor[-1]).cpu().tolist() return dict(zip(keys, averaged))
def broadcast_weights(params: Iterable[torch.Tensor], src: int = 0)
-
Broadcast the weights from the given parameters to all workers. This can be used to ensure that all workers have the same model to start with.
Expand source code
def broadcast_weights(params: tp.Iterable[torch.Tensor], src: int = 0): """Broadcast the weights from the given parameters to all workers. This can be used to ensure that all workers have the same model to start with. """ if not is_distributed(): return params = list(params) _check_number_of_params(params) handles = [] for param in params: handle = distributed.broadcast(param.data, src=src, async_op=True) handles.append(handle) for handle in handles: handle.wait()
def eager_sync_grad(params: Iterable[torch.Tensor])
-
Similar to
sync_grad()
, except this is a context manager that will start syncing gradient as soon as they become available. This can be faster, but requires backward to be called no more than once!Warning: This only synchronize the given params. When using this with
model.parameters()
,this will not synchronize buffers. For most cases, this should be alright. When using BatchNorm, this can lead to small differences in how the model is evaluated at valid time.
Expand source code
@contextmanager def eager_sync_grad(params: tp.Iterable[torch.Tensor]): """Similar to `sync_grad`, except this is a context manager that will start syncing gradient as soon as they become available. This can be faster, but requires backward to be called no more than once! ..Warning:: This only synchronize the given params. When using this with `model.parameters()`, this will not synchronize buffers. For most cases, this should be alright. When using BatchNorm, this can lead to small differences in how the model is evaluated at valid time. """ if not is_distributed(): yield return params = list([p for p in params if p.requires_grad]) _check_number_of_params(params) hooks = [] handles = [] waiting_params = set(params) def _callback(param, grad): if param not in waiting_params: raise RuntimeError(f"We got a gradient twice for parameter {param}.") handle = torch.distributed.all_reduce(grad.data, op=torch.distributed.ReduceOp.SUM, async_op=True) handles.append((param, handle)) waiting_params.remove(param) for param in params: hooks.append(param.register_hook(partial(_callback, param))) try: yield finally: for hook in hooks: hook.remove() _check_number_of_params(list(waiting_params)) # verify all workers have the same nb of remaining params. for param, handle in handles: handle.wait() assert param.grad is not None param.grad.data /= world_size
def init(backend='nccl')
-
Expand source code
def init(backend='nccl'): global rank, world_size dora_distrib.init(backend) rank = dora_distrib.rank() world_size = dora_distrib.world_size()
def is_distributed()
-
Expand source code
def is_distributed(): return world_size > 1
def is_rank_zero()
-
Expand source code
def is_rank_zero(): return rank == 0
def loader(dataset, *args, shuffle=False, klass=torch.utils.data.dataloader.DataLoader, **kwargs)
-
Create a dataloader properly in case of distributed training. If a gradient is going to be computed you must set
shuffle=True
.Expand source code
def loader(dataset, *args, shuffle=False, klass=DataLoader, **kwargs): """ Create a dataloader properly in case of distributed training. If a gradient is going to be computed you must set `shuffle=True`. """ if not is_distributed(): return klass(dataset, *args, shuffle=shuffle, **kwargs) if shuffle: # train means we will compute backward, we use DistributedSampler sampler = DistributedSampler(dataset) # We ignore shuffle, DistributedSampler already shuffles return klass(dataset, *args, **kwargs, sampler=sampler) else: # We make a manual shard, as DistributedSampler otherwise replicate some examples dataset = Subset(dataset, list(range(rank, len(dataset), world_size))) return klass(dataset, *args, shuffle=shuffle, **kwargs)
def rank_zero_only(fn: Callable) ‑> Callable
-
Function that can be used as a decorator to enable a function/method being called only on rank 0.
Expand source code
def rank_zero_only(fn: tp.Callable) -> tp.Callable: """Function that can be used as a decorator to enable a function/method being called only on rank 0.""" @wraps(fn) def wrapped_fn(*args: tp.Any, **kwargs: tp.Any) -> tp.Optional[tp.Any]: if is_rank_zero(): return fn(*args, **kwargs) return None return wrapped_fn
def sync_grad(params: Iterable[torch.Tensor])
-
Simpler alternative to DistributedDataParallel, that doesn't rely on any black magic. For simple models it can also be as fast. Just call this on your model parameters after the call to backward.
Warning: This only synchronize the given params. When using this with
model.parameters()
,this will not synchronize buffers. For most cases, this should be alright. When using BatchNorm, this can lead to small differences in how the model is evaluated at valid time.
Expand source code
def sync_grad(params: tp.Iterable[torch.Tensor]): """ Simpler alternative to DistributedDataParallel, that doesn't rely on any black magic. For simple models it can also be as fast. Just call this on your model parameters after the call to backward. ..Warning:: This only synchronize the given params. When using this with `model.parameters()`, this will not synchronize buffers. For most cases, this should be alright. When using BatchNorm, this can lead to small differences in how the model is evaluated at valid time. """ if not is_distributed(): return params = [p for p in params if p.grad is not None] _check_number_of_params(params) handles = [] for p in params: if p.grad is not None: handle = torch.distributed.all_reduce( p.grad.data, op=torch.distributed.ReduceOp.SUM, async_op=True) handles.append((p, handle)) for p, handle in handles: handle.wait() p.grad.data /= world_size
def wrap(model)
-
Expand source code
def wrap(model): if is_distributed(): return DistributedDataParallel( model, device_ids=[torch.cuda.current_device()], output_device=torch.cuda.current_device()) else: return model