84 lines
2.7 KiB
Python
84 lines
2.7 KiB
Python
from llvmlite import ir
|
|
from numba.core import cgutils, types
|
|
from numba.core.imputils import Registry
|
|
from numba.cuda import libdevice, libdevicefuncs
|
|
|
|
registry = Registry()
|
|
lower = registry.lower
|
|
|
|
|
|
def libdevice_implement(func, retty, nbargs):
|
|
def core(context, builder, sig, args):
|
|
lmod = builder.module
|
|
fretty = context.get_value_type(retty)
|
|
fargtys = [context.get_value_type(arg.ty) for arg in nbargs]
|
|
fnty = ir.FunctionType(fretty, fargtys)
|
|
fn = cgutils.get_or_insert_function(lmod, fnty, func)
|
|
return builder.call(fn, args)
|
|
|
|
key = getattr(libdevice, func[5:])
|
|
|
|
argtys = [arg.ty for arg in args if not arg.is_ptr]
|
|
lower(key, *argtys)(core)
|
|
|
|
|
|
def libdevice_implement_multiple_returns(func, retty, prototype_args):
|
|
sig = libdevicefuncs.create_signature(retty, prototype_args)
|
|
nb_retty = sig.return_type
|
|
|
|
def core(context, builder, sig, args):
|
|
lmod = builder.module
|
|
|
|
fargtys = []
|
|
for arg in prototype_args:
|
|
ty = context.get_value_type(arg.ty)
|
|
if arg.is_ptr:
|
|
ty = ty.as_pointer()
|
|
fargtys.append(ty)
|
|
|
|
fretty = context.get_value_type(retty)
|
|
|
|
fnty = ir.FunctionType(fretty, fargtys)
|
|
fn = cgutils.get_or_insert_function(lmod, fnty, func)
|
|
|
|
# For returned values that are returned through a pointer, we need to
|
|
# allocate variables on the stack and pass a pointer to them.
|
|
actual_args = []
|
|
virtual_args = []
|
|
arg_idx = 0
|
|
for arg in prototype_args:
|
|
if arg.is_ptr:
|
|
# Allocate space for return value and add to args
|
|
tmp_arg = cgutils.alloca_once(builder,
|
|
context.get_value_type(arg.ty))
|
|
actual_args.append(tmp_arg)
|
|
virtual_args.append(tmp_arg)
|
|
else:
|
|
actual_args.append(args[arg_idx])
|
|
arg_idx += 1
|
|
|
|
ret = builder.call(fn, actual_args)
|
|
|
|
# Following the call, we need to assemble the returned values into a
|
|
# tuple for returning back to the caller.
|
|
tuple_args = []
|
|
if retty != types.void:
|
|
tuple_args.append(ret)
|
|
for arg in virtual_args:
|
|
tuple_args.append(builder.load(arg))
|
|
|
|
if isinstance(nb_retty, types.UniTuple):
|
|
return cgutils.pack_array(builder, tuple_args)
|
|
else:
|
|
return cgutils.pack_struct(builder, tuple_args)
|
|
|
|
key = getattr(libdevice, func[5:])
|
|
lower(key, *sig.args)(core)
|
|
|
|
|
|
for func, (retty, args) in libdevicefuncs.functions.items():
|
|
if any([arg.is_ptr for arg in args]):
|
|
libdevice_implement_multiple_returns(func, retty, args)
|
|
else:
|
|
libdevice_implement(func, retty, args)
|