108 lines
4.2 KiB
Python
108 lines
4.2 KiB
Python
|
import torch
|
||
|
from typing import Any, Dict, List, Optional, Sequence, Tuple, TypeVar, Union, overload
|
||
|
from ._functions import Scatter, Gather
|
||
|
import warnings
|
||
|
|
||
|
__all__ = ['scatter', 'scatter_kwargs', 'gather']
|
||
|
|
||
|
def is_namedtuple(obj: Any) -> bool:
|
||
|
# Check if type was created from collections.namedtuple or a typing.NamedTuple.
|
||
|
warnings.warn("is_namedtuple is deprecated, please use the python checks instead")
|
||
|
return _is_namedtuple(obj)
|
||
|
|
||
|
def _is_namedtuple(obj: Any) -> bool:
|
||
|
# Check if type was created from collections.namedtuple or a typing.NamedTuple.
|
||
|
return (
|
||
|
isinstance(obj, tuple) and hasattr(obj, "_asdict") and hasattr(obj, "_fields")
|
||
|
)
|
||
|
|
||
|
|
||
|
T = TypeVar("T", dict, list, tuple)
|
||
|
|
||
|
# For some reason, 'scatter' returns a tuple when given a single Tensor input but a list otherwise.
|
||
|
@overload
|
||
|
def scatter(
|
||
|
inputs: torch.Tensor,
|
||
|
target_gpus: Sequence[Union[int, torch.device]],
|
||
|
dim: int = ...,
|
||
|
) -> Tuple[torch.Tensor, ...]:
|
||
|
...
|
||
|
|
||
|
@overload
|
||
|
def scatter(inputs: T, target_gpus: Sequence[Union[int, torch.device]], dim: int = ...) -> List[T]:
|
||
|
...
|
||
|
|
||
|
def scatter(inputs, target_gpus, dim=0):
|
||
|
r"""Slice tensors into approximately equal chunks and distributes them across given GPUs.
|
||
|
|
||
|
Duplicates references to objects that are not tensors.
|
||
|
"""
|
||
|
def scatter_map(obj):
|
||
|
if isinstance(obj, torch.Tensor):
|
||
|
return Scatter.apply(target_gpus, None, dim, obj)
|
||
|
if _is_namedtuple(obj):
|
||
|
return [type(obj)(*args) for args in zip(*map(scatter_map, obj))]
|
||
|
if isinstance(obj, tuple) and len(obj) > 0:
|
||
|
return list(zip(*map(scatter_map, obj)))
|
||
|
if isinstance(obj, list) and len(obj) > 0:
|
||
|
return [list(i) for i in zip(*map(scatter_map, obj))]
|
||
|
if isinstance(obj, dict) and len(obj) > 0:
|
||
|
return [type(obj)(i) for i in zip(*map(scatter_map, obj.items()))]
|
||
|
return [obj for _ in target_gpus]
|
||
|
|
||
|
# After scatter_map is called, a scatter_map cell will exist. This cell
|
||
|
# has a reference to the actual function scatter_map, which has references
|
||
|
# to a closure that has a reference to the scatter_map cell (because the
|
||
|
# fn is recursive). To avoid this reference cycle, we set the function to
|
||
|
# None, clearing the cell
|
||
|
try:
|
||
|
res = scatter_map(inputs)
|
||
|
finally:
|
||
|
scatter_map = None # type: ignore[assignment]
|
||
|
return res
|
||
|
|
||
|
|
||
|
def scatter_kwargs(
|
||
|
inputs: Tuple[Any, ...],
|
||
|
kwargs: Optional[Dict[str, Any]],
|
||
|
target_gpus: Sequence[Union[int, torch.device]],
|
||
|
dim: int = 0,
|
||
|
) -> Tuple[Tuple[Any, ...], Tuple[Dict[str, Any], ...]]:
|
||
|
r"""Scatter with support for kwargs dictionary."""
|
||
|
scattered_inputs = scatter(inputs, target_gpus, dim) if inputs else []
|
||
|
scattered_kwargs = scatter(kwargs, target_gpus, dim) if kwargs else []
|
||
|
if len(scattered_inputs) < len(scattered_kwargs):
|
||
|
scattered_inputs.extend(() for _ in range(len(scattered_kwargs) - len(scattered_inputs)))
|
||
|
elif len(scattered_kwargs) < len(inputs):
|
||
|
scattered_kwargs.extend({} for _ in range(len(scattered_inputs) - len(scattered_kwargs)))
|
||
|
return tuple(scattered_inputs), tuple(scattered_kwargs)
|
||
|
|
||
|
|
||
|
def gather(outputs: Any, target_device: Union[int, torch.device], dim: int = 0) -> Any:
|
||
|
r"""Gather tensors from different GPUs on a specified device.
|
||
|
|
||
|
Use 'cpu' for CPU to avoid a deprecation warning.
|
||
|
"""
|
||
|
def gather_map(outputs):
|
||
|
out = outputs[0]
|
||
|
if isinstance(out, torch.Tensor):
|
||
|
return Gather.apply(target_device, dim, *outputs)
|
||
|
if out is None:
|
||
|
return None
|
||
|
if isinstance(out, dict):
|
||
|
if not all(len(out) == len(d) for d in outputs):
|
||
|
raise ValueError('All dicts must have the same number of keys')
|
||
|
return type(out)((k, gather_map([d[k] for d in outputs]))
|
||
|
for k in out)
|
||
|
if _is_namedtuple(out):
|
||
|
return type(out)._make(map(gather_map, zip(*outputs)))
|
||
|
return type(out)(map(gather_map, zip(*outputs)))
|
||
|
|
||
|
# Recursive function calls like this create reference cycles.
|
||
|
# Setting the function to None clears the refcycle.
|
||
|
try:
|
||
|
res = gather_map(outputs)
|
||
|
finally:
|
||
|
gather_map = None # type: ignore[assignment]
|
||
|
return res
|