ai-content-maker/.venv/Lib/site-packages/numba/tests/test_inlining.py

280 lines
11 KiB
Python

import re
import numpy as np
from numba.tests.support import (TestCase, override_config, captured_stdout,
skip_parfors_unsupported)
from numba import jit, njit
from numba.core import types, ir, postproc, compiler
from numba.core.ir_utils import (guard, find_callname, find_const,
get_definition, simplify_CFG)
from numba.core.registry import CPUDispatcher
from numba.core.inline_closurecall import inline_closure_call
from numba.core.untyped_passes import (ExtractByteCode, TranslateByteCode, FixupArgs,
IRProcessing, DeadBranchPrune,
RewriteSemanticConstants, GenericRewrites,
WithLifting, PreserveIR, InlineClosureLikes)
from numba.core.typed_passes import (NopythonTypeInference, AnnotateTypes,
NopythonRewrites, PreParforPass, ParforPass,
DumpParforDiagnostics, NativeLowering,
NativeParforLowering, IRLegalization,
NoPythonBackend, NativeLowering,
ParforFusionPass, ParforPreLoweringPass)
from numba.core.compiler_machinery import FunctionPass, PassManager, register_pass
import unittest
@register_pass(analysis_only=False, mutates_CFG=True)
class InlineTestPass(FunctionPass):
_name = "inline_test_pass"
def __init__(self):
FunctionPass.__init__(self)
def run_pass(self, state):
# assuming the function has one block with one call inside
assert len(state.func_ir.blocks) == 1
block = list(state.func_ir.blocks.values())[0]
for i, stmt in enumerate(block.body):
if guard(find_callname,state.func_ir, stmt.value) is not None:
inline_closure_call(state.func_ir, {}, block, i, lambda: None,
state.typingctx, state.targetctx, (),
state.typemap, state.calltypes)
break
# also fix up the IR
post_proc = postproc.PostProcessor(state.func_ir)
post_proc.run()
post_proc.remove_dels()
return True
def gen_pipeline(state, test_pass):
name = 'inline_test'
pm = PassManager(name)
pm.add_pass(TranslateByteCode, "analyzing bytecode")
pm.add_pass(FixupArgs, "fix up args")
pm.add_pass(IRProcessing, "processing IR")
pm.add_pass(WithLifting, "Handle with contexts")
# pre typing
if not state.flags.no_rewrites:
pm.add_pass(GenericRewrites, "nopython rewrites")
pm.add_pass(RewriteSemanticConstants, "rewrite semantic constants")
pm.add_pass(DeadBranchPrune, "dead branch pruning")
pm.add_pass(InlineClosureLikes,
"inline calls to locally defined closures")
# typing
pm.add_pass(NopythonTypeInference, "nopython frontend")
if state.flags.auto_parallel.enabled:
pm.add_pass(PreParforPass, "Preprocessing for parfors")
if not state.flags.no_rewrites:
pm.add_pass(NopythonRewrites, "nopython rewrites")
if state.flags.auto_parallel.enabled:
pm.add_pass(ParforPass, "convert to parfors")
pm.add_pass(ParforFusionPass, "fuse parfors")
pm.add_pass(ParforPreLoweringPass, "parfor prelowering")
pm.add_pass(test_pass, "inline test")
# legalise
pm.add_pass(IRLegalization, "ensure IR is legal prior to lowering")
pm.add_pass(AnnotateTypes, "annotate types")
pm.add_pass(PreserveIR, "preserve IR")
# lower
if state.flags.auto_parallel.enabled:
pm.add_pass(NativeParforLowering, "native parfor lowering")
else:
pm.add_pass(NativeLowering, "native lowering")
pm.add_pass(NoPythonBackend, "nopython mode backend")
pm.add_pass(DumpParforDiagnostics, "dump parfor diagnostics")
return pm
class InlineTestPipeline(compiler.CompilerBase):
"""compiler pipeline for testing inlining after optimization
"""
def define_pipelines(self):
pm = gen_pipeline(self.state, InlineTestPass)
pm.finalize()
return [pm]
class TestInlining(TestCase):
"""
Check that jitted inner functions are inlined into outer functions,
in nopython mode.
Note that not all inner functions are guaranteed to be inlined.
We just trust LLVM's inlining heuristics.
"""
def make_pattern(self, fullname):
"""
Make regexpr to match mangled name
"""
parts = fullname.split('.')
return r'_ZN?' + r''.join([r'\d+{}'.format(p) for p in parts])
def assert_has_pattern(self, fullname, text):
pat = self.make_pattern(fullname)
self.assertIsNotNone(re.search(pat, text),
msg='expected {}'.format(pat))
def assert_not_has_pattern(self, fullname, text):
pat = self.make_pattern(fullname)
self.assertIsNone(re.search(pat, text),
msg='unexpected {}'.format(pat))
def test_inner_function(self):
from numba.tests.inlining_usecases import outer_simple, \
__name__ as prefix
with override_config('DUMP_ASSEMBLY', True):
with captured_stdout() as out:
cfunc = jit((types.int32,), nopython=True)(outer_simple)
self.assertPreciseEqual(cfunc(1), 4)
# Check the inner function was elided from the output (which also
# guarantees it was inlined into the outer function).
asm = out.getvalue()
self.assert_has_pattern('%s.outer_simple' % prefix, asm)
self.assert_not_has_pattern('%s.inner' % prefix, asm)
def test_multiple_inner_functions(self):
from numba.tests.inlining_usecases import outer_multiple, \
__name__ as prefix
# Same with multiple inner functions, and multiple calls to
# the same inner function (inner()). This checks that linking in
# the same library/module twice doesn't produce linker errors.
with override_config('DUMP_ASSEMBLY', True):
with captured_stdout() as out:
cfunc = jit((types.int32,), nopython=True)(outer_multiple)
self.assertPreciseEqual(cfunc(1), 6)
asm = out.getvalue()
self.assert_has_pattern('%s.outer_multiple' % prefix, asm)
self.assert_not_has_pattern('%s.more' % prefix, asm)
self.assert_not_has_pattern('%s.inner' % prefix, asm)
@skip_parfors_unsupported
def test_inline_call_after_parfor(self):
from numba.tests.inlining_usecases import __dummy__
# replace the call to make sure inlining doesn't cause label conflict
# with parfor body
def test_impl(A):
__dummy__()
return A.sum()
j_func = njit(parallel=True, pipeline_class=InlineTestPipeline)(
test_impl)
A = np.arange(10)
self.assertEqual(test_impl(A), j_func(A))
@skip_parfors_unsupported
def test_inline_update_target_def(self):
def test_impl(a):
if a == 1:
b = 2
else:
b = 3
return b
func_ir = compiler.run_frontend(test_impl)
blocks = list(func_ir.blocks.values())
for block in blocks:
for i, stmt in enumerate(block.body):
# match b = 2 and replace with lambda: 2
if (isinstance(stmt, ir.Assign) and isinstance(stmt.value, ir.Var)
and guard(find_const, func_ir, stmt.value) == 2):
# replace expr with a dummy call
func_ir._definitions[stmt.target.name].remove(stmt.value)
stmt.value = ir.Expr.call(ir.Var(block.scope, "myvar", loc=stmt.loc), (), (), stmt.loc)
func_ir._definitions[stmt.target.name].append(stmt.value)
#func = g.py_func#
inline_closure_call(func_ir, {}, block, i, lambda: 2)
break
self.assertEqual(len(func_ir._definitions['b']), 2)
@skip_parfors_unsupported
def test_inline_var_dict_ret(self):
# make sure inline_closure_call returns the variable replacement dict
# and it contains the original variable name used in locals
@njit(locals={'b': types.float64})
def g(a):
b = a + 1
return b
def test_impl():
return g(1)
func_ir = compiler.run_frontend(test_impl)
blocks = list(func_ir.blocks.values())
for block in blocks:
for i, stmt in enumerate(block.body):
if (isinstance(stmt, ir.Assign)
and isinstance(stmt.value, ir.Expr)
and stmt.value.op == 'call'):
func_def = guard(get_definition, func_ir, stmt.value.func)
if (isinstance(func_def, (ir.Global, ir.FreeVar))
and isinstance(func_def.value, CPUDispatcher)):
py_func = func_def.value.py_func
_, var_map = inline_closure_call(
func_ir, py_func.__globals__, block, i, py_func)
break
self.assertTrue('b' in var_map)
@skip_parfors_unsupported
def test_inline_call_branch_pruning(self):
# branch pruning pass should run properly in inlining to enable
# functions with type checks
@njit
def foo(A=None):
if A is None:
return 2
else:
return A
def test_impl(A=None):
return foo(A)
@register_pass(analysis_only=False, mutates_CFG=True)
class PruningInlineTestPass(FunctionPass):
_name = "pruning_inline_test_pass"
def __init__(self):
FunctionPass.__init__(self)
def run_pass(self, state):
# assuming the function has one block with one call inside
assert len(state.func_ir.blocks) == 1
block = list(state.func_ir.blocks.values())[0]
for i, stmt in enumerate(block.body):
if (guard(find_callname, state.func_ir, stmt.value)
is not None):
inline_closure_call(state.func_ir, {}, block, i,
foo.py_func, state.typingctx, state.targetctx,
(state.typemap[stmt.value.args[0].name],),
state.typemap, state.calltypes)
break
return True
class InlineTestPipelinePrune(compiler.CompilerBase):
def define_pipelines(self):
pm = gen_pipeline(self.state, PruningInlineTestPass)
pm.finalize()
return [pm]
# make sure inline_closure_call runs in full pipeline
j_func = njit(pipeline_class=InlineTestPipelinePrune)(test_impl)
A = 3
self.assertEqual(test_impl(A), j_func(A))
self.assertEqual(test_impl(), j_func())
# make sure IR doesn't have branches
fir = j_func.overloads[(types.Omitted(None),)].metadata['preserved_ir']
fir.blocks = simplify_CFG(fir.blocks)
self.assertEqual(len(fir.blocks), 1)
if __name__ == '__main__':
unittest.main()