274 lines
9.2 KiB
Python
274 lines
9.2 KiB
Python
import numba
|
|
from numba.tests.support import TestCase, unittest
|
|
from numba.core.registry import cpu_target
|
|
from numba.core.compiler import CompilerBase, Flags
|
|
from numba.core.compiler_machinery import PassManager
|
|
from numba.core import types, ir, bytecode, compiler, ir_utils, registry
|
|
from numba.core.untyped_passes import (ExtractByteCode, TranslateByteCode,
|
|
FixupArgs, IRProcessing,)
|
|
|
|
from numba.core.typed_passes import (NopythonTypeInference,
|
|
type_inference_stage, DeadCodeElimination)
|
|
from numba.experimental import jitclass
|
|
|
|
# global constant for testing find_const
|
|
GLOBAL_B = 11
|
|
|
|
|
|
@jitclass([('val', numba.core.types.List(numba.intp))])
|
|
class Dummy(object):
|
|
def __init__(self, val):
|
|
self.val = val
|
|
|
|
|
|
class TestIrUtils(TestCase):
|
|
"""
|
|
Tests ir handling utility functions like find_callname.
|
|
"""
|
|
|
|
def test_obj_func_match(self):
|
|
"""Test matching of an object method (other than Array see #3449)
|
|
"""
|
|
|
|
def test_func():
|
|
d = Dummy([1])
|
|
d.val.append(2)
|
|
|
|
test_ir = compiler.run_frontend(test_func)
|
|
typingctx = cpu_target.typing_context
|
|
targetctx = cpu_target.target_context
|
|
typing_res = type_inference_stage(
|
|
typingctx, targetctx, test_ir, (), None)
|
|
matched_call = ir_utils.find_callname(
|
|
test_ir, test_ir.blocks[0].body[7].value, typing_res.typemap)
|
|
self.assertTrue(isinstance(matched_call, tuple) and
|
|
len(matched_call) == 2 and
|
|
matched_call[0] == 'append')
|
|
|
|
def test_dead_code_elimination(self):
|
|
|
|
class Tester(CompilerBase):
|
|
|
|
@classmethod
|
|
def mk_pipeline(cls, args, return_type=None, flags=None, locals={},
|
|
library=None, typing_context=None,
|
|
target_context=None):
|
|
if not flags:
|
|
flags = Flags()
|
|
flags.nrt = True
|
|
if typing_context is None:
|
|
typing_context = registry.cpu_target.typing_context
|
|
if target_context is None:
|
|
target_context = registry.cpu_target.target_context
|
|
return cls(typing_context, target_context, library, args,
|
|
return_type, flags, locals)
|
|
|
|
def compile_to_ir(self, func, DCE=False):
|
|
"""
|
|
Compile and return IR
|
|
"""
|
|
func_id = bytecode.FunctionIdentity.from_function(func)
|
|
self.state.func_id = func_id
|
|
ExtractByteCode().run_pass(self.state)
|
|
state = self.state
|
|
|
|
name = "DCE_testing"
|
|
pm = PassManager(name)
|
|
pm.add_pass(TranslateByteCode, "analyzing bytecode")
|
|
pm.add_pass(FixupArgs, "fix up args")
|
|
pm.add_pass(IRProcessing, "processing IR")
|
|
pm.add_pass(NopythonTypeInference, "nopython frontend")
|
|
if DCE is True:
|
|
pm.add_pass(DeadCodeElimination, "DCE after typing")
|
|
pm.finalize()
|
|
pm.run(state)
|
|
return state.func_ir
|
|
|
|
def check_initial_ir(the_ir):
|
|
# dead stuff:
|
|
# a const int value 0xdead
|
|
# an assign of above into to variable `dead`
|
|
# a const int above 0xdeaddead
|
|
# an assign of said int to variable `deaddead`
|
|
# this is 2 statements to remove
|
|
|
|
self.assertEqual(len(the_ir.blocks), 1)
|
|
block = the_ir.blocks[0]
|
|
deads = []
|
|
for x in block.find_insts(ir.Assign):
|
|
if isinstance(getattr(x, 'target', None), ir.Var):
|
|
if 'dead' in getattr(x.target, 'name', ''):
|
|
deads.append(x)
|
|
|
|
self.assertEqual(len(deads), 2)
|
|
for d in deads:
|
|
# check the ir.Const is the definition and the value is expected
|
|
const_val = the_ir.get_definition(d.value)
|
|
self.assertTrue(int('0x%s' % d.target.name, 16),
|
|
const_val.value)
|
|
|
|
return deads
|
|
|
|
def check_dce_ir(the_ir):
|
|
self.assertEqual(len(the_ir.blocks), 1)
|
|
block = the_ir.blocks[0]
|
|
deads = []
|
|
consts = []
|
|
for x in block.find_insts(ir.Assign):
|
|
if isinstance(getattr(x, 'target', None), ir.Var):
|
|
if 'dead' in getattr(x.target, 'name', ''):
|
|
deads.append(x)
|
|
if isinstance(getattr(x, 'value', None), ir.Const):
|
|
consts.append(x)
|
|
self.assertEqual(len(deads), 0)
|
|
|
|
# check the consts to make sure there's no reference to 0xdead or
|
|
# 0xdeaddead
|
|
for x in consts:
|
|
self.assertTrue(x.value.value not in [0xdead, 0xdeaddead])
|
|
|
|
def foo(x):
|
|
y = x + 1
|
|
dead = 0xdead # noqa
|
|
z = y + 2
|
|
deaddead = 0xdeaddead # noqa
|
|
ret = z * z
|
|
return ret
|
|
|
|
test_pipeline = Tester.mk_pipeline((types.intp,))
|
|
no_dce = test_pipeline.compile_to_ir(foo)
|
|
removed = check_initial_ir(no_dce)
|
|
|
|
test_pipeline = Tester.mk_pipeline((types.intp,))
|
|
w_dce = test_pipeline.compile_to_ir(foo, DCE=True)
|
|
check_dce_ir(w_dce)
|
|
|
|
# check that the count of initial - removed = dce
|
|
self.assertEqual(len(no_dce.blocks[0].body) - len(removed),
|
|
len(w_dce.blocks[0].body))
|
|
|
|
def test_find_const_global(self):
|
|
"""
|
|
Test find_const() for values in globals (ir.Global) and freevars
|
|
(ir.FreeVar) that are considered constants for compilation.
|
|
"""
|
|
FREEVAR_C = 12
|
|
|
|
def foo(a):
|
|
b = GLOBAL_B
|
|
c = FREEVAR_C
|
|
return a + b + c
|
|
|
|
f_ir = compiler.run_frontend(foo)
|
|
block = f_ir.blocks[0]
|
|
const_b = None
|
|
const_c = None
|
|
|
|
for inst in block.body:
|
|
if isinstance(inst, ir.Assign) and inst.target.name == 'b':
|
|
const_b = ir_utils.guard(
|
|
ir_utils.find_const, f_ir, inst.target)
|
|
if isinstance(inst, ir.Assign) and inst.target.name == 'c':
|
|
const_c = ir_utils.guard(
|
|
ir_utils.find_const, f_ir, inst.target)
|
|
|
|
self.assertEqual(const_b, GLOBAL_B)
|
|
self.assertEqual(const_c, FREEVAR_C)
|
|
|
|
def test_flatten_labels(self):
|
|
""" tests flatten_labels """
|
|
def foo(a):
|
|
acc = 0
|
|
if a > 3:
|
|
acc += 1
|
|
if a > 19:
|
|
return 53
|
|
elif a < 1000:
|
|
if a >= 12:
|
|
acc += 1
|
|
for x in range(10):
|
|
acc -= 1
|
|
if acc < 2:
|
|
break
|
|
else:
|
|
acc += 7
|
|
else:
|
|
raise ValueError("some string")
|
|
# prevents inline of return on py310
|
|
py310_defeat1 = 1 # noqa
|
|
py310_defeat2 = 2 # noqa
|
|
py310_defeat3 = 3 # noqa
|
|
py310_defeat4 = 4 # noqa
|
|
return acc
|
|
|
|
def bar(a):
|
|
acc = 0
|
|
z = 12
|
|
if a > 3:
|
|
acc += 1
|
|
z += 12
|
|
if a > 19:
|
|
z += 12
|
|
return 53
|
|
elif a < 1000:
|
|
if a >= 12:
|
|
z += 12
|
|
acc += 1
|
|
for x in range(10):
|
|
z += 12
|
|
acc -= 1
|
|
if acc < 2:
|
|
break
|
|
else:
|
|
z += 12
|
|
acc += 7
|
|
else:
|
|
raise ValueError("some string")
|
|
py310_defeat1 = 1 # noqa
|
|
py310_defeat2 = 2 # noqa
|
|
py310_defeat3 = 3 # noqa
|
|
py310_defeat4 = 4 # noqa
|
|
return acc
|
|
|
|
def baz(a):
|
|
acc = 0
|
|
if a > 3:
|
|
acc += 1
|
|
if a > 19:
|
|
return 53
|
|
else: # extra control flow in comparison to foo
|
|
return 55
|
|
elif a < 1000:
|
|
if a >= 12:
|
|
acc += 1
|
|
for x in range(10):
|
|
acc -= 1
|
|
if acc < 2:
|
|
break
|
|
else:
|
|
acc += 7
|
|
else:
|
|
raise ValueError("some string")
|
|
py310_defeat1 = 1 # noqa
|
|
py310_defeat2 = 2 # noqa
|
|
py310_defeat3 = 3 # noqa
|
|
py310_defeat4 = 4 # noqa
|
|
return acc
|
|
|
|
def get_flat_cfg(func):
|
|
func_ir = ir_utils.compile_to_numba_ir(func, dict())
|
|
flat_blocks = ir_utils.flatten_labels(func_ir.blocks)
|
|
self.assertEqual(max(flat_blocks.keys()) + 1, len(func_ir.blocks))
|
|
return ir_utils.compute_cfg_from_blocks(flat_blocks)
|
|
|
|
foo_cfg = get_flat_cfg(foo)
|
|
bar_cfg = get_flat_cfg(bar)
|
|
baz_cfg = get_flat_cfg(baz)
|
|
|
|
self.assertEqual(foo_cfg, bar_cfg)
|
|
self.assertNotEqual(foo_cfg, baz_cfg)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|