127 lines
4.7 KiB
Python
127 lines
4.7 KiB
Python
import warnings
|
|
|
|
import torch
|
|
from . import comm
|
|
from torch.autograd import Function
|
|
from torch._utils import _get_device_index
|
|
from typing import List, Optional
|
|
|
|
|
|
class Broadcast(Function):
|
|
|
|
@staticmethod
|
|
def forward(ctx, target_gpus, *inputs):
|
|
assert all(i.device.type != 'cpu' for i in inputs), (
|
|
'Broadcast function not implemented for CPU tensors'
|
|
)
|
|
target_gpus = [_get_device_index(x, True) for x in target_gpus]
|
|
ctx.target_gpus = target_gpus
|
|
if len(inputs) == 0:
|
|
return tuple()
|
|
ctx.num_inputs = len(inputs)
|
|
ctx.input_device = inputs[0].get_device()
|
|
outputs = comm.broadcast_coalesced(inputs, ctx.target_gpus)
|
|
non_differentiables = []
|
|
for idx, input_requires_grad in enumerate(ctx.needs_input_grad[1:]):
|
|
if not input_requires_grad:
|
|
for output in outputs:
|
|
non_differentiables.append(output[idx])
|
|
ctx.mark_non_differentiable(*non_differentiables)
|
|
return tuple([t for tensors in outputs for t in tensors])
|
|
|
|
@staticmethod
|
|
def backward(ctx, *grad_outputs):
|
|
return (None,) + ReduceAddCoalesced.apply(ctx.input_device, ctx.num_inputs, *grad_outputs)
|
|
|
|
|
|
class ReduceAddCoalesced(Function):
|
|
|
|
@staticmethod
|
|
def forward(ctx, destination, num_inputs, *grads):
|
|
ctx.target_gpus = [grads[i].get_device() for i in range(0, len(grads), num_inputs)]
|
|
|
|
grads_ = [grads[i:i + num_inputs]
|
|
for i in range(0, len(grads), num_inputs)]
|
|
return comm.reduce_add_coalesced(grads_, destination)
|
|
|
|
@staticmethod
|
|
def backward(ctx, *grad_outputs):
|
|
return (None, None,) + Broadcast.apply(ctx.target_gpus, *grad_outputs)
|
|
|
|
|
|
class Gather(Function):
|
|
|
|
@staticmethod
|
|
def forward(ctx, target_device, dim, *inputs):
|
|
assert all(i.device.type != 'cpu' for i in inputs), (
|
|
'Gather function not implemented for CPU tensors'
|
|
)
|
|
if (target_device == 'cpu'):
|
|
ctx.target_device = 'cpu'
|
|
else:
|
|
target_device = _get_device_index(target_device, True)
|
|
ctx.target_device = target_device
|
|
ctx.dim = dim
|
|
ctx.input_gpus = tuple(i.get_device() for i in inputs)
|
|
if all(t.dim() == 0 for t in inputs) and dim == 0:
|
|
inputs = tuple(t.view(1) for t in inputs)
|
|
warnings.warn('Was asked to gather along dimension 0, but all '
|
|
'input tensors were scalars; will instead unsqueeze '
|
|
'and return a vector.')
|
|
ctx.unsqueezed_scalar = True
|
|
else:
|
|
ctx.unsqueezed_scalar = False
|
|
ctx.input_sizes = tuple(i.size(ctx.dim) for i in inputs)
|
|
return comm.gather(inputs, ctx.dim, ctx.target_device)
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
scattered_grads = Scatter.apply(ctx.input_gpus, ctx.input_sizes, ctx.dim, grad_output)
|
|
if ctx.unsqueezed_scalar:
|
|
scattered_grads = tuple(g[0] for g in scattered_grads)
|
|
return (None, None) + scattered_grads
|
|
|
|
|
|
class Scatter(Function):
|
|
|
|
@staticmethod
|
|
def forward(ctx, target_gpus, chunk_sizes, dim, input):
|
|
target_gpus = [_get_device_index(x, True) for x in target_gpus]
|
|
ctx.dim = dim
|
|
ctx.input_device = input.get_device() if input.device.type != "cpu" else -1
|
|
streams = None
|
|
if torch.cuda.is_available() and ctx.input_device == -1:
|
|
# Perform CPU to GPU copies in a background stream
|
|
streams = [_get_stream(torch.device("cuda", device)) for device in target_gpus]
|
|
outputs = comm.scatter(input, target_gpus, chunk_sizes, ctx.dim, streams)
|
|
# Synchronize with the copy stream
|
|
if streams is not None:
|
|
for i, output in enumerate(outputs):
|
|
with torch.cuda.device(target_gpus[i]):
|
|
main_stream = torch.cuda.current_stream()
|
|
main_stream.wait_stream(streams[i])
|
|
output.record_stream(main_stream)
|
|
return outputs
|
|
|
|
@staticmethod
|
|
def backward(ctx, *grad_output):
|
|
return None, None, None, Gather.apply(ctx.input_device, ctx.dim, *grad_output)
|
|
|
|
|
|
# background streams used for copying
|
|
_streams: Optional[List[Optional[torch.Stream]]] = None
|
|
|
|
def _get_stream(device: torch.device):
|
|
"""Get a background stream for copying between CPU and target device."""
|
|
global _streams
|
|
if device.type == "cpu":
|
|
return None
|
|
device_mod = getattr(torch, device.type, None)
|
|
if device_mod is None:
|
|
return None
|
|
if _streams is None:
|
|
_streams = [None] * device_mod.device_count()
|
|
if _streams[device.index] is None:
|
|
_streams[device.index] = device_mod.Stream(device.index)
|
|
return _streams[device.index]
|