ai-content-maker/.venv/Lib/site-packages/numba/tests/test_tuples.py

766 lines
23 KiB
Python
Raw Normal View History

2024-05-03 04:18:51 +03:00
import collections
import itertools
import numpy as np
from numba import njit, jit, typeof, literally
from numba.core import types, errors, utils
from numba.tests.support import TestCase, MemoryLeakMixin, tag
import unittest
Rect = collections.namedtuple('Rect', ('width', 'height'))
Point = collections.namedtuple('Point', ('x', 'y', 'z'))
Point2 = collections.namedtuple('Point2', ('x', 'y', 'z'))
Empty = collections.namedtuple('Empty', ())
def tuple_return_usecase(a, b):
return a, b
def tuple_first(tup):
a, b = tup
return a
def tuple_second(tup):
a, b = tup
return b
def tuple_index(tup, idx):
return tup[idx]
def tuple_index_static(tup):
# Note the negative index
return tup[-2]
def tuple_slice2(tup):
return tup[1:-1]
def tuple_slice3(tup):
return tup[1::2]
def len_usecase(tup):
return len(tup)
def add_usecase(a, b):
return a + b
def eq_usecase(a, b):
return a == b
def ne_usecase(a, b):
return a != b
def gt_usecase(a, b):
return a > b
def ge_usecase(a, b):
return a >= b
def lt_usecase(a, b):
return a < b
def le_usecase(a, b):
return a <= b
def in_usecase(a, b):
return a in b
def bool_usecase(tup):
return bool(tup), (3 if tup else 2)
def getattr_usecase(tup):
return tup.z, tup.y, tup.x
def make_point(a, b, c):
return Point(a, b, c)
def make_point_kws(a, b, c):
return Point(z=c, y=b, x=a)
def make_point_nrt(n):
r = Rect(list(range(n)), np.zeros(n + 1))
# This also exercises attribute access
p = Point(r, len(r.width), len(r.height))
return p
def type_usecase(tup, *args):
return type(tup)(*args)
def identity(tup):
return tup
def index_method_usecase(tup, value):
return tup.index(value)
def tuple_unpack_static_getitem_err():
# see issue3895, `c` is imprecise
a, b, c, d = [], [], [], 0.0
a.append(1)
b.append(1)
return
class TestTupleLengthError(unittest.TestCase):
def test_tuple_length_error(self):
# issue 2195
# raise an error on tuples greater than 1000 in length
@njit
def eattuple(tup):
return len(tup)
with self.assertRaises(errors.UnsupportedError) as raises:
tup = tuple(range(1001))
eattuple(tup)
expected = "Tuple 'tup' length must be smaller than 1000"
self.assertIn(expected, str(raises.exception))
class TestTupleTypeNotIterable(unittest.TestCase):
'''
issue 4369
raise an error if 'type' is not iterable
'''
def test_namedtuple_types_exception(self):
with self.assertRaises(errors.TypingError) as raises:
types.NamedTuple(types.uint32, 'p')
self.assertIn(
"Argument 'types' is not iterable",
str(raises.exception)
)
def test_tuple_types_exception(self):
with self.assertRaises(errors.TypingError) as raises:
types.Tuple((types.uint32))
self.assertIn(
"Argument 'types' is not iterable",
str(raises.exception)
)
class TestTupleReturn(TestCase):
def test_array_tuple(self):
aryty = types.Array(types.float64, 1, 'C')
cfunc = njit((aryty, aryty))(tuple_return_usecase)
a = b = np.arange(5, dtype='float64')
ra, rb = cfunc(a, b)
self.assertPreciseEqual(ra, a)
self.assertPreciseEqual(rb, b)
del a, b
self.assertPreciseEqual(ra, rb)
def test_scalar_tuple(self):
scalarty = types.float32
cfunc = njit((scalarty, scalarty))(tuple_return_usecase)
a = b = 1
ra, rb = cfunc(a, b)
self.assertEqual(ra, a)
self.assertEqual(rb, b)
def test_hetero_tuple(self):
alltypes = []
allvalues = []
alltypes.append((types.int32, types.int64))
allvalues.append((1, 2))
alltypes.append((types.float32, types.float64))
allvalues.append((1.125, .25))
alltypes.append((types.int32, types.float64))
allvalues.append((1231, .5))
for (ta, tb), (a, b) in zip(alltypes, allvalues):
cfunc = njit((ta, tb))(tuple_return_usecase)
ra, rb = cfunc(a, b)
self.assertPreciseEqual((ra, rb), (a, b))
class TestTuplePassing(TestCase):
def test_unituple(self):
tuple_type = types.UniTuple(types.int32, 2)
cf_first = njit((tuple_type,))(tuple_first)
cf_second = njit((tuple_type,))(tuple_second)
self.assertPreciseEqual(cf_first((4, 5)), 4)
self.assertPreciseEqual(cf_second((4, 5)), 5)
def test_hetero_tuple(self):
tuple_type = types.Tuple((types.int64, types.float32))
cf_first = njit((tuple_type,))(tuple_first)
cf_second = njit((tuple_type,))(tuple_second)
self.assertPreciseEqual(cf_first((2**61, 1.5)), 2**61)
self.assertPreciseEqual(cf_second((2**61, 1.5)), 1.5)
def test_size_mismatch(self):
# Issue #1638: tuple size should be checked when unboxing
tuple_type = types.UniTuple(types.int32, 2)
cfunc = njit((tuple_type,))(tuple_first)
entry_point = cfunc.overloads[cfunc.signatures[0]].entry_point
with self.assertRaises(ValueError) as raises:
entry_point((4, 5, 6))
self.assertEqual(str(raises.exception),
("size mismatch for tuple, "
"expected 2 element(s) but got 3"))
class TestOperations(TestCase):
def test_len(self):
pyfunc = len_usecase
cfunc = njit((types.Tuple((types.int64, types.float32)),))(pyfunc)
self.assertPreciseEqual(cfunc((4, 5)), 2)
cfunc = njit((types.UniTuple(types.int64, 3),))(pyfunc)
self.assertPreciseEqual(cfunc((4, 5, 6)), 3)
def test_index_literal(self):
# issue #6023, test non-static getitem with IntegerLiteral index
def pyfunc(tup, idx):
idx = literally(idx)
return tup[idx]
cfunc = njit(pyfunc)
tup = (4, 3.1, 'sss')
for i in range(len(tup)):
self.assertPreciseEqual(cfunc(tup, i), tup[i])
def test_index(self):
pyfunc = tuple_index
cfunc = njit((types.UniTuple(types.int64, 3), types.int64),)(pyfunc)
tup = (4, 3, 6)
for i in range(len(tup)):
self.assertPreciseEqual(cfunc(tup, i), tup[i])
# test negative indexing
for i in range(len(tup) + 1):
self.assertPreciseEqual(cfunc(tup, -i), tup[-i])
# oob indexes, +ve then -ve
with self.assertRaises(IndexError) as raises:
cfunc(tup, len(tup))
self.assertEqual("tuple index out of range", str(raises.exception))
with self.assertRaises(IndexError) as raises:
cfunc(tup, -(len(tup) + 1))
self.assertEqual("tuple index out of range", str(raises.exception))
# Test empty tuple, this is a bit unusual as `njit` will infer the empty
# tuple arg as a types.Tuple and not match the compiled signature, this
# is essentially because the test originally relied on
# `compile_isolated`.
args = (types.UniTuple(types.int64, 0), types.int64,)
cr = njit(args)(pyfunc).overloads[args]
with self.assertRaises(IndexError) as raises:
cr.entry_point((), 0)
self.assertEqual("tuple index out of range", str(raises.exception))
# test uintp indexing (because, e.g., parfor generates unsigned prange)
cfunc = njit((types.UniTuple(types.int64, 3), types.uintp,),)(pyfunc)
for i in range(len(tup)):
self.assertPreciseEqual(cfunc(tup, types.uintp(i)), tup[i])
# With a compile-time static index (the code generation path is
# different)
pyfunc = tuple_index_static
for typ in (types.UniTuple(types.int64, 4),
types.Tuple((types.int64, types.int32, types.int64, types.int32))):
cfunc = njit((typ,))(pyfunc)
tup = (4, 3, 42, 6)
self.assertPreciseEqual(cfunc(tup), pyfunc(tup))
typ = types.UniTuple(types.int64, 1)
with self.assertTypingError():
njit((typ,))(pyfunc)
# test unpack, staticgetitem with imprecise type (issue #3895)
pyfunc = tuple_unpack_static_getitem_err
with self.assertTypingError() as raises:
njit((),)(pyfunc)
msg = ("Cannot infer the type of variable 'c', have imprecise type: "
"list(undefined)<iv=None>.")
self.assertIn(msg, str(raises.exception))
def test_in(self):
pyfunc = in_usecase
cfunc = njit((types.int64, types.UniTuple(types.int64, 3),),)(pyfunc)
tup = (4, 1, 5)
for i in range(5):
self.assertPreciseEqual(cfunc(i, tup), pyfunc(i, tup))
# Test the empty case
cfunc = njit((types.int64, types.Tuple([]),),)(pyfunc)
self.assertPreciseEqual(cfunc(1, ()), pyfunc(1, ()))
def check_slice(self, pyfunc):
tup = (4, 5, 6, 7)
cfunc = njit((types.UniTuple(types.int64, 4),),)(pyfunc)
self.assertPreciseEqual(cfunc(tup), pyfunc(tup))
args = types.Tuple((types.int64, types.int32, types.int64, types.int32))
cfunc = njit((args,))(pyfunc)
self.assertPreciseEqual(cfunc(tup), pyfunc(tup))
def test_slice2(self):
self.check_slice(tuple_slice2)
def test_slice3(self):
self.check_slice(tuple_slice3)
def test_bool(self):
pyfunc = bool_usecase
cfunc = njit((types.Tuple((types.int64, types.int32)),),)(pyfunc)
args = ((4, 5),)
self.assertPreciseEqual(cfunc(*args), pyfunc(*args))
cfunc = njit((types.UniTuple(types.int64, 3),),)(pyfunc)
args = ((4, 5, 6),)
self.assertPreciseEqual(cfunc(*args), pyfunc(*args))
cfunc = njit((types.Tuple(()),),)(pyfunc)
self.assertPreciseEqual(cfunc(()), pyfunc(()))
def test_add(self):
pyfunc = add_usecase
samples = [(types.Tuple(()), ()),
(types.UniTuple(types.int32, 0), ()),
(types.UniTuple(types.int32, 1), (42,)),
(types.Tuple((types.int64, types.float32)), (3, 4.5)),
]
for (ta, a), (tb, b) in itertools.product(samples, samples):
cfunc = njit((ta, tb),)(pyfunc)
expected = pyfunc(a, b)
got = cfunc(a, b)
self.assertPreciseEqual(got, expected, msg=(ta, tb))
def _test_compare(self, pyfunc):
def eq(pyfunc, cfunc, args):
self.assertIs(cfunc(*args), pyfunc(*args),
"mismatch for arguments %s" % (args,))
# Same-sized tuples
argtypes = [types.Tuple((types.int64, types.float32)),
types.UniTuple(types.int32, 2)]
for ta, tb in itertools.product(argtypes, argtypes):
cfunc = njit((ta, tb),)(pyfunc)
for args in [((4, 5), (4, 5)),
((4, 5), (4, 6)),
((4, 6), (4, 5)),
((4, 5), (5, 4))]:
eq(pyfunc, cfunc, args)
# Different-sized tuples
argtypes = [types.Tuple((types.int64, types.float32)),
types.UniTuple(types.int32, 3)]
cfunc = njit(tuple(argtypes),)(pyfunc)
for args in [((4, 5), (4, 5, 6)),
((4, 5), (4, 4, 6)),
((4, 5), (4, 6, 7))]:
eq(pyfunc, cfunc, args)
def test_eq(self):
self._test_compare(eq_usecase)
def test_ne(self):
self._test_compare(ne_usecase)
def test_gt(self):
self._test_compare(gt_usecase)
def test_ge(self):
self._test_compare(ge_usecase)
def test_lt(self):
self._test_compare(lt_usecase)
def test_le(self):
self._test_compare(le_usecase)
class TestNamedTuple(TestCase, MemoryLeakMixin):
def test_unpack(self):
def check(p):
for pyfunc in tuple_first, tuple_second:
cfunc = jit(nopython=True)(pyfunc)
self.assertPreciseEqual(cfunc(p), pyfunc(p))
# Homogeneous
check(Rect(4, 5))
# Heterogeneous
check(Rect(4, 5.5))
def test_len(self):
def check(p):
pyfunc = len_usecase
cfunc = jit(nopython=True)(pyfunc)
self.assertPreciseEqual(cfunc(p), pyfunc(p))
# Homogeneous
check(Rect(4, 5))
check(Point(4, 5, 6))
# Heterogeneous
check(Rect(4, 5.5))
check(Point(4, 5.5, 6j))
def test_index(self):
pyfunc = tuple_index
cfunc = jit(nopython=True)(pyfunc)
p = Point(4, 5, 6)
for i in range(len(p)):
self.assertPreciseEqual(cfunc(p, i), pyfunc(p, i))
# test uintp indexing (because, e.g., parfor generates unsigned prange)
for i in range(len(p)):
self.assertPreciseEqual(cfunc(p, types.uintp(i)), pyfunc(p, i))
def test_bool(self):
def check(p):
pyfunc = bool_usecase
cfunc = jit(nopython=True)(pyfunc)
self.assertPreciseEqual(cfunc(p), pyfunc(p))
# Homogeneous
check(Rect(4, 5))
# Heterogeneous
check(Rect(4, 5.5))
check(Empty())
def _test_compare(self, pyfunc):
def eq(pyfunc, cfunc, args):
self.assertIs(cfunc(*args), pyfunc(*args),
"mismatch for arguments %s" % (args,))
cfunc = jit(nopython=True)(pyfunc)
# Same-sized named tuples
for a, b in [((4, 5), (4, 5)),
((4, 5), (4, 6)),
((4, 6), (4, 5)),
((4, 5), (5, 4))]:
eq(pyfunc, cfunc, (Rect(*a), Rect(*b)))
# Different-sized named tuples
for a, b in [((4, 5), (4, 5, 6)),
((4, 5), (4, 4, 6)),
((4, 5), (4, 6, 7))]:
eq(pyfunc, cfunc, (Rect(*a), Point(*b)))
def test_eq(self):
self._test_compare(eq_usecase)
def test_ne(self):
self._test_compare(ne_usecase)
def test_gt(self):
self._test_compare(gt_usecase)
def test_ge(self):
self._test_compare(ge_usecase)
def test_lt(self):
self._test_compare(lt_usecase)
def test_le(self):
self._test_compare(le_usecase)
def test_getattr(self):
pyfunc = getattr_usecase
cfunc = jit(nopython=True)(pyfunc)
for args in (4, 5, 6), (4, 5.5, 6j):
p = Point(*args)
self.assertPreciseEqual(cfunc(p), pyfunc(p))
def test_construct(self):
def check(pyfunc):
cfunc = jit(nopython=True)(pyfunc)
for args in (4, 5, 6), (4, 5.5, 6j):
expected = pyfunc(*args)
got = cfunc(*args)
self.assertIs(type(got), type(expected))
self.assertPreciseEqual(got, expected)
check(make_point)
check(make_point_kws)
def test_type(self):
# Test the type() built-in on named tuples
pyfunc = type_usecase
cfunc = jit(nopython=True)(pyfunc)
arg_tuples = [(4, 5, 6), (4, 5.5, 6j)]
for tup_args, args in itertools.product(arg_tuples, arg_tuples):
tup = Point(*tup_args)
expected = pyfunc(tup, *args)
got = cfunc(tup, *args)
self.assertIs(type(got), type(expected))
self.assertPreciseEqual(got, expected)
def test_literal_unification(self):
# Test for #3565.
@jit(nopython=True)
def Data1(value):
return Rect(value, -321)
@jit(nopython=True)
def call(i, j):
if j == 0:
# In the error, `result` is typed to `Rect(int, LiteralInt)`
# because of the `-321` literal. This doesn't match the
# `result` type in the other branch.
result = Data1(i)
else:
# `result` is typed to be `Rect(int, int)`
result = Rect(i, j)
return result
r = call(123, 1321)
self.assertEqual(r, Rect(width=123, height=1321))
r = call(123, 0)
self.assertEqual(r, Rect(width=123, height=-321))
def test_string_literal_in_ctor(self):
# Test for issue #3813
@jit(nopython=True)
def foo():
return Rect(10, 'somestring')
r = foo()
self.assertEqual(r, Rect(width=10, height='somestring'))
def test_dispatcher_mistreat(self):
# Test for issue #5215 that mistreat namedtuple as tuples
@jit(nopython=True)
def foo(x):
return x
in1 = (1, 2, 3)
out1 = foo(in1)
self.assertEqual(in1, out1)
in2 = Point(1, 2, 3)
out2 = foo(in2)
self.assertEqual(in2, out2)
# Check the signatures
self.assertEqual(len(foo.nopython_signatures), 2)
self.assertEqual(foo.nopython_signatures[0].args[0], typeof(in1))
self.assertEqual(foo.nopython_signatures[1].args[0], typeof(in2))
# Differently named
in3 = Point2(1, 2, 3)
out3 = foo(in3)
self.assertEqual(in3, out3)
self.assertEqual(len(foo.nopython_signatures), 3)
self.assertEqual(foo.nopython_signatures[2].args[0], typeof(in3))
class TestTupleNRT(TestCase, MemoryLeakMixin):
def test_tuple_add(self):
def pyfunc(x):
a = np.arange(3)
return (a,) + (x,)
cfunc = jit(nopython=True)(pyfunc)
x = 123
expect_a, expect_x = pyfunc(x)
got_a, got_x = cfunc(x)
np.testing.assert_equal(got_a, expect_a)
self.assertEqual(got_x, expect_x)
class TestNamedTupleNRT(TestCase, MemoryLeakMixin):
def test_return(self):
# Check returning a namedtuple with a list inside it
pyfunc = make_point_nrt
cfunc = jit(nopython=True)(pyfunc)
for arg in (3, 0):
expected = pyfunc(arg)
got = cfunc(arg)
self.assertIs(type(got), type(expected))
self.assertPreciseEqual(got, expected)
class TestConversions(TestCase):
"""
Test implicit conversions between tuple types.
"""
def check_conversion(self, fromty, toty, val):
pyfunc = identity
cfunc = njit(toty(fromty))(pyfunc)
res = cfunc(val)
self.assertEqual(res, val)
def test_conversions(self):
check = self.check_conversion
fromty = types.UniTuple(types.int32, 2)
check(fromty, types.UniTuple(types.float32, 2), (4, 5))
check(fromty, types.Tuple((types.float32, types.int16)), (4, 5))
aty = types.UniTuple(types.int32, 0)
bty = types.Tuple(())
check(aty, bty, ())
check(bty, aty, ())
with self.assertRaises(errors.TypingError) as raises:
check(fromty, types.Tuple((types.float32,)), (4, 5))
msg = "No conversion from UniTuple(int32 x 2) to UniTuple(float32 x 1)"
self.assertIn(msg, str(raises.exception))
class TestMethods(TestCase):
def test_index(self):
pyfunc = index_method_usecase
cfunc = jit(nopython=True)(pyfunc)
self.assertEqual(cfunc((1, 2, 3), 2), 1)
with self.assertRaises(ValueError) as raises:
cfunc((1, 2, 3), 4)
msg = 'tuple.index(x): x not in tuple'
self.assertEqual(msg, str(raises.exception))
class TestTupleBuild(TestCase):
def test_build_unpack(self):
def check(p):
pyfunc = lambda a: (1, *a)
cfunc = jit(nopython=True)(pyfunc)
self.assertPreciseEqual(cfunc(p), pyfunc(p))
# Homogeneous
check((4, 5))
# Heterogeneous
check((4, 5.5))
def test_build_unpack_assign_like(self):
# see #6534
def check(p):
pyfunc = lambda a: (*a,)
cfunc = jit(nopython=True)(pyfunc)
self.assertPreciseEqual(cfunc(p), pyfunc(p))
# Homogeneous
check((4, 5))
# Heterogeneous
check((4, 5.5))
def test_build_unpack_fail_on_list_assign_like(self):
# see #6534
def check(p):
pyfunc = lambda a: (*a,)
cfunc = jit(nopython=True)(pyfunc)
self.assertPreciseEqual(cfunc(p), pyfunc(p))
with self.assertRaises(errors.TypingError) as raises:
check([4, 5])
# Python 3.9 has a peephole rewrite due to large changes in tuple
# unpacking. It results in a tuple + list situation from the above
# so the error message reflects that. Catching this specific and
# seemingly rare sequence in the peephole rewrite is prohibitively
# hard. Should it be reported numerous times, revisit then.
msg1 = "No implementation of function"
self.assertIn(msg1, str(raises.exception))
msg2 = "tuple(reflected list(" # ignore the rest of reflected list
# part, it's repr is quite volatile.
self.assertIn(msg2, str(raises.exception))
def test_build_unpack_more(self):
def check(p):
pyfunc = lambda a: (1, *a, (1, 2), *a)
cfunc = jit(nopython=True)(pyfunc)
self.assertPreciseEqual(cfunc(p), pyfunc(p))
# Homogeneous
check((4, 5))
# Heterogeneous
check((4, 5.5))
def test_build_unpack_call(self):
def check(p):
@jit
def inner(*args):
return args
pyfunc = lambda a: inner(1, *a)
cfunc = jit(nopython=True)(pyfunc)
self.assertPreciseEqual(cfunc(p), pyfunc(p))
# Homogeneous
check((4, 5))
# Heterogeneous
check((4, 5.5))
def test_build_unpack_call_more(self):
def check(p):
@jit
def inner(*args):
return args
pyfunc = lambda a: inner(1, *a, *(1, 2), *a)
cfunc = jit(nopython=True)(pyfunc)
self.assertPreciseEqual(cfunc(p), pyfunc(p))
# Homogeneous
check((4, 5))
# Heterogeneous
check((4, 5.5))
def test_tuple_constructor(self):
def check(pyfunc, arg):
cfunc = jit(nopython=True)(pyfunc)
self.assertPreciseEqual(cfunc(arg), pyfunc(arg))
# empty
check(lambda _: tuple(), ())
# Homogeneous
check(lambda a: tuple(a), (4, 5))
# Heterogeneous
check(lambda a: tuple(a), (4, 5.5))
@unittest.skipIf(utils.PYVERSION < (3, 9), "needs Python 3.9+")
def test_unpack_with_predicate_fails(self):
# this fails as the list_to_tuple/list_extend peephole bytecode
# rewriting needed for Python 3.9+ cannot yet traverse the CFG.
@njit
def foo():
a = (1,)
b = (3,2, 4)
return (*(b if a[0] else (5, 6)),)
with self.assertRaises(errors.UnsupportedError) as raises:
foo()
msg = "op_LIST_EXTEND at the start of a block"
self.assertIn(msg, str(raises.exception))
def test_build_unpack_with_calls_in_unpack(self):
def check(p):
def pyfunc(a):
z = [1, 2]
return (*a, z.append(3), z.extend(a), np.ones(3)), z
cfunc = jit(nopython=True)(pyfunc)
self.assertPreciseEqual(cfunc(p), pyfunc(p))
check((4, 5))
def test_build_unpack_complicated(self):
def check(p):
def pyfunc(a):
z = [1, 2]
return (*a, *(*a, a), *(a, (*(a, (1, 2), *(3,), *a),
(a, 1, (2, 3), *a, 1), (1,))),
*(z.append(4), z.extend(a))), z
cfunc = jit(nopython=True)(pyfunc)
self.assertPreciseEqual(cfunc(p), pyfunc(p))
check((10, 20))
if __name__ == '__main__':
unittest.main()