ai-content-maker/.venv/Lib/site-packages/numba/cuda/cudadrv/nvrtc.py

261 lines
9.5 KiB
Python
Raw Normal View History

2024-05-03 04:18:51 +03:00
from ctypes import byref, c_char, c_char_p, c_int, c_size_t, c_void_p, POINTER
from enum import IntEnum
from numba.core import config
from numba.cuda.cudadrv.error import (NvrtcError, NvrtcCompilationError,
NvrtcSupportError)
import functools
import os
import threading
import warnings
# Opaque handle for compilation unit
nvrtc_program = c_void_p
# Result code
nvrtc_result = c_int
class NvrtcResult(IntEnum):
NVRTC_SUCCESS = 0
NVRTC_ERROR_OUT_OF_MEMORY = 1
NVRTC_ERROR_PROGRAM_CREATION_FAILURE = 2
NVRTC_ERROR_INVALID_INPUT = 3
NVRTC_ERROR_INVALID_PROGRAM = 4
NVRTC_ERROR_INVALID_OPTION = 5
NVRTC_ERROR_COMPILATION = 6
NVRTC_ERROR_BUILTIN_OPERATION_FAILURE = 7
NVRTC_ERROR_NO_NAME_EXPRESSIONS_AFTER_COMPILATION = 8
NVRTC_ERROR_NO_LOWERED_NAMES_BEFORE_COMPILATION = 9
NVRTC_ERROR_NAME_EXPRESSION_NOT_VALID = 10
NVRTC_ERROR_INTERNAL_ERROR = 11
_nvrtc_lock = threading.Lock()
class NvrtcProgram:
"""
A class for managing the lifetime of nvrtcProgram instances. Instances of
the class own an nvrtcProgram; when an instance is deleted, the underlying
nvrtcProgram is destroyed using the appropriate NVRTC API.
"""
def __init__(self, nvrtc, handle):
self._nvrtc = nvrtc
self._handle = handle
@property
def handle(self):
return self._handle
def __del__(self):
if self._handle:
self._nvrtc.destroy_program(self)
class NVRTC:
"""
Provides a Pythonic interface to the NVRTC APIs, abstracting away the C API
calls.
The sole instance of this class is a process-wide singleton, similar to the
NVVM interface. Initialization is protected by a lock and uses the standard
(for Numba) open_cudalib function to load the NVRTC library.
"""
_PROTOTYPES = {
# nvrtcResult nvrtcVersion(int *major, int *minor)
'nvrtcVersion': (nvrtc_result, POINTER(c_int), POINTER(c_int)),
# nvrtcResult nvrtcCreateProgram(nvrtcProgram *prog,
# const char *src,
# const char *name,
# int numHeaders,
# const char * const *headers,
# const char * const *includeNames)
'nvrtcCreateProgram': (nvrtc_result, nvrtc_program, c_char_p, c_char_p,
c_int, POINTER(c_char_p), POINTER(c_char_p)),
# nvrtcResult nvrtcDestroyProgram(nvrtcProgram *prog);
'nvrtcDestroyProgram': (nvrtc_result, POINTER(nvrtc_program)),
# nvrtcResult nvrtcCompileProgram(nvrtcProgram prog,
# int numOptions,
# const char * const *options)
'nvrtcCompileProgram': (nvrtc_result, nvrtc_program, c_int,
POINTER(c_char_p)),
# nvrtcResult nvrtcGetPTXSize(nvrtcProgram prog, size_t *ptxSizeRet);
'nvrtcGetPTXSize': (nvrtc_result, nvrtc_program, POINTER(c_size_t)),
# nvrtcResult nvrtcGetPTX(nvrtcProgram prog, char *ptx);
'nvrtcGetPTX': (nvrtc_result, nvrtc_program, c_char_p),
# nvrtcResult nvrtcGetCUBINSize(nvrtcProgram prog,
# size_t *cubinSizeRet);
'nvrtcGetCUBINSize': (nvrtc_result, nvrtc_program, POINTER(c_size_t)),
# nvrtcResult nvrtcGetCUBIN(nvrtcProgram prog, char *cubin);
'nvrtcGetCUBIN': (nvrtc_result, nvrtc_program, c_char_p),
# nvrtcResult nvrtcGetProgramLogSize(nvrtcProgram prog,
# size_t *logSizeRet);
'nvrtcGetProgramLogSize': (nvrtc_result, nvrtc_program,
POINTER(c_size_t)),
# nvrtcResult nvrtcGetProgramLog(nvrtcProgram prog, char *log);
'nvrtcGetProgramLog': (nvrtc_result, nvrtc_program, c_char_p),
}
# Singleton reference
__INSTANCE = None
def __new__(cls):
with _nvrtc_lock:
if cls.__INSTANCE is None:
from numba.cuda.cudadrv.libs import open_cudalib
cls.__INSTANCE = inst = object.__new__(cls)
try:
lib = open_cudalib('nvrtc')
except OSError as e:
cls.__INSTANCE = None
raise NvrtcSupportError("NVRTC cannot be loaded") from e
# Find & populate functions
for name, proto in inst._PROTOTYPES.items():
func = getattr(lib, name)
func.restype = proto[0]
func.argtypes = proto[1:]
@functools.wraps(func)
def checked_call(*args, func=func, name=name):
error = func(*args)
if error == NvrtcResult.NVRTC_ERROR_COMPILATION:
raise NvrtcCompilationError()
elif error != NvrtcResult.NVRTC_SUCCESS:
try:
error_name = NvrtcResult(error).name
except ValueError:
error_name = ('Unknown nvrtc_result '
f'(error code: {error})')
msg = f'Failed to call {name}: {error_name}'
raise NvrtcError(msg)
setattr(inst, name, checked_call)
return cls.__INSTANCE
def get_version(self):
"""
Get the NVRTC version as a tuple (major, minor).
"""
major = c_int()
minor = c_int()
self.nvrtcVersion(byref(major), byref(minor))
return major.value, minor.value
def create_program(self, src, name):
"""
Create an NVRTC program with managed lifetime.
"""
if isinstance(src, str):
src = src.encode()
if isinstance(name, str):
name = name.encode()
handle = nvrtc_program()
# The final three arguments are for passing the contents of headers -
# this is not supported, so there are 0 headers and the header names
# and contents are null.
self.nvrtcCreateProgram(byref(handle), src, name, 0, None, None)
return NvrtcProgram(self, handle)
def compile_program(self, program, options):
"""
Compile an NVRTC program. Compilation may fail due to a user error in
the source; this function returns ``True`` if there is a compilation
error and ``False`` on success.
"""
# We hold a list of encoded options to ensure they can't be collected
# prior to the call to nvrtcCompileProgram
encoded_options = [opt.encode() for opt in options]
option_pointers = [c_char_p(opt) for opt in encoded_options]
c_options_type = (c_char_p * len(options))
c_options = c_options_type(*option_pointers)
try:
self.nvrtcCompileProgram(program.handle, len(options), c_options)
return False
except NvrtcCompilationError:
return True
def destroy_program(self, program):
"""
Destroy an NVRTC program.
"""
self.nvrtcDestroyProgram(byref(program.handle))
def get_compile_log(self, program):
"""
Get the compile log as a Python string.
"""
log_size = c_size_t()
self.nvrtcGetProgramLogSize(program.handle, byref(log_size))
log = (c_char * log_size.value)()
self.nvrtcGetProgramLog(program.handle, log)
return log.value.decode()
def get_ptx(self, program):
"""
Get the compiled PTX as a Python string.
"""
ptx_size = c_size_t()
self.nvrtcGetPTXSize(program.handle, byref(ptx_size))
ptx = (c_char * ptx_size.value)()
self.nvrtcGetPTX(program.handle, ptx)
return ptx.value.decode()
def compile(src, name, cc):
"""
Compile a CUDA C/C++ source to PTX for a given compute capability.
:param src: The source code to compile
:type src: str
:param name: The filename of the source (for information only)
:type name: str
:param cc: A tuple ``(major, minor)`` of the compute capability
:type cc: tuple
:return: The compiled PTX and compilation log
:rtype: tuple
"""
nvrtc = NVRTC()
program = nvrtc.create_program(src, name)
# Compilation options:
# - Compile for the current device's compute capability.
# - The CUDA include path is added.
# - Relocatable Device Code (rdc) is needed to prevent device functions
# being optimized away.
major, minor = cc
arch = f'--gpu-architecture=compute_{major}{minor}'
include = f'-I{config.CUDA_INCLUDE_PATH}'
cudadrv_path = os.path.dirname(os.path.abspath(__file__))
numba_cuda_path = os.path.dirname(cudadrv_path)
numba_include = f'-I{numba_cuda_path}'
options = [arch, include, numba_include, '-rdc', 'true']
# Compile the program
compile_error = nvrtc.compile_program(program, options)
# Get log from compilation
log = nvrtc.get_compile_log(program)
# If the compile failed, provide the log in an exception
if compile_error:
msg = (f'NVRTC Compilation failure whilst compiling {name}:\n\n{log}')
raise NvrtcError(msg)
# Otherwise, if there's any content in the log, present it as a warning
if log:
msg = (f"NVRTC log messages whilst compiling {name}:\n\n{log}")
warnings.warn(msg)
ptx = nvrtc.get_ptx(program)
return ptx, log