486 lines
15 KiB
Python
486 lines
15 KiB
Python
|
r"""
|
||
|
This package introduces support for the XPU backend, specifically tailored for
|
||
|
Intel GPU optimization.
|
||
|
|
||
|
This package is lazily initialized, so you can always import it, and use
|
||
|
:func:`is_available()` to determine if your system supports XPU.
|
||
|
"""
|
||
|
import threading
|
||
|
import traceback
|
||
|
from functools import lru_cache
|
||
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||
|
|
||
|
import torch
|
||
|
import torch._C
|
||
|
from .. import device as _device
|
||
|
from .._utils import _dummy_type, _LazySeedTracker
|
||
|
from ._utils import _get_device_index
|
||
|
from .streams import Event, Stream
|
||
|
|
||
|
_initialized = False
|
||
|
_tls = threading.local()
|
||
|
_initialization_lock = threading.Lock()
|
||
|
_queued_calls: List[
|
||
|
Tuple[Callable[[], None], List[str]]
|
||
|
] = [] # don't invoke these until initialization occurs
|
||
|
_is_in_bad_fork = getattr(torch._C, "_xpu_isInBadFork", lambda: False)
|
||
|
_device_t = Union[_device, str, int, None]
|
||
|
_lazy_seed_tracker = _LazySeedTracker()
|
||
|
default_generators: Tuple[torch._C.Generator] = () # type: ignore[assignment]
|
||
|
|
||
|
|
||
|
def _is_compiled() -> bool:
|
||
|
r"""Return true if compile with XPU support."""
|
||
|
return torch._C._has_xpu
|
||
|
|
||
|
|
||
|
if _is_compiled():
|
||
|
_XpuDeviceProperties = torch._C._XpuDeviceProperties
|
||
|
_exchange_device = torch._C._xpu_exchangeDevice
|
||
|
_maybe_exchange_device = torch._C._xpu_maybeExchangeDevice
|
||
|
else:
|
||
|
# Define dummy if PyTorch was compiled without XPU
|
||
|
_XpuDeviceProperties = _dummy_type("_XpuDeviceProperties") # type: ignore[assignment, misc]
|
||
|
|
||
|
def _exchange_device(device: int) -> int:
|
||
|
raise NotImplementedError("PyTorch was compiled without XPU support")
|
||
|
|
||
|
def _maybe_exchange_device(device: int) -> int:
|
||
|
raise NotImplementedError("PyTorch was compiled without XPU support")
|
||
|
|
||
|
|
||
|
@lru_cache(maxsize=1)
|
||
|
def device_count() -> int:
|
||
|
r"""Return the number of XPU device available."""
|
||
|
if not _is_compiled():
|
||
|
return 0
|
||
|
return torch._C._xpu_getDeviceCount()
|
||
|
|
||
|
|
||
|
def is_available() -> bool:
|
||
|
r"""Return a bool indicating if XPU is currently available."""
|
||
|
# This function nerver throws.
|
||
|
return device_count() > 0
|
||
|
|
||
|
|
||
|
def is_bf16_supported():
|
||
|
r"""Return a bool indicating if the current XPU device supports dtype bfloat16."""
|
||
|
return True
|
||
|
|
||
|
|
||
|
def is_initialized():
|
||
|
r"""Return whether PyTorch's XPU state has been initialized."""
|
||
|
return _initialized and not _is_in_bad_fork()
|
||
|
|
||
|
|
||
|
def _lazy_call(callable, **kwargs):
|
||
|
if is_initialized():
|
||
|
callable()
|
||
|
else:
|
||
|
global _lazy_seed_tracker
|
||
|
if kwargs.get("seed_all", False):
|
||
|
_lazy_seed_tracker.queue_seed_all(callable, traceback.format_stack())
|
||
|
elif kwargs.get("seed", False):
|
||
|
_lazy_seed_tracker.queue_seed(callable, traceback.format_stack())
|
||
|
else:
|
||
|
# Don't store the actual traceback to avoid memory cycle
|
||
|
_queued_calls.append((callable, traceback.format_stack()))
|
||
|
|
||
|
|
||
|
def init():
|
||
|
r"""Initialize PyTorch's XPU state.
|
||
|
This is a Python API about lazy initialization that avoids initializing
|
||
|
XPU until the first time it is accessed. Does nothing if the XPU state is
|
||
|
already initialized.
|
||
|
"""
|
||
|
_lazy_init()
|
||
|
|
||
|
|
||
|
def _lazy_init():
|
||
|
global _initialized, _queued_calls
|
||
|
if is_initialized() or hasattr(_tls, "is_initializing"):
|
||
|
return
|
||
|
with _initialization_lock:
|
||
|
# This test was was protected via GIL. Double-check whether XPU has
|
||
|
# already been initialized.
|
||
|
if is_initialized():
|
||
|
return
|
||
|
# Stop promptly upon encountering a bad fork error.
|
||
|
if _is_in_bad_fork():
|
||
|
raise RuntimeError(
|
||
|
"Cannot re-initialize XPU in forked subprocess. To use XPU with "
|
||
|
"multiprocessing, you must use the 'spawn' start method"
|
||
|
)
|
||
|
if not _is_compiled():
|
||
|
raise AssertionError("Torch not compiled with XPU enabled")
|
||
|
# This function inits XPU backend and detects bad fork processing.
|
||
|
torch._C._xpu_init()
|
||
|
# Some of the queued calls may reentrantly call _lazy_init(); We need to
|
||
|
# just return without initializing in that case.
|
||
|
_tls.is_initializing = True
|
||
|
|
||
|
for calls in _lazy_seed_tracker.get_calls():
|
||
|
if calls:
|
||
|
_queued_calls.append(calls)
|
||
|
|
||
|
try:
|
||
|
for queued_call, orig_traceback in _queued_calls:
|
||
|
try:
|
||
|
queued_call()
|
||
|
except Exception as e:
|
||
|
msg = (
|
||
|
f"XPU call failed lazily at initialization with error: {str(e)}\n\n"
|
||
|
f"XPU call was originally invoked at:\n\n{''.join(orig_traceback)}"
|
||
|
)
|
||
|
raise Exception(msg) from e
|
||
|
finally:
|
||
|
delattr(_tls, "is_initializing")
|
||
|
_initialized = True
|
||
|
|
||
|
|
||
|
class _DeviceGuard:
|
||
|
def __init__(self, index: int):
|
||
|
self.idx = index
|
||
|
self.prev_idx = -1
|
||
|
|
||
|
def __enter__(self):
|
||
|
self.prev_idx = torch.xpu._exchange_device(self.idx)
|
||
|
|
||
|
def __exit__(self, type: Any, value: Any, traceback: Any):
|
||
|
self.idx = torch.xpu._maybe_exchange_device(self.prev_idx)
|
||
|
return False
|
||
|
|
||
|
|
||
|
class device:
|
||
|
r"""Context-manager that changes the selected device.
|
||
|
|
||
|
Args:
|
||
|
device (torch.device or int or str): device index to select. It's a no-op if
|
||
|
this argument is a negative integer or ``None``.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, device: Any):
|
||
|
self.idx = _get_device_index(device, optional=True)
|
||
|
self.prev_idx = -1
|
||
|
|
||
|
def __enter__(self):
|
||
|
self.prev_idx = torch.xpu._exchange_device(self.idx)
|
||
|
|
||
|
def __exit__(self, type: Any, value: Any, traceback: Any):
|
||
|
self.idx = torch.xpu._maybe_exchange_device(self.prev_idx)
|
||
|
return False
|
||
|
|
||
|
|
||
|
class device_of(device):
|
||
|
r"""Context-manager that changes the current device to that of given object.
|
||
|
|
||
|
You can use both tensors and storages as arguments. If a given object is
|
||
|
not allocated on a XPU, this is a no-op.
|
||
|
|
||
|
Args:
|
||
|
obj (Tensor or Storage): object allocated on the selected device.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, obj):
|
||
|
idx = obj.get_device() if obj.is_xpu else -1
|
||
|
super().__init__(idx)
|
||
|
|
||
|
|
||
|
def set_device(device: _device_t) -> None:
|
||
|
r"""Set the current device.
|
||
|
|
||
|
Args:
|
||
|
device (torch.device or int or str): selected device. This function is a
|
||
|
no-op if this argument is negative.
|
||
|
"""
|
||
|
_lazy_init()
|
||
|
device = _get_device_index(device)
|
||
|
if device >= 0:
|
||
|
torch._C._xpu_setDevice(device)
|
||
|
|
||
|
|
||
|
def get_device_name(device: Optional[_device_t] = None) -> str:
|
||
|
r"""Get the name of a device.
|
||
|
|
||
|
Args:
|
||
|
device (torch.device or int or str, optional): device for which to
|
||
|
return the name. This function is a no-op if this argument is a
|
||
|
negative integer. It uses the current device, given by :func:`~torch.xpu.current_device`,
|
||
|
if :attr:`device` is ``None`` (default).
|
||
|
|
||
|
Returns:
|
||
|
str: the name of the device
|
||
|
"""
|
||
|
return get_device_properties(device).name
|
||
|
|
||
|
|
||
|
def get_device_capability(device: Optional[_device_t] = None) -> Dict[str, Any]:
|
||
|
r"""Get the xpu capability of a device.
|
||
|
|
||
|
Args:
|
||
|
device (torch.device or int or str, optional): device for which to
|
||
|
return the device capability. This function is a no-op if this
|
||
|
argument is a negative integer. It uses the current device, given by
|
||
|
:func:`~torch.xpu.current_device`, if :attr:`device` is ``None``
|
||
|
(default).
|
||
|
|
||
|
Returns:
|
||
|
Dict[str, Any]: the xpu capability dictionary of the device
|
||
|
"""
|
||
|
prop = get_device_properties(device)
|
||
|
return {
|
||
|
"max_work_group_size": prop.max_work_group_size,
|
||
|
"max_num_sub_groups": prop.max_num_sub_groups,
|
||
|
"sub_group_sizes": prop.sub_group_sizes,
|
||
|
}
|
||
|
|
||
|
|
||
|
def get_device_properties(device: Optional[_device_t] = None) -> _XpuDeviceProperties:
|
||
|
r"""Get the properties of a device.
|
||
|
|
||
|
Args:
|
||
|
device (torch.device or int or str): device for which to return the
|
||
|
properties of the device.
|
||
|
|
||
|
Returns:
|
||
|
_XpuDeviceProperties: the properties of the device
|
||
|
"""
|
||
|
_lazy_init()
|
||
|
device = _get_device_index(device, optional=True)
|
||
|
if device < 0 or device >= device_count():
|
||
|
raise AssertionError("Invalid device index")
|
||
|
return _get_device_properties(device) # type: ignore[name-defined] # noqa: F821
|
||
|
|
||
|
|
||
|
def current_device() -> int:
|
||
|
r"""Return the index of a currently selected device."""
|
||
|
_lazy_init()
|
||
|
return torch._C._xpu_getDevice()
|
||
|
|
||
|
|
||
|
def _get_device(device: Union[int, str, torch.device]) -> torch.device:
|
||
|
r"""Return the torch.device type object from the passed in device.
|
||
|
|
||
|
Args:
|
||
|
device (torch.device or int or str): selected device.
|
||
|
"""
|
||
|
if isinstance(device, str):
|
||
|
device = torch.device(device)
|
||
|
elif isinstance(device, int):
|
||
|
device = torch.device("xpu", device)
|
||
|
return device
|
||
|
|
||
|
|
||
|
class StreamContext:
|
||
|
r"""Context-manager that selects a given stream.
|
||
|
|
||
|
All XPU kernels queued within its context will be enqueued on a selected
|
||
|
stream.
|
||
|
|
||
|
Args:
|
||
|
Stream (Stream): selected stream. This manager is a no-op if it's
|
||
|
``None``.
|
||
|
.. note:: Streams are per-device.
|
||
|
"""
|
||
|
cur_stream: Optional["torch.xpu.Stream"]
|
||
|
|
||
|
def __init__(self, stream: Optional["torch.xpu.Stream"]):
|
||
|
self.stream = stream
|
||
|
self.idx = _get_device_index(None, True)
|
||
|
if self.idx is None:
|
||
|
self.idx = -1
|
||
|
|
||
|
def __enter__(self):
|
||
|
cur_stream = self.stream
|
||
|
if cur_stream is None or self.idx == -1:
|
||
|
return
|
||
|
self.src_prev_stream = torch.xpu.current_stream(None)
|
||
|
|
||
|
# If the stream is not on the current device, then set the current stream on the device
|
||
|
if self.src_prev_stream.device != cur_stream.device:
|
||
|
with device(cur_stream.device):
|
||
|
self.dst_prev_stream = torch.xpu.current_stream(cur_stream.device)
|
||
|
torch.xpu.set_stream(cur_stream)
|
||
|
|
||
|
def __exit__(self, type: Any, value: Any, traceback: Any):
|
||
|
cur_stream = self.stream
|
||
|
if cur_stream is None or self.idx == -1:
|
||
|
return
|
||
|
|
||
|
# Reset the stream on the original device and destination device
|
||
|
if self.src_prev_stream.device != cur_stream.device:
|
||
|
torch.xpu.set_stream(self.dst_prev_stream)
|
||
|
torch.xpu.set_stream(self.src_prev_stream)
|
||
|
|
||
|
|
||
|
def stream(stream: Optional["torch.xpu.Stream"]) -> StreamContext:
|
||
|
r"""Wrap around the Context-manager StreamContext that selects a given stream.
|
||
|
|
||
|
Arguments:
|
||
|
stream (Stream): selected stream. This manager is a no-op if it's ``None``.
|
||
|
"""
|
||
|
return StreamContext(stream)
|
||
|
|
||
|
|
||
|
def _set_stream_by_id(stream_id, device_index, device_type):
|
||
|
r"""set stream specified by the stream id, device index and device type
|
||
|
|
||
|
Args: stream_id (int): not visible to the user, used to assigned to the specific stream.
|
||
|
device_index (int): selected device index.
|
||
|
device_type (int): selected device type.
|
||
|
"""
|
||
|
torch._C._xpu_setStream(
|
||
|
stream_id=stream_id,
|
||
|
device_index=device_index,
|
||
|
device_type=device_type,
|
||
|
)
|
||
|
|
||
|
|
||
|
def set_stream(stream: Stream):
|
||
|
r"""Set the current stream.This is a wrapper API to set the stream.
|
||
|
Usage of this function is discouraged in favor of the ``stream``
|
||
|
context manager.
|
||
|
|
||
|
Args:
|
||
|
stream (Stream): selected stream. This function is a no-op
|
||
|
if this argument is ``None``.
|
||
|
"""
|
||
|
if stream is None:
|
||
|
return
|
||
|
_lazy_init()
|
||
|
_set_stream_by_id(
|
||
|
stream_id=stream.stream_id,
|
||
|
device_index=stream.device_index,
|
||
|
device_type=stream.device_type,
|
||
|
)
|
||
|
|
||
|
|
||
|
def current_stream(device: Optional[_device_t] = None) -> Stream:
|
||
|
r"""Return the currently selected :class:`Stream` for a given device.
|
||
|
|
||
|
Args:
|
||
|
device (torch.device or int, optional): selected device. Returns
|
||
|
the currently selected :class:`Stream` for the current device, given
|
||
|
by :func:`~torch.xpu.current_device`, if :attr:`device` is ``None``
|
||
|
(default).
|
||
|
"""
|
||
|
_lazy_init()
|
||
|
streamdata = torch._C._xpu_getCurrentStream(
|
||
|
_get_device_index(device, optional=True)
|
||
|
)
|
||
|
return Stream(
|
||
|
stream_id=streamdata[0], device_index=streamdata[1], device_type=streamdata[2]
|
||
|
)
|
||
|
|
||
|
|
||
|
def synchronize(device: _device_t = None) -> None:
|
||
|
r"""Wait for all kernels in all streams on a XPU device to complete.
|
||
|
|
||
|
Args:
|
||
|
device (torch.device or int, optional): device for which to synchronize.
|
||
|
It uses the current device, given by :func:`~torch.xpu.current_device`,
|
||
|
if :attr:`device` is ``None`` (default).
|
||
|
"""
|
||
|
_lazy_init()
|
||
|
device = _get_device_index(device, optional=True)
|
||
|
return torch._C._xpu_synchronize(device)
|
||
|
|
||
|
|
||
|
def empty_cache() -> None:
|
||
|
r"""Release all unoccupied cached memory currently held by the caching
|
||
|
allocator so that those can be used in other XPU application.
|
||
|
|
||
|
.. note::
|
||
|
:func:`~torch.xpu.empty_cache` doesn't increase the amount of XPU
|
||
|
memory available for PyTorch. However, it may help reduce fragmentation
|
||
|
of XPU memory in certain cases.
|
||
|
"""
|
||
|
if is_initialized():
|
||
|
torch._C._xpu_emptyCache()
|
||
|
|
||
|
|
||
|
def _get_generator(device: torch.device) -> torch._C.Generator:
|
||
|
r"""Return the XPU Generator object for the given device.
|
||
|
|
||
|
Args:
|
||
|
device (torch.device): selected device.
|
||
|
"""
|
||
|
idx = device.index
|
||
|
if idx is None:
|
||
|
idx = current_device()
|
||
|
return torch.xpu.default_generators[idx]
|
||
|
|
||
|
|
||
|
def _set_rng_state_offset(
|
||
|
offset: int, device: Union[int, str, torch.device] = "xpu"
|
||
|
) -> None:
|
||
|
r"""Set the random number generator state offset of the specified GPU.
|
||
|
|
||
|
Args:
|
||
|
offset (int): The desired offset
|
||
|
device (torch.device or int, optional): The device to set the RNG state.
|
||
|
Default: ``'xpu'`` (i.e., ``torch.device('xpu')``, the current XPU device).
|
||
|
"""
|
||
|
final_device = _get_device(device)
|
||
|
|
||
|
def cb():
|
||
|
default_generator = _get_generator(final_device)
|
||
|
default_generator.set_offset(offset)
|
||
|
|
||
|
_lazy_call(cb)
|
||
|
|
||
|
|
||
|
def _get_rng_state_offset(device: Union[int, str, torch.device] = "xpu") -> int:
|
||
|
r"""Return the random number generator state offset of the specified GPU.
|
||
|
|
||
|
Args:
|
||
|
device (torch.device or int, optional): The device to return the RNG state offset of.
|
||
|
Default: ``'xpu'`` (i.e., ``torch.device('xpu')``, the current XPU device).
|
||
|
|
||
|
.. warning::
|
||
|
This function eagerly initializes XPU.
|
||
|
"""
|
||
|
_lazy_init()
|
||
|
final_device = _get_device(device)
|
||
|
default_generator = _get_generator(final_device)
|
||
|
return default_generator.get_offset()
|
||
|
|
||
|
|
||
|
from .random import * # noqa: F403
|
||
|
|
||
|
|
||
|
__all__ = [
|
||
|
"Event",
|
||
|
"Stream",
|
||
|
"StreamContext",
|
||
|
"current_device",
|
||
|
"current_stream",
|
||
|
"default_generators",
|
||
|
"device",
|
||
|
"device_of",
|
||
|
"device_count",
|
||
|
"empty_cache",
|
||
|
"get_device_capability",
|
||
|
"get_device_name",
|
||
|
"get_device_properties",
|
||
|
"get_rng_state",
|
||
|
"get_rng_state_all",
|
||
|
"get_stream",
|
||
|
"init",
|
||
|
"initial_seed",
|
||
|
"is_available",
|
||
|
"is_bf16_supported",
|
||
|
"is_initialized",
|
||
|
"manual_seed",
|
||
|
"manual_seed_all",
|
||
|
"seed",
|
||
|
"seed_all",
|
||
|
"set_device",
|
||
|
"set_rng_state",
|
||
|
"set_rng_state_all",
|
||
|
"set_stream",
|
||
|
"stream",
|
||
|
"streams",
|
||
|
"synchronize",
|
||
|
]
|