121 lines
3.9 KiB
Python
121 lines
3.9 KiB
Python
# Copyright 2019 Kakao Brain
|
|
#
|
|
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
"""Utilities for eliminating boilerplate code to handle abstract streams with
|
|
CPU device.
|
|
"""
|
|
from contextlib import contextmanager
|
|
from typing import Generator, List, Union, cast
|
|
|
|
import torch
|
|
|
|
__all__: List[str] = ["CPUStreamType", "new_stream", "current_stream", "default_stream",
|
|
"use_device", "use_stream", "get_device", "wait_stream", "record_stream",
|
|
"is_cuda", "as_cuda"]
|
|
|
|
|
|
class CPUStreamType:
|
|
pass
|
|
|
|
|
|
# The placeholder on place of streams for the CPU device instead of CUDA.
|
|
CPUStream = CPUStreamType()
|
|
|
|
# It represents both CUDA streams and the CPU stream.
|
|
AbstractStream = Union[torch.cuda.Stream, CPUStreamType]
|
|
|
|
|
|
def new_stream(device: torch.device) -> AbstractStream:
|
|
"""Creates a new stream for either CPU or CUDA device."""
|
|
if device.type != "cuda":
|
|
return CPUStream
|
|
return torch.cuda.Stream(device)
|
|
|
|
|
|
def current_stream(device: torch.device) -> AbstractStream:
|
|
""":func:`torch.cuda.current_stream` for either CPU or CUDA device."""
|
|
if device.type != "cuda":
|
|
return CPUStream
|
|
return torch.cuda.current_stream(device)
|
|
|
|
|
|
def default_stream(device: torch.device) -> AbstractStream:
|
|
""":func:`torch.cuda.default_stream` for either CPU or CUDA device."""
|
|
if device.type != "cuda":
|
|
return CPUStream
|
|
return torch.cuda.default_stream(device)
|
|
|
|
|
|
@contextmanager
|
|
def use_device(device: torch.device) -> Generator[None, None, None]:
|
|
""":func:`torch.cuda.device` for either CPU or CUDA device."""
|
|
if device.type != "cuda":
|
|
yield
|
|
return
|
|
|
|
with torch.cuda.device(device):
|
|
yield
|
|
|
|
|
|
@contextmanager
|
|
def use_stream(stream: AbstractStream) -> Generator[None, None, None]:
|
|
""":func:`torch.cuda.stream` for either CPU or CUDA stream."""
|
|
if not is_cuda(stream):
|
|
yield
|
|
return
|
|
|
|
with torch.cuda.stream(as_cuda(stream)):
|
|
yield
|
|
|
|
|
|
def get_device(stream: AbstractStream) -> torch.device:
|
|
"""Gets the device from CPU or CUDA stream."""
|
|
if is_cuda(stream):
|
|
return as_cuda(stream).device
|
|
return torch.device("cpu")
|
|
|
|
|
|
def wait_stream(source: AbstractStream, target: AbstractStream) -> None:
|
|
""":meth:`torch.cuda.Stream.wait_stream` for either CPU or CUDA stream. It
|
|
makes the source stream wait until the target stream completes work queued.
|
|
"""
|
|
if is_cuda(target):
|
|
if is_cuda(source):
|
|
# A CUDA stream waits another CUDA stream.
|
|
as_cuda(source).wait_stream(as_cuda(target))
|
|
else:
|
|
# CPU waits a CUDA stream.
|
|
as_cuda(target).synchronize()
|
|
|
|
# If the target is CPU, synchronization is not required.
|
|
|
|
|
|
def record_stream(tensor: torch.Tensor, stream: AbstractStream) -> None:
|
|
""":meth:`torch.Tensor.record_stream` for either CPU or CUDA stream."""
|
|
if is_cuda(stream):
|
|
# NOTE(sublee): record_stream() on a shifted view tensor throws
|
|
# RuntimeError in PyTorch 1.1.0, and does nothing in 1.2.0. To safely
|
|
# protect the tensor against unexpected reallocation, here we use a
|
|
# temporal tensor associated with the same storage without shifting as
|
|
# a workaround.
|
|
#
|
|
# Issue: https://github.com/pytorch/pytorch/issues/27366
|
|
#
|
|
tensor = tensor.new_empty([0]).set_(tensor._typed_storage())
|
|
|
|
# Typechecking: torch.cuda.Stream is incompatible with torch._C.Stream
|
|
tensor.record_stream(as_cuda(stream)) # type: ignore[arg-type]
|
|
|
|
|
|
def is_cuda(stream: AbstractStream) -> bool:
|
|
"""Returns ``True`` if the given stream is a valid CUDA stream."""
|
|
return stream is not CPUStream
|
|
|
|
|
|
def as_cuda(stream: AbstractStream) -> torch.cuda.Stream:
|
|
"""Casts the given stream as :class:`torch.cuda.Stream`."""
|
|
return cast(torch.cuda.Stream, stream)
|