import collections import weakref import gc import operator from itertools import takewhile import unittest from numba import njit, jit from numba.core.compiler import CompilerBase, DefaultPassBuilder from numba.core.untyped_passes import PreserveIR from numba.core.typed_passes import IRLegalization from numba.core import types, ir from numba.tests.support import TestCase, override_config, SerialMixin class _Dummy(object): def __init__(self, recorder, name): self.recorder = recorder self.name = name recorder._add_dummy(self) def __add__(self, other): assert isinstance(other, _Dummy) return _Dummy(self.recorder, "%s + %s" % (self.name, other.name)) def __iter__(self): return _DummyIterator(self.recorder, "iter(%s)" % self.name) class _DummyIterator(_Dummy): count = 0 def __next__(self): if self.count >= 3: raise StopIteration self.count += 1 return _Dummy(self.recorder, "%s#%s" % (self.name, self.count)) next = __next__ class RefRecorder(object): """ An object which records events when instances created through it are deleted. Custom events can also be recorded to aid in diagnosis. """ def __init__(self): self._counts = collections.defaultdict(int) self._events = [] self._wrs = {} def make_dummy(self, name): """ Make an object whose deletion will be recorded as *name*. """ return _Dummy(self, name) def _add_dummy(self, dummy): wr = weakref.ref(dummy, self._on_disposal) self._wrs[wr] = dummy.name __call__ = make_dummy def mark(self, event): """ Manually append *event* to the recorded events. *event* can be formatted using format(). """ count = self._counts[event] + 1 self._counts[event] = count self._events.append(event.format(count=count)) def _on_disposal(self, wr): name = self._wrs.pop(wr) self._events.append(name) @property def alive(self): """ A list of objects which haven't been deleted yet. """ return [wr() for wr in self._wrs] @property def recorded(self): """ A list of recorded events. """ return self._events def simple_usecase1(rec): a = rec('a') b = rec('b') c = rec('c') a = b + c rec.mark('--1--') d = a + a # b + c + b + c rec.mark('--2--') return d def simple_usecase2(rec): a = rec('a') b = rec('b') rec.mark('--1--') x = a y = x a = None return y def looping_usecase1(rec): a = rec('a') b = rec('b') c = rec('c') x = b for y in a: x = x + y rec.mark('--loop bottom--') rec.mark('--loop exit--') x = x + c return x def looping_usecase2(rec): a = rec('a') b = rec('b') cum = rec('cum') for x in a: rec.mark('--outer loop top--') cum = cum + x z = x + x rec.mark('--inner loop entry #{count}--') for y in b: rec.mark('--inner loop top #{count}--') cum = cum + y rec.mark('--inner loop bottom #{count}--') rec.mark('--inner loop exit #{count}--') if cum: cum = y + z else: # Never gets here, but let the Numba compiler see a `break` opcode break rec.mark('--outer loop bottom #{count}--') else: rec.mark('--outer loop else--') rec.mark('--outer loop exit--') return cum def generator_usecase1(rec): a = rec('a') b = rec('b') yield a yield b def generator_usecase2(rec): a = rec('a') b = rec('b') for x in a: yield x yield b class MyError(RuntimeError): pass def do_raise(x): raise MyError(x) def raising_usecase1(rec): a = rec('a') b = rec('b') d = rec('d') if a: do_raise("foo") c = rec('c') c + a c + b def raising_usecase2(rec): a = rec('a') b = rec('b') if a: c = rec('c') do_raise(b) a + c def raising_usecase3(rec): a = rec('a') b = rec('b') if a: raise MyError(b) def del_before_definition(rec): """ This test reveal a bug that there is a del on uninitialized variable """ n = 5 for i in range(n): rec.mark(str(i)) n = 0 for j in range(n): return 0 else: if i < 2: continue elif i == 2: for j in range(i): return i rec.mark('FAILED') rec.mark('FAILED') rec.mark('FAILED') rec.mark('OK') return -1 def inf_loop_multiple_back_edge(rec): """ test to reveal bug of invalid liveness when infinite loop has multiple backedge. """ while True: rec.mark("yield") yield p = rec('p') if p: rec.mark('bra') pass class TestObjLifetime(TestCase): """ Test lifetime of Python objects inside jit-compiled functions. """ def compile(self, pyfunc): # Note: looplift must be disabled. The test require the function # control-flow to be unchanged. cfunc = jit((types.pyobject,), forceobj=True, looplift=False)(pyfunc) return cfunc def compile_and_record(self, pyfunc, raises=None): rec = RefRecorder() cfunc = self.compile(pyfunc) if raises is not None: with self.assertRaises(raises): cfunc(rec) else: cfunc(rec) return rec def assertRecordOrder(self, rec, expected): """ Check that the *expected* markers occur in that order in *rec*'s recorded events. """ actual = [] recorded = rec.recorded remaining = list(expected) # Find out in which order, if any, the expected events were recorded for d in recorded: if d in remaining: actual.append(d) # User may or may not expect duplicates, handle them properly remaining.remove(d) self.assertEqual(actual, expected, "the full list of recorded events is: %r" % (recorded,)) def test_simple1(self): rec = self.compile_and_record(simple_usecase1) self.assertFalse(rec.alive) self.assertRecordOrder(rec, ['a', 'b', '--1--']) self.assertRecordOrder(rec, ['a', 'c', '--1--']) self.assertRecordOrder(rec, ['--1--', 'b + c', '--2--']) def test_simple2(self): rec = self.compile_and_record(simple_usecase2) self.assertFalse(rec.alive) self.assertRecordOrder(rec, ['b', '--1--', 'a']) def test_looping1(self): rec = self.compile_and_record(looping_usecase1) self.assertFalse(rec.alive) # a and b are unneeded after the loop, check they were disposed of self.assertRecordOrder(rec, ['a', 'b', '--loop exit--', 'c']) # check disposal order of iterator items and iterator self.assertRecordOrder(rec, ['iter(a)#1', '--loop bottom--', 'iter(a)#2', '--loop bottom--', 'iter(a)#3', '--loop bottom--', 'iter(a)', '--loop exit--', ]) def test_looping2(self): rec = self.compile_and_record(looping_usecase2) self.assertFalse(rec.alive) # `a` is disposed of after its iterator is taken self.assertRecordOrder(rec, ['a', '--outer loop top--']) # Check disposal of iterators self.assertRecordOrder(rec, ['iter(a)', '--outer loop else--', '--outer loop exit--']) self.assertRecordOrder(rec, ['iter(b)', '--inner loop exit #1--', 'iter(b)', '--inner loop exit #2--', 'iter(b)', '--inner loop exit #3--', ]) # Disposal of in-loop variable `x` self.assertRecordOrder(rec, ['iter(a)#1', '--inner loop entry #1--', 'iter(a)#2', '--inner loop entry #2--', 'iter(a)#3', '--inner loop entry #3--', ]) # Disposal of in-loop variable `z` self.assertRecordOrder(rec, ['iter(a)#1 + iter(a)#1', '--outer loop bottom #1--', ]) def exercise_generator(self, genfunc): cfunc = self.compile(genfunc) # Exhaust the generator rec = RefRecorder() with self.assertRefCount(rec): gen = cfunc(rec) next(gen) self.assertTrue(rec.alive) list(gen) self.assertFalse(rec.alive) # Instantiate the generator but never iterate rec = RefRecorder() with self.assertRefCount(rec): gen = cfunc(rec) del gen gc.collect() self.assertFalse(rec.alive) # Stop iterating before exhaustion rec = RefRecorder() with self.assertRefCount(rec): gen = cfunc(rec) next(gen) self.assertTrue(rec.alive) del gen gc.collect() self.assertFalse(rec.alive) def test_generator1(self): self.exercise_generator(generator_usecase1) def test_generator2(self): self.exercise_generator(generator_usecase2) def test_del_before_definition(self): rec = self.compile_and_record(del_before_definition) self.assertEqual(rec.recorded, ['0', '1', '2']) def test_raising1(self): with self.assertRefCount(do_raise): rec = self.compile_and_record(raising_usecase1, raises=MyError) self.assertFalse(rec.alive) def test_raising2(self): with self.assertRefCount(do_raise): rec = self.compile_and_record(raising_usecase2, raises=MyError) self.assertFalse(rec.alive) def test_raising3(self): with self.assertRefCount(MyError): rec = self.compile_and_record(raising_usecase3, raises=MyError) self.assertFalse(rec.alive) def test_inf_loop_multiple_back_edge(self): cfunc = self.compile(inf_loop_multiple_back_edge) rec = RefRecorder() iterator = iter(cfunc(rec)) next(iterator) self.assertEqual(rec.alive, []) next(iterator) self.assertEqual(rec.alive, []) next(iterator) self.assertEqual(rec.alive, []) self.assertEqual(rec.recorded, ['yield', 'p', 'bra', 'yield', 'p', 'bra', 'yield']) class TestExtendingVariableLifetimes(SerialMixin, TestCase): # Test for `numba.config.EXTEND_VARIABLE_LIFETIMES` which moves the ir.Del # nodes to just before a block's terminator, i.e. their lifetime is extended # beyond the point of last use. def test_lifetime_basic(self): def get_ir(extend_lifetimes): class IRPreservingCompiler(CompilerBase): def define_pipelines(self): pm = DefaultPassBuilder.define_nopython_pipeline(self.state) pm.add_pass_after(PreserveIR, IRLegalization) pm.finalize() return [pm] @njit(pipeline_class=IRPreservingCompiler) def foo(): a = 10 b = 20 c = a + b # a and b are now unused, standard behaviour is ir.Del for them here d = c / c return d with override_config('EXTEND_VARIABLE_LIFETIMES', extend_lifetimes): foo() cres = foo.overloads[foo.signatures[0]] func_ir = cres.metadata['preserved_ir'] return func_ir def check(func_ir, expect): # assert single block self.assertEqual(len(func_ir.blocks), 1) blk = next(iter(func_ir.blocks.values())) # check sequencing for expect_class, got_stmt in zip(expect, blk.body): self.assertIsInstance(got_stmt, expect_class) del_after_use_ir = get_ir(False) # should be 3 assigns (a, b, c), 2 del (a, b), assign (d), del (c) # assign for cast d to return, del (d), return expect = [*((ir.Assign,) * 3), ir.Del, ir.Del, ir.Assign, ir.Del, ir.Assign, ir.Del, ir.Return] check(del_after_use_ir, expect) del_at_block_end_ir = get_ir(True) # should be 4 assigns (a, b, c, d), assign for cast d to return, # 4 dels (a, b, c, d) then the return. expect = [*((ir.Assign,) * 4), ir.Assign, *((ir.Del,) * 4), ir.Return] check(del_at_block_end_ir, expect) def test_dbg_extend_lifetimes(self): def get_ir(**options): class IRPreservingCompiler(CompilerBase): def define_pipelines(self): pm = DefaultPassBuilder.define_nopython_pipeline(self.state) pm.add_pass_after(PreserveIR, IRLegalization) pm.finalize() return [pm] @njit(pipeline_class=IRPreservingCompiler, **options) def foo(): a = 10 b = 20 c = a + b # a and b are now unused, standard behaviour is ir.Del for them here d = c / c return d foo() cres = foo.overloads[foo.signatures[0]] func_ir = cres.metadata['preserved_ir'] return func_ir # _dbg_extend_lifetimes is on when debug=True ir_debug = get_ir(debug=True) # explicitly turn on _dbg_extend_lifetimes ir_debug_ext = get_ir(debug=True, _dbg_extend_lifetimes=True) # explicitly turn off _dbg_extend_lifetimes ir_debug_no_ext = get_ir(debug=True, _dbg_extend_lifetimes=False) def is_del_grouped_at_the_end(fir): [blk] = fir.blocks.values() # Mark all statements that are ir.Del inst_is_del = [isinstance(stmt, ir.Del) for stmt in blk.body] # Get the leading segment that are not dels not_dels = list(takewhile(operator.not_, inst_is_del)) # Compute the starting position of the dels begin = len(not_dels) # Get the remaining segment that are all dels all_dels = list(takewhile(operator.truth, inst_is_del[begin:])) # Compute the ending position of the dels end = begin + len(all_dels) # If the dels are all grouped at the end (before the terminator), # the end position will be the last position of the list return end == len(inst_is_del) - 1 self.assertTrue(is_del_grouped_at_the_end(ir_debug)) self.assertTrue(is_del_grouped_at_the_end(ir_debug_ext)) self.assertFalse(is_del_grouped_at_the_end(ir_debug_no_ext)) if __name__ == "__main__": unittest.main()