ai-content-maker/.venv/Lib/site-packages/numba/cuda/libdeviceimpl.py

84 lines
2.7 KiB
Python

from llvmlite import ir
from numba.core import cgutils, types
from numba.core.imputils import Registry
from numba.cuda import libdevice, libdevicefuncs
registry = Registry()
lower = registry.lower
def libdevice_implement(func, retty, nbargs):
def core(context, builder, sig, args):
lmod = builder.module
fretty = context.get_value_type(retty)
fargtys = [context.get_value_type(arg.ty) for arg in nbargs]
fnty = ir.FunctionType(fretty, fargtys)
fn = cgutils.get_or_insert_function(lmod, fnty, func)
return builder.call(fn, args)
key = getattr(libdevice, func[5:])
argtys = [arg.ty for arg in args if not arg.is_ptr]
lower(key, *argtys)(core)
def libdevice_implement_multiple_returns(func, retty, prototype_args):
sig = libdevicefuncs.create_signature(retty, prototype_args)
nb_retty = sig.return_type
def core(context, builder, sig, args):
lmod = builder.module
fargtys = []
for arg in prototype_args:
ty = context.get_value_type(arg.ty)
if arg.is_ptr:
ty = ty.as_pointer()
fargtys.append(ty)
fretty = context.get_value_type(retty)
fnty = ir.FunctionType(fretty, fargtys)
fn = cgutils.get_or_insert_function(lmod, fnty, func)
# For returned values that are returned through a pointer, we need to
# allocate variables on the stack and pass a pointer to them.
actual_args = []
virtual_args = []
arg_idx = 0
for arg in prototype_args:
if arg.is_ptr:
# Allocate space for return value and add to args
tmp_arg = cgutils.alloca_once(builder,
context.get_value_type(arg.ty))
actual_args.append(tmp_arg)
virtual_args.append(tmp_arg)
else:
actual_args.append(args[arg_idx])
arg_idx += 1
ret = builder.call(fn, actual_args)
# Following the call, we need to assemble the returned values into a
# tuple for returning back to the caller.
tuple_args = []
if retty != types.void:
tuple_args.append(ret)
for arg in virtual_args:
tuple_args.append(builder.load(arg))
if isinstance(nb_retty, types.UniTuple):
return cgutils.pack_array(builder, tuple_args)
else:
return cgutils.pack_struct(builder, tuple_args)
key = getattr(libdevice, func[5:])
lower(key, *sig.args)(core)
for func, (retty, args) in libdevicefuncs.functions.items():
if any([arg.is_ptr for arg in args]):
libdevice_implement_multiple_returns(func, retty, args)
else:
libdevice_implement(func, retty, args)