# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """Torch distributed utilities.""" import typing as tp import torch def rank(): if torch.distributed.is_initialized(): return torch.distributed.get_rank() else: return 0 def world_size(): if torch.distributed.is_initialized(): return torch.distributed.get_world_size() else: return 1 def is_distributed(): return world_size() > 1 def all_reduce(tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM): if is_distributed(): return torch.distributed.all_reduce(tensor, op) def _is_complex_or_float(tensor): return torch.is_floating_point(tensor) or torch.is_complex(tensor) def _check_number_of_params(params: tp.List[torch.Tensor]): # utility function to check that the number of params in all workers is the same, # and thus avoid a deadlock with distributed all reduce. if not is_distributed() or not params: return tensor = torch.tensor([len(params)], device=params[0].device, dtype=torch.long) all_reduce(tensor) if tensor.item() != len(params) * world_size(): # If not all the workers have the same number, for at least one of them, # this inequality will be verified. raise RuntimeError(f"Mismatch in number of params: ours is {len(params)}, " "at least one worker has a different one.") def broadcast_tensors(tensors: tp.Iterable[torch.Tensor], src: int = 0): """Broadcast the tensors from the given parameters to all workers. This can be used to ensure that all workers have the same model to start with. """ if not is_distributed(): return tensors = [tensor for tensor in tensors if _is_complex_or_float(tensor)] _check_number_of_params(tensors) handles = [] for tensor in tensors: handle = torch.distributed.broadcast(tensor.data, src=src, async_op=True) handles.append(handle) for handle in handles: handle.wait()