449 lines
14 KiB
Python
449 lines
14 KiB
Python
import math
|
|
import operator
|
|
from llvmlite import ir
|
|
from numba.core import types, typing, cgutils, targetconfig
|
|
from numba.core.imputils import Registry
|
|
from numba.types import float32, float64, int64, uint64
|
|
from numba.cuda import libdevice
|
|
from numba import cuda
|
|
|
|
registry = Registry()
|
|
lower = registry.lower
|
|
|
|
|
|
booleans = []
|
|
booleans += [('isnand', 'isnanf', math.isnan)]
|
|
booleans += [('isinfd', 'isinff', math.isinf)]
|
|
booleans += [('isfinited', 'finitef', math.isfinite)]
|
|
|
|
unarys = []
|
|
unarys += [('ceil', 'ceilf', math.ceil)]
|
|
unarys += [('floor', 'floorf', math.floor)]
|
|
unarys += [('fabs', 'fabsf', math.fabs)]
|
|
unarys += [('exp', 'expf', math.exp)]
|
|
unarys += [('expm1', 'expm1f', math.expm1)]
|
|
unarys += [('erf', 'erff', math.erf)]
|
|
unarys += [('erfc', 'erfcf', math.erfc)]
|
|
unarys += [('tgamma', 'tgammaf', math.gamma)]
|
|
unarys += [('lgamma', 'lgammaf', math.lgamma)]
|
|
unarys += [('sqrt', 'sqrtf', math.sqrt)]
|
|
unarys += [('log', 'logf', math.log)]
|
|
unarys += [('log2', 'log2f', math.log2)]
|
|
unarys += [('log10', 'log10f', math.log10)]
|
|
unarys += [('log1p', 'log1pf', math.log1p)]
|
|
unarys += [('acosh', 'acoshf', math.acosh)]
|
|
unarys += [('acos', 'acosf', math.acos)]
|
|
unarys += [('cos', 'cosf', math.cos)]
|
|
unarys += [('cosh', 'coshf', math.cosh)]
|
|
unarys += [('asinh', 'asinhf', math.asinh)]
|
|
unarys += [('asin', 'asinf', math.asin)]
|
|
unarys += [('sin', 'sinf', math.sin)]
|
|
unarys += [('sinh', 'sinhf', math.sinh)]
|
|
unarys += [('atan', 'atanf', math.atan)]
|
|
unarys += [('atanh', 'atanhf', math.atanh)]
|
|
unarys += [('tan', 'tanf', math.tan)]
|
|
unarys += [('trunc', 'truncf', math.trunc)]
|
|
|
|
unarys_fastmath = {}
|
|
unarys_fastmath['cosf'] = 'fast_cosf'
|
|
unarys_fastmath['sinf'] = 'fast_sinf'
|
|
unarys_fastmath['tanf'] = 'fast_tanf'
|
|
unarys_fastmath['expf'] = 'fast_expf'
|
|
unarys_fastmath['log2f'] = 'fast_log2f'
|
|
unarys_fastmath['log10f'] = 'fast_log10f'
|
|
unarys_fastmath['logf'] = 'fast_logf'
|
|
|
|
binarys = []
|
|
binarys += [('copysign', 'copysignf', math.copysign)]
|
|
binarys += [('atan2', 'atan2f', math.atan2)]
|
|
binarys += [('pow', 'powf', math.pow)]
|
|
binarys += [('fmod', 'fmodf', math.fmod)]
|
|
binarys += [('hypot', 'hypotf', math.hypot)]
|
|
binarys += [('remainder', 'remainderf', math.remainder)]
|
|
|
|
binarys_fastmath = {}
|
|
binarys_fastmath['powf'] = 'fast_powf'
|
|
|
|
|
|
@lower(math.isinf, types.Integer)
|
|
@lower(math.isnan, types.Integer)
|
|
def math_isinf_isnan_int(context, builder, sig, args):
|
|
return context.get_constant(types.boolean, 0)
|
|
|
|
|
|
@lower(operator.truediv, types.float32, types.float32)
|
|
def maybe_fast_truediv(context, builder, sig, args):
|
|
if context.fastmath:
|
|
sig = typing.signature(float32, float32, float32)
|
|
impl = context.get_function(libdevice.fast_fdividef, sig)
|
|
return impl(builder, args)
|
|
else:
|
|
with cgutils.if_zero(builder, args[1]):
|
|
context.error_model.fp_zero_division(builder, ("division by zero",))
|
|
res = builder.fdiv(*args)
|
|
return res
|
|
|
|
|
|
@lower(math.isfinite, types.Integer)
|
|
def math_isfinite_int(context, builder, sig, args):
|
|
return context.get_constant(types.boolean, 1)
|
|
|
|
|
|
@lower(math.sin, types.float16)
|
|
def fp16_sin_impl(context, builder, sig, args):
|
|
def fp16_sin(x):
|
|
return cuda.fp16.hsin(x)
|
|
|
|
return context.compile_internal(builder, fp16_sin, sig, args)
|
|
|
|
|
|
@lower(math.cos, types.float16)
|
|
def fp16_cos_impl(context, builder, sig, args):
|
|
def fp16_cos(x):
|
|
return cuda.fp16.hcos(x)
|
|
|
|
return context.compile_internal(builder, fp16_cos, sig, args)
|
|
|
|
|
|
@lower(math.log, types.float16)
|
|
def fp16_log_impl(context, builder, sig, args):
|
|
def fp16_log(x):
|
|
return cuda.fp16.hlog(x)
|
|
|
|
return context.compile_internal(builder, fp16_log, sig, args)
|
|
|
|
|
|
@lower(math.log10, types.float16)
|
|
def fp16_log10_impl(context, builder, sig, args):
|
|
def fp16_log10(x):
|
|
return cuda.fp16.hlog10(x)
|
|
|
|
return context.compile_internal(builder, fp16_log10, sig, args)
|
|
|
|
|
|
@lower(math.log2, types.float16)
|
|
def fp16_log2_impl(context, builder, sig, args):
|
|
def fp16_log2(x):
|
|
return cuda.fp16.hlog2(x)
|
|
|
|
return context.compile_internal(builder, fp16_log2, sig, args)
|
|
|
|
|
|
@lower(math.exp, types.float16)
|
|
def fp16_exp_impl(context, builder, sig, args):
|
|
def fp16_exp(x):
|
|
return cuda.fp16.hexp(x)
|
|
|
|
return context.compile_internal(builder, fp16_exp, sig, args)
|
|
|
|
|
|
@lower(math.floor, types.float16)
|
|
def fp16_floor_impl(context, builder, sig, args):
|
|
def fp16_floor(x):
|
|
return cuda.fp16.hfloor(x)
|
|
|
|
return context.compile_internal(builder, fp16_floor, sig, args)
|
|
|
|
|
|
@lower(math.ceil, types.float16)
|
|
def fp16_ceil_impl(context, builder, sig, args):
|
|
def fp16_ceil(x):
|
|
return cuda.fp16.hceil(x)
|
|
|
|
return context.compile_internal(builder, fp16_ceil, sig, args)
|
|
|
|
|
|
@lower(math.sqrt, types.float16)
|
|
def fp16_sqrt_impl(context, builder, sig, args):
|
|
def fp16_sqrt(x):
|
|
return cuda.fp16.hsqrt(x)
|
|
|
|
return context.compile_internal(builder, fp16_sqrt, sig, args)
|
|
|
|
|
|
@lower(math.fabs, types.float16)
|
|
def fp16_fabs_impl(context, builder, sig, args):
|
|
def fp16_fabs(x):
|
|
return cuda.fp16.habs(x)
|
|
|
|
return context.compile_internal(builder, fp16_fabs, sig, args)
|
|
|
|
|
|
@lower(math.trunc, types.float16)
|
|
def fp16_trunc_impl(context, builder, sig, args):
|
|
def fp16_trunc(x):
|
|
return cuda.fp16.htrunc(x)
|
|
|
|
return context.compile_internal(builder, fp16_trunc, sig, args)
|
|
|
|
|
|
def impl_boolean(key, ty, libfunc):
|
|
def lower_boolean_impl(context, builder, sig, args):
|
|
libfunc_impl = context.get_function(libfunc,
|
|
typing.signature(types.int32, ty))
|
|
result = libfunc_impl(builder, args)
|
|
return context.cast(builder, result, types.int32, types.boolean)
|
|
|
|
lower(key, ty)(lower_boolean_impl)
|
|
|
|
|
|
def get_lower_unary_impl(key, ty, libfunc):
|
|
def lower_unary_impl(context, builder, sig, args):
|
|
actual_libfunc = libfunc
|
|
fast_replacement = None
|
|
if ty == float32 and context.fastmath:
|
|
fast_replacement = unarys_fastmath.get(libfunc.__name__)
|
|
|
|
if fast_replacement is not None:
|
|
actual_libfunc = getattr(libdevice, fast_replacement)
|
|
|
|
libfunc_impl = context.get_function(actual_libfunc,
|
|
typing.signature(ty, ty))
|
|
return libfunc_impl(builder, args)
|
|
return lower_unary_impl
|
|
|
|
|
|
def get_unary_impl_for_fn_and_ty(fn, ty):
|
|
# tanh is a special case - because it is not registered like the other
|
|
# unary implementations, it does not appear in the unarys list. However,
|
|
# its implementation can be looked up by key like the other
|
|
# implementations, so we add it to the list we search here.
|
|
tanh_impls = ('tanh', 'tanhf', math.tanh)
|
|
for fname64, fname32, key in unarys + [tanh_impls]:
|
|
if fn == key:
|
|
if ty == float32:
|
|
impl = getattr(libdevice, fname32)
|
|
elif ty == float64:
|
|
impl = getattr(libdevice, fname64)
|
|
|
|
return get_lower_unary_impl(key, ty, impl)
|
|
|
|
raise RuntimeError(f"Implementation of {fn} for {ty} not found")
|
|
|
|
|
|
def impl_unary(key, ty, libfunc):
|
|
lower_unary_impl = get_lower_unary_impl(key, ty, libfunc)
|
|
lower(key, ty)(lower_unary_impl)
|
|
|
|
|
|
def impl_unary_int(key, ty, libfunc):
|
|
def lower_unary_int_impl(context, builder, sig, args):
|
|
if sig.args[0] == int64:
|
|
convert = builder.sitofp
|
|
elif sig.args[0] == uint64:
|
|
convert = builder.uitofp
|
|
else:
|
|
m = 'Only 64-bit integers are supported for generic unary int ops'
|
|
raise TypeError(m)
|
|
|
|
arg = convert(args[0], ir.DoubleType())
|
|
sig = typing.signature(float64, float64)
|
|
libfunc_impl = context.get_function(libfunc, sig)
|
|
return libfunc_impl(builder, [arg])
|
|
|
|
lower(key, ty)(lower_unary_int_impl)
|
|
|
|
|
|
def get_lower_binary_impl(key, ty, libfunc):
|
|
def lower_binary_impl(context, builder, sig, args):
|
|
actual_libfunc = libfunc
|
|
fast_replacement = None
|
|
if ty == float32 and context.fastmath:
|
|
fast_replacement = binarys_fastmath.get(libfunc.__name__)
|
|
|
|
if fast_replacement is not None:
|
|
actual_libfunc = getattr(libdevice, fast_replacement)
|
|
|
|
libfunc_impl = context.get_function(actual_libfunc,
|
|
typing.signature(ty, ty, ty))
|
|
return libfunc_impl(builder, args)
|
|
return lower_binary_impl
|
|
|
|
|
|
def get_binary_impl_for_fn_and_ty(fn, ty):
|
|
for fname64, fname32, key in binarys:
|
|
if fn == key:
|
|
if ty == float32:
|
|
impl = getattr(libdevice, fname32)
|
|
elif ty == float64:
|
|
impl = getattr(libdevice, fname64)
|
|
|
|
return get_lower_binary_impl(key, ty, impl)
|
|
|
|
raise RuntimeError(f"Implementation of {fn} for {ty} not found")
|
|
|
|
|
|
def impl_binary(key, ty, libfunc):
|
|
lower_binary_impl = get_lower_binary_impl(key, ty, libfunc)
|
|
lower(key, ty, ty)(lower_binary_impl)
|
|
|
|
|
|
def impl_binary_int(key, ty, libfunc):
|
|
def lower_binary_int_impl(context, builder, sig, args):
|
|
if sig.args[0] == int64:
|
|
convert = builder.sitofp
|
|
elif sig.args[0] == uint64:
|
|
convert = builder.uitofp
|
|
else:
|
|
m = 'Only 64-bit integers are supported for generic binary int ops'
|
|
raise TypeError(m)
|
|
|
|
args = [convert(arg, ir.DoubleType()) for arg in args]
|
|
sig = typing.signature(float64, float64, float64)
|
|
libfunc_impl = context.get_function(libfunc, sig)
|
|
return libfunc_impl(builder, args)
|
|
|
|
lower(key, ty, ty)(lower_binary_int_impl)
|
|
|
|
|
|
for fname64, fname32, key in booleans:
|
|
impl32 = getattr(libdevice, fname32)
|
|
impl64 = getattr(libdevice, fname64)
|
|
impl_boolean(key, float32, impl32)
|
|
impl_boolean(key, float64, impl64)
|
|
|
|
|
|
for fname64, fname32, key in unarys:
|
|
impl32 = getattr(libdevice, fname32)
|
|
impl64 = getattr(libdevice, fname64)
|
|
impl_unary(key, float32, impl32)
|
|
impl_unary(key, float64, impl64)
|
|
impl_unary_int(key, int64, impl64)
|
|
impl_unary_int(key, uint64, impl64)
|
|
|
|
|
|
for fname64, fname32, key in binarys:
|
|
impl32 = getattr(libdevice, fname32)
|
|
impl64 = getattr(libdevice, fname64)
|
|
impl_binary(key, float32, impl32)
|
|
impl_binary(key, float64, impl64)
|
|
impl_binary_int(key, int64, impl64)
|
|
impl_binary_int(key, uint64, impl64)
|
|
|
|
|
|
def impl_pow_int(ty, libfunc):
|
|
def lower_pow_impl_int(context, builder, sig, args):
|
|
powi_sig = typing.signature(ty, ty, types.int32)
|
|
libfunc_impl = context.get_function(libfunc, powi_sig)
|
|
return libfunc_impl(builder, args)
|
|
|
|
lower(math.pow, ty, types.int32)(lower_pow_impl_int)
|
|
|
|
|
|
impl_pow_int(types.float32, libdevice.powif)
|
|
impl_pow_int(types.float64, libdevice.powi)
|
|
|
|
|
|
def impl_modf(ty, libfunc):
|
|
retty = types.UniTuple(ty, 2)
|
|
|
|
def lower_modf_impl(context, builder, sig, args):
|
|
modf_sig = typing.signature(retty, ty)
|
|
libfunc_impl = context.get_function(libfunc, modf_sig)
|
|
return libfunc_impl(builder, args)
|
|
|
|
lower(math.modf, ty)(lower_modf_impl)
|
|
|
|
|
|
impl_modf(types.float32, libdevice.modff)
|
|
impl_modf(types.float64, libdevice.modf)
|
|
|
|
|
|
def impl_frexp(ty, libfunc):
|
|
retty = types.Tuple((ty, types.int32))
|
|
|
|
def lower_frexp_impl(context, builder, sig, args):
|
|
frexp_sig = typing.signature(retty, ty)
|
|
libfunc_impl = context.get_function(libfunc, frexp_sig)
|
|
return libfunc_impl(builder, args)
|
|
|
|
lower(math.frexp, ty)(lower_frexp_impl)
|
|
|
|
|
|
impl_frexp(types.float32, libdevice.frexpf)
|
|
impl_frexp(types.float64, libdevice.frexp)
|
|
|
|
|
|
def impl_ldexp(ty, libfunc):
|
|
def lower_ldexp_impl(context, builder, sig, args):
|
|
ldexp_sig = typing.signature(ty, ty, types.int32)
|
|
libfunc_impl = context.get_function(libfunc, ldexp_sig)
|
|
return libfunc_impl(builder, args)
|
|
|
|
lower(math.ldexp, ty, types.int32)(lower_ldexp_impl)
|
|
|
|
|
|
impl_ldexp(types.float32, libdevice.ldexpf)
|
|
impl_ldexp(types.float64, libdevice.ldexp)
|
|
|
|
|
|
def impl_tanh(ty, libfunc):
|
|
def lower_tanh_impl(context, builder, sig, args):
|
|
def get_compute_capability():
|
|
flags = targetconfig.ConfigStack().top()
|
|
return flags.compute_capability
|
|
|
|
def tanh_impl_libdevice():
|
|
tanh_sig = typing.signature(ty, ty)
|
|
libfunc_impl = context.get_function(libfunc, tanh_sig)
|
|
return libfunc_impl(builder, args)
|
|
|
|
def tanhf_impl_fastmath():
|
|
fnty = ir.FunctionType(ir.FloatType(), [ir.FloatType()])
|
|
asm = ir.InlineAsm(fnty, 'tanh.approx.f32 $0, $1;', '=f,f')
|
|
return builder.call(asm, args)
|
|
|
|
if ty == float32 and context.fastmath:
|
|
cc = get_compute_capability()
|
|
if cc >= (7,5):
|
|
return tanhf_impl_fastmath()
|
|
|
|
return tanh_impl_libdevice()
|
|
|
|
lower(math.tanh, ty)(lower_tanh_impl)
|
|
|
|
|
|
impl_tanh(types.float32, libdevice.tanhf)
|
|
impl_tanh(types.float64, libdevice.tanh)
|
|
|
|
impl_unary_int(math.tanh, int64, libdevice.tanh)
|
|
impl_unary_int(math.tanh, uint64, libdevice.tanh)
|
|
|
|
# Complex power implementations - translations of _Py_c_pow from CPython
|
|
# https://github.com/python/cpython/blob/a755410e054e1e2390de5830befc08fe80706c66/Objects/complexobject.c#L123-L151
|
|
#
|
|
# The complex64 variant casts all constants and some variables to ensure that
|
|
# as much computation is done in single precision as possible. A small number
|
|
# of operations are still done in 64-bit, but these come from libdevice code.
|
|
|
|
|
|
def cpow_implement(fty, cty):
|
|
def core(context, builder, sig, args):
|
|
def cpow_internal(a, b):
|
|
|
|
if b.real == fty(0.0) and b.imag == fty(0.0):
|
|
return cty(1.0) + cty(0.0j)
|
|
elif a.real == fty(0.0) and b.real == fty(0.0):
|
|
return cty(0.0) + cty(0.0j)
|
|
|
|
vabs = math.hypot(a.real, a.imag)
|
|
len = math.pow(vabs, b.real)
|
|
at = math.atan2(a.imag, a.real)
|
|
phase = at * b.real
|
|
if b.imag != fty(0.0):
|
|
len /= math.exp(at * b.imag)
|
|
phase += b.imag * math.log(vabs)
|
|
|
|
return len * (cty(math.cos(phase)) +
|
|
cty(math.sin(phase) * cty(1.0j)))
|
|
|
|
return context.compile_internal(builder, cpow_internal, sig, args)
|
|
|
|
lower(operator.pow, cty, cty)(core)
|
|
lower(operator.ipow, cty, cty)(core)
|
|
lower(pow, cty, cty)(core)
|
|
|
|
|
|
cpow_implement(types.float32, types.complex64)
|
|
cpow_implement(types.float64, types.complex128)
|