import unittest from llvmlite import ir from llvmlite import binding as llvm from llvmlite.tests import TestCase from . import refprune_proto as proto def _iterate_cases(generate_test): def wrap(fn): def wrapped(self): return generate_test(self, fn) wrapped.__doc__ = f"generated test for {fn.__module__}.{fn.__name__}" return wrapped for k, case_fn in proto.__dict__.items(): if k.startswith('case'): yield f'test_{k}', wrap(case_fn) class TestRefPrunePrototype(TestCase): """ Test that the prototype is working. """ def generate_test(self, case_gen): nodes, edges, expected = case_gen() got = proto.FanoutAlgorithm(nodes, edges).run() self.assertEqual(expected, got) # Generate tests for name, case in _iterate_cases(generate_test): locals()[name] = case ptr_ty = ir.IntType(8).as_pointer() class TestRefPrunePass(TestCase): """ Test that the C++ implementation matches the expected behavior as for the prototype. This generates a LLVM module for each test case, runs the pruner and checks that the expected results are achieved. """ def make_incref(self, m): fnty = ir.FunctionType(ir.VoidType(), [ptr_ty]) return ir.Function(m, fnty, name='NRT_incref') def make_decref(self, m): fnty = ir.FunctionType(ir.VoidType(), [ptr_ty]) return ir.Function(m, fnty, name='NRT_decref') def make_switcher(self, m): fnty = ir.FunctionType(ir.IntType(32), ()) return ir.Function(m, fnty, name='switcher') def make_brancher(self, m): fnty = ir.FunctionType(ir.IntType(1), ()) return ir.Function(m, fnty, name='brancher') def generate_ir(self, nodes, edges): # Build LLVM module for the CFG m = ir.Module() incref_fn = self.make_incref(m) decref_fn = self.make_decref(m) switcher_fn = self.make_switcher(m) brancher_fn = self.make_brancher(m) fnty = ir.FunctionType(ir.VoidType(), [ptr_ty]) fn = ir.Function(m, fnty, name='main') [ptr] = fn.args ptr.name = 'mem' # populate the BB nodes bbmap = {} for bb in edges: bbmap[bb] = fn.append_basic_block(bb) # populate the BB builder = ir.IRBuilder() for bb, jump_targets in edges.items(): builder.position_at_end(bbmap[bb]) # Insert increfs and decrefs for action in nodes[bb]: if action == 'incref': builder.call(incref_fn, [ptr]) elif action == 'decref': builder.call(decref_fn, [ptr]) else: raise AssertionError('unreachable') # Insert the terminator. # Switch base on the number of jump targets. n_targets = len(jump_targets) if n_targets == 0: builder.ret_void() elif n_targets == 1: [dst] = jump_targets builder.branch(bbmap[dst]) elif n_targets == 2: [left, right] = jump_targets sel = builder.call(brancher_fn, ()) builder.cbranch(sel, bbmap[left], bbmap[right]) elif n_targets > 2: sel = builder.call(switcher_fn, ()) [head, *tail] = jump_targets sw = builder.switch(sel, default=bbmap[head]) for i, dst in enumerate(tail): sw.add_case(sel.type(i), bbmap[dst]) else: raise AssertionError('unreachable') return m def apply_refprune(self, irmod): mod = llvm.parse_assembly(str(irmod)) pm = llvm.ModulePassManager() pm.add_refprune_pass() pm.run(mod) return mod def check(self, mod, expected, nodes): # preprocess incref/decref locations d = {} for k, vs in nodes.items(): n_incref = vs.count('incref') n_decref = vs.count('decref') d[k] = {'incref': n_incref, 'decref': n_decref} for k, stats in d.items(): if expected.get(k): stats['incref'] -= 1 for dec_bb in expected[k]: d[dec_bb]['decref'] -= 1 # find the main function for f in mod.functions: if f.name == 'main': break # check each BB for bb in f.blocks: stats = d[bb.name] text = str(bb) n_incref = text.count('NRT_incref') n_decref = text.count('NRT_decref') self.assertEqual(stats['incref'], n_incref, msg=f'BB {bb}') self.assertEqual(stats['decref'], n_decref, msg=f'BB {bb}') def generate_test(self, case_gen): nodes, edges, expected = case_gen() irmod = self.generate_ir(nodes, edges) outmod = self.apply_refprune(irmod) self.check(outmod, expected, nodes) # Generate tests for name, case in _iterate_cases(generate_test): locals()[name] = case class BaseTestByIR(TestCase): refprune_bitmask = 0 prologue = r""" declare void @NRT_incref(i8* %ptr) declare void @NRT_decref(i8* %ptr) """ def check(self, irmod, subgraph_limit=None): mod = llvm.parse_assembly(f"{self.prologue}\n{irmod}") pm = llvm.ModulePassManager() if subgraph_limit is None: pm.add_refprune_pass(self.refprune_bitmask) else: pm.add_refprune_pass(self.refprune_bitmask, subgraph_limit=subgraph_limit) before = llvm.dump_refprune_stats() pm.run(mod) after = llvm.dump_refprune_stats() return mod, after - before class TestPerBB(BaseTestByIR): refprune_bitmask = llvm.RefPruneSubpasses.PER_BB per_bb_ir_1 = r""" define void @main(i8* %ptr) { call void @NRT_incref(i8* %ptr) call void @NRT_decref(i8* %ptr) ret void } """ def test_per_bb_1(self): mod, stats = self.check(self.per_bb_ir_1) self.assertEqual(stats.basicblock, 2) per_bb_ir_2 = r""" define void @main(i8* %ptr) { call void @NRT_incref(i8* %ptr) call void @NRT_incref(i8* %ptr) call void @NRT_incref(i8* %ptr) call void @NRT_decref(i8* %ptr) call void @NRT_decref(i8* %ptr) ret void } """ def test_per_bb_2(self): mod, stats = self.check(self.per_bb_ir_2) self.assertEqual(stats.basicblock, 4) # not pruned self.assertIn("call void @NRT_incref(i8* %ptr)", str(mod)) per_bb_ir_3 = r""" define void @main(i8* %ptr, i8* %other) { call void @NRT_incref(i8* %ptr) call void @NRT_incref(i8* %ptr) call void @NRT_decref(i8* %ptr) call void @NRT_decref(i8* %other) ret void } """ def test_per_bb_3(self): mod, stats = self.check(self.per_bb_ir_3) self.assertEqual(stats.basicblock, 2) # not pruned self.assertIn("call void @NRT_decref(i8* %other)", str(mod)) per_bb_ir_4 = r""" ; reordered define void @main(i8* %ptr, i8* %other) { call void @NRT_incref(i8* %ptr) call void @NRT_decref(i8* %ptr) call void @NRT_decref(i8* %ptr) call void @NRT_decref(i8* %other) call void @NRT_incref(i8* %ptr) ret void } """ def test_per_bb_4(self): mod, stats = self.check(self.per_bb_ir_4) self.assertEqual(stats.basicblock, 4) # not pruned self.assertIn("call void @NRT_decref(i8* %other)", str(mod)) class TestDiamond(BaseTestByIR): refprune_bitmask = llvm.RefPruneSubpasses.DIAMOND per_diamond_1 = r""" define void @main(i8* %ptr) { bb_A: call void @NRT_incref(i8* %ptr) br label %bb_B bb_B: call void @NRT_decref(i8* %ptr) ret void } """ def test_per_diamond_1(self): mod, stats = self.check(self.per_diamond_1) self.assertEqual(stats.diamond, 2) per_diamond_2 = r""" define void @main(i8* %ptr, i1 %cond) { bb_A: call void @NRT_incref(i8* %ptr) br i1 %cond, label %bb_B, label %bb_C bb_B: br label %bb_D bb_C: br label %bb_D bb_D: call void @NRT_decref(i8* %ptr) ret void } """ def test_per_diamond_2(self): mod, stats = self.check(self.per_diamond_2) self.assertEqual(stats.diamond, 2) per_diamond_3 = r""" define void @main(i8* %ptr, i1 %cond) { bb_A: call void @NRT_incref(i8* %ptr) br i1 %cond, label %bb_B, label %bb_C bb_B: br label %bb_D bb_C: call void @NRT_decref(i8* %ptr) ; reject because of decref in diamond br label %bb_D bb_D: call void @NRT_decref(i8* %ptr) ret void } """ def test_per_diamond_3(self): mod, stats = self.check(self.per_diamond_3) self.assertEqual(stats.diamond, 0) per_diamond_4 = r""" define void @main(i8* %ptr, i1 %cond) { bb_A: call void @NRT_incref(i8* %ptr) br i1 %cond, label %bb_B, label %bb_C bb_B: call void @NRT_incref(i8* %ptr) ; extra incref will not affect prune br label %bb_D bb_C: br label %bb_D bb_D: call void @NRT_decref(i8* %ptr) ret void } """ def test_per_diamond_4(self): mod, stats = self.check(self.per_diamond_4) self.assertEqual(stats.diamond, 2) per_diamond_5 = r""" define void @main(i8* %ptr, i1 %cond) { bb_A: call void @NRT_incref(i8* %ptr) call void @NRT_incref(i8* %ptr) br i1 %cond, label %bb_B, label %bb_C bb_B: br label %bb_D bb_C: br label %bb_D bb_D: call void @NRT_decref(i8* %ptr) call void @NRT_decref(i8* %ptr) ret void } """ def test_per_diamond_5(self): mod, stats = self.check(self.per_diamond_5) self.assertEqual(stats.diamond, 4) class TestFanout(BaseTestByIR): """More complex cases are tested in TestRefPrunePass """ refprune_bitmask = llvm.RefPruneSubpasses.FANOUT fanout_1 = r""" define void @main(i8* %ptr, i1 %cond) { bb_A: call void @NRT_incref(i8* %ptr) br i1 %cond, label %bb_B, label %bb_C bb_B: call void @NRT_decref(i8* %ptr) ret void bb_C: call void @NRT_decref(i8* %ptr) ret void } """ def test_fanout_1(self): mod, stats = self.check(self.fanout_1) self.assertEqual(stats.fanout, 3) fanout_2 = r""" define void @main(i8* %ptr, i1 %cond, i8** %excinfo) { bb_A: call void @NRT_incref(i8* %ptr) br i1 %cond, label %bb_B, label %bb_C bb_B: call void @NRT_decref(i8* %ptr) ret void bb_C: call void @NRT_decref(i8* %ptr) br label %bb_B ; illegal jump to other decref } """ def test_fanout_2(self): mod, stats = self.check(self.fanout_2) self.assertEqual(stats.fanout, 0) fanout_3 = r""" define void @main(i8* %ptr, i1 %cond) { bb_A: call void @NRT_incref(i8* %ptr) call void @NRT_incref(i8* %ptr) br i1 %cond, label %bb_B, label %bb_C bb_B: call void @NRT_decref(i8* %ptr) call void @NRT_decref(i8* %ptr) call void @NRT_decref(i8* %ptr) ret void bb_C: call void @NRT_decref(i8* %ptr) call void @NRT_decref(i8* %ptr) ret void } """ def test_fanout_3(self): mod, stats = self.check(self.fanout_3) self.assertEqual(stats.fanout, 6) def test_fanout_3_limited(self): # With subgraph limit at 1, it is essentially turning off the fanout # pruner. mod, stats = self.check(self.fanout_3, subgraph_limit=1) self.assertEqual(stats.fanout, 0) class TestFanoutRaise(BaseTestByIR): refprune_bitmask = llvm.RefPruneSubpasses.FANOUT_RAISE fanout_raise_1 = r""" define i32 @main(i8* %ptr, i1 %cond, i8** %excinfo) { bb_A: call void @NRT_incref(i8* %ptr) br i1 %cond, label %bb_B, label %bb_C bb_B: call void @NRT_decref(i8* %ptr) ret i32 0 bb_C: store i8* null, i8** %excinfo, !numba_exception_output !0 ret i32 1 } !0 = !{i1 true} """ def test_fanout_raise_1(self): mod, stats = self.check(self.fanout_raise_1) self.assertEqual(stats.fanout_raise, 2) fanout_raise_2 = r""" define i32 @main(i8* %ptr, i1 %cond, i8** %excinfo) { bb_A: call void @NRT_incref(i8* %ptr) br i1 %cond, label %bb_B, label %bb_C bb_B: call void @NRT_decref(i8* %ptr) ret i32 0 bb_C: store i8* null, i8** %excinfo, !numba_exception_typo !0 ; bad metadata ret i32 1 } !0 = !{i1 true} """ def test_fanout_raise_2(self): # This is ensuring that fanout_raise is not pruning when the metadata # is incorrectly named. mod, stats = self.check(self.fanout_raise_2) self.assertEqual(stats.fanout_raise, 0) fanout_raise_3 = r""" define i32 @main(i8* %ptr, i1 %cond, i8** %excinfo) { bb_A: call void @NRT_incref(i8* %ptr) br i1 %cond, label %bb_B, label %bb_C bb_B: call void @NRT_decref(i8* %ptr) ret i32 0 bb_C: store i8* null, i8** %excinfo, !numba_exception_output !0 ret i32 1 } !0 = !{i32 1} ; ok; use i32 """ def test_fanout_raise_3(self): mod, stats = self.check(self.fanout_raise_3) self.assertEqual(stats.fanout_raise, 2) fanout_raise_4 = r""" define i32 @main(i8* %ptr, i1 %cond, i8** %excinfo) { bb_A: call void @NRT_incref(i8* %ptr) br i1 %cond, label %bb_B, label %bb_C bb_B: ret i32 1 ; BAD; all tails are raising without decref bb_C: ret i32 1 ; BAD; all tails are raising without decref } !0 = !{i1 1} """ def test_fanout_raise_4(self): mod, stats = self.check(self.fanout_raise_4) self.assertEqual(stats.fanout_raise, 0) fanout_raise_5 = r""" define i32 @main(i8* %ptr, i1 %cond, i8** %excinfo) { bb_A: call void @NRT_incref(i8* %ptr) br i1 %cond, label %bb_B, label %bb_C bb_B: call void @NRT_decref(i8* %ptr) br label %common.ret bb_C: store i8* null, i8** %excinfo, !numba_exception_output !0 br label %common.ret common.ret: %common.ret.op = phi i32 [ 0, %bb_B ], [ 1, %bb_C ] ret i32 %common.ret.op } !0 = !{i1 1} """ def test_fanout_raise_5(self): mod, stats = self.check(self.fanout_raise_5) self.assertEqual(stats.fanout_raise, 2) if __name__ == '__main__': unittest.main()