69 lines
2.1 KiB
Python
69 lines
2.1 KiB
Python
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||
|
# All rights reserved.
|
||
|
#
|
||
|
# This source code is licensed under the license found in the
|
||
|
# LICENSE file in the root directory of this source tree.
|
||
|
|
||
|
"""Torch distributed utilities."""
|
||
|
|
||
|
import typing as tp
|
||
|
|
||
|
import torch
|
||
|
|
||
|
|
||
|
def rank():
|
||
|
if torch.distributed.is_initialized():
|
||
|
return torch.distributed.get_rank()
|
||
|
else:
|
||
|
return 0
|
||
|
|
||
|
|
||
|
def world_size():
|
||
|
if torch.distributed.is_initialized():
|
||
|
return torch.distributed.get_world_size()
|
||
|
else:
|
||
|
return 1
|
||
|
|
||
|
|
||
|
def is_distributed():
|
||
|
return world_size() > 1
|
||
|
|
||
|
|
||
|
def all_reduce(tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM):
|
||
|
if is_distributed():
|
||
|
return torch.distributed.all_reduce(tensor, op)
|
||
|
|
||
|
|
||
|
def _is_complex_or_float(tensor):
|
||
|
return torch.is_floating_point(tensor) or torch.is_complex(tensor)
|
||
|
|
||
|
|
||
|
def _check_number_of_params(params: tp.List[torch.Tensor]):
|
||
|
# utility function to check that the number of params in all workers is the same,
|
||
|
# and thus avoid a deadlock with distributed all reduce.
|
||
|
if not is_distributed() or not params:
|
||
|
return
|
||
|
tensor = torch.tensor([len(params)], device=params[0].device, dtype=torch.long)
|
||
|
all_reduce(tensor)
|
||
|
if tensor.item() != len(params) * world_size():
|
||
|
# If not all the workers have the same number, for at least one of them,
|
||
|
# this inequality will be verified.
|
||
|
raise RuntimeError(f"Mismatch in number of params: ours is {len(params)}, "
|
||
|
"at least one worker has a different one.")
|
||
|
|
||
|
|
||
|
def broadcast_tensors(tensors: tp.Iterable[torch.Tensor], src: int = 0):
|
||
|
"""Broadcast the tensors from the given parameters to all workers.
|
||
|
This can be used to ensure that all workers have the same model to start with.
|
||
|
"""
|
||
|
if not is_distributed():
|
||
|
return
|
||
|
tensors = [tensor for tensor in tensors if _is_complex_or_float(tensor)]
|
||
|
_check_number_of_params(tensors)
|
||
|
handles = []
|
||
|
for tensor in tensors:
|
||
|
handle = torch.distributed.broadcast(tensor.data, src=src, async_op=True)
|
||
|
handles.append(handle)
|
||
|
for handle in handles:
|
||
|
handle.wait()
|