143 lines
4.2 KiB
Python
143 lines
4.2 KiB
Python
|
"""
|
||
|
CUDA Runtime wrapper.
|
||
|
|
||
|
This provides a very minimal set of bindings, since the Runtime API is not
|
||
|
really used in Numba except for querying the Runtime version.
|
||
|
"""
|
||
|
|
||
|
import ctypes
|
||
|
import functools
|
||
|
import sys
|
||
|
|
||
|
from numba.core import config
|
||
|
from numba.cuda.cudadrv.driver import ERROR_MAP, make_logger
|
||
|
from numba.cuda.cudadrv.error import CudaSupportError, CudaRuntimeError
|
||
|
from numba.cuda.cudadrv.libs import open_cudalib
|
||
|
from numba.cuda.cudadrv.rtapi import API_PROTOTYPES
|
||
|
from numba.cuda.cudadrv import enums
|
||
|
|
||
|
|
||
|
class CudaRuntimeAPIError(CudaRuntimeError):
|
||
|
"""
|
||
|
Raised when there is an error accessing a C API from the CUDA Runtime.
|
||
|
"""
|
||
|
def __init__(self, code, msg):
|
||
|
self.code = code
|
||
|
self.msg = msg
|
||
|
super().__init__(code, msg)
|
||
|
|
||
|
def __str__(self):
|
||
|
return "[%s] %s" % (self.code, self.msg)
|
||
|
|
||
|
|
||
|
class Runtime:
|
||
|
"""
|
||
|
Runtime object that lazily binds runtime API functions.
|
||
|
"""
|
||
|
|
||
|
def __init__(self):
|
||
|
self.is_initialized = False
|
||
|
|
||
|
def _initialize(self):
|
||
|
# lazily initialize logger
|
||
|
global _logger
|
||
|
_logger = make_logger()
|
||
|
|
||
|
if config.DISABLE_CUDA:
|
||
|
msg = ("CUDA is disabled due to setting NUMBA_DISABLE_CUDA=1 "
|
||
|
"in the environment, or because CUDA is unsupported on "
|
||
|
"32-bit systems.")
|
||
|
raise CudaSupportError(msg)
|
||
|
self.lib = open_cudalib('cudart')
|
||
|
|
||
|
self.is_initialized = True
|
||
|
|
||
|
def __getattr__(self, fname):
|
||
|
# First request of a runtime API function
|
||
|
try:
|
||
|
proto = API_PROTOTYPES[fname]
|
||
|
except KeyError:
|
||
|
raise AttributeError(fname)
|
||
|
restype = proto[0]
|
||
|
argtypes = proto[1:]
|
||
|
|
||
|
if not self.is_initialized:
|
||
|
self._initialize()
|
||
|
|
||
|
# Find function in runtime library
|
||
|
libfn = self._find_api(fname)
|
||
|
libfn.restype = restype
|
||
|
libfn.argtypes = argtypes
|
||
|
|
||
|
safe_call = self._wrap_api_call(fname, libfn)
|
||
|
setattr(self, fname, safe_call)
|
||
|
return safe_call
|
||
|
|
||
|
def _wrap_api_call(self, fname, libfn):
|
||
|
@functools.wraps(libfn)
|
||
|
def safe_cuda_api_call(*args):
|
||
|
_logger.debug('call runtime api: %s', libfn.__name__)
|
||
|
retcode = libfn(*args)
|
||
|
self._check_error(fname, retcode)
|
||
|
return safe_cuda_api_call
|
||
|
|
||
|
def _check_error(self, fname, retcode):
|
||
|
if retcode != enums.CUDA_SUCCESS:
|
||
|
errname = ERROR_MAP.get(retcode, "cudaErrorUnknown")
|
||
|
msg = "Call to %s results in %s" % (fname, errname)
|
||
|
_logger.error(msg)
|
||
|
raise CudaRuntimeAPIError(retcode, msg)
|
||
|
|
||
|
def _find_api(self, fname):
|
||
|
try:
|
||
|
return getattr(self.lib, fname)
|
||
|
except AttributeError:
|
||
|
pass
|
||
|
|
||
|
# Not found.
|
||
|
# Delay missing function error to use
|
||
|
def absent_function(*args, **kws):
|
||
|
msg = "runtime missing function: %s."
|
||
|
raise CudaRuntimeError(msg % fname)
|
||
|
|
||
|
setattr(self, fname, absent_function)
|
||
|
return absent_function
|
||
|
|
||
|
def get_version(self):
|
||
|
"""
|
||
|
Returns the CUDA Runtime version as a tuple (major, minor).
|
||
|
"""
|
||
|
rtver = ctypes.c_int()
|
||
|
self.cudaRuntimeGetVersion(ctypes.byref(rtver))
|
||
|
# The version is encoded as (1000 * major) + (10 * minor)
|
||
|
major = rtver.value // 1000
|
||
|
minor = (rtver.value - (major * 1000)) // 10
|
||
|
return (major, minor)
|
||
|
|
||
|
def is_supported_version(self):
|
||
|
"""
|
||
|
Returns True if the CUDA Runtime is a supported version.
|
||
|
"""
|
||
|
|
||
|
return self.get_version() in self.supported_versions
|
||
|
|
||
|
@property
|
||
|
def supported_versions(self):
|
||
|
"""A tuple of all supported CUDA toolkit versions. Versions are given in
|
||
|
the form ``(major_version, minor_version)``."""
|
||
|
if sys.platform not in ('linux', 'win32') or config.MACHINE_BITS != 64:
|
||
|
# Only 64-bit Linux and Windows are supported
|
||
|
return ()
|
||
|
return ((11, 0), (11, 1), (11, 2), (11, 3), (11, 4), (11, 5), (11, 6),
|
||
|
(11, 7))
|
||
|
|
||
|
|
||
|
runtime = Runtime()
|
||
|
|
||
|
|
||
|
def get_version():
|
||
|
"""
|
||
|
Return the runtime version as a tuple of (major, minor)
|
||
|
"""
|
||
|
return runtime.get_version()
|