import numpy as np import unittest from numba import jit, njit from numba.core import types from numba.tests.support import TestCase, MemoryLeakMixin from numba.core.datamodel.testing import test_factory forceobj_flags = {'nopython': False, 'forceobj': True} nopython_flags = {'nopython': True} def make_consumer(gen_func): def consumer(x): res = 0.0 for y in gen_func(x): res += y return res return consumer def gen1(x): for i in range(x): yield i def gen2(x): for i in range(x): yield i for j in range(1, 3): yield i + j def gen3(x): # Polymorphic yield types must be unified yield x yield x + 1.5 yield x + 1j def gen4(x, y, z): for i in range(3): yield z yield y + z return yield x def gen5(): # The bytecode for this generator doesn't contain any YIELD_VALUE # (it's optimized away). We fail typing it, since the yield type # is entirely undefined. if 0: yield 1 def gen6(a, b): # Infinite loop: exercise computation of state variables x = a + 1 while True: y = b + 2 yield x + y def gen7(arr): # Array variable in generator state for i in range(arr.size): yield arr[i] # Optional arguments and boolean state members def gen8(x=1, y=2, b=False): bb = not b yield x if bb: yield y if b: yield x + y def genobj(x): object() yield x def return_generator_expr(x): return (i * 2 for i in x) def gen_ndindex(shape): for ind in np.ndindex(shape): yield ind def gen_flat(arr): for val in arr.flat: yield val def gen_ndenumerate(arr): for tup in np.ndenumerate(arr): yield tup def gen_bool(): yield True def gen_unification_error(): yield None yield 1j def gen_optional_and_type_unification_error(): # yields complex and optional(literalint) i = 0 yield 1j while True: i = yield i def gen_changing_tuple_type(): # https://github.com/numba/numba/issues/7295 yield 1, 2 yield 3, 4 def gen_changing_number_type(): # additional test for https://github.com/numba/numba/issues/7295 yield 1 yield 3.5 yield 67.8j class TestGenerators(MemoryLeakMixin, TestCase): def check_generator(self, pygen, cgen): self.assertEqual(next(cgen), next(pygen)) # Use list comprehensions to make sure we trash the generator's # former C stack. expected = [x for x in pygen] got = [x for x in cgen] self.assertEqual(expected, got) with self.assertRaises(StopIteration): next(cgen) def check_gen1(self, **kwargs): pyfunc = gen1 cr = jit((types.int32,), **kwargs)(pyfunc) pygen = pyfunc(8) cgen = cr(8) self.check_generator(pygen, cgen) def test_gen1(self): self.check_gen1(**nopython_flags) def test_gen1_objmode(self): self.check_gen1(**forceobj_flags) def check_gen2(self, **kwargs): pyfunc = gen2 cr = jit((types.int32,), **kwargs)(pyfunc) pygen = pyfunc(8) cgen = cr(8) self.check_generator(pygen, cgen) def test_gen2(self): self.check_gen2(**nopython_flags) def test_gen2_objmode(self): self.check_gen2(**forceobj_flags) def check_gen3(self, **kwargs): pyfunc = gen3 cr = jit((types.int32,), **kwargs)(pyfunc) pygen = pyfunc(8) cgen = cr(8) self.check_generator(pygen, cgen) def test_gen3(self): self.check_gen3(**nopython_flags) def test_gen3_objmode(self): self.check_gen3(**forceobj_flags) def check_gen4(self, **kwargs): pyfunc = gen4 cr = jit((types.int32,) * 3, **kwargs)(pyfunc) pygen = pyfunc(5, 6, 7) cgen = cr(5, 6, 7) self.check_generator(pygen, cgen) def test_gen4(self): self.check_gen4(**nopython_flags) def test_gen4_objmode(self): self.check_gen4(**forceobj_flags) def test_gen5(self): with self.assertTypingError() as raises: jit((), **nopython_flags)(gen5) self.assertIn("Cannot type generator: it does not yield any value", str(raises.exception)) def test_gen5_objmode(self): cgen = jit((), **forceobj_flags)(gen5)() self.assertEqual(list(cgen), []) with self.assertRaises(StopIteration): next(cgen) def check_gen6(self, **kwargs): cr = jit((types.int32,) * 2, **kwargs)(gen6) cgen = cr(5, 6) l = [] for i in range(3): l.append(next(cgen)) self.assertEqual(l, [14] * 3) def test_gen6(self): self.check_gen6(**nopython_flags) def test_gen6_objmode(self): self.check_gen6(**forceobj_flags) def check_gen7(self, **kwargs): pyfunc = gen7 cr = jit((types.Array(types.float64, 1, 'C'),), **kwargs)(pyfunc) arr = np.linspace(1, 10, 7) pygen = pyfunc(arr.copy()) cgen = cr(arr) self.check_generator(pygen, cgen) def test_gen7(self): self.check_gen7(**nopython_flags) def test_gen7_objmode(self): self.check_gen7(**forceobj_flags) def check_gen8(self, **jit_args): pyfunc = gen8 cfunc = jit(**jit_args)(pyfunc) def check(*args, **kwargs): self.check_generator(pyfunc(*args, **kwargs), cfunc(*args, **kwargs)) check(2, 3) check(4) check(y=5) check(x=6, b=True) def test_gen8(self): self.check_gen8(nopython=True) def test_gen8_objmode(self): self.check_gen8(forceobj=True) def check_gen9(self, **kwargs): pyfunc = gen_bool cr = jit((), **kwargs)(pyfunc) pygen = pyfunc() cgen = cr() self.check_generator(pygen, cgen) def test_gen9(self): self.check_gen9(**nopython_flags) def test_gen9_objmode(self): self.check_gen9(**forceobj_flags) def check_consume_generator(self, gen_func): cgen = jit(nopython=True)(gen_func) cfunc = jit(nopython=True)(make_consumer(cgen)) pyfunc = make_consumer(gen_func) expected = pyfunc(5) got = cfunc(5) self.assertPreciseEqual(got, expected) def test_consume_gen1(self): self.check_consume_generator(gen1) def test_consume_gen2(self): self.check_consume_generator(gen2) def test_consume_gen3(self): self.check_consume_generator(gen3) # Check generator storage of some types def check_ndindex(self, **kwargs): pyfunc = gen_ndindex cr = jit((types.UniTuple(types.intp, 2),), **kwargs)(pyfunc) shape = (2, 3) pygen = pyfunc(shape) cgen = cr(shape) self.check_generator(pygen, cgen) def test_ndindex(self): self.check_ndindex(**nopython_flags) def test_ndindex_objmode(self): self.check_ndindex(**forceobj_flags) def check_np_flat(self, pyfunc, **kwargs): cr = jit((types.Array(types.int32, 2, "C"),), **kwargs)(pyfunc) arr = np.arange(6, dtype=np.int32).reshape((2, 3)) self.check_generator(pyfunc(arr), cr(arr)) crA = jit((types.Array(types.int32, 2, "A"),), **kwargs)(pyfunc) arr = arr.T self.check_generator(pyfunc(arr), crA(arr)) def test_np_flat(self): self.check_np_flat(gen_flat, **nopython_flags) def test_np_flat_objmode(self): self.check_np_flat(gen_flat, **forceobj_flags) def test_ndenumerate(self): self.check_np_flat(gen_ndenumerate, **nopython_flags) def test_ndenumerate_objmode(self): self.check_np_flat(gen_ndenumerate, **forceobj_flags) def test_type_unification_error(self): pyfunc = gen_unification_error with self.assertTypingError() as raises: jit((), **nopython_flags)(pyfunc) msg = ("Can't unify yield type from the following types: complex128, " "none") self.assertIn(msg, str(raises.exception)) def test_optional_expansion_type_unification_error(self): pyfunc = gen_optional_and_type_unification_error with self.assertTypingError() as raises: jit((), **nopython_flags)(pyfunc) msg = ("Can't unify yield type from the following types: complex128, " "int%s, none") self.assertIn(msg % types.intp.bitwidth, str(raises.exception)) def test_changing_tuple_type(self): # test https://github.com/numba/numba/issues/7295 pyfunc = gen_changing_tuple_type expected = list(pyfunc()) got = list(njit(pyfunc)()) self.assertEqual(expected, got) def test_changing_number_type(self): # additional test for https://github.com/numba/numba/issues/7295 pyfunc = gen_changing_number_type expected = list(pyfunc()) got = list(njit(pyfunc)()) self.assertEqual(expected, got) def nrt_gen0(ary): for elem in ary: yield elem def nrt_gen1(ary1, ary2): for e1, e2 in zip(ary1, ary2): yield e1 yield e2 class TestNrtArrayGen(MemoryLeakMixin, TestCase): def test_nrt_gen0(self): pygen = nrt_gen0 cgen = jit(nopython=True)(pygen) py_ary = np.arange(10) c_ary = py_ary.copy() py_res = list(pygen(py_ary)) c_res = list(cgen(c_ary)) np.testing.assert_equal(py_ary, c_ary) self.assertEqual(py_res, c_res) # Check reference count self.assertRefCountEqual(py_ary, c_ary) def test_nrt_gen1(self): pygen = nrt_gen1 cgen = jit(nopython=True)(pygen) py_ary1 = np.arange(10) py_ary2 = py_ary1 + 100 c_ary1 = py_ary1.copy() c_ary2 = py_ary2.copy() py_res = list(pygen(py_ary1, py_ary2)) c_res = list(cgen(c_ary1, c_ary2)) np.testing.assert_equal(py_ary1, c_ary1) np.testing.assert_equal(py_ary2, c_ary2) self.assertEqual(py_res, c_res) # Check reference count self.assertRefCountEqual(py_ary1, c_ary1) self.assertRefCountEqual(py_ary2, c_ary2) def test_combine_gen0_gen1(self): """ Issue #1163 is observed when two generator with NRT object arguments is ran in sequence. The first one does a invalid free and corrupts the NRT memory subsystem. The second generator is likely to segfault due to corrupted NRT data structure (an invalid MemInfo). """ self.test_nrt_gen0() self.test_nrt_gen1() def test_nrt_gen0_stop_iteration(self): """ Test cleanup on StopIteration """ pygen = nrt_gen0 cgen = jit(nopython=True)(pygen) py_ary = np.arange(1) c_ary = py_ary.copy() py_iter = pygen(py_ary) c_iter = cgen(c_ary) py_res = next(py_iter) c_res = next(c_iter) with self.assertRaises(StopIteration): py_res = next(py_iter) with self.assertRaises(StopIteration): c_res = next(c_iter) del py_iter del c_iter np.testing.assert_equal(py_ary, c_ary) self.assertEqual(py_res, c_res) # Check reference count self.assertRefCountEqual(py_ary, c_ary) def test_nrt_gen0_no_iter(self): """ Test cleanup for a initialized but never iterated (never call next()) generator. """ pygen = nrt_gen0 cgen = jit(nopython=True)(pygen) py_ary = np.arange(1) c_ary = py_ary.copy() py_iter = pygen(py_ary) c_iter = cgen(c_ary) del py_iter del c_iter np.testing.assert_equal(py_ary, c_ary) # Check reference count self.assertRefCountEqual(py_ary, c_ary) # TODO: fix nested generator and MemoryLeakMixin class TestNrtNestedGen(TestCase): def test_nrt_nested_gen(self): def gen0(arr): for i in range(arr.size): yield arr def factory(gen0): def gen1(arr): out = np.zeros_like(arr) for x in gen0(arr): out = out + x return out, arr return gen1 py_arr = np.arange(10) c_arr = py_arr.copy() py_res, py_old = factory(gen0)(py_arr) c_gen = jit(nopython=True)(factory(jit(nopython=True)(gen0))) c_res, c_old = c_gen(c_arr) self.assertIsNot(py_arr, c_arr) self.assertIs(py_old, py_arr) self.assertIs(c_old, c_arr) np.testing.assert_equal(py_res, c_res) self.assertRefCountEqual(py_res, c_res) # The below test will fail due to generator finalizer not invoked. # This kept a reference of the c_old. # # self.assertEqual(sys.getrefcount(py_old), # sys.getrefcount(c_old)) @unittest.expectedFailure def test_nrt_nested_gen_refct(self): def gen0(arr): yield arr def factory(gen0): def gen1(arr): for out in gen0(arr): return out return gen1 py_arr = np.arange(10) c_arr = py_arr.copy() py_old = factory(gen0)(py_arr) c_gen = jit(nopython=True)(factory(jit(nopython=True)(gen0))) c_old = c_gen(c_arr) self.assertIsNot(py_arr, c_arr) self.assertIs(py_old, py_arr) self.assertIs(c_old, c_arr) self.assertRefCountEqual(py_old, c_old) def test_nrt_nested_nopython_gen(self): """ Test nesting three generators """ def factory(decor=lambda x: x): @decor def foo(a, n): for i in range(n): yield a[i] a[i] += i @decor def bar(n): a = np.arange(n) for i in foo(a, n): yield i * 2 for i in range(a.size): yield a[i] @decor def cat(n): for i in bar(n): yield i + i return cat py_gen = factory() c_gen = factory(jit(nopython=True)) py_res = list(py_gen(10)) c_res = list(c_gen(10)) self.assertEqual(py_res, c_res) class TestGeneratorWithNRT(MemoryLeakMixin, TestCase): def test_issue_1254(self): """ Missing environment for returning array """ @jit(nopython=True) def random_directions(n): for i in range(n): vec = np.empty(3) vec[:] = 12 yield vec outputs = list(random_directions(5)) self.assertEqual(len(outputs), 5) expect = np.empty(3) expect[:] = 12 for got in outputs: np.testing.assert_equal(expect, got) def test_issue_1265(self): """ Double-free for locally allocated, non escaping NRT objects """ def py_gen(rmin, rmax, nr): a = np.linspace(rmin, rmax, nr) yield a[0] yield a[1] c_gen = jit(nopython=True)(py_gen) py_res = list(py_gen(-2, 2, 100)) c_res = list(c_gen(-2, 2, 100)) self.assertEqual(py_res, c_res) def py_driver(args): rmin, rmax, nr = args points = np.empty(nr, dtype=np.complex128) for i, c in enumerate(py_gen(rmin, rmax, nr)): points[i] = c return points @jit(nopython=True) def c_driver(args): rmin, rmax, nr = args points = np.empty(nr, dtype=np.complex128) for i, c in enumerate(c_gen(rmin, rmax, nr)): points[i] = c return points n = 2 patches = (-2, -1, n) py_res = py_driver(patches) # The error will cause a segfault here c_res = c_driver(patches) np.testing.assert_equal(py_res, c_res) def test_issue_1808(self): """ Incorrect return data model """ magic = 0xdeadbeef @njit def generator(): yield magic @njit def get_generator(): return generator() @njit def main(): out = 0 for x in get_generator(): out += x return out self.assertEqual(main(), magic) class TestGeneratorModel(test_factory()): fe_type = types.Generator(gen_func=None, yield_type=types.int32, arg_types=[types.int64, types.float32], state_types=[types.intp, types.intp[::1]], has_finalizer=False) if __name__ == '__main__': unittest.main()