347 lines
13 KiB
Python
347 lines
13 KiB
Python
|
import warnings
|
||
|
from abc import ABC, abstractmethod
|
||
|
from types import TracebackType
|
||
|
from typing import Any, List, NamedTuple, Optional, Type
|
||
|
|
||
|
import torch
|
||
|
import torch.distributed as dist
|
||
|
|
||
|
__all__ = ['JoinHook', 'Joinable', 'Join']
|
||
|
|
||
|
class JoinHook:
|
||
|
r"""
|
||
|
This defines a join hook, which provides two entry points in the join context manager.
|
||
|
|
||
|
Entry points : a main hook, which is called repeatedly while there exists a non-joined
|
||
|
process, and a post-hook, which is called once all processes have joined.
|
||
|
|
||
|
To implement a join hook for the generic join context manager, define a
|
||
|
class that inherits from :class:`JoinHook` and override ``main_hook()`` and
|
||
|
``post_hook()`` as appropriate.
|
||
|
"""
|
||
|
|
||
|
def main_hook(self) -> None:
|
||
|
r"""Call this hook while there exists a non-joined process to shadow collective communications in a training iteration.
|
||
|
|
||
|
Training iteration i.e., in one forward pass, backward pass, and optimizer step.
|
||
|
"""
|
||
|
...
|
||
|
|
||
|
def post_hook(self, is_last_joiner: bool) -> None:
|
||
|
r"""
|
||
|
Call hook after all processes have joined.
|
||
|
|
||
|
It is passed an additional ``bool`` argument ``is_last_joiner``, which indicates if the rank is one of the last to join.
|
||
|
|
||
|
Arguments:
|
||
|
is_last_joiner (bool): ``True`` if the rank is one of the last to
|
||
|
join; ``False`` otherwise.
|
||
|
"""
|
||
|
...
|
||
|
|
||
|
|
||
|
class Joinable(ABC):
|
||
|
r"""
|
||
|
This defines an abstract base class for joinable classes.
|
||
|
|
||
|
A joinable class
|
||
|
(inheriting from :class:`Joinable`) should implement :meth:`join_hook`,
|
||
|
which returns a :class:`JoinHook` instance, in addition to
|
||
|
:meth:`join_device` and :meth:`join_process_group` that return device and
|
||
|
process group information, respectively.
|
||
|
"""
|
||
|
|
||
|
@abstractmethod
|
||
|
def __init__(self):
|
||
|
super().__init__()
|
||
|
self._join_config = _JoinConfig.construct_disabled_join_config()
|
||
|
|
||
|
@abstractmethod
|
||
|
def join_hook(self, **kwargs) -> JoinHook:
|
||
|
r"""
|
||
|
Return a :class:`JoinHook` instance for the given :class:`Joinable`.
|
||
|
|
||
|
Arguments:
|
||
|
kwargs (dict): a :class:`dict` containing any keyword arguments
|
||
|
to modify the behavior of the join hook at run time; all
|
||
|
:class:`Joinable` instances sharing the same join context
|
||
|
manager are forwarded the same value for ``kwargs``.
|
||
|
"""
|
||
|
...
|
||
|
|
||
|
@property
|
||
|
@abstractmethod
|
||
|
def join_device(self) -> torch.device:
|
||
|
r"""Return the device from which to perform collective communications needed by the join context manager."""
|
||
|
...
|
||
|
|
||
|
@property
|
||
|
@abstractmethod
|
||
|
def join_process_group(self) -> Any:
|
||
|
r"""Returns the process group for the collective communications needed by the join context manager itself."""
|
||
|
...
|
||
|
|
||
|
|
||
|
class _JoinConfig(NamedTuple):
|
||
|
r"""This includes all fields needed from a :class:`Joinable` instance for the join context manager side."""
|
||
|
|
||
|
enable: bool
|
||
|
throw_on_early_termination: bool
|
||
|
is_first_joinable: bool
|
||
|
|
||
|
@staticmethod
|
||
|
def construct_disabled_join_config():
|
||
|
r"""Return a :class:`_JoinConfig` instance indicating that join-related logic should be disabled.
|
||
|
|
||
|
e.g. if the caller is not in a join context manager.
|
||
|
"""
|
||
|
return _JoinConfig(
|
||
|
enable=False,
|
||
|
throw_on_early_termination=False,
|
||
|
is_first_joinable=False
|
||
|
)
|
||
|
|
||
|
|
||
|
|
||
|
class Join:
|
||
|
r"""
|
||
|
This class defines the generic join context manager, which allows custom hooks to be called after a process joins.
|
||
|
|
||
|
These hooks should shadow the
|
||
|
collective communications of non-joined processes to prevent hanging and
|
||
|
erroring and to ensure algorithmic correctness. Refer to :class:`JoinHook`
|
||
|
for details about the hook definition.
|
||
|
|
||
|
.. warning::
|
||
|
The context manager requires each participating :class:`Joinable` to
|
||
|
call the method :meth:`notify_join_context()` before its own per-
|
||
|
iteration collective communications to ensure correctness.
|
||
|
|
||
|
.. warning::
|
||
|
The context manager requires that all ``process_group`` attributes in
|
||
|
the :class:`JoinHook` objects are the same. If there are multiple
|
||
|
:class:`JoinHook` objects, then the ``device`` of the first is used.
|
||
|
The process group and device information is used for checking for non-
|
||
|
joined processes and for notifying processes to throw an exception if
|
||
|
``throw_on_early_termination`` is enabled, both of which using an all-
|
||
|
reduce.
|
||
|
|
||
|
Arguments:
|
||
|
joinables (List[Joinable]): a list of the participating
|
||
|
:class:`Joinable` s; their hooks are iterated over in the given
|
||
|
order.
|
||
|
|
||
|
enable (bool): a flag enabling uneven input detection; setting to
|
||
|
``False`` disables the context manager's functionality and should
|
||
|
only be set when the user knows the inputs will not be uneven
|
||
|
(default: ``True``).
|
||
|
|
||
|
throw_on_early_termination (bool): a flag controlling whether to throw an
|
||
|
exception upon detecting uneven inputs (default: ``False``).
|
||
|
|
||
|
Example::
|
||
|
|
||
|
>>> import os
|
||
|
>>> import torch
|
||
|
>>> import torch.distributed as dist
|
||
|
>>> import torch.multiprocessing as mp
|
||
|
>>> # xdoctest: +SKIP
|
||
|
>>> import torch.nn.parallel.DistributedDataParallel as DDP
|
||
|
>>> import torch.distributed.optim.ZeroRedundancyOptimizer as ZeRO
|
||
|
>>> from torch.distributed.algorithms.join import Join
|
||
|
>>>
|
||
|
>>> # On each spawned worker
|
||
|
>>> def worker(rank):
|
||
|
>>> dist.init_process_group("nccl", rank=rank, world_size=2)
|
||
|
>>> model = DDP(torch.nn.Linear(1, 1).to(rank), device_ids=[rank])
|
||
|
>>> optim = ZeRO(model.parameters(), torch.optim.Adam, lr=0.01)
|
||
|
>>> # Rank 1 gets one more input than rank 0
|
||
|
>>> inputs = [torch.tensor([1.]).to(rank) for _ in range(10 + rank)]
|
||
|
>>> with Join([model, optim]):
|
||
|
>>> for input in inputs:
|
||
|
>>> loss = model(input).sum()
|
||
|
>>> loss.backward()
|
||
|
>>> optim.step()
|
||
|
>>> # All ranks reach here without hanging/erroring
|
||
|
"""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
joinables: List[Joinable],
|
||
|
enable: bool = True,
|
||
|
throw_on_early_termination: bool = False,
|
||
|
**kwargs,
|
||
|
):
|
||
|
if len(joinables) == 0:
|
||
|
raise ValueError("The join context manager requires at least one joinable")
|
||
|
self._joinables = joinables
|
||
|
self._join_hooks = [joinable.join_hook(**kwargs) for joinable in self._joinables]
|
||
|
self._enable = enable
|
||
|
self._throw_on_early_termination = throw_on_early_termination
|
||
|
self._set_joinable_configs()
|
||
|
self._extract_dist_info()
|
||
|
|
||
|
def _set_joinable_configs(self) -> None:
|
||
|
r"""Set the :class:`_JoinConfig` of each participating :class:`Joinable`."""
|
||
|
assert len(self._joinables) > 0
|
||
|
is_first_joinable = True
|
||
|
for joinable in self._joinables:
|
||
|
joinable._join_config = _JoinConfig(
|
||
|
enable=self._enable,
|
||
|
throw_on_early_termination=self._throw_on_early_termination,
|
||
|
is_first_joinable=is_first_joinable
|
||
|
)
|
||
|
is_first_joinable = False
|
||
|
|
||
|
def _extract_dist_info(self) -> None:
|
||
|
r"""
|
||
|
Extract the process group and device information from the joinables.
|
||
|
|
||
|
If there are multiple joinables, then the context manager uses the
|
||
|
first specified device.
|
||
|
|
||
|
Preconditions:
|
||
|
``self._joinables`` is not ``None`` and is non-empty.
|
||
|
|
||
|
Raises:
|
||
|
ValueError
|
||
|
If there are multiple conflicting ``process_group`` attributes
|
||
|
among the ``Joinable`` objects.
|
||
|
"""
|
||
|
process_group = None
|
||
|
device = None
|
||
|
for joinable in self._joinables:
|
||
|
if process_group is None:
|
||
|
process_group = joinable.join_process_group
|
||
|
elif process_group != joinable.join_process_group:
|
||
|
raise ValueError("Using join context manager with multiple process groups")
|
||
|
if device is None:
|
||
|
device = joinable.join_device
|
||
|
self._process_group = process_group
|
||
|
self._rank = dist.get_rank(self._process_group)
|
||
|
self._device = device
|
||
|
|
||
|
def __enter__(self):
|
||
|
...
|
||
|
|
||
|
def __exit__(
|
||
|
self,
|
||
|
type: Optional[Type[BaseException]],
|
||
|
value: Optional[BaseException],
|
||
|
traceback: Optional[TracebackType]
|
||
|
):
|
||
|
r"""
|
||
|
Repeatedly runs the main hooks until all processes join; then, runs the post-hooks.
|
||
|
|
||
|
Raises:
|
||
|
RuntimeError
|
||
|
If ``throw_on_early_termination=True``.
|
||
|
"""
|
||
|
if not self._enable or type:
|
||
|
return # propagate the exception directly if one was raised
|
||
|
|
||
|
all_procs_joined = False
|
||
|
is_last_joiner = True
|
||
|
|
||
|
i = 0
|
||
|
WARN_THRESHOLD = 1000
|
||
|
warnings.simplefilter("once")
|
||
|
|
||
|
while not all_procs_joined:
|
||
|
if i > WARN_THRESHOLD:
|
||
|
warnings.warn(
|
||
|
"Detected uneven input skew of greater than "
|
||
|
f"{WARN_THRESHOLD}. This means that rank "
|
||
|
f"{self._rank} has at least {WARN_THRESHOLD} "
|
||
|
f"fewer inputs than other currently-active ranks. "
|
||
|
"This level of skew could lead to performance "
|
||
|
"degradation during training."
|
||
|
)
|
||
|
# Shadow the all-reduce in non-joined processes
|
||
|
num_nonjoined_procs = self._get_num_nonjoined_procs()
|
||
|
if num_nonjoined_procs == 0:
|
||
|
all_procs_joined = True
|
||
|
else:
|
||
|
if self._throw_on_early_termination:
|
||
|
self._notify_procs_to_terminate()
|
||
|
|
||
|
# Run main hooks
|
||
|
for join_hook in self._join_hooks:
|
||
|
join_hook.main_hook()
|
||
|
|
||
|
is_last_joiner = False
|
||
|
i += 1
|
||
|
|
||
|
# Run post-hooks
|
||
|
for join_hook in self._join_hooks:
|
||
|
join_hook.post_hook(is_last_joiner)
|
||
|
|
||
|
def _get_num_nonjoined_procs(self):
|
||
|
r"""Return the number of non-joined processes by shadowing an all-reduce in the non-joined processes."""
|
||
|
num_nonjoined_procs = torch.zeros(1, device=self._device)
|
||
|
dist.all_reduce(num_nonjoined_procs, group=self._process_group)
|
||
|
return num_nonjoined_procs.item()
|
||
|
|
||
|
def _notify_procs_to_terminate(self):
|
||
|
r"""Schedule an all-reduce to notify non-joined processes to terminate.
|
||
|
|
||
|
Also raise a ``RuntimeError`` indicating that the current process has exhausted its inputs.
|
||
|
"""
|
||
|
ones = torch.ones(1, device=self._device)
|
||
|
dist.all_reduce(ones, group=self._process_group)
|
||
|
raise RuntimeError(f"Rank {self._rank} exhausted all inputs.")
|
||
|
|
||
|
@staticmethod
|
||
|
def notify_join_context(joinable: Joinable):
|
||
|
r"""
|
||
|
Notifies the join context manager that the calling process has not yet joined.
|
||
|
|
||
|
Then, if ``throw_on_early_termination=True``, checks if uneven inputs have been detected
|
||
|
(i.e. if one process has already joined) and throws an exception if so.
|
||
|
|
||
|
This method should be called from a :class:`Joinable` object before
|
||
|
its per-iteration collective communications. For example, this should
|
||
|
be called at the beginning of the forward pass in
|
||
|
:class:`DistributedDataParallel`.
|
||
|
|
||
|
Only the first :class:`Joinable` object passed into the context
|
||
|
manager performs the collective communications in this method, and
|
||
|
for the others, this method is vacuous.
|
||
|
|
||
|
Arguments:
|
||
|
joinable (Joinable): the :class:`Joinable` object calling this
|
||
|
method.
|
||
|
|
||
|
Returns:
|
||
|
An async work handle for the all-reduce meant to notify the context
|
||
|
manager that the process has not yet joined if ``joinable`` is the
|
||
|
first one passed into the context manager; ``None`` otherwise.
|
||
|
"""
|
||
|
assert hasattr(joinable, "_join_config"), \
|
||
|
f"Check that the {type(joinable)} constructor calls the " \
|
||
|
"``Joinable`` constructor"
|
||
|
|
||
|
join_config = joinable._join_config
|
||
|
# First joinable is responsible for the collective communications
|
||
|
if not join_config.is_first_joinable or not join_config.enable:
|
||
|
return None
|
||
|
|
||
|
device = joinable.join_device
|
||
|
process_group = joinable.join_process_group
|
||
|
|
||
|
# Schedule an all-reduce to indicate that the caller has not yet joined
|
||
|
ones = torch.ones(1, device=device)
|
||
|
work = dist.all_reduce(ones, group=process_group, async_op=True)
|
||
|
|
||
|
if join_config.throw_on_early_termination:
|
||
|
# Check if uneven inputs have been detected
|
||
|
zeros = torch.zeros(1, device=device)
|
||
|
dist.all_reduce(zeros, group=process_group)
|
||
|
should_throw = zeros.item()
|
||
|
if should_throw:
|
||
|
raise RuntimeError(
|
||
|
"Detected at least one rank that exhausted inputs. "
|
||
|
"Throwing across all ranks."
|
||
|
)
|
||
|
return work
|