ai-content-maker/.venv/Lib/site-packages/torch/distributed/pipeline/sync/microbatch.py

235 lines
7.3 KiB
Python

# Copyright 2019 Kakao Brain
#
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
"""Manipulation of micro-batches."""
import typing
from typing import Any, Callable, List, Union, cast, Sequence
import torch
from torch import Tensor
import torch.cuda.comm
__all__: List[str] = ["NoChunk", "Batch", "check", "scatter", "gather"]
Tensors = Sequence[Tensor]
TensorOrTensors = Union[Tensor, Tensors]
Function = Callable[[TensorOrTensors], Union[List[Any], Tensor]]
class NoChunk:
"""
Wrapper for a Tensor in :meth:`Pipe.forward` indicating that the tensor
should not be chunked on the batch dimension and instead be replicated
as-is across all micro-batches. This is useful for tensors which might
not have any 'batch' semantics for the model.
"""
def __init__(self, inp: Tensor):
if not torch.is_tensor(inp):
raise TypeError(f'NoChunk only supported for tensors, found: {inp}')
self._tensor = inp
@property
def tensor(self):
return self._tensor
class Batch:
"""
An abstraction representing a microbatch in the pipeline.
"""
def __init__(self, values: Union[List[Any], Tensor]) -> None:
self._values = values
self.atomic = torch.is_tensor(values)
# Verify at least on tensor
if not self.atomic:
if not any(torch.is_tensor(value) for value in self._values):
raise TypeError(f'No tensors found in batch: {self._values}')
@property
def tensor(self) -> Tensor:
"""Retrieves the underlying tensor."""
if not self.atomic:
raise AttributeError("not atomic batch")
return cast(Tensor, self._values)
@property
def values(self):
"""Retrieves the underlying values for the batch"""
return self._values
def find_tensor_idx(self):
"""
Retrieves the index of first tensor found.
"""
if self.atomic:
return 0
for i, value in enumerate(self._values):
if torch.is_tensor(value):
return i
raise TypeError("No tensor found!")
def get_device(self):
"""
Retrieves the device for this microbatch.
"""
if self.atomic:
return self._values.device # type: ignore[union-attr]
for value in self._values:
if torch.is_tensor(value):
return value.device
def call(self, function: Function) -> "Batch":
"""Calls a function on the microbatch. It also wraps
the output with :class:`Batch`.
"""
if self.atomic:
return Batch(function(self._values))
else:
return Batch(function(*self._values))
def __repr__(self) -> str:
return f"Batch[atomic={self.atomic!r}]({self._values!r})"
def __iter__(self):
if self.atomic:
yield self._values
else:
yield from self._values
def __len__(self) -> int:
return 1 if self.atomic else len(self._values)
def __getitem__(self, index: int):
if not self.atomic:
return self._values[index]
if index != 0:
raise IndexError("atomic batch allows index 0 only")
return self._values
# NOTE(sublee): pyflakes can't detect "overload" instead of "typing.overload".
@typing.overload
def __setitem__(self, index: int, value: Tensor) -> None:
...
@typing.overload
def __setitem__(self, index: slice, value: Tensors) -> None:
...
def __setitem__(self, index: Union[int, slice], value) -> None:
if isinstance(index, int):
self._setitem_by_index(index, value)
else:
self._setitem_by_slice(index, value)
def _setitem_by_index(self, index: int, value) -> None:
if not self.atomic:
i = index
self._values = self._values[:i] + (value,) + self._values[i + 1 :] # type: ignore[operator]
return
if index != 0:
raise IndexError("atomic batch allows index 0 only")
self._values = value
def _setitem_by_slice(self, index: slice, value) -> None:
if not (index.start is index.stop is index.step is None): # noqa: E714
raise NotImplementedError("only slice [:] supported")
if not self.atomic:
self._values = value
return
if len(value) != 1:
raise IndexError("atomic batch cannot be replaced with multiple tensors")
self._values = value[0]
def check(first_device, *inputs) -> None:
"""
Checks whether the input contains at least one tensor and each tensor is
on the same device as the first partition.
Raises:
ValueError: input does not contain at least one tensor
"""
if not any(torch.is_tensor(input) for input in inputs):
raise TypeError(f'inputs do not have any tensors: {inputs}')
if any(torch.is_tensor(input) and input.device != first_device for input in inputs):
raise ValueError('All inputs should be on the same device as the first partition')
def scatter(*inputs, chunks: int) -> List[Batch]:
"""Splits an input mini-batch into multiple micro-batches."""
if len(inputs) == 1 and isinstance(inputs[0], Tensor):
return [Batch(x) for x in inputs[0].chunk(chunks)]
batches: List[Any] = [[] for _ in range(chunks)]
# Actual number of chunks produced
num_chunks = -1
for input in inputs:
if torch.is_tensor(input):
# Chunk only tensors.
tensors = input.chunk(chunks)
# Validate number of chunks equal across all inputs.
if num_chunks != -1 and num_chunks != len(tensors):
raise RuntimeError(f'Found different number of chunks produced for inputs: {num_chunks} and {len(tensors)}')
num_chunks = len(tensors)
for i, tensor in enumerate(tensors):
batches[i].append(tensor)
else:
# Replicate non-tensors or tensors wrapped with 'NoChunk'.
for i in range(chunks):
if isinstance(input, NoChunk):
# Extract the tensor out.
batches[i].append(input.tensor)
else:
batches[i].append(input)
# Truncate to actual number of chunks
batches = batches[:num_chunks]
return [Batch(x) for x in batches]
def gather(outputs: List[Batch]):
"""Concatenates output micro-batches into a mini-batch."""
output: Any
if outputs[0].atomic:
tensors = tuple(b.tensor for b in outputs)
output = torch.cat(tensors)
else:
output_buf: List[Any] = []
for i in range(len(outputs[0])):
output_type = type(outputs[0][i])
current_outputs = []
for batch in outputs:
if output_type != type(batch[i]):
raise TypeError(f'Types for microbatch outputs do not match, found: {output_type} and {type(batch[i])}')
current_outputs.append(batch[i])
if torch.is_tensor(outputs[0][i]):
output_buf.append(torch.cat(current_outputs))
else:
output_buf.append(current_outputs)
output = tuple(output_buf)
return output