# Copyright 2019 Kakao Brain # # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. """The pipeline parallelism of Pipe.""" from queue import Queue from types import TracebackType from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Type, Union, cast, Sequence import torch from torch import Tensor, nn from torch.autograd.profiler import record_function from .checkpoint import Checkpointing from .copy import Copy, Wait from .dependency import fork, join from .microbatch import Batch from .skip.layout import SkipLayout from .skip.tracker import SkipTrackerThroughPotals, use_skip_tracker from .stream import AbstractStream, current_stream, use_device from .worker import Task, create_workers __all__: List[str] = ["Pipeline"] Tensors = Sequence[Tensor] TensorOrTensors = Union[Tensor, Tensors] ExcInfo = Tuple[Type[BaseException], BaseException, TracebackType] # Queue is generic only in stubs. # https://mypy.readthedocs.io/en/latest/common_issues.html#using-classes-that-are-generic-in-stubs-but-not-at-runtime if TYPE_CHECKING: InQueue = Queue[Optional["Task"]] OutQueue = Queue[Tuple[bool, Union[Tuple["Task", Batch], ExcInfo, None]]] else: InQueue = Queue OutQueue = Queue def _depend(fork_from: Batch, join_to: Batch) -> None: fork_from_idx = fork_from.find_tensor_idx() join_to_idx = join_to.find_tensor_idx() fork_from[fork_from_idx], phony = fork(fork_from[fork_from_idx]) join_to[join_to_idx] = join(join_to[join_to_idx], phony) def _copy(batch: Batch, prev_stream: AbstractStream, next_stream: AbstractStream) -> None: batch[:] = Copy.apply(prev_stream, next_stream, *batch) # Gradients are only supported for float Tensors. batch[:] = tuple([x.detach() if torch.is_tensor(x) and not x.is_floating_point() else x for x in batch]) def _wait(batch: Batch, prev_stream: AbstractStream, next_stream: AbstractStream) -> None: batch[:] = Wait.apply(prev_stream, next_stream, *batch) # Gradients are only supported for float Tensors. batch[:] = tuple([x.detach() if torch.is_tensor(x) and not x.is_floating_point() else x for x in batch]) def _clock_cycles(m: int, n: int) -> Iterable[List[Tuple[int, int]]]: """Generate schedules for each clock cycle.""" # m: number of micro-batches # n: number of partitions # i: index of micro-batch # j: index of partition # k: clock number # # k (i,j) (i,j) (i,j) # - ----- ----- ----- # 0 (0,0) # 1 (1,0) (0,1) # 2 (2,0) (1,1) (0,2) # 3 (2,1) (1,2) # 4 (2,2) for k in range(m + n - 1): yield [(k - j, j) for j in range(max(1 + k - m, 0), min(1 + k, n))] class Pipeline: """The pipeline parallelism for Pipe.""" def __init__( self, partitions: List[nn.Sequential], devices: List[torch.device], copy_streams: List[List[AbstractStream]], skip_layout: SkipLayout, checkpoint_stop: int, ) -> None: self.partitions = partitions self.devices = devices self.copy_streams = copy_streams self.skip_layout = skip_layout self.checkpoint_stop = checkpoint_stop (self.in_queues, self.out_queues) = create_workers(devices) def run(self, batches: List[Batch]) -> None: """Runs pipeline parallelism. It modifies the given batches in place. """ partitions = self.partitions devices = self.devices skip_layout = self.skip_layout m = len(batches) n = len(partitions) skip_trackers = [SkipTrackerThroughPotals(skip_layout) for _ in batches] for schedule in _clock_cycles(m, n): self.fence(batches, schedule, skip_trackers) self.compute(batches, schedule, skip_trackers) def fence( self, batches: List[Batch], schedule: List[Tuple[int, int]], skip_trackers: List[SkipTrackerThroughPotals], ) -> None: """Copy micro-batches after computation for the previous micro-batches.""" copy_streams = self.copy_streams skip_layout = self.skip_layout for i, j in schedule: # Ensure that batches[i-1] is executed after batches[i] in # backpropagation by an explicit dependency. if i != 0 and j != 0: _depend(batches[i - 1], batches[i]) next_stream = copy_streams[j][i] for prev_j, ns, name in skip_layout.copy_policy(j): prev_stream = copy_streams[prev_j][i] skip_trackers[i].copy(batches[i], prev_stream, next_stream, ns, name) if j != 0: prev_stream = copy_streams[j - 1][i] _copy(batches[i], prev_stream, next_stream) def compute( self, batches: List[Batch], schedule: List[Tuple[int, int]], skip_trackers: List[SkipTrackerThroughPotals], ) -> None: """Run tasks with synchronization to copy streams.""" partitions = self.partitions devices = self.devices copy_streams = self.copy_streams checkpoint_stop = self.checkpoint_stop # Disable checkpointing if in eval mode. if not self.partitions[0].training: checkpoint_stop = 0 n = len(partitions) streams = [current_stream(d) for d in devices] exc_info: Optional[ExcInfo] = None # With checkpointing, the autograd graph looks like this diagram: # ┌─────┸──────┐ # │ Copy │ # └─────┰──────┘ (fence) # ─ ─ ─ ╂ ─ ─ ─ ─ ─ ─ ─ ─ ─ # ┃ (compute) # ┌─────┸──────┐ # │ Wait │ [1] Synchronize the current stream with the copy stream. # └─────┰──────┘ # ┌─────┸──────┐ # │ Checkpoint │ [2] Compute a partition within checkpointing. # └─────┰──────┘ # ┌─────┸──────┐ # │ Wait │ [3] Synchronize the copy stream with the current stream. # └─────┰──────┘ # ┠ ─ ─ ─ ┐ # ┃ ┌─────┴─────┐ # ┃ │ Recompute │ [4] Schedule the recomputation at backpropagation. # ┃ └─────┬─────┘ # ┠ ─ ─ ─ ┘ # ┃ # ─ ─ ─ ╂ ─ ─ ─ ─ ─ ─ ─ ─ ─ # ┌─────┸──────┐ (fence) # │ Copy │ # └─────┰──────┘ for i, j in schedule: batch = batches[i] partition = partitions[j] # Synchronize with the copied input. ([1] in the diagram) if j != 0: _wait(batch, copy_streams[j][i], streams[j]) # Determine whether checkpointing or not. checkpoint = i < checkpoint_stop if checkpoint: def function( *inputs, partition: nn.Module = partition, skip_tracker: SkipTrackerThroughPotals = skip_trackers[i], chunk_id: int = i, part_id: int = j, ) -> TensorOrTensors: with use_skip_tracker(skip_tracker), record_function("chunk%d-part%d" % (chunk_id, part_id)): return partition(*inputs) chk = Checkpointing(function, batch) # type: ignore[arg-type] task = Task(streams[j], compute=chk.checkpoint, finalize=chk.recompute) del function, chk else: def compute( batch: Batch = batch, partition: nn.Module = partition, skip_tracker: SkipTrackerThroughPotals = skip_trackers[i], chunk_id: int = i, part_id: int = j, ) -> Batch: with use_skip_tracker(skip_tracker), record_function("chunk%d-part%d" % (chunk_id, part_id)): return batch.call(partition) task = Task(streams[j], compute=compute, finalize=None) del compute # Compute tasks in parallel. ([2] in the diagram) self.in_queues[j].put(task) for i, j in schedule: ok, payload = self.out_queues[j].get() # Hold the first exception. if exc_info is not None: continue elif not ok: exc_info = cast(ExcInfo, payload) continue task, batch = cast(Tuple[Task, Batch], payload) # The copy stream synchronizes to copy the output. ([3] in the # diagram) if j != n - 1: _wait(batch, streams[j], copy_streams[j][i]) # Finalize tasks. If checkpointing is enabled, here the # recomputation is scheduled at backpropagation. ([4] in the # diagram) with use_device(devices[j]): task.finalize(batch) batches[i] = batch # Fail at the first exception. if exc_info is not None: raise exc_info[0].with_traceback(exc_info[1], exc_info[2])