import warnings from abc import ABC, abstractmethod from types import TracebackType from typing import Any, List, NamedTuple, Optional, Type import torch import torch.distributed as dist __all__ = ['JoinHook', 'Joinable', 'Join'] class JoinHook: r""" This defines a join hook, which provides two entry points in the join context manager. Entry points : a main hook, which is called repeatedly while there exists a non-joined process, and a post-hook, which is called once all processes have joined. To implement a join hook for the generic join context manager, define a class that inherits from :class:`JoinHook` and override ``main_hook()`` and ``post_hook()`` as appropriate. """ def main_hook(self) -> None: r"""Call this hook while there exists a non-joined process to shadow collective communications in a training iteration. Training iteration i.e., in one forward pass, backward pass, and optimizer step. """ ... def post_hook(self, is_last_joiner: bool) -> None: r""" Call hook after all processes have joined. It is passed an additional ``bool`` argument ``is_last_joiner``, which indicates if the rank is one of the last to join. Arguments: is_last_joiner (bool): ``True`` if the rank is one of the last to join; ``False`` otherwise. """ ... class Joinable(ABC): r""" This defines an abstract base class for joinable classes. A joinable class (inheriting from :class:`Joinable`) should implement :meth:`join_hook`, which returns a :class:`JoinHook` instance, in addition to :meth:`join_device` and :meth:`join_process_group` that return device and process group information, respectively. """ @abstractmethod def __init__(self): super().__init__() self._join_config = _JoinConfig.construct_disabled_join_config() @abstractmethod def join_hook(self, **kwargs) -> JoinHook: r""" Return a :class:`JoinHook` instance for the given :class:`Joinable`. Arguments: kwargs (dict): a :class:`dict` containing any keyword arguments to modify the behavior of the join hook at run time; all :class:`Joinable` instances sharing the same join context manager are forwarded the same value for ``kwargs``. """ ... @property @abstractmethod def join_device(self) -> torch.device: r"""Return the device from which to perform collective communications needed by the join context manager.""" ... @property @abstractmethod def join_process_group(self) -> Any: r"""Returns the process group for the collective communications needed by the join context manager itself.""" ... class _JoinConfig(NamedTuple): r"""This includes all fields needed from a :class:`Joinable` instance for the join context manager side.""" enable: bool throw_on_early_termination: bool is_first_joinable: bool @staticmethod def construct_disabled_join_config(): r"""Return a :class:`_JoinConfig` instance indicating that join-related logic should be disabled. e.g. if the caller is not in a join context manager. """ return _JoinConfig( enable=False, throw_on_early_termination=False, is_first_joinable=False ) class Join: r""" This class defines the generic join context manager, which allows custom hooks to be called after a process joins. These hooks should shadow the collective communications of non-joined processes to prevent hanging and erroring and to ensure algorithmic correctness. Refer to :class:`JoinHook` for details about the hook definition. .. warning:: The context manager requires each participating :class:`Joinable` to call the method :meth:`notify_join_context()` before its own per- iteration collective communications to ensure correctness. .. warning:: The context manager requires that all ``process_group`` attributes in the :class:`JoinHook` objects are the same. If there are multiple :class:`JoinHook` objects, then the ``device`` of the first is used. The process group and device information is used for checking for non- joined processes and for notifying processes to throw an exception if ``throw_on_early_termination`` is enabled, both of which using an all- reduce. Arguments: joinables (List[Joinable]): a list of the participating :class:`Joinable` s; their hooks are iterated over in the given order. enable (bool): a flag enabling uneven input detection; setting to ``False`` disables the context manager's functionality and should only be set when the user knows the inputs will not be uneven (default: ``True``). throw_on_early_termination (bool): a flag controlling whether to throw an exception upon detecting uneven inputs (default: ``False``). Example:: >>> import os >>> import torch >>> import torch.distributed as dist >>> import torch.multiprocessing as mp >>> # xdoctest: +SKIP >>> import torch.nn.parallel.DistributedDataParallel as DDP >>> import torch.distributed.optim.ZeroRedundancyOptimizer as ZeRO >>> from torch.distributed.algorithms.join import Join >>> >>> # On each spawned worker >>> def worker(rank): >>> dist.init_process_group("nccl", rank=rank, world_size=2) >>> model = DDP(torch.nn.Linear(1, 1).to(rank), device_ids=[rank]) >>> optim = ZeRO(model.parameters(), torch.optim.Adam, lr=0.01) >>> # Rank 1 gets one more input than rank 0 >>> inputs = [torch.tensor([1.]).to(rank) for _ in range(10 + rank)] >>> with Join([model, optim]): >>> for input in inputs: >>> loss = model(input).sum() >>> loss.backward() >>> optim.step() >>> # All ranks reach here without hanging/erroring """ def __init__( self, joinables: List[Joinable], enable: bool = True, throw_on_early_termination: bool = False, **kwargs, ): if len(joinables) == 0: raise ValueError("The join context manager requires at least one joinable") self._joinables = joinables self._join_hooks = [joinable.join_hook(**kwargs) for joinable in self._joinables] self._enable = enable self._throw_on_early_termination = throw_on_early_termination self._set_joinable_configs() self._extract_dist_info() def _set_joinable_configs(self) -> None: r"""Set the :class:`_JoinConfig` of each participating :class:`Joinable`.""" assert len(self._joinables) > 0 is_first_joinable = True for joinable in self._joinables: joinable._join_config = _JoinConfig( enable=self._enable, throw_on_early_termination=self._throw_on_early_termination, is_first_joinable=is_first_joinable ) is_first_joinable = False def _extract_dist_info(self) -> None: r""" Extract the process group and device information from the joinables. If there are multiple joinables, then the context manager uses the first specified device. Preconditions: ``self._joinables`` is not ``None`` and is non-empty. Raises: ValueError If there are multiple conflicting ``process_group`` attributes among the ``Joinable`` objects. """ process_group = None device = None for joinable in self._joinables: if process_group is None: process_group = joinable.join_process_group elif process_group != joinable.join_process_group: raise ValueError("Using join context manager with multiple process groups") if device is None: device = joinable.join_device self._process_group = process_group self._rank = dist.get_rank(self._process_group) self._device = device def __enter__(self): ... def __exit__( self, type: Optional[Type[BaseException]], value: Optional[BaseException], traceback: Optional[TracebackType] ): r""" Repeatedly runs the main hooks until all processes join; then, runs the post-hooks. Raises: RuntimeError If ``throw_on_early_termination=True``. """ if not self._enable or type: return # propagate the exception directly if one was raised all_procs_joined = False is_last_joiner = True i = 0 WARN_THRESHOLD = 1000 warnings.simplefilter("once") while not all_procs_joined: if i > WARN_THRESHOLD: warnings.warn( "Detected uneven input skew of greater than " f"{WARN_THRESHOLD}. This means that rank " f"{self._rank} has at least {WARN_THRESHOLD} " f"fewer inputs than other currently-active ranks. " "This level of skew could lead to performance " "degradation during training." ) # Shadow the all-reduce in non-joined processes num_nonjoined_procs = self._get_num_nonjoined_procs() if num_nonjoined_procs == 0: all_procs_joined = True else: if self._throw_on_early_termination: self._notify_procs_to_terminate() # Run main hooks for join_hook in self._join_hooks: join_hook.main_hook() is_last_joiner = False i += 1 # Run post-hooks for join_hook in self._join_hooks: join_hook.post_hook(is_last_joiner) def _get_num_nonjoined_procs(self): r"""Return the number of non-joined processes by shadowing an all-reduce in the non-joined processes.""" num_nonjoined_procs = torch.zeros(1, device=self._device) dist.all_reduce(num_nonjoined_procs, group=self._process_group) return num_nonjoined_procs.item() def _notify_procs_to_terminate(self): r"""Schedule an all-reduce to notify non-joined processes to terminate. Also raise a ``RuntimeError`` indicating that the current process has exhausted its inputs. """ ones = torch.ones(1, device=self._device) dist.all_reduce(ones, group=self._process_group) raise RuntimeError(f"Rank {self._rank} exhausted all inputs.") @staticmethod def notify_join_context(joinable: Joinable): r""" Notifies the join context manager that the calling process has not yet joined. Then, if ``throw_on_early_termination=True``, checks if uneven inputs have been detected (i.e. if one process has already joined) and throws an exception if so. This method should be called from a :class:`Joinable` object before its per-iteration collective communications. For example, this should be called at the beginning of the forward pass in :class:`DistributedDataParallel`. Only the first :class:`Joinable` object passed into the context manager performs the collective communications in this method, and for the others, this method is vacuous. Arguments: joinable (Joinable): the :class:`Joinable` object calling this method. Returns: An async work handle for the all-reduce meant to notify the context manager that the process has not yet joined if ``joinable`` is the first one passed into the context manager; ``None`` otherwise. """ assert hasattr(joinable, "_join_config"), \ f"Check that the {type(joinable)} constructor calls the " \ "``Joinable`` constructor" join_config = joinable._join_config # First joinable is responsible for the collective communications if not join_config.is_first_joinable or not join_config.enable: return None device = joinable.join_device process_group = joinable.join_process_group # Schedule an all-reduce to indicate that the caller has not yet joined ones = torch.ones(1, device=device) work = dist.all_reduce(ones, group=process_group, async_op=True) if join_config.throw_on_early_termination: # Check if uneven inputs have been detected zeros = torch.zeros(1, device=device) dist.all_reduce(zeros, group=process_group) should_throw = zeros.item() if should_throw: raise RuntimeError( "Detected at least one rank that exhausted inputs. " "Throwing across all ranks." ) return work