154 lines
3.7 KiB
Python
154 lines
3.7 KiB
Python
|
"""
|
||
|
Test problems in nested calls.
|
||
|
Usually due to invalid type conversion between function boundaries.
|
||
|
"""
|
||
|
|
||
|
|
||
|
from numba import int32, int64
|
||
|
from numba import jit
|
||
|
from numba.core import types
|
||
|
from numba.extending import overload
|
||
|
from numba.tests.support import TestCase, tag
|
||
|
import unittest
|
||
|
|
||
|
|
||
|
@jit(nopython=True)
|
||
|
def f_inner(a, b, c):
|
||
|
return a, b, c
|
||
|
|
||
|
def f(x, y, z):
|
||
|
return f_inner(x, c=y, b=z)
|
||
|
|
||
|
@jit(nopython=True)
|
||
|
def g_inner(a, b=2, c=3):
|
||
|
return a, b, c
|
||
|
|
||
|
def g(x, y, z):
|
||
|
return g_inner(x, b=y), g_inner(a=z, c=x)
|
||
|
|
||
|
@jit(nopython=True)
|
||
|
def star_inner(a=5, *b):
|
||
|
return a, b
|
||
|
|
||
|
def star(x, y, z):
|
||
|
return star_inner(a=x), star_inner(x, y, z)
|
||
|
|
||
|
def star_call(x, y, z):
|
||
|
return star_inner(x, *y), star_inner(*z)
|
||
|
|
||
|
@jit(nopython=True)
|
||
|
def argcast_inner(a, b):
|
||
|
if b:
|
||
|
# Here `a` is unified to int64 (from int32 originally)
|
||
|
a = int64(0)
|
||
|
return a
|
||
|
|
||
|
def argcast(a, b):
|
||
|
return argcast_inner(int32(a), b)
|
||
|
|
||
|
|
||
|
def generated_inner(x, y=5, z=6):
|
||
|
assert 0, "unreachable"
|
||
|
|
||
|
|
||
|
@overload(generated_inner)
|
||
|
def ol_generated_inner(x, y=5, z=6):
|
||
|
if isinstance(x, types.Complex):
|
||
|
def impl(x, y=5, z=6):
|
||
|
return x + y, z
|
||
|
else:
|
||
|
def impl(x, y=5, z=6):
|
||
|
return x - y, z
|
||
|
return impl
|
||
|
|
||
|
|
||
|
def call_generated(a, b):
|
||
|
return generated_inner(a, z=b)
|
||
|
|
||
|
|
||
|
class TestNestedCall(TestCase):
|
||
|
|
||
|
def compile_func(self, pyfunc, objmode=False):
|
||
|
def check(*args, **kwargs):
|
||
|
expected = pyfunc(*args, **kwargs)
|
||
|
result = f(*args, **kwargs)
|
||
|
self.assertPreciseEqual(result, expected)
|
||
|
flags = dict(forceobj=True) if objmode else dict(nopython=True)
|
||
|
f = jit(**flags)(pyfunc)
|
||
|
return f, check
|
||
|
|
||
|
def test_boolean_return(self):
|
||
|
@jit(nopython=True)
|
||
|
def inner(x):
|
||
|
return not x
|
||
|
|
||
|
@jit(nopython=True)
|
||
|
def outer(x):
|
||
|
if inner(x):
|
||
|
return True
|
||
|
else:
|
||
|
return False
|
||
|
|
||
|
self.assertFalse(outer(True))
|
||
|
self.assertTrue(outer(False))
|
||
|
|
||
|
def test_named_args(self, objmode=False):
|
||
|
"""
|
||
|
Test a nested function call with named (keyword) arguments.
|
||
|
"""
|
||
|
cfunc, check = self.compile_func(f, objmode)
|
||
|
check(1, 2, 3)
|
||
|
check(1, y=2, z=3)
|
||
|
|
||
|
def test_named_args_objmode(self):
|
||
|
self.test_named_args(objmode=True)
|
||
|
|
||
|
def test_default_args(self, objmode=False):
|
||
|
"""
|
||
|
Test a nested function call using default argument values.
|
||
|
"""
|
||
|
cfunc, check = self.compile_func(g, objmode)
|
||
|
check(1, 2, 3)
|
||
|
check(1, y=2, z=3)
|
||
|
|
||
|
def test_default_args_objmode(self):
|
||
|
self.test_default_args(objmode=True)
|
||
|
|
||
|
def test_star_args(self):
|
||
|
"""
|
||
|
Test a nested function call to a function with *args in its signature.
|
||
|
"""
|
||
|
cfunc, check = self.compile_func(star)
|
||
|
check(1, 2, 3)
|
||
|
|
||
|
def test_star_call(self, objmode=False):
|
||
|
"""
|
||
|
Test a function call with a *args.
|
||
|
"""
|
||
|
cfunc, check = self.compile_func(star_call, objmode)
|
||
|
check(1, (2,), (3,))
|
||
|
|
||
|
def test_star_call_objmode(self):
|
||
|
self.test_star_call(objmode=True)
|
||
|
|
||
|
def test_argcast(self):
|
||
|
"""
|
||
|
Issue #1488: implicitly casting an argument variable should not
|
||
|
break nested calls.
|
||
|
"""
|
||
|
cfunc, check = self.compile_func(argcast)
|
||
|
check(1, 0)
|
||
|
check(1, 1)
|
||
|
|
||
|
def test_call_generated(self):
|
||
|
"""
|
||
|
Test a nested function call to a generated jit function.
|
||
|
"""
|
||
|
cfunc = jit(nopython=True)(call_generated)
|
||
|
self.assertPreciseEqual(cfunc(1, 2), (-4, 2))
|
||
|
self.assertPreciseEqual(cfunc(1j, 2), (1j + 5, 2))
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
unittest.main()
|