261 lines
9.5 KiB
Python
261 lines
9.5 KiB
Python
from ctypes import byref, c_char, c_char_p, c_int, c_size_t, c_void_p, POINTER
|
|
from enum import IntEnum
|
|
from numba.core import config
|
|
from numba.cuda.cudadrv.error import (NvrtcError, NvrtcCompilationError,
|
|
NvrtcSupportError)
|
|
|
|
import functools
|
|
import os
|
|
import threading
|
|
import warnings
|
|
|
|
# Opaque handle for compilation unit
|
|
nvrtc_program = c_void_p
|
|
|
|
# Result code
|
|
nvrtc_result = c_int
|
|
|
|
|
|
class NvrtcResult(IntEnum):
|
|
NVRTC_SUCCESS = 0
|
|
NVRTC_ERROR_OUT_OF_MEMORY = 1
|
|
NVRTC_ERROR_PROGRAM_CREATION_FAILURE = 2
|
|
NVRTC_ERROR_INVALID_INPUT = 3
|
|
NVRTC_ERROR_INVALID_PROGRAM = 4
|
|
NVRTC_ERROR_INVALID_OPTION = 5
|
|
NVRTC_ERROR_COMPILATION = 6
|
|
NVRTC_ERROR_BUILTIN_OPERATION_FAILURE = 7
|
|
NVRTC_ERROR_NO_NAME_EXPRESSIONS_AFTER_COMPILATION = 8
|
|
NVRTC_ERROR_NO_LOWERED_NAMES_BEFORE_COMPILATION = 9
|
|
NVRTC_ERROR_NAME_EXPRESSION_NOT_VALID = 10
|
|
NVRTC_ERROR_INTERNAL_ERROR = 11
|
|
|
|
|
|
_nvrtc_lock = threading.Lock()
|
|
|
|
|
|
class NvrtcProgram:
|
|
"""
|
|
A class for managing the lifetime of nvrtcProgram instances. Instances of
|
|
the class own an nvrtcProgram; when an instance is deleted, the underlying
|
|
nvrtcProgram is destroyed using the appropriate NVRTC API.
|
|
"""
|
|
def __init__(self, nvrtc, handle):
|
|
self._nvrtc = nvrtc
|
|
self._handle = handle
|
|
|
|
@property
|
|
def handle(self):
|
|
return self._handle
|
|
|
|
def __del__(self):
|
|
if self._handle:
|
|
self._nvrtc.destroy_program(self)
|
|
|
|
|
|
class NVRTC:
|
|
"""
|
|
Provides a Pythonic interface to the NVRTC APIs, abstracting away the C API
|
|
calls.
|
|
|
|
The sole instance of this class is a process-wide singleton, similar to the
|
|
NVVM interface. Initialization is protected by a lock and uses the standard
|
|
(for Numba) open_cudalib function to load the NVRTC library.
|
|
"""
|
|
_PROTOTYPES = {
|
|
# nvrtcResult nvrtcVersion(int *major, int *minor)
|
|
'nvrtcVersion': (nvrtc_result, POINTER(c_int), POINTER(c_int)),
|
|
# nvrtcResult nvrtcCreateProgram(nvrtcProgram *prog,
|
|
# const char *src,
|
|
# const char *name,
|
|
# int numHeaders,
|
|
# const char * const *headers,
|
|
# const char * const *includeNames)
|
|
'nvrtcCreateProgram': (nvrtc_result, nvrtc_program, c_char_p, c_char_p,
|
|
c_int, POINTER(c_char_p), POINTER(c_char_p)),
|
|
# nvrtcResult nvrtcDestroyProgram(nvrtcProgram *prog);
|
|
'nvrtcDestroyProgram': (nvrtc_result, POINTER(nvrtc_program)),
|
|
# nvrtcResult nvrtcCompileProgram(nvrtcProgram prog,
|
|
# int numOptions,
|
|
# const char * const *options)
|
|
'nvrtcCompileProgram': (nvrtc_result, nvrtc_program, c_int,
|
|
POINTER(c_char_p)),
|
|
# nvrtcResult nvrtcGetPTXSize(nvrtcProgram prog, size_t *ptxSizeRet);
|
|
'nvrtcGetPTXSize': (nvrtc_result, nvrtc_program, POINTER(c_size_t)),
|
|
# nvrtcResult nvrtcGetPTX(nvrtcProgram prog, char *ptx);
|
|
'nvrtcGetPTX': (nvrtc_result, nvrtc_program, c_char_p),
|
|
# nvrtcResult nvrtcGetCUBINSize(nvrtcProgram prog,
|
|
# size_t *cubinSizeRet);
|
|
'nvrtcGetCUBINSize': (nvrtc_result, nvrtc_program, POINTER(c_size_t)),
|
|
# nvrtcResult nvrtcGetCUBIN(nvrtcProgram prog, char *cubin);
|
|
'nvrtcGetCUBIN': (nvrtc_result, nvrtc_program, c_char_p),
|
|
# nvrtcResult nvrtcGetProgramLogSize(nvrtcProgram prog,
|
|
# size_t *logSizeRet);
|
|
'nvrtcGetProgramLogSize': (nvrtc_result, nvrtc_program,
|
|
POINTER(c_size_t)),
|
|
# nvrtcResult nvrtcGetProgramLog(nvrtcProgram prog, char *log);
|
|
'nvrtcGetProgramLog': (nvrtc_result, nvrtc_program, c_char_p),
|
|
}
|
|
|
|
# Singleton reference
|
|
__INSTANCE = None
|
|
|
|
def __new__(cls):
|
|
with _nvrtc_lock:
|
|
if cls.__INSTANCE is None:
|
|
from numba.cuda.cudadrv.libs import open_cudalib
|
|
cls.__INSTANCE = inst = object.__new__(cls)
|
|
try:
|
|
lib = open_cudalib('nvrtc')
|
|
except OSError as e:
|
|
cls.__INSTANCE = None
|
|
raise NvrtcSupportError("NVRTC cannot be loaded") from e
|
|
|
|
# Find & populate functions
|
|
for name, proto in inst._PROTOTYPES.items():
|
|
func = getattr(lib, name)
|
|
func.restype = proto[0]
|
|
func.argtypes = proto[1:]
|
|
|
|
@functools.wraps(func)
|
|
def checked_call(*args, func=func, name=name):
|
|
error = func(*args)
|
|
if error == NvrtcResult.NVRTC_ERROR_COMPILATION:
|
|
raise NvrtcCompilationError()
|
|
elif error != NvrtcResult.NVRTC_SUCCESS:
|
|
try:
|
|
error_name = NvrtcResult(error).name
|
|
except ValueError:
|
|
error_name = ('Unknown nvrtc_result '
|
|
f'(error code: {error})')
|
|
msg = f'Failed to call {name}: {error_name}'
|
|
raise NvrtcError(msg)
|
|
|
|
setattr(inst, name, checked_call)
|
|
|
|
return cls.__INSTANCE
|
|
|
|
def get_version(self):
|
|
"""
|
|
Get the NVRTC version as a tuple (major, minor).
|
|
"""
|
|
major = c_int()
|
|
minor = c_int()
|
|
self.nvrtcVersion(byref(major), byref(minor))
|
|
return major.value, minor.value
|
|
|
|
def create_program(self, src, name):
|
|
"""
|
|
Create an NVRTC program with managed lifetime.
|
|
"""
|
|
if isinstance(src, str):
|
|
src = src.encode()
|
|
if isinstance(name, str):
|
|
name = name.encode()
|
|
|
|
handle = nvrtc_program()
|
|
|
|
# The final three arguments are for passing the contents of headers -
|
|
# this is not supported, so there are 0 headers and the header names
|
|
# and contents are null.
|
|
self.nvrtcCreateProgram(byref(handle), src, name, 0, None, None)
|
|
return NvrtcProgram(self, handle)
|
|
|
|
def compile_program(self, program, options):
|
|
"""
|
|
Compile an NVRTC program. Compilation may fail due to a user error in
|
|
the source; this function returns ``True`` if there is a compilation
|
|
error and ``False`` on success.
|
|
"""
|
|
# We hold a list of encoded options to ensure they can't be collected
|
|
# prior to the call to nvrtcCompileProgram
|
|
encoded_options = [opt.encode() for opt in options]
|
|
option_pointers = [c_char_p(opt) for opt in encoded_options]
|
|
c_options_type = (c_char_p * len(options))
|
|
c_options = c_options_type(*option_pointers)
|
|
try:
|
|
self.nvrtcCompileProgram(program.handle, len(options), c_options)
|
|
return False
|
|
except NvrtcCompilationError:
|
|
return True
|
|
|
|
def destroy_program(self, program):
|
|
"""
|
|
Destroy an NVRTC program.
|
|
"""
|
|
self.nvrtcDestroyProgram(byref(program.handle))
|
|
|
|
def get_compile_log(self, program):
|
|
"""
|
|
Get the compile log as a Python string.
|
|
"""
|
|
log_size = c_size_t()
|
|
self.nvrtcGetProgramLogSize(program.handle, byref(log_size))
|
|
|
|
log = (c_char * log_size.value)()
|
|
self.nvrtcGetProgramLog(program.handle, log)
|
|
|
|
return log.value.decode()
|
|
|
|
def get_ptx(self, program):
|
|
"""
|
|
Get the compiled PTX as a Python string.
|
|
"""
|
|
ptx_size = c_size_t()
|
|
self.nvrtcGetPTXSize(program.handle, byref(ptx_size))
|
|
|
|
ptx = (c_char * ptx_size.value)()
|
|
self.nvrtcGetPTX(program.handle, ptx)
|
|
|
|
return ptx.value.decode()
|
|
|
|
|
|
def compile(src, name, cc):
|
|
"""
|
|
Compile a CUDA C/C++ source to PTX for a given compute capability.
|
|
|
|
:param src: The source code to compile
|
|
:type src: str
|
|
:param name: The filename of the source (for information only)
|
|
:type name: str
|
|
:param cc: A tuple ``(major, minor)`` of the compute capability
|
|
:type cc: tuple
|
|
:return: The compiled PTX and compilation log
|
|
:rtype: tuple
|
|
"""
|
|
nvrtc = NVRTC()
|
|
program = nvrtc.create_program(src, name)
|
|
|
|
# Compilation options:
|
|
# - Compile for the current device's compute capability.
|
|
# - The CUDA include path is added.
|
|
# - Relocatable Device Code (rdc) is needed to prevent device functions
|
|
# being optimized away.
|
|
major, minor = cc
|
|
arch = f'--gpu-architecture=compute_{major}{minor}'
|
|
include = f'-I{config.CUDA_INCLUDE_PATH}'
|
|
|
|
cudadrv_path = os.path.dirname(os.path.abspath(__file__))
|
|
numba_cuda_path = os.path.dirname(cudadrv_path)
|
|
numba_include = f'-I{numba_cuda_path}'
|
|
options = [arch, include, numba_include, '-rdc', 'true']
|
|
|
|
# Compile the program
|
|
compile_error = nvrtc.compile_program(program, options)
|
|
|
|
# Get log from compilation
|
|
log = nvrtc.get_compile_log(program)
|
|
|
|
# If the compile failed, provide the log in an exception
|
|
if compile_error:
|
|
msg = (f'NVRTC Compilation failure whilst compiling {name}:\n\n{log}')
|
|
raise NvrtcError(msg)
|
|
|
|
# Otherwise, if there's any content in the log, present it as a warning
|
|
if log:
|
|
msg = (f"NVRTC log messages whilst compiling {name}:\n\n{log}")
|
|
warnings.warn(msg)
|
|
|
|
ptx = nvrtc.get_ptx(program)
|
|
return ptx, log
|