import collections import itertools import numpy as np from numba import njit, jit, typeof, literally from numba.core import types, errors, utils from numba.tests.support import TestCase, MemoryLeakMixin, tag import unittest Rect = collections.namedtuple('Rect', ('width', 'height')) Point = collections.namedtuple('Point', ('x', 'y', 'z')) Point2 = collections.namedtuple('Point2', ('x', 'y', 'z')) Empty = collections.namedtuple('Empty', ()) def tuple_return_usecase(a, b): return a, b def tuple_first(tup): a, b = tup return a def tuple_second(tup): a, b = tup return b def tuple_index(tup, idx): return tup[idx] def tuple_index_static(tup): # Note the negative index return tup[-2] def tuple_slice2(tup): return tup[1:-1] def tuple_slice3(tup): return tup[1::2] def len_usecase(tup): return len(tup) def add_usecase(a, b): return a + b def eq_usecase(a, b): return a == b def ne_usecase(a, b): return a != b def gt_usecase(a, b): return a > b def ge_usecase(a, b): return a >= b def lt_usecase(a, b): return a < b def le_usecase(a, b): return a <= b def in_usecase(a, b): return a in b def bool_usecase(tup): return bool(tup), (3 if tup else 2) def getattr_usecase(tup): return tup.z, tup.y, tup.x def make_point(a, b, c): return Point(a, b, c) def make_point_kws(a, b, c): return Point(z=c, y=b, x=a) def make_point_nrt(n): r = Rect(list(range(n)), np.zeros(n + 1)) # This also exercises attribute access p = Point(r, len(r.width), len(r.height)) return p def type_usecase(tup, *args): return type(tup)(*args) def identity(tup): return tup def index_method_usecase(tup, value): return tup.index(value) def tuple_unpack_static_getitem_err(): # see issue3895, `c` is imprecise a, b, c, d = [], [], [], 0.0 a.append(1) b.append(1) return class TestTupleLengthError(unittest.TestCase): def test_tuple_length_error(self): # issue 2195 # raise an error on tuples greater than 1000 in length @njit def eattuple(tup): return len(tup) with self.assertRaises(errors.UnsupportedError) as raises: tup = tuple(range(1001)) eattuple(tup) expected = "Tuple 'tup' length must be smaller than 1000" self.assertIn(expected, str(raises.exception)) class TestTupleTypeNotIterable(unittest.TestCase): ''' issue 4369 raise an error if 'type' is not iterable ''' def test_namedtuple_types_exception(self): with self.assertRaises(errors.TypingError) as raises: types.NamedTuple(types.uint32, 'p') self.assertIn( "Argument 'types' is not iterable", str(raises.exception) ) def test_tuple_types_exception(self): with self.assertRaises(errors.TypingError) as raises: types.Tuple((types.uint32)) self.assertIn( "Argument 'types' is not iterable", str(raises.exception) ) class TestTupleReturn(TestCase): def test_array_tuple(self): aryty = types.Array(types.float64, 1, 'C') cfunc = njit((aryty, aryty))(tuple_return_usecase) a = b = np.arange(5, dtype='float64') ra, rb = cfunc(a, b) self.assertPreciseEqual(ra, a) self.assertPreciseEqual(rb, b) del a, b self.assertPreciseEqual(ra, rb) def test_scalar_tuple(self): scalarty = types.float32 cfunc = njit((scalarty, scalarty))(tuple_return_usecase) a = b = 1 ra, rb = cfunc(a, b) self.assertEqual(ra, a) self.assertEqual(rb, b) def test_hetero_tuple(self): alltypes = [] allvalues = [] alltypes.append((types.int32, types.int64)) allvalues.append((1, 2)) alltypes.append((types.float32, types.float64)) allvalues.append((1.125, .25)) alltypes.append((types.int32, types.float64)) allvalues.append((1231, .5)) for (ta, tb), (a, b) in zip(alltypes, allvalues): cfunc = njit((ta, tb))(tuple_return_usecase) ra, rb = cfunc(a, b) self.assertPreciseEqual((ra, rb), (a, b)) class TestTuplePassing(TestCase): def test_unituple(self): tuple_type = types.UniTuple(types.int32, 2) cf_first = njit((tuple_type,))(tuple_first) cf_second = njit((tuple_type,))(tuple_second) self.assertPreciseEqual(cf_first((4, 5)), 4) self.assertPreciseEqual(cf_second((4, 5)), 5) def test_hetero_tuple(self): tuple_type = types.Tuple((types.int64, types.float32)) cf_first = njit((tuple_type,))(tuple_first) cf_second = njit((tuple_type,))(tuple_second) self.assertPreciseEqual(cf_first((2**61, 1.5)), 2**61) self.assertPreciseEqual(cf_second((2**61, 1.5)), 1.5) def test_size_mismatch(self): # Issue #1638: tuple size should be checked when unboxing tuple_type = types.UniTuple(types.int32, 2) cfunc = njit((tuple_type,))(tuple_first) entry_point = cfunc.overloads[cfunc.signatures[0]].entry_point with self.assertRaises(ValueError) as raises: entry_point((4, 5, 6)) self.assertEqual(str(raises.exception), ("size mismatch for tuple, " "expected 2 element(s) but got 3")) class TestOperations(TestCase): def test_len(self): pyfunc = len_usecase cfunc = njit((types.Tuple((types.int64, types.float32)),))(pyfunc) self.assertPreciseEqual(cfunc((4, 5)), 2) cfunc = njit((types.UniTuple(types.int64, 3),))(pyfunc) self.assertPreciseEqual(cfunc((4, 5, 6)), 3) def test_index_literal(self): # issue #6023, test non-static getitem with IntegerLiteral index def pyfunc(tup, idx): idx = literally(idx) return tup[idx] cfunc = njit(pyfunc) tup = (4, 3.1, 'sss') for i in range(len(tup)): self.assertPreciseEqual(cfunc(tup, i), tup[i]) def test_index(self): pyfunc = tuple_index cfunc = njit((types.UniTuple(types.int64, 3), types.int64),)(pyfunc) tup = (4, 3, 6) for i in range(len(tup)): self.assertPreciseEqual(cfunc(tup, i), tup[i]) # test negative indexing for i in range(len(tup) + 1): self.assertPreciseEqual(cfunc(tup, -i), tup[-i]) # oob indexes, +ve then -ve with self.assertRaises(IndexError) as raises: cfunc(tup, len(tup)) self.assertEqual("tuple index out of range", str(raises.exception)) with self.assertRaises(IndexError) as raises: cfunc(tup, -(len(tup) + 1)) self.assertEqual("tuple index out of range", str(raises.exception)) # Test empty tuple, this is a bit unusual as `njit` will infer the empty # tuple arg as a types.Tuple and not match the compiled signature, this # is essentially because the test originally relied on # `compile_isolated`. args = (types.UniTuple(types.int64, 0), types.int64,) cr = njit(args)(pyfunc).overloads[args] with self.assertRaises(IndexError) as raises: cr.entry_point((), 0) self.assertEqual("tuple index out of range", str(raises.exception)) # test uintp indexing (because, e.g., parfor generates unsigned prange) cfunc = njit((types.UniTuple(types.int64, 3), types.uintp,),)(pyfunc) for i in range(len(tup)): self.assertPreciseEqual(cfunc(tup, types.uintp(i)), tup[i]) # With a compile-time static index (the code generation path is # different) pyfunc = tuple_index_static for typ in (types.UniTuple(types.int64, 4), types.Tuple((types.int64, types.int32, types.int64, types.int32))): cfunc = njit((typ,))(pyfunc) tup = (4, 3, 42, 6) self.assertPreciseEqual(cfunc(tup), pyfunc(tup)) typ = types.UniTuple(types.int64, 1) with self.assertTypingError(): njit((typ,))(pyfunc) # test unpack, staticgetitem with imprecise type (issue #3895) pyfunc = tuple_unpack_static_getitem_err with self.assertTypingError() as raises: njit((),)(pyfunc) msg = ("Cannot infer the type of variable 'c', have imprecise type: " "list(undefined).") self.assertIn(msg, str(raises.exception)) def test_in(self): pyfunc = in_usecase cfunc = njit((types.int64, types.UniTuple(types.int64, 3),),)(pyfunc) tup = (4, 1, 5) for i in range(5): self.assertPreciseEqual(cfunc(i, tup), pyfunc(i, tup)) # Test the empty case cfunc = njit((types.int64, types.Tuple([]),),)(pyfunc) self.assertPreciseEqual(cfunc(1, ()), pyfunc(1, ())) def check_slice(self, pyfunc): tup = (4, 5, 6, 7) cfunc = njit((types.UniTuple(types.int64, 4),),)(pyfunc) self.assertPreciseEqual(cfunc(tup), pyfunc(tup)) args = types.Tuple((types.int64, types.int32, types.int64, types.int32)) cfunc = njit((args,))(pyfunc) self.assertPreciseEqual(cfunc(tup), pyfunc(tup)) def test_slice2(self): self.check_slice(tuple_slice2) def test_slice3(self): self.check_slice(tuple_slice3) def test_bool(self): pyfunc = bool_usecase cfunc = njit((types.Tuple((types.int64, types.int32)),),)(pyfunc) args = ((4, 5),) self.assertPreciseEqual(cfunc(*args), pyfunc(*args)) cfunc = njit((types.UniTuple(types.int64, 3),),)(pyfunc) args = ((4, 5, 6),) self.assertPreciseEqual(cfunc(*args), pyfunc(*args)) cfunc = njit((types.Tuple(()),),)(pyfunc) self.assertPreciseEqual(cfunc(()), pyfunc(())) def test_add(self): pyfunc = add_usecase samples = [(types.Tuple(()), ()), (types.UniTuple(types.int32, 0), ()), (types.UniTuple(types.int32, 1), (42,)), (types.Tuple((types.int64, types.float32)), (3, 4.5)), ] for (ta, a), (tb, b) in itertools.product(samples, samples): cfunc = njit((ta, tb),)(pyfunc) expected = pyfunc(a, b) got = cfunc(a, b) self.assertPreciseEqual(got, expected, msg=(ta, tb)) def _test_compare(self, pyfunc): def eq(pyfunc, cfunc, args): self.assertIs(cfunc(*args), pyfunc(*args), "mismatch for arguments %s" % (args,)) # Same-sized tuples argtypes = [types.Tuple((types.int64, types.float32)), types.UniTuple(types.int32, 2)] for ta, tb in itertools.product(argtypes, argtypes): cfunc = njit((ta, tb),)(pyfunc) for args in [((4, 5), (4, 5)), ((4, 5), (4, 6)), ((4, 6), (4, 5)), ((4, 5), (5, 4))]: eq(pyfunc, cfunc, args) # Different-sized tuples argtypes = [types.Tuple((types.int64, types.float32)), types.UniTuple(types.int32, 3)] cfunc = njit(tuple(argtypes),)(pyfunc) for args in [((4, 5), (4, 5, 6)), ((4, 5), (4, 4, 6)), ((4, 5), (4, 6, 7))]: eq(pyfunc, cfunc, args) def test_eq(self): self._test_compare(eq_usecase) def test_ne(self): self._test_compare(ne_usecase) def test_gt(self): self._test_compare(gt_usecase) def test_ge(self): self._test_compare(ge_usecase) def test_lt(self): self._test_compare(lt_usecase) def test_le(self): self._test_compare(le_usecase) class TestNamedTuple(TestCase, MemoryLeakMixin): def test_unpack(self): def check(p): for pyfunc in tuple_first, tuple_second: cfunc = jit(nopython=True)(pyfunc) self.assertPreciseEqual(cfunc(p), pyfunc(p)) # Homogeneous check(Rect(4, 5)) # Heterogeneous check(Rect(4, 5.5)) def test_len(self): def check(p): pyfunc = len_usecase cfunc = jit(nopython=True)(pyfunc) self.assertPreciseEqual(cfunc(p), pyfunc(p)) # Homogeneous check(Rect(4, 5)) check(Point(4, 5, 6)) # Heterogeneous check(Rect(4, 5.5)) check(Point(4, 5.5, 6j)) def test_index(self): pyfunc = tuple_index cfunc = jit(nopython=True)(pyfunc) p = Point(4, 5, 6) for i in range(len(p)): self.assertPreciseEqual(cfunc(p, i), pyfunc(p, i)) # test uintp indexing (because, e.g., parfor generates unsigned prange) for i in range(len(p)): self.assertPreciseEqual(cfunc(p, types.uintp(i)), pyfunc(p, i)) def test_bool(self): def check(p): pyfunc = bool_usecase cfunc = jit(nopython=True)(pyfunc) self.assertPreciseEqual(cfunc(p), pyfunc(p)) # Homogeneous check(Rect(4, 5)) # Heterogeneous check(Rect(4, 5.5)) check(Empty()) def _test_compare(self, pyfunc): def eq(pyfunc, cfunc, args): self.assertIs(cfunc(*args), pyfunc(*args), "mismatch for arguments %s" % (args,)) cfunc = jit(nopython=True)(pyfunc) # Same-sized named tuples for a, b in [((4, 5), (4, 5)), ((4, 5), (4, 6)), ((4, 6), (4, 5)), ((4, 5), (5, 4))]: eq(pyfunc, cfunc, (Rect(*a), Rect(*b))) # Different-sized named tuples for a, b in [((4, 5), (4, 5, 6)), ((4, 5), (4, 4, 6)), ((4, 5), (4, 6, 7))]: eq(pyfunc, cfunc, (Rect(*a), Point(*b))) def test_eq(self): self._test_compare(eq_usecase) def test_ne(self): self._test_compare(ne_usecase) def test_gt(self): self._test_compare(gt_usecase) def test_ge(self): self._test_compare(ge_usecase) def test_lt(self): self._test_compare(lt_usecase) def test_le(self): self._test_compare(le_usecase) def test_getattr(self): pyfunc = getattr_usecase cfunc = jit(nopython=True)(pyfunc) for args in (4, 5, 6), (4, 5.5, 6j): p = Point(*args) self.assertPreciseEqual(cfunc(p), pyfunc(p)) def test_construct(self): def check(pyfunc): cfunc = jit(nopython=True)(pyfunc) for args in (4, 5, 6), (4, 5.5, 6j): expected = pyfunc(*args) got = cfunc(*args) self.assertIs(type(got), type(expected)) self.assertPreciseEqual(got, expected) check(make_point) check(make_point_kws) def test_type(self): # Test the type() built-in on named tuples pyfunc = type_usecase cfunc = jit(nopython=True)(pyfunc) arg_tuples = [(4, 5, 6), (4, 5.5, 6j)] for tup_args, args in itertools.product(arg_tuples, arg_tuples): tup = Point(*tup_args) expected = pyfunc(tup, *args) got = cfunc(tup, *args) self.assertIs(type(got), type(expected)) self.assertPreciseEqual(got, expected) def test_literal_unification(self): # Test for #3565. @jit(nopython=True) def Data1(value): return Rect(value, -321) @jit(nopython=True) def call(i, j): if j == 0: # In the error, `result` is typed to `Rect(int, LiteralInt)` # because of the `-321` literal. This doesn't match the # `result` type in the other branch. result = Data1(i) else: # `result` is typed to be `Rect(int, int)` result = Rect(i, j) return result r = call(123, 1321) self.assertEqual(r, Rect(width=123, height=1321)) r = call(123, 0) self.assertEqual(r, Rect(width=123, height=-321)) def test_string_literal_in_ctor(self): # Test for issue #3813 @jit(nopython=True) def foo(): return Rect(10, 'somestring') r = foo() self.assertEqual(r, Rect(width=10, height='somestring')) def test_dispatcher_mistreat(self): # Test for issue #5215 that mistreat namedtuple as tuples @jit(nopython=True) def foo(x): return x in1 = (1, 2, 3) out1 = foo(in1) self.assertEqual(in1, out1) in2 = Point(1, 2, 3) out2 = foo(in2) self.assertEqual(in2, out2) # Check the signatures self.assertEqual(len(foo.nopython_signatures), 2) self.assertEqual(foo.nopython_signatures[0].args[0], typeof(in1)) self.assertEqual(foo.nopython_signatures[1].args[0], typeof(in2)) # Differently named in3 = Point2(1, 2, 3) out3 = foo(in3) self.assertEqual(in3, out3) self.assertEqual(len(foo.nopython_signatures), 3) self.assertEqual(foo.nopython_signatures[2].args[0], typeof(in3)) class TestTupleNRT(TestCase, MemoryLeakMixin): def test_tuple_add(self): def pyfunc(x): a = np.arange(3) return (a,) + (x,) cfunc = jit(nopython=True)(pyfunc) x = 123 expect_a, expect_x = pyfunc(x) got_a, got_x = cfunc(x) np.testing.assert_equal(got_a, expect_a) self.assertEqual(got_x, expect_x) class TestNamedTupleNRT(TestCase, MemoryLeakMixin): def test_return(self): # Check returning a namedtuple with a list inside it pyfunc = make_point_nrt cfunc = jit(nopython=True)(pyfunc) for arg in (3, 0): expected = pyfunc(arg) got = cfunc(arg) self.assertIs(type(got), type(expected)) self.assertPreciseEqual(got, expected) class TestConversions(TestCase): """ Test implicit conversions between tuple types. """ def check_conversion(self, fromty, toty, val): pyfunc = identity cfunc = njit(toty(fromty))(pyfunc) res = cfunc(val) self.assertEqual(res, val) def test_conversions(self): check = self.check_conversion fromty = types.UniTuple(types.int32, 2) check(fromty, types.UniTuple(types.float32, 2), (4, 5)) check(fromty, types.Tuple((types.float32, types.int16)), (4, 5)) aty = types.UniTuple(types.int32, 0) bty = types.Tuple(()) check(aty, bty, ()) check(bty, aty, ()) with self.assertRaises(errors.TypingError) as raises: check(fromty, types.Tuple((types.float32,)), (4, 5)) msg = "No conversion from UniTuple(int32 x 2) to UniTuple(float32 x 1)" self.assertIn(msg, str(raises.exception)) class TestMethods(TestCase): def test_index(self): pyfunc = index_method_usecase cfunc = jit(nopython=True)(pyfunc) self.assertEqual(cfunc((1, 2, 3), 2), 1) with self.assertRaises(ValueError) as raises: cfunc((1, 2, 3), 4) msg = 'tuple.index(x): x not in tuple' self.assertEqual(msg, str(raises.exception)) class TestTupleBuild(TestCase): def test_build_unpack(self): def check(p): pyfunc = lambda a: (1, *a) cfunc = jit(nopython=True)(pyfunc) self.assertPreciseEqual(cfunc(p), pyfunc(p)) # Homogeneous check((4, 5)) # Heterogeneous check((4, 5.5)) def test_build_unpack_assign_like(self): # see #6534 def check(p): pyfunc = lambda a: (*a,) cfunc = jit(nopython=True)(pyfunc) self.assertPreciseEqual(cfunc(p), pyfunc(p)) # Homogeneous check((4, 5)) # Heterogeneous check((4, 5.5)) def test_build_unpack_fail_on_list_assign_like(self): # see #6534 def check(p): pyfunc = lambda a: (*a,) cfunc = jit(nopython=True)(pyfunc) self.assertPreciseEqual(cfunc(p), pyfunc(p)) with self.assertRaises(errors.TypingError) as raises: check([4, 5]) # Python 3.9 has a peephole rewrite due to large changes in tuple # unpacking. It results in a tuple + list situation from the above # so the error message reflects that. Catching this specific and # seemingly rare sequence in the peephole rewrite is prohibitively # hard. Should it be reported numerous times, revisit then. msg1 = "No implementation of function" self.assertIn(msg1, str(raises.exception)) msg2 = "tuple(reflected list(" # ignore the rest of reflected list # part, it's repr is quite volatile. self.assertIn(msg2, str(raises.exception)) def test_build_unpack_more(self): def check(p): pyfunc = lambda a: (1, *a, (1, 2), *a) cfunc = jit(nopython=True)(pyfunc) self.assertPreciseEqual(cfunc(p), pyfunc(p)) # Homogeneous check((4, 5)) # Heterogeneous check((4, 5.5)) def test_build_unpack_call(self): def check(p): @jit def inner(*args): return args pyfunc = lambda a: inner(1, *a) cfunc = jit(nopython=True)(pyfunc) self.assertPreciseEqual(cfunc(p), pyfunc(p)) # Homogeneous check((4, 5)) # Heterogeneous check((4, 5.5)) def test_build_unpack_call_more(self): def check(p): @jit def inner(*args): return args pyfunc = lambda a: inner(1, *a, *(1, 2), *a) cfunc = jit(nopython=True)(pyfunc) self.assertPreciseEqual(cfunc(p), pyfunc(p)) # Homogeneous check((4, 5)) # Heterogeneous check((4, 5.5)) def test_tuple_constructor(self): def check(pyfunc, arg): cfunc = jit(nopython=True)(pyfunc) self.assertPreciseEqual(cfunc(arg), pyfunc(arg)) # empty check(lambda _: tuple(), ()) # Homogeneous check(lambda a: tuple(a), (4, 5)) # Heterogeneous check(lambda a: tuple(a), (4, 5.5)) @unittest.skipIf(utils.PYVERSION < (3, 9), "needs Python 3.9+") def test_unpack_with_predicate_fails(self): # this fails as the list_to_tuple/list_extend peephole bytecode # rewriting needed for Python 3.9+ cannot yet traverse the CFG. @njit def foo(): a = (1,) b = (3,2, 4) return (*(b if a[0] else (5, 6)),) with self.assertRaises(errors.UnsupportedError) as raises: foo() msg = "op_LIST_EXTEND at the start of a block" self.assertIn(msg, str(raises.exception)) def test_build_unpack_with_calls_in_unpack(self): def check(p): def pyfunc(a): z = [1, 2] return (*a, z.append(3), z.extend(a), np.ones(3)), z cfunc = jit(nopython=True)(pyfunc) self.assertPreciseEqual(cfunc(p), pyfunc(p)) check((4, 5)) def test_build_unpack_complicated(self): def check(p): def pyfunc(a): z = [1, 2] return (*a, *(*a, a), *(a, (*(a, (1, 2), *(3,), *a), (a, 1, (2, 3), *a, 1), (1,))), *(z.append(4), z.extend(a))), z cfunc = jit(nopython=True)(pyfunc) self.assertPreciseEqual(cfunc(p), pyfunc(p)) check((10, 20)) if __name__ == '__main__': unittest.main()