ai-content-maker/.venv/Lib/site-packages/numba/tests/test_nested_calls.py

154 lines
3.7 KiB
Python
Raw Normal View History

2024-05-03 04:18:51 +03:00
"""
Test problems in nested calls.
Usually due to invalid type conversion between function boundaries.
"""
from numba import int32, int64
from numba import jit
from numba.core import types
from numba.extending import overload
from numba.tests.support import TestCase, tag
import unittest
@jit(nopython=True)
def f_inner(a, b, c):
return a, b, c
def f(x, y, z):
return f_inner(x, c=y, b=z)
@jit(nopython=True)
def g_inner(a, b=2, c=3):
return a, b, c
def g(x, y, z):
return g_inner(x, b=y), g_inner(a=z, c=x)
@jit(nopython=True)
def star_inner(a=5, *b):
return a, b
def star(x, y, z):
return star_inner(a=x), star_inner(x, y, z)
def star_call(x, y, z):
return star_inner(x, *y), star_inner(*z)
@jit(nopython=True)
def argcast_inner(a, b):
if b:
# Here `a` is unified to int64 (from int32 originally)
a = int64(0)
return a
def argcast(a, b):
return argcast_inner(int32(a), b)
def generated_inner(x, y=5, z=6):
assert 0, "unreachable"
@overload(generated_inner)
def ol_generated_inner(x, y=5, z=6):
if isinstance(x, types.Complex):
def impl(x, y=5, z=6):
return x + y, z
else:
def impl(x, y=5, z=6):
return x - y, z
return impl
def call_generated(a, b):
return generated_inner(a, z=b)
class TestNestedCall(TestCase):
def compile_func(self, pyfunc, objmode=False):
def check(*args, **kwargs):
expected = pyfunc(*args, **kwargs)
result = f(*args, **kwargs)
self.assertPreciseEqual(result, expected)
flags = dict(forceobj=True) if objmode else dict(nopython=True)
f = jit(**flags)(pyfunc)
return f, check
def test_boolean_return(self):
@jit(nopython=True)
def inner(x):
return not x
@jit(nopython=True)
def outer(x):
if inner(x):
return True
else:
return False
self.assertFalse(outer(True))
self.assertTrue(outer(False))
def test_named_args(self, objmode=False):
"""
Test a nested function call with named (keyword) arguments.
"""
cfunc, check = self.compile_func(f, objmode)
check(1, 2, 3)
check(1, y=2, z=3)
def test_named_args_objmode(self):
self.test_named_args(objmode=True)
def test_default_args(self, objmode=False):
"""
Test a nested function call using default argument values.
"""
cfunc, check = self.compile_func(g, objmode)
check(1, 2, 3)
check(1, y=2, z=3)
def test_default_args_objmode(self):
self.test_default_args(objmode=True)
def test_star_args(self):
"""
Test a nested function call to a function with *args in its signature.
"""
cfunc, check = self.compile_func(star)
check(1, 2, 3)
def test_star_call(self, objmode=False):
"""
Test a function call with a *args.
"""
cfunc, check = self.compile_func(star_call, objmode)
check(1, (2,), (3,))
def test_star_call_objmode(self):
self.test_star_call(objmode=True)
def test_argcast(self):
"""
Issue #1488: implicitly casting an argument variable should not
break nested calls.
"""
cfunc, check = self.compile_func(argcast)
check(1, 0)
check(1, 1)
def test_call_generated(self):
"""
Test a nested function call to a generated jit function.
"""
cfunc = jit(nopython=True)(call_generated)
self.assertPreciseEqual(cfunc(1, 2), (-4, 2))
self.assertPreciseEqual(cfunc(1j, 2), (1j + 5, 2))
if __name__ == '__main__':
unittest.main()