527 lines
14 KiB
Python
527 lines
14 KiB
Python
import unittest
|
|
from llvmlite import ir
|
|
from llvmlite import binding as llvm
|
|
from llvmlite.tests import TestCase
|
|
|
|
from . import refprune_proto as proto
|
|
|
|
|
|
def _iterate_cases(generate_test):
|
|
def wrap(fn):
|
|
def wrapped(self):
|
|
return generate_test(self, fn)
|
|
wrapped.__doc__ = f"generated test for {fn.__module__}.{fn.__name__}"
|
|
return wrapped
|
|
|
|
for k, case_fn in proto.__dict__.items():
|
|
if k.startswith('case'):
|
|
yield f'test_{k}', wrap(case_fn)
|
|
|
|
|
|
class TestRefPrunePrototype(TestCase):
|
|
"""
|
|
Test that the prototype is working.
|
|
"""
|
|
def generate_test(self, case_gen):
|
|
nodes, edges, expected = case_gen()
|
|
got = proto.FanoutAlgorithm(nodes, edges).run()
|
|
self.assertEqual(expected, got)
|
|
|
|
# Generate tests
|
|
for name, case in _iterate_cases(generate_test):
|
|
locals()[name] = case
|
|
|
|
|
|
ptr_ty = ir.IntType(8).as_pointer()
|
|
|
|
|
|
class TestRefPrunePass(TestCase):
|
|
"""
|
|
Test that the C++ implementation matches the expected behavior as for
|
|
the prototype.
|
|
|
|
This generates a LLVM module for each test case, runs the pruner and checks
|
|
that the expected results are achieved.
|
|
"""
|
|
|
|
def make_incref(self, m):
|
|
fnty = ir.FunctionType(ir.VoidType(), [ptr_ty])
|
|
return ir.Function(m, fnty, name='NRT_incref')
|
|
|
|
def make_decref(self, m):
|
|
fnty = ir.FunctionType(ir.VoidType(), [ptr_ty])
|
|
return ir.Function(m, fnty, name='NRT_decref')
|
|
|
|
def make_switcher(self, m):
|
|
fnty = ir.FunctionType(ir.IntType(32), ())
|
|
return ir.Function(m, fnty, name='switcher')
|
|
|
|
def make_brancher(self, m):
|
|
fnty = ir.FunctionType(ir.IntType(1), ())
|
|
return ir.Function(m, fnty, name='brancher')
|
|
|
|
def generate_ir(self, nodes, edges):
|
|
# Build LLVM module for the CFG
|
|
m = ir.Module()
|
|
|
|
incref_fn = self.make_incref(m)
|
|
decref_fn = self.make_decref(m)
|
|
switcher_fn = self.make_switcher(m)
|
|
brancher_fn = self.make_brancher(m)
|
|
|
|
fnty = ir.FunctionType(ir.VoidType(), [ptr_ty])
|
|
fn = ir.Function(m, fnty, name='main')
|
|
[ptr] = fn.args
|
|
ptr.name = 'mem'
|
|
# populate the BB nodes
|
|
bbmap = {}
|
|
for bb in edges:
|
|
bbmap[bb] = fn.append_basic_block(bb)
|
|
# populate the BB
|
|
builder = ir.IRBuilder()
|
|
for bb, jump_targets in edges.items():
|
|
builder.position_at_end(bbmap[bb])
|
|
# Insert increfs and decrefs
|
|
for action in nodes[bb]:
|
|
if action == 'incref':
|
|
builder.call(incref_fn, [ptr])
|
|
elif action == 'decref':
|
|
builder.call(decref_fn, [ptr])
|
|
else:
|
|
raise AssertionError('unreachable')
|
|
|
|
# Insert the terminator.
|
|
# Switch base on the number of jump targets.
|
|
n_targets = len(jump_targets)
|
|
if n_targets == 0:
|
|
builder.ret_void()
|
|
elif n_targets == 1:
|
|
[dst] = jump_targets
|
|
builder.branch(bbmap[dst])
|
|
elif n_targets == 2:
|
|
[left, right] = jump_targets
|
|
sel = builder.call(brancher_fn, ())
|
|
builder.cbranch(sel, bbmap[left], bbmap[right])
|
|
elif n_targets > 2:
|
|
sel = builder.call(switcher_fn, ())
|
|
[head, *tail] = jump_targets
|
|
|
|
sw = builder.switch(sel, default=bbmap[head])
|
|
for i, dst in enumerate(tail):
|
|
sw.add_case(sel.type(i), bbmap[dst])
|
|
else:
|
|
raise AssertionError('unreachable')
|
|
|
|
return m
|
|
|
|
def apply_refprune(self, irmod):
|
|
mod = llvm.parse_assembly(str(irmod))
|
|
pm = llvm.ModulePassManager()
|
|
pm.add_refprune_pass()
|
|
pm.run(mod)
|
|
return mod
|
|
|
|
def check(self, mod, expected, nodes):
|
|
# preprocess incref/decref locations
|
|
d = {}
|
|
for k, vs in nodes.items():
|
|
n_incref = vs.count('incref')
|
|
n_decref = vs.count('decref')
|
|
d[k] = {'incref': n_incref, 'decref': n_decref}
|
|
for k, stats in d.items():
|
|
if expected.get(k):
|
|
stats['incref'] -= 1
|
|
for dec_bb in expected[k]:
|
|
d[dec_bb]['decref'] -= 1
|
|
|
|
# find the main function
|
|
for f in mod.functions:
|
|
if f.name == 'main':
|
|
break
|
|
# check each BB
|
|
for bb in f.blocks:
|
|
stats = d[bb.name]
|
|
text = str(bb)
|
|
n_incref = text.count('NRT_incref')
|
|
n_decref = text.count('NRT_decref')
|
|
self.assertEqual(stats['incref'], n_incref, msg=f'BB {bb}')
|
|
self.assertEqual(stats['decref'], n_decref, msg=f'BB {bb}')
|
|
|
|
def generate_test(self, case_gen):
|
|
nodes, edges, expected = case_gen()
|
|
irmod = self.generate_ir(nodes, edges)
|
|
outmod = self.apply_refprune(irmod)
|
|
self.check(outmod, expected, nodes)
|
|
|
|
# Generate tests
|
|
for name, case in _iterate_cases(generate_test):
|
|
locals()[name] = case
|
|
|
|
|
|
class BaseTestByIR(TestCase):
|
|
refprune_bitmask = 0
|
|
|
|
prologue = r"""
|
|
declare void @NRT_incref(i8* %ptr)
|
|
declare void @NRT_decref(i8* %ptr)
|
|
"""
|
|
|
|
def check(self, irmod, subgraph_limit=None):
|
|
mod = llvm.parse_assembly(f"{self.prologue}\n{irmod}")
|
|
pm = llvm.ModulePassManager()
|
|
if subgraph_limit is None:
|
|
pm.add_refprune_pass(self.refprune_bitmask)
|
|
else:
|
|
pm.add_refprune_pass(self.refprune_bitmask,
|
|
subgraph_limit=subgraph_limit)
|
|
before = llvm.dump_refprune_stats()
|
|
pm.run(mod)
|
|
after = llvm.dump_refprune_stats()
|
|
return mod, after - before
|
|
|
|
|
|
class TestPerBB(BaseTestByIR):
|
|
refprune_bitmask = llvm.RefPruneSubpasses.PER_BB
|
|
|
|
per_bb_ir_1 = r"""
|
|
define void @main(i8* %ptr) {
|
|
call void @NRT_incref(i8* %ptr)
|
|
call void @NRT_decref(i8* %ptr)
|
|
ret void
|
|
}
|
|
"""
|
|
|
|
def test_per_bb_1(self):
|
|
mod, stats = self.check(self.per_bb_ir_1)
|
|
self.assertEqual(stats.basicblock, 2)
|
|
|
|
per_bb_ir_2 = r"""
|
|
define void @main(i8* %ptr) {
|
|
call void @NRT_incref(i8* %ptr)
|
|
call void @NRT_incref(i8* %ptr)
|
|
call void @NRT_incref(i8* %ptr)
|
|
call void @NRT_decref(i8* %ptr)
|
|
call void @NRT_decref(i8* %ptr)
|
|
ret void
|
|
}
|
|
"""
|
|
|
|
def test_per_bb_2(self):
|
|
mod, stats = self.check(self.per_bb_ir_2)
|
|
self.assertEqual(stats.basicblock, 4)
|
|
# not pruned
|
|
self.assertIn("call void @NRT_incref(i8* %ptr)", str(mod))
|
|
|
|
per_bb_ir_3 = r"""
|
|
define void @main(i8* %ptr, i8* %other) {
|
|
call void @NRT_incref(i8* %ptr)
|
|
call void @NRT_incref(i8* %ptr)
|
|
call void @NRT_decref(i8* %ptr)
|
|
call void @NRT_decref(i8* %other)
|
|
ret void
|
|
}
|
|
"""
|
|
|
|
def test_per_bb_3(self):
|
|
mod, stats = self.check(self.per_bb_ir_3)
|
|
self.assertEqual(stats.basicblock, 2)
|
|
# not pruned
|
|
self.assertIn("call void @NRT_decref(i8* %other)", str(mod))
|
|
|
|
per_bb_ir_4 = r"""
|
|
; reordered
|
|
define void @main(i8* %ptr, i8* %other) {
|
|
call void @NRT_incref(i8* %ptr)
|
|
call void @NRT_decref(i8* %ptr)
|
|
call void @NRT_decref(i8* %ptr)
|
|
call void @NRT_decref(i8* %other)
|
|
call void @NRT_incref(i8* %ptr)
|
|
ret void
|
|
}
|
|
"""
|
|
|
|
def test_per_bb_4(self):
|
|
mod, stats = self.check(self.per_bb_ir_4)
|
|
self.assertEqual(stats.basicblock, 4)
|
|
# not pruned
|
|
self.assertIn("call void @NRT_decref(i8* %other)", str(mod))
|
|
|
|
|
|
class TestDiamond(BaseTestByIR):
|
|
refprune_bitmask = llvm.RefPruneSubpasses.DIAMOND
|
|
|
|
per_diamond_1 = r"""
|
|
define void @main(i8* %ptr) {
|
|
bb_A:
|
|
call void @NRT_incref(i8* %ptr)
|
|
br label %bb_B
|
|
bb_B:
|
|
call void @NRT_decref(i8* %ptr)
|
|
ret void
|
|
}
|
|
"""
|
|
|
|
def test_per_diamond_1(self):
|
|
mod, stats = self.check(self.per_diamond_1)
|
|
self.assertEqual(stats.diamond, 2)
|
|
|
|
per_diamond_2 = r"""
|
|
define void @main(i8* %ptr, i1 %cond) {
|
|
bb_A:
|
|
call void @NRT_incref(i8* %ptr)
|
|
br i1 %cond, label %bb_B, label %bb_C
|
|
bb_B:
|
|
br label %bb_D
|
|
bb_C:
|
|
br label %bb_D
|
|
bb_D:
|
|
call void @NRT_decref(i8* %ptr)
|
|
ret void
|
|
}
|
|
"""
|
|
|
|
def test_per_diamond_2(self):
|
|
mod, stats = self.check(self.per_diamond_2)
|
|
self.assertEqual(stats.diamond, 2)
|
|
|
|
per_diamond_3 = r"""
|
|
define void @main(i8* %ptr, i1 %cond) {
|
|
bb_A:
|
|
call void @NRT_incref(i8* %ptr)
|
|
br i1 %cond, label %bb_B, label %bb_C
|
|
bb_B:
|
|
br label %bb_D
|
|
bb_C:
|
|
call void @NRT_decref(i8* %ptr) ; reject because of decref in diamond
|
|
br label %bb_D
|
|
bb_D:
|
|
call void @NRT_decref(i8* %ptr)
|
|
ret void
|
|
}
|
|
"""
|
|
|
|
def test_per_diamond_3(self):
|
|
mod, stats = self.check(self.per_diamond_3)
|
|
self.assertEqual(stats.diamond, 0)
|
|
|
|
per_diamond_4 = r"""
|
|
define void @main(i8* %ptr, i1 %cond) {
|
|
bb_A:
|
|
call void @NRT_incref(i8* %ptr)
|
|
br i1 %cond, label %bb_B, label %bb_C
|
|
bb_B:
|
|
call void @NRT_incref(i8* %ptr) ; extra incref will not affect prune
|
|
br label %bb_D
|
|
bb_C:
|
|
br label %bb_D
|
|
bb_D:
|
|
call void @NRT_decref(i8* %ptr)
|
|
ret void
|
|
}
|
|
"""
|
|
|
|
def test_per_diamond_4(self):
|
|
mod, stats = self.check(self.per_diamond_4)
|
|
self.assertEqual(stats.diamond, 2)
|
|
|
|
per_diamond_5 = r"""
|
|
define void @main(i8* %ptr, i1 %cond) {
|
|
bb_A:
|
|
call void @NRT_incref(i8* %ptr)
|
|
call void @NRT_incref(i8* %ptr)
|
|
br i1 %cond, label %bb_B, label %bb_C
|
|
bb_B:
|
|
br label %bb_D
|
|
bb_C:
|
|
br label %bb_D
|
|
bb_D:
|
|
call void @NRT_decref(i8* %ptr)
|
|
call void @NRT_decref(i8* %ptr)
|
|
ret void
|
|
}
|
|
"""
|
|
|
|
def test_per_diamond_5(self):
|
|
mod, stats = self.check(self.per_diamond_5)
|
|
self.assertEqual(stats.diamond, 4)
|
|
|
|
|
|
class TestFanout(BaseTestByIR):
|
|
"""More complex cases are tested in TestRefPrunePass
|
|
"""
|
|
|
|
refprune_bitmask = llvm.RefPruneSubpasses.FANOUT
|
|
|
|
fanout_1 = r"""
|
|
define void @main(i8* %ptr, i1 %cond) {
|
|
bb_A:
|
|
call void @NRT_incref(i8* %ptr)
|
|
br i1 %cond, label %bb_B, label %bb_C
|
|
bb_B:
|
|
call void @NRT_decref(i8* %ptr)
|
|
ret void
|
|
bb_C:
|
|
call void @NRT_decref(i8* %ptr)
|
|
ret void
|
|
}
|
|
"""
|
|
|
|
def test_fanout_1(self):
|
|
mod, stats = self.check(self.fanout_1)
|
|
self.assertEqual(stats.fanout, 3)
|
|
|
|
fanout_2 = r"""
|
|
define void @main(i8* %ptr, i1 %cond, i8** %excinfo) {
|
|
bb_A:
|
|
call void @NRT_incref(i8* %ptr)
|
|
br i1 %cond, label %bb_B, label %bb_C
|
|
bb_B:
|
|
call void @NRT_decref(i8* %ptr)
|
|
ret void
|
|
bb_C:
|
|
call void @NRT_decref(i8* %ptr)
|
|
br label %bb_B ; illegal jump to other decref
|
|
}
|
|
"""
|
|
|
|
def test_fanout_2(self):
|
|
mod, stats = self.check(self.fanout_2)
|
|
self.assertEqual(stats.fanout, 0)
|
|
|
|
fanout_3 = r"""
|
|
define void @main(i8* %ptr, i1 %cond) {
|
|
bb_A:
|
|
call void @NRT_incref(i8* %ptr)
|
|
call void @NRT_incref(i8* %ptr)
|
|
br i1 %cond, label %bb_B, label %bb_C
|
|
bb_B:
|
|
call void @NRT_decref(i8* %ptr)
|
|
call void @NRT_decref(i8* %ptr)
|
|
call void @NRT_decref(i8* %ptr)
|
|
ret void
|
|
bb_C:
|
|
call void @NRT_decref(i8* %ptr)
|
|
call void @NRT_decref(i8* %ptr)
|
|
ret void
|
|
}
|
|
"""
|
|
|
|
def test_fanout_3(self):
|
|
mod, stats = self.check(self.fanout_3)
|
|
self.assertEqual(stats.fanout, 6)
|
|
|
|
def test_fanout_3_limited(self):
|
|
# With subgraph limit at 1, it is essentially turning off the fanout
|
|
# pruner.
|
|
mod, stats = self.check(self.fanout_3, subgraph_limit=1)
|
|
self.assertEqual(stats.fanout, 0)
|
|
|
|
|
|
class TestFanoutRaise(BaseTestByIR):
|
|
refprune_bitmask = llvm.RefPruneSubpasses.FANOUT_RAISE
|
|
|
|
fanout_raise_1 = r"""
|
|
define i32 @main(i8* %ptr, i1 %cond, i8** %excinfo) {
|
|
bb_A:
|
|
call void @NRT_incref(i8* %ptr)
|
|
br i1 %cond, label %bb_B, label %bb_C
|
|
bb_B:
|
|
call void @NRT_decref(i8* %ptr)
|
|
ret i32 0
|
|
bb_C:
|
|
store i8* null, i8** %excinfo, !numba_exception_output !0
|
|
ret i32 1
|
|
}
|
|
!0 = !{i1 true}
|
|
"""
|
|
|
|
def test_fanout_raise_1(self):
|
|
mod, stats = self.check(self.fanout_raise_1)
|
|
self.assertEqual(stats.fanout_raise, 2)
|
|
|
|
fanout_raise_2 = r"""
|
|
define i32 @main(i8* %ptr, i1 %cond, i8** %excinfo) {
|
|
bb_A:
|
|
call void @NRT_incref(i8* %ptr)
|
|
br i1 %cond, label %bb_B, label %bb_C
|
|
bb_B:
|
|
call void @NRT_decref(i8* %ptr)
|
|
ret i32 0
|
|
bb_C:
|
|
store i8* null, i8** %excinfo, !numba_exception_typo !0 ; bad metadata
|
|
ret i32 1
|
|
}
|
|
|
|
!0 = !{i1 true}
|
|
"""
|
|
|
|
def test_fanout_raise_2(self):
|
|
# This is ensuring that fanout_raise is not pruning when the metadata
|
|
# is incorrectly named.
|
|
mod, stats = self.check(self.fanout_raise_2)
|
|
self.assertEqual(stats.fanout_raise, 0)
|
|
|
|
fanout_raise_3 = r"""
|
|
define i32 @main(i8* %ptr, i1 %cond, i8** %excinfo) {
|
|
bb_A:
|
|
call void @NRT_incref(i8* %ptr)
|
|
br i1 %cond, label %bb_B, label %bb_C
|
|
bb_B:
|
|
call void @NRT_decref(i8* %ptr)
|
|
ret i32 0
|
|
bb_C:
|
|
store i8* null, i8** %excinfo, !numba_exception_output !0
|
|
ret i32 1
|
|
}
|
|
|
|
!0 = !{i32 1} ; ok; use i32
|
|
"""
|
|
|
|
def test_fanout_raise_3(self):
|
|
mod, stats = self.check(self.fanout_raise_3)
|
|
self.assertEqual(stats.fanout_raise, 2)
|
|
|
|
fanout_raise_4 = r"""
|
|
define i32 @main(i8* %ptr, i1 %cond, i8** %excinfo) {
|
|
bb_A:
|
|
call void @NRT_incref(i8* %ptr)
|
|
br i1 %cond, label %bb_B, label %bb_C
|
|
bb_B:
|
|
ret i32 1 ; BAD; all tails are raising without decref
|
|
bb_C:
|
|
ret i32 1 ; BAD; all tails are raising without decref
|
|
}
|
|
|
|
!0 = !{i1 1}
|
|
"""
|
|
|
|
def test_fanout_raise_4(self):
|
|
mod, stats = self.check(self.fanout_raise_4)
|
|
self.assertEqual(stats.fanout_raise, 0)
|
|
|
|
fanout_raise_5 = r"""
|
|
define i32 @main(i8* %ptr, i1 %cond, i8** %excinfo) {
|
|
bb_A:
|
|
call void @NRT_incref(i8* %ptr)
|
|
br i1 %cond, label %bb_B, label %bb_C
|
|
bb_B:
|
|
call void @NRT_decref(i8* %ptr)
|
|
br label %common.ret
|
|
bb_C:
|
|
store i8* null, i8** %excinfo, !numba_exception_output !0
|
|
br label %common.ret
|
|
common.ret:
|
|
%common.ret.op = phi i32 [ 0, %bb_B ], [ 1, %bb_C ]
|
|
ret i32 %common.ret.op
|
|
}
|
|
!0 = !{i1 1}
|
|
"""
|
|
|
|
def test_fanout_raise_5(self):
|
|
mod, stats = self.check(self.fanout_raise_5)
|
|
self.assertEqual(stats.fanout_raise, 2)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|