256 lines
9.4 KiB
Python
256 lines
9.4 KiB
Python
# Copyright 2019 Kakao Brain
|
|
#
|
|
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
"""The pipeline parallelism of Pipe."""
|
|
from queue import Queue
|
|
from types import TracebackType
|
|
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Type, Union, cast, Sequence
|
|
|
|
import torch
|
|
from torch import Tensor, nn
|
|
from torch.autograd.profiler import record_function
|
|
|
|
from .checkpoint import Checkpointing
|
|
from .copy import Copy, Wait
|
|
from .dependency import fork, join
|
|
from .microbatch import Batch
|
|
from .skip.layout import SkipLayout
|
|
from .skip.tracker import SkipTrackerThroughPotals, use_skip_tracker
|
|
from .stream import AbstractStream, current_stream, use_device
|
|
from .worker import Task, create_workers
|
|
|
|
__all__: List[str] = ["Pipeline"]
|
|
|
|
|
|
Tensors = Sequence[Tensor]
|
|
TensorOrTensors = Union[Tensor, Tensors]
|
|
|
|
ExcInfo = Tuple[Type[BaseException], BaseException, TracebackType]
|
|
|
|
# Queue is generic only in stubs.
|
|
# https://mypy.readthedocs.io/en/latest/common_issues.html#using-classes-that-are-generic-in-stubs-but-not-at-runtime
|
|
if TYPE_CHECKING:
|
|
InQueue = Queue[Optional["Task"]]
|
|
OutQueue = Queue[Tuple[bool, Union[Tuple["Task", Batch], ExcInfo, None]]]
|
|
else:
|
|
InQueue = Queue
|
|
OutQueue = Queue
|
|
|
|
|
|
def _depend(fork_from: Batch, join_to: Batch) -> None:
|
|
fork_from_idx = fork_from.find_tensor_idx()
|
|
join_to_idx = join_to.find_tensor_idx()
|
|
|
|
fork_from[fork_from_idx], phony = fork(fork_from[fork_from_idx])
|
|
join_to[join_to_idx] = join(join_to[join_to_idx], phony)
|
|
|
|
|
|
def _copy(batch: Batch, prev_stream: AbstractStream, next_stream: AbstractStream) -> None:
|
|
batch[:] = Copy.apply(prev_stream, next_stream, *batch)
|
|
# Gradients are only supported for float Tensors.
|
|
batch[:] = tuple([x.detach() if torch.is_tensor(x) and not x.is_floating_point() else x for x in batch])
|
|
|
|
|
|
def _wait(batch: Batch, prev_stream: AbstractStream, next_stream: AbstractStream) -> None:
|
|
batch[:] = Wait.apply(prev_stream, next_stream, *batch)
|
|
# Gradients are only supported for float Tensors.
|
|
batch[:] = tuple([x.detach() if torch.is_tensor(x) and not x.is_floating_point() else x for x in batch])
|
|
|
|
|
|
def _clock_cycles(m: int, n: int) -> Iterable[List[Tuple[int, int]]]:
|
|
"""Generate schedules for each clock cycle."""
|
|
# m: number of micro-batches
|
|
# n: number of partitions
|
|
# i: index of micro-batch
|
|
# j: index of partition
|
|
# k: clock number
|
|
#
|
|
# k (i,j) (i,j) (i,j)
|
|
# - ----- ----- -----
|
|
# 0 (0,0)
|
|
# 1 (1,0) (0,1)
|
|
# 2 (2,0) (1,1) (0,2)
|
|
# 3 (2,1) (1,2)
|
|
# 4 (2,2)
|
|
for k in range(m + n - 1):
|
|
yield [(k - j, j) for j in range(max(1 + k - m, 0), min(1 + k, n))]
|
|
|
|
|
|
class Pipeline:
|
|
"""The pipeline parallelism for Pipe."""
|
|
|
|
def __init__(
|
|
self,
|
|
partitions: List[nn.Sequential],
|
|
devices: List[torch.device],
|
|
copy_streams: List[List[AbstractStream]],
|
|
skip_layout: SkipLayout,
|
|
checkpoint_stop: int,
|
|
) -> None:
|
|
self.partitions = partitions
|
|
self.devices = devices
|
|
self.copy_streams = copy_streams
|
|
self.skip_layout = skip_layout
|
|
self.checkpoint_stop = checkpoint_stop
|
|
(self.in_queues, self.out_queues) = create_workers(devices)
|
|
|
|
def run(self, batches: List[Batch]) -> None:
|
|
"""Runs pipeline parallelism.
|
|
|
|
It modifies the given batches in place.
|
|
|
|
"""
|
|
partitions = self.partitions
|
|
devices = self.devices
|
|
skip_layout = self.skip_layout
|
|
|
|
m = len(batches)
|
|
n = len(partitions)
|
|
|
|
skip_trackers = [SkipTrackerThroughPotals(skip_layout) for _ in batches]
|
|
|
|
for schedule in _clock_cycles(m, n):
|
|
self.fence(batches, schedule, skip_trackers)
|
|
self.compute(batches, schedule, skip_trackers)
|
|
|
|
def fence(
|
|
self, batches: List[Batch], schedule: List[Tuple[int, int]], skip_trackers: List[SkipTrackerThroughPotals],
|
|
) -> None:
|
|
"""Copy micro-batches after computation for the previous micro-batches."""
|
|
copy_streams = self.copy_streams
|
|
skip_layout = self.skip_layout
|
|
|
|
for i, j in schedule:
|
|
# Ensure that batches[i-1] is executed after batches[i] in
|
|
# backpropagation by an explicit dependency.
|
|
if i != 0 and j != 0:
|
|
_depend(batches[i - 1], batches[i])
|
|
|
|
next_stream = copy_streams[j][i]
|
|
|
|
for prev_j, ns, name in skip_layout.copy_policy(j):
|
|
prev_stream = copy_streams[prev_j][i]
|
|
skip_trackers[i].copy(batches[i], prev_stream, next_stream, ns, name)
|
|
|
|
if j != 0:
|
|
prev_stream = copy_streams[j - 1][i]
|
|
_copy(batches[i], prev_stream, next_stream)
|
|
|
|
def compute(
|
|
self, batches: List[Batch], schedule: List[Tuple[int, int]], skip_trackers: List[SkipTrackerThroughPotals],
|
|
) -> None:
|
|
"""Run tasks with synchronization to copy streams."""
|
|
partitions = self.partitions
|
|
devices = self.devices
|
|
copy_streams = self.copy_streams
|
|
checkpoint_stop = self.checkpoint_stop
|
|
|
|
# Disable checkpointing if in eval mode.
|
|
if not self.partitions[0].training:
|
|
checkpoint_stop = 0
|
|
|
|
n = len(partitions)
|
|
streams = [current_stream(d) for d in devices]
|
|
exc_info: Optional[ExcInfo] = None
|
|
|
|
# With checkpointing, the autograd graph looks like this diagram:
|
|
# ┌─────┸──────┐
|
|
# │ Copy │
|
|
# └─────┰──────┘ (fence)
|
|
# ─ ─ ─ ╂ ─ ─ ─ ─ ─ ─ ─ ─ ─
|
|
# ┃ (compute)
|
|
# ┌─────┸──────┐
|
|
# │ Wait │ [1] Synchronize the current stream with the copy stream.
|
|
# └─────┰──────┘
|
|
# ┌─────┸──────┐
|
|
# │ Checkpoint │ [2] Compute a partition within checkpointing.
|
|
# └─────┰──────┘
|
|
# ┌─────┸──────┐
|
|
# │ Wait │ [3] Synchronize the copy stream with the current stream.
|
|
# └─────┰──────┘
|
|
# ┠ ─ ─ ─ ┐
|
|
# ┃ ┌─────┴─────┐
|
|
# ┃ │ Recompute │ [4] Schedule the recomputation at backpropagation.
|
|
# ┃ └─────┬─────┘
|
|
# ┠ ─ ─ ─ ┘
|
|
# ┃
|
|
# ─ ─ ─ ╂ ─ ─ ─ ─ ─ ─ ─ ─ ─
|
|
# ┌─────┸──────┐ (fence)
|
|
# │ Copy │
|
|
# └─────┰──────┘
|
|
for i, j in schedule:
|
|
batch = batches[i]
|
|
partition = partitions[j]
|
|
|
|
# Synchronize with the copied input. ([1] in the diagram)
|
|
if j != 0:
|
|
_wait(batch, copy_streams[j][i], streams[j])
|
|
|
|
# Determine whether checkpointing or not.
|
|
checkpoint = i < checkpoint_stop
|
|
if checkpoint:
|
|
|
|
def function(
|
|
*inputs,
|
|
partition: nn.Module = partition,
|
|
skip_tracker: SkipTrackerThroughPotals = skip_trackers[i],
|
|
chunk_id: int = i,
|
|
part_id: int = j,
|
|
) -> TensorOrTensors:
|
|
with use_skip_tracker(skip_tracker), record_function("chunk%d-part%d" % (chunk_id, part_id)):
|
|
return partition(*inputs)
|
|
|
|
chk = Checkpointing(function, batch) # type: ignore[arg-type]
|
|
task = Task(streams[j], compute=chk.checkpoint, finalize=chk.recompute)
|
|
del function, chk
|
|
|
|
else:
|
|
|
|
def compute(
|
|
batch: Batch = batch,
|
|
partition: nn.Module = partition,
|
|
skip_tracker: SkipTrackerThroughPotals = skip_trackers[i],
|
|
chunk_id: int = i,
|
|
part_id: int = j,
|
|
) -> Batch:
|
|
with use_skip_tracker(skip_tracker), record_function("chunk%d-part%d" % (chunk_id, part_id)):
|
|
return batch.call(partition)
|
|
|
|
task = Task(streams[j], compute=compute, finalize=None)
|
|
del compute
|
|
|
|
# Compute tasks in parallel. ([2] in the diagram)
|
|
self.in_queues[j].put(task)
|
|
|
|
for i, j in schedule:
|
|
ok, payload = self.out_queues[j].get()
|
|
|
|
# Hold the first exception.
|
|
if exc_info is not None:
|
|
continue
|
|
elif not ok:
|
|
exc_info = cast(ExcInfo, payload)
|
|
continue
|
|
|
|
task, batch = cast(Tuple[Task, Batch], payload)
|
|
|
|
# The copy stream synchronizes to copy the output. ([3] in the
|
|
# diagram)
|
|
if j != n - 1:
|
|
_wait(batch, streams[j], copy_streams[j][i])
|
|
|
|
# Finalize tasks. If checkpointing is enabled, here the
|
|
# recomputation is scheduled at backpropagation. ([4] in the
|
|
# diagram)
|
|
with use_device(devices[j]):
|
|
task.finalize(batch)
|
|
|
|
batches[i] = batch
|
|
|
|
# Fail at the first exception.
|
|
if exc_info is not None:
|
|
raise exc_info[0].with_traceback(exc_info[1], exc_info[2])
|