491 lines
18 KiB
Python
491 lines
18 KiB
Python
|
# Copyright 2019 Kakao Brain
|
||
|
#
|
||
|
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||
|
#
|
||
|
# This source code is licensed under the BSD license found in the
|
||
|
# LICENSE file in the root directory of this source tree.
|
||
|
"""The Pipe interface."""
|
||
|
from collections import OrderedDict
|
||
|
from typing import TYPE_CHECKING, Any, Iterable, Iterator, List, Optional, Union, Sequence, Tuple, cast
|
||
|
|
||
|
import torch
|
||
|
from torch import Tensor, nn
|
||
|
from torch.distributed.rpc import RRef
|
||
|
import torch.autograd
|
||
|
import torch.cuda
|
||
|
|
||
|
from . import microbatch
|
||
|
from .batchnorm import DeferredBatchNorm
|
||
|
from .pipeline import Pipeline
|
||
|
from .skip.layout import inspect_skip_layout
|
||
|
from .skip.skippable import verify_skippables
|
||
|
from .stream import AbstractStream, new_stream
|
||
|
|
||
|
__all__ = ["Pipe", "BalanceError", "PipeSequential", "WithDevice"]
|
||
|
|
||
|
|
||
|
Device = Union[torch.device, int, str]
|
||
|
Devices = Union[Iterable[Device], List[Device]]
|
||
|
|
||
|
Tensors = Sequence[Tensor]
|
||
|
TensorOrTensors = Union[Tensor, Tensors]
|
||
|
|
||
|
if TYPE_CHECKING:
|
||
|
# Typechecking: nn.Module is not a Generic
|
||
|
Module = nn.Module[TensorOrTensors] # type: ignore[type-arg]
|
||
|
NamedModules = OrderedDict[str, Module]
|
||
|
else:
|
||
|
Module = nn.Module
|
||
|
NamedModules = OrderedDict
|
||
|
|
||
|
|
||
|
def _recommend_auto_balance(message: str) -> str:
|
||
|
"""Expands a message with recommendation to :mod:`torchpipe.balance`."""
|
||
|
return f"""{message}
|
||
|
|
||
|
If your model is still under development, its optimal balance would change
|
||
|
frequently. In this case, we highly recommend 'torch.distributed.pipeline.sync.balance' for
|
||
|
naive automatic balancing:
|
||
|
|
||
|
from torch.distributed.pipeline.sync import Pipe
|
||
|
from torch.distributed.pipeline.sync.balance import balance_by_time
|
||
|
|
||
|
partitions = torch.cuda.device_count()
|
||
|
sample = torch.empty(...)
|
||
|
balance = balance_by_time(partitions, model, sample)
|
||
|
|
||
|
model = Pipe(model, balance, ...)
|
||
|
"""
|
||
|
|
||
|
|
||
|
def _verify_module(module: nn.Sequential) -> None:
|
||
|
if not isinstance(module, nn.Sequential):
|
||
|
raise TypeError("module must be nn.Sequential to be partitioned")
|
||
|
|
||
|
named_children = list(module.named_children())
|
||
|
if len(named_children) != len(module):
|
||
|
raise ValueError("module with duplicate children is not supported")
|
||
|
|
||
|
|
||
|
def _verify_splitting(
|
||
|
module: nn.Sequential, partitions: List[nn.Sequential], devices: List[torch.device]
|
||
|
) -> None:
|
||
|
num_parameters = len(list(module.parameters()))
|
||
|
num_child_parameters = sum(len(list(child.parameters())) for child in module.children())
|
||
|
if num_parameters == num_child_parameters:
|
||
|
return
|
||
|
|
||
|
for i in range(len(partitions)):
|
||
|
for j in range(i + 1, len(partitions)):
|
||
|
parti = partitions[i]
|
||
|
partj = partitions[j]
|
||
|
if devices[i] == devices[j]:
|
||
|
continue
|
||
|
for p in parti.parameters():
|
||
|
for q in partj.parameters():
|
||
|
if p is q:
|
||
|
raise ValueError("module with duplicate parameters on distinct devices is not supported")
|
||
|
|
||
|
|
||
|
class BalanceError(ValueError):
|
||
|
pass
|
||
|
|
||
|
|
||
|
def _retrieve_device(module: nn.Module) -> torch.device:
|
||
|
"""Validates all parameters in the Module have the same device and returns
|
||
|
the appropriate device.
|
||
|
|
||
|
Args:
|
||
|
An ``nn.Module`` to process.
|
||
|
|
||
|
Returns:
|
||
|
``torch.Device`` for the entire module.
|
||
|
|
||
|
Raises:
|
||
|
ValueError:
|
||
|
If devices for ``nn.Module`` parameters are not all same.
|
||
|
"""
|
||
|
|
||
|
device = None
|
||
|
for parameter in module.parameters():
|
||
|
if device is None:
|
||
|
device = parameter.device
|
||
|
elif device != parameter.device:
|
||
|
raise ValueError(
|
||
|
f'nn.Module: {module}, should have all parameters on a single device,'
|
||
|
' please use .to() to place the module on a single device')
|
||
|
|
||
|
return device if device is not None else torch.device("cpu")
|
||
|
|
||
|
|
||
|
class PipeSequential(nn.Sequential):
|
||
|
"""
|
||
|
Pipe variant of ``nn.Sequential`` which supports multiple inputs.
|
||
|
"""
|
||
|
|
||
|
def forward(self, *inputs):
|
||
|
for module in self:
|
||
|
if isinstance(inputs, Tuple): # type: ignore[arg-type]
|
||
|
inputs = module(*inputs)
|
||
|
else:
|
||
|
# Don't expand single variables (ex: lists/Tensor)
|
||
|
inputs = module(inputs)
|
||
|
return inputs
|
||
|
|
||
|
|
||
|
class WithDevice(nn.Module):
|
||
|
"""
|
||
|
Wraps an ``nn.Module`` which is part of ``nn.Sequential`` passed into :class:`Pipe`
|
||
|
that overrides the device for that module. In cases where :class:`Pipe`
|
||
|
can't implicitly determine the device for the module and places it on CPU,
|
||
|
this wrapper can be used to override the implicit behavior and explicitly
|
||
|
specify which device a module should run on.
|
||
|
|
||
|
The provided module is also moved to the given device via ``.to(device)``
|
||
|
by :class:`Pipe`
|
||
|
|
||
|
Args:
|
||
|
module(:class:`torch.nn.Module`): The module to be wrapped.
|
||
|
device(:class:`torch.device`): The device to run the module on.
|
||
|
|
||
|
Example::
|
||
|
>>> # xdoctest: +SKIP("distributed")
|
||
|
>>> fc1 = nn.Linear(16, 8).cuda(0)
|
||
|
>>> fc2 = nn.Linear(8, 4).cuda(1)
|
||
|
>>> dropout = nn.Dropout()
|
||
|
>>>
|
||
|
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA1)
|
||
|
>>> # Dropout does not have any parameters/buffers, but we want to
|
||
|
>>> # run it on cuda:1 to avoid any GPU to CPU transfers.
|
||
|
>>> model = nn.Sequential(fc1, fc2, WithDevice(dropout, 'cuda:1'))
|
||
|
>>> # xdoctest: +SKIP("Needs RPC framework init")
|
||
|
>>> model = Pipe(model, chunks=8)
|
||
|
"""
|
||
|
def __init__(self, module: nn.Module, device: torch.device):
|
||
|
super().__init__()
|
||
|
self._module = module
|
||
|
self._device = torch.device(device)
|
||
|
|
||
|
def forward(self, *args, **kwargs):
|
||
|
return self._module(*args, **kwargs)
|
||
|
|
||
|
@property
|
||
|
def module(self):
|
||
|
return self._module
|
||
|
|
||
|
@property
|
||
|
def device(self):
|
||
|
return self._device
|
||
|
|
||
|
|
||
|
def _assemble_partition(modules: List[nn.Module]):
|
||
|
modules_list: List[nn.Module] = []
|
||
|
for module in modules:
|
||
|
if isinstance(module, nn.Sequential):
|
||
|
modules_list.extend(module.children())
|
||
|
else:
|
||
|
modules_list.append(module)
|
||
|
return PipeSequential(*modules_list)
|
||
|
|
||
|
|
||
|
def _split_module(modules: nn.Sequential) -> Tuple[List[nn.Sequential], List[torch.device]]:
|
||
|
partitions = []
|
||
|
devices = []
|
||
|
|
||
|
current_partition = []
|
||
|
current_device = None
|
||
|
for name, module in modules.named_children():
|
||
|
if isinstance(module, WithDevice):
|
||
|
# Process device override and move module to appropriate device.
|
||
|
device = module.device
|
||
|
module = module.module
|
||
|
module.to(device)
|
||
|
else:
|
||
|
device = _retrieve_device(module)
|
||
|
if current_device is not None and (current_device != device or device.type == 'cpu'):
|
||
|
partitions.append(_assemble_partition(current_partition))
|
||
|
devices.append(current_device)
|
||
|
current_partition = []
|
||
|
current_device = device
|
||
|
current_partition.append(module)
|
||
|
|
||
|
if current_device is not None:
|
||
|
partitions.append(_assemble_partition(current_partition))
|
||
|
devices.append(current_device)
|
||
|
|
||
|
partitions = cast(List[nn.Sequential], nn.ModuleList(partitions))
|
||
|
|
||
|
return partitions, devices
|
||
|
|
||
|
|
||
|
MOVING_DENIED = TypeError("denied to move parameters and buffers, because Pipe should manage device placement")
|
||
|
|
||
|
|
||
|
class Pipe(Module):
|
||
|
"""Wraps an arbitrary :class:`nn.Sequential <torch.nn.Sequential>` module
|
||
|
to train on using synchronous pipeline parallelism. If the module requires
|
||
|
lots of memory and doesn't fit on a single GPU, pipeline parallelism is a
|
||
|
useful technique to employ for training.
|
||
|
|
||
|
The implementation is based on the torchgpipe_ paper.
|
||
|
|
||
|
.. _torchgpipe: https://arxiv.org/abs/2004.09910
|
||
|
|
||
|
Pipe combines pipeline parallelism with checkpointing to reduce peak
|
||
|
memory required to train while minimizing device under-utilization.
|
||
|
|
||
|
You should place all the modules on the appropriate devices and wrap them
|
||
|
into an :class:`nn.Sequential <torch.nn.Sequential>` module defining the
|
||
|
desired order of execution. If a module does not contain any
|
||
|
parameters/buffers, it is assumed this module should be executed on CPU
|
||
|
and appropriate input tensors to the module are moved to CPU before
|
||
|
execution. This behavior can be overridden by the :class:`WithDevice`
|
||
|
wrapper which can be used to explicitly specify which device a module
|
||
|
should run on.
|
||
|
|
||
|
Args:
|
||
|
module (:class:`nn.Sequential <torch.nn.Sequential>`):
|
||
|
sequential module to be parallelized using pipelining. Each module
|
||
|
in the sequence has to have all of its parameters on a single
|
||
|
device. Each module in the sequence has to either be an nn.Module
|
||
|
or :class:`nn.Sequential <torch.nn.Sequential>` (to combine multiple
|
||
|
sequential modules on a single device)
|
||
|
chunks (int):
|
||
|
number of micro-batches (default: ``1``)
|
||
|
checkpoint (str):
|
||
|
when to enable checkpointing, one of ``'always'``,
|
||
|
``'except_last'``, or ``'never'`` (default: ``'except_last'``).
|
||
|
``'never'`` disables checkpointing completely, ``'except_last'``
|
||
|
enables checkpointing for all micro-batches except the last one
|
||
|
and ``'always'`` enables checkpointing for all micro-batches.
|
||
|
deferred_batch_norm (bool):
|
||
|
whether to use deferred ``BatchNorm`` moving statistics (default:
|
||
|
:data:`False`). If set to :data:`True`, we track statistics across
|
||
|
multiple micro-batches to update the running statistics per
|
||
|
mini-batch.
|
||
|
|
||
|
Raises:
|
||
|
TypeError:
|
||
|
the module is not a :class:`nn.Sequential <torch.nn.Sequential>`.
|
||
|
ValueError:
|
||
|
invalid arguments
|
||
|
|
||
|
Example::
|
||
|
Pipeline of two FC layers across GPUs 0 and 1.
|
||
|
|
||
|
>>> # Need to initialize RPC framework first.
|
||
|
>>> # xdoctest: +SKIP
|
||
|
>>> os.environ['MASTER_ADDR'] = 'localhost'
|
||
|
>>> os.environ['MASTER_PORT'] = '29500'
|
||
|
>>> torch.distributed.rpc.init_rpc('worker', rank=0, world_size=1)
|
||
|
>>>
|
||
|
>>> # Build pipe.
|
||
|
>>> fc1 = nn.Linear(16, 8).cuda(0)
|
||
|
>>> fc2 = nn.Linear(8, 4).cuda(1)
|
||
|
>>> model = nn.Sequential(fc1, fc2)
|
||
|
>>> model = Pipe(model, chunks=8)
|
||
|
>>> input = torch.rand(16, 16).cuda(0)
|
||
|
>>> output_rref = model(input)
|
||
|
|
||
|
.. note::
|
||
|
You can wrap a :class:`Pipe` model with
|
||
|
:class:`torch.nn.parallel.DistributedDataParallel` only when the
|
||
|
checkpoint parameter of :class:`Pipe` is ``'never'``.
|
||
|
|
||
|
.. note::
|
||
|
:class:`Pipe` only supports intra-node pipelining currently, but
|
||
|
will be expanded to support inter-node pipelining in the future.
|
||
|
The forward function returns an :class:`~torch.distributed.rpc.RRef`
|
||
|
to allow for inter-node pipelining in the future, where the output
|
||
|
might be on a remote host. For intra-node pipelining you can use
|
||
|
:meth:`~torch.distributed.rpc.RRef.local_value` to retrieve the
|
||
|
output locally.
|
||
|
|
||
|
.. warning::
|
||
|
:class:`Pipe` is experimental and subject to change.
|
||
|
"""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
module: nn.Sequential,
|
||
|
chunks: int = 1,
|
||
|
checkpoint: str = "except_last",
|
||
|
deferred_batch_norm: bool = False,
|
||
|
) -> None:
|
||
|
super().__init__()
|
||
|
|
||
|
# Check if RPC framework is initialized.
|
||
|
if not torch.distributed.rpc._is_current_rpc_agent_set():
|
||
|
raise RuntimeError(
|
||
|
'Please initialize RPC framework for Pipe using '
|
||
|
'torch.distributed.rpc.init_rpc')
|
||
|
|
||
|
chunks = int(chunks)
|
||
|
checkpoint = str(checkpoint)
|
||
|
|
||
|
if chunks <= 0:
|
||
|
raise ValueError("number of chunks must be positive integer")
|
||
|
if checkpoint not in ["always", "except_last", "never"]:
|
||
|
raise ValueError("checkpoint is not one of 'always', 'except_last', or 'never'")
|
||
|
|
||
|
_verify_module(module)
|
||
|
|
||
|
# Verify if the underlying skippable modules satisfy integrity. The
|
||
|
# integrity can be verified before forward() because it is static.
|
||
|
verify_skippables(module)
|
||
|
|
||
|
self.chunks = chunks
|
||
|
self.checkpoint = checkpoint
|
||
|
|
||
|
if deferred_batch_norm:
|
||
|
module = DeferredBatchNorm.convert_deferred_batch_norm(module, chunks)
|
||
|
|
||
|
self.partitions, self.devices = _split_module(module)
|
||
|
_verify_splitting(module, self.partitions, self.devices)
|
||
|
|
||
|
self._copy_streams: List[List[AbstractStream]] = []
|
||
|
self._skip_layout = inspect_skip_layout(self.partitions)
|
||
|
|
||
|
# Separate CUDA streams for copy.
|
||
|
copy_streams = self._ensure_copy_streams()
|
||
|
|
||
|
# The micro-batch index where the checkpointing stops.
|
||
|
checkpoint_stop = {"always": self.chunks, "except_last": self.chunks - 1, "never": 0}[self.checkpoint]
|
||
|
|
||
|
self.pipeline = Pipeline(self.partitions, self.devices, copy_streams, self._skip_layout, checkpoint_stop)
|
||
|
|
||
|
def __len__(self) -> int:
|
||
|
"""Counts the length of the underlying sequential module."""
|
||
|
return sum(len(p) for p in self.partitions)
|
||
|
|
||
|
def __getitem__(self, index: int) -> nn.Module:
|
||
|
"""Gets a layer in the underlying sequential module."""
|
||
|
partitions = self.partitions
|
||
|
if index < 0:
|
||
|
partitions = partitions[::-1]
|
||
|
|
||
|
for partition in partitions:
|
||
|
try:
|
||
|
return partition[index]
|
||
|
except IndexError:
|
||
|
pass
|
||
|
|
||
|
shift = len(partition)
|
||
|
|
||
|
if index < 0:
|
||
|
index += shift
|
||
|
else:
|
||
|
index -= shift
|
||
|
|
||
|
raise IndexError
|
||
|
|
||
|
def __iter__(self) -> Iterator[nn.Module]:
|
||
|
"""Iterates over children of the underlying sequential module."""
|
||
|
for partition in self.partitions:
|
||
|
yield from partition
|
||
|
|
||
|
# Pipe should manage the device of each partition.
|
||
|
# Deny cuda(), cpu(), and to() with device, by TypeError.
|
||
|
def cuda(self, device: Optional[Device] = None) -> "Pipe":
|
||
|
raise MOVING_DENIED
|
||
|
|
||
|
def cpu(self) -> "Pipe":
|
||
|
raise MOVING_DENIED
|
||
|
|
||
|
def to(self, *args: Any, **kwargs: Any) -> "Pipe":
|
||
|
# Deny these usages:
|
||
|
#
|
||
|
# - to(device[, dtype, non_blocking])
|
||
|
# - to(tensor[, non_blocking])
|
||
|
#
|
||
|
# But allow this:
|
||
|
#
|
||
|
# - to(dtype[, non_blocking])
|
||
|
#
|
||
|
if "device" in kwargs or "tensor" in kwargs:
|
||
|
raise MOVING_DENIED
|
||
|
|
||
|
if args:
|
||
|
if isinstance(args[0], (torch.device, int, str)):
|
||
|
raise MOVING_DENIED
|
||
|
if torch.is_tensor(args[0]):
|
||
|
raise MOVING_DENIED
|
||
|
|
||
|
return super().to(*args, **kwargs)
|
||
|
|
||
|
def _ensure_copy_streams(self) -> List[List[AbstractStream]]:
|
||
|
"""Ensures that :class:`Pipe` caches CUDA streams for copy.
|
||
|
|
||
|
It's worth to cache CUDA streams although PyTorch already manages a
|
||
|
pool of pre-allocated CUDA streams, because it may reduce GPU memory
|
||
|
fragmentation when the number of micro-batches is small.
|
||
|
|
||
|
"""
|
||
|
if not self._copy_streams:
|
||
|
for device in self.devices:
|
||
|
self._copy_streams.append([new_stream(device) for _ in range(self.chunks)])
|
||
|
|
||
|
return self._copy_streams
|
||
|
|
||
|
def forward(self, *inputs) -> RRef:
|
||
|
"""
|
||
|
Processes a single input mini-batch through the pipe and returns an
|
||
|
:class:`~torch.distributed.rpc.RRef` pointing to the output.
|
||
|
:class:`Pipe` is a fairly transparent module wrapper. It doesn't
|
||
|
modify the input and output signature of the underlying module. But
|
||
|
there's type restriction. Input and output have to contain at least one
|
||
|
tensor. This restriction is applied at partition boundaries too.
|
||
|
|
||
|
The sequence of inputs are fed into the first stage of the pipeline as
|
||
|
``*inputs``. As a result the positional args for this function should
|
||
|
match the positional args for the first stage of the pipeline. The same
|
||
|
condition applies for output of one stage of the pipeline which is the
|
||
|
input for the next stage.
|
||
|
|
||
|
The input tensor is split into multiple micro-batches based on the
|
||
|
``chunks`` parameter used to initialize :class:`Pipe`. The batch size
|
||
|
is assumed to be the first dimension of the tensor and if the batch
|
||
|
size is less than ``chunks``, the number of micro-batches is equal to
|
||
|
the batch size.
|
||
|
|
||
|
Only tensors are split into multiple micro-batches, non-Tensor inputs
|
||
|
are just replicated as-is in each micro-batch. For non-Tensor outputs
|
||
|
in the last stage of the pipeline, they are aggregated as a ``List``
|
||
|
and returned the user. For example, if you have 2 micro-batches
|
||
|
returning the integer 5, the user would receive the consolidated
|
||
|
output of `[5, 5]`
|
||
|
|
||
|
All the input tensors need to be on the same device as the first
|
||
|
partition of the pipeline.
|
||
|
|
||
|
If a tensor is wrapped with the :class:`NoChunk` wrapper, the tensor
|
||
|
is not split across micro-batches and is replicated as-is similar to
|
||
|
non-tensors.
|
||
|
|
||
|
Args:
|
||
|
inputs: input mini-batch
|
||
|
|
||
|
Returns:
|
||
|
:class:`~torch.distributed.rpc.RRef` to the output of the mini-batch
|
||
|
|
||
|
Raises:
|
||
|
TypeError: input doesn't contain at least one tensor
|
||
|
|
||
|
"""
|
||
|
first_partition_device = self.devices[0] if len(self.devices) != 0 else torch.device("cpu")
|
||
|
microbatch.check(first_partition_device, *inputs)
|
||
|
|
||
|
if not self.devices:
|
||
|
# Empty sequential module is not illegal.
|
||
|
return RRef(*inputs)
|
||
|
|
||
|
# Divide a mini-batch into micro-batches.
|
||
|
batches = microbatch.scatter(*inputs, chunks=self.chunks)
|
||
|
|
||
|
# Run pipeline parallelism.
|
||
|
self.pipeline.run(batches)
|
||
|
|
||
|
# Merge the micro-batches into one mini-batch.
|
||
|
output = microbatch.gather(batches)
|
||
|
return RRef(output)
|