# Copyright 2019 Kakao Brain # # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. """Manipulation of micro-batches.""" import typing from typing import Any, Callable, List, Union, cast, Sequence import torch from torch import Tensor import torch.cuda.comm __all__: List[str] = ["NoChunk", "Batch", "check", "scatter", "gather"] Tensors = Sequence[Tensor] TensorOrTensors = Union[Tensor, Tensors] Function = Callable[[TensorOrTensors], Union[List[Any], Tensor]] class NoChunk: """ Wrapper for a Tensor in :meth:`Pipe.forward` indicating that the tensor should not be chunked on the batch dimension and instead be replicated as-is across all micro-batches. This is useful for tensors which might not have any 'batch' semantics for the model. """ def __init__(self, inp: Tensor): if not torch.is_tensor(inp): raise TypeError(f'NoChunk only supported for tensors, found: {inp}') self._tensor = inp @property def tensor(self): return self._tensor class Batch: """ An abstraction representing a microbatch in the pipeline. """ def __init__(self, values: Union[List[Any], Tensor]) -> None: self._values = values self.atomic = torch.is_tensor(values) # Verify at least on tensor if not self.atomic: if not any(torch.is_tensor(value) for value in self._values): raise TypeError(f'No tensors found in batch: {self._values}') @property def tensor(self) -> Tensor: """Retrieves the underlying tensor.""" if not self.atomic: raise AttributeError("not atomic batch") return cast(Tensor, self._values) @property def values(self): """Retrieves the underlying values for the batch""" return self._values def find_tensor_idx(self): """ Retrieves the index of first tensor found. """ if self.atomic: return 0 for i, value in enumerate(self._values): if torch.is_tensor(value): return i raise TypeError("No tensor found!") def get_device(self): """ Retrieves the device for this microbatch. """ if self.atomic: return self._values.device # type: ignore[union-attr] for value in self._values: if torch.is_tensor(value): return value.device def call(self, function: Function) -> "Batch": """Calls a function on the microbatch. It also wraps the output with :class:`Batch`. """ if self.atomic: return Batch(function(self._values)) else: return Batch(function(*self._values)) def __repr__(self) -> str: return f"Batch[atomic={self.atomic!r}]({self._values!r})" def __iter__(self): if self.atomic: yield self._values else: yield from self._values def __len__(self) -> int: return 1 if self.atomic else len(self._values) def __getitem__(self, index: int): if not self.atomic: return self._values[index] if index != 0: raise IndexError("atomic batch allows index 0 only") return self._values # NOTE(sublee): pyflakes can't detect "overload" instead of "typing.overload". @typing.overload def __setitem__(self, index: int, value: Tensor) -> None: ... @typing.overload def __setitem__(self, index: slice, value: Tensors) -> None: ... def __setitem__(self, index: Union[int, slice], value) -> None: if isinstance(index, int): self._setitem_by_index(index, value) else: self._setitem_by_slice(index, value) def _setitem_by_index(self, index: int, value) -> None: if not self.atomic: i = index self._values = self._values[:i] + (value,) + self._values[i + 1 :] # type: ignore[operator] return if index != 0: raise IndexError("atomic batch allows index 0 only") self._values = value def _setitem_by_slice(self, index: slice, value) -> None: if not (index.start is index.stop is index.step is None): # noqa: E714 raise NotImplementedError("only slice [:] supported") if not self.atomic: self._values = value return if len(value) != 1: raise IndexError("atomic batch cannot be replaced with multiple tensors") self._values = value[0] def check(first_device, *inputs) -> None: """ Checks whether the input contains at least one tensor and each tensor is on the same device as the first partition. Raises: ValueError: input does not contain at least one tensor """ if not any(torch.is_tensor(input) for input in inputs): raise TypeError(f'inputs do not have any tensors: {inputs}') if any(torch.is_tensor(input) and input.device != first_device for input in inputs): raise ValueError('All inputs should be on the same device as the first partition') def scatter(*inputs, chunks: int) -> List[Batch]: """Splits an input mini-batch into multiple micro-batches.""" if len(inputs) == 1 and isinstance(inputs[0], Tensor): return [Batch(x) for x in inputs[0].chunk(chunks)] batches: List[Any] = [[] for _ in range(chunks)] # Actual number of chunks produced num_chunks = -1 for input in inputs: if torch.is_tensor(input): # Chunk only tensors. tensors = input.chunk(chunks) # Validate number of chunks equal across all inputs. if num_chunks != -1 and num_chunks != len(tensors): raise RuntimeError(f'Found different number of chunks produced for inputs: {num_chunks} and {len(tensors)}') num_chunks = len(tensors) for i, tensor in enumerate(tensors): batches[i].append(tensor) else: # Replicate non-tensors or tensors wrapped with 'NoChunk'. for i in range(chunks): if isinstance(input, NoChunk): # Extract the tensor out. batches[i].append(input.tensor) else: batches[i].append(input) # Truncate to actual number of chunks batches = batches[:num_chunks] return [Batch(x) for x in batches] def gather(outputs: List[Batch]): """Concatenates output micro-batches into a mini-batch.""" output: Any if outputs[0].atomic: tensors = tuple(b.tensor for b in outputs) output = torch.cat(tensors) else: output_buf: List[Any] = [] for i in range(len(outputs[0])): output_type = type(outputs[0][i]) current_outputs = [] for batch in outputs: if output_type != type(batch[i]): raise TypeError(f'Types for microbatch outputs do not match, found: {output_type} and {type(batch[i])}') current_outputs.append(batch[i]) if torch.is_tensor(outputs[0][i]): output_buf.append(torch.cat(current_outputs)) else: output_buf.append(current_outputs) output = tuple(output_buf) return output