# Tests numba.analysis functions import collections import types as pytypes import numpy as np from numba.core.compiler import run_frontend, Flags, StateDict from numba import jit, njit, literal_unroll from numba.core import types, errors, ir, rewrites, ir_utils, utils, cpu from numba.core import postproc from numba.core.inline_closurecall import InlineClosureCallPass from numba.tests.support import (TestCase, MemoryLeakMixin, SerialMixin, IRPreservingTestPipeline) from numba.core.analysis import dead_branch_prune, rewrite_semantic_constants from numba.core.untyped_passes import (ReconstructSSA, TranslateByteCode, IRProcessing, DeadBranchPrune, PreserveIR) from numba.core.compiler import DefaultPassBuilder, CompilerBase, PassManager _GLOBAL = 123 enable_pyobj_flags = Flags() enable_pyobj_flags.enable_pyobject = True def compile_to_ir(func): func_ir = run_frontend(func) state = StateDict() state.func_ir = func_ir state.typemap = None state.calltypes = None # Transform to SSA ReconstructSSA().run_pass(state) # call this to get print etc rewrites rewrites.rewrite_registry.apply('before-inference', state) return func_ir class TestBranchPruneBase(MemoryLeakMixin, TestCase): """ Tests branch pruning """ _DEBUG = False # find *all* branches def find_branches(self, the_ir): branches = [] for blk in the_ir.blocks.values(): tmp = [_ for _ in blk.find_insts(cls=ir.Branch)] branches.extend(tmp) return branches def assert_prune(self, func, args_tys, prune, *args, **kwargs): # This checks that the expected pruned branches have indeed been pruned. # func is a python function to assess # args_tys is the numba types arguments tuple # prune arg is a list, one entry per branch. The value in the entry is # encoded as follows: # True: using constant inference only, the True branch will be pruned # False: using constant inference only, the False branch will be pruned # None: under no circumstances should this branch be pruned # *args: the argument instances to pass to the function to check # execution is still valid post transform # **kwargs: # - flags: args to pass to `jit` default is `nopython=True`, # e.g. permits use of e.g. object mode. func_ir = compile_to_ir(func) before = func_ir.copy() if self._DEBUG: print("=" * 80) print("before inline") func_ir.dump() # run closure inlining to ensure that nonlocals in closures are visible inline_pass = InlineClosureCallPass(func_ir, cpu.ParallelOptions(False),) inline_pass.run() # Remove all Dels, and re-run postproc post_proc = postproc.PostProcessor(func_ir) post_proc.run() rewrite_semantic_constants(func_ir, args_tys) if self._DEBUG: print("=" * 80) print("before prune") func_ir.dump() dead_branch_prune(func_ir, args_tys) after = func_ir if self._DEBUG: print("after prune") func_ir.dump() before_branches = self.find_branches(before) self.assertEqual(len(before_branches), len(prune)) # what is expected to be pruned expect_removed = [] for idx, prune in enumerate(prune): branch = before_branches[idx] if prune is True: expect_removed.append(branch.truebr) elif prune is False: expect_removed.append(branch.falsebr) elif prune is None: pass # nothing should be removed! elif prune == 'both': expect_removed.append(branch.falsebr) expect_removed.append(branch.truebr) else: assert 0, "unreachable" # compare labels original_labels = set([_ for _ in before.blocks.keys()]) new_labels = set([_ for _ in after.blocks.keys()]) # assert that the new labels are precisely the original less the # expected pruned labels try: self.assertEqual(new_labels, original_labels - set(expect_removed)) except AssertionError as e: print("new_labels", sorted(new_labels)) print("original_labels", sorted(original_labels)) print("expect_removed", sorted(expect_removed)) raise e supplied_flags = kwargs.pop('flags', {'nopython': True}) # NOTE: original testing used `compile_isolated` hence use of `cres`. cres = jit(args_tys, **supplied_flags)(func).overloads[args_tys] if args is None: res = cres.entry_point() expected = func() else: res = cres.entry_point(*args) expected = func(*args) self.assertEqual(res, expected) class TestBranchPrune(TestBranchPruneBase, SerialMixin): def test_single_if(self): def impl(x): if 1 == 0: return 3.14159 self.assert_prune(impl, (types.NoneType('none'),), [True], None) def impl(x): if 1 == 1: return 3.14159 self.assert_prune(impl, (types.NoneType('none'),), [False], None) def impl(x): if x is None: return 3.14159 self.assert_prune(impl, (types.NoneType('none'),), [False], None) self.assert_prune(impl, (types.IntegerLiteral(10),), [True], 10) def impl(x): if x == 10: return 3.14159 self.assert_prune(impl, (types.NoneType('none'),), [True], None) self.assert_prune(impl, (types.IntegerLiteral(10),), [None], 10) def impl(x): if x == 10: z = 3.14159 # noqa: F841 # no effect self.assert_prune(impl, (types.NoneType('none'),), [True], None) self.assert_prune(impl, (types.IntegerLiteral(10),), [None], 10) def impl(x): z = None y = z if x == y: return 100 self.assert_prune(impl, (types.NoneType('none'),), [False], None) self.assert_prune(impl, (types.IntegerLiteral(10),), [True], 10) def test_single_if_else(self): def impl(x): if x is None: return 3.14159 else: return 1.61803 self.assert_prune(impl, (types.NoneType('none'),), [False], None) self.assert_prune(impl, (types.IntegerLiteral(10),), [True], 10) def test_single_if_const_val(self): def impl(x): if x == 100: return 3.14159 self.assert_prune(impl, (types.NoneType('none'),), [True], None) self.assert_prune(impl, (types.IntegerLiteral(100),), [None], 100) def impl(x): # switch the condition order if 100 == x: return 3.14159 self.assert_prune(impl, (types.NoneType('none'),), [True], None) self.assert_prune(impl, (types.IntegerLiteral(100),), [None], 100) def test_single_if_else_two_const_val(self): def impl(x, y): if x == y: return 3.14159 else: return 1.61803 self.assert_prune(impl, (types.IntegerLiteral(100),) * 2, [None], 100, 100) self.assert_prune(impl, (types.NoneType('none'),) * 2, [False], None, None) self.assert_prune(impl, (types.IntegerLiteral(100), types.NoneType('none'),), [True], 100, None) self.assert_prune(impl, (types.IntegerLiteral(100), types.IntegerLiteral(1000)), [None], 100, 1000) def test_single_if_else_w_following_undetermined(self): def impl(x): x_is_none_work = False if x is None: x_is_none_work = True else: dead = 7 # noqa: F841 # no effect if x_is_none_work: y = 10 else: y = -3 return y self.assert_prune(impl, (types.NoneType('none'),), [False, None], None) self.assert_prune(impl, (types.IntegerLiteral(10),), [True, None], 10) def impl(x): x_is_none_work = False if x is None: x_is_none_work = True else: pass if x_is_none_work: y = 10 else: y = -3 return y if utils.PYVERSION >= (3, 10): # Python 3.10 creates a block with a NOP in it for the `pass` which # means it gets pruned. self.assert_prune(impl, (types.NoneType('none'),), [False, None], None) else: self.assert_prune(impl, (types.NoneType('none'),), [None, None], None) self.assert_prune(impl, (types.IntegerLiteral(10),), [True, None], 10) def test_double_if_else_rt_const(self): def impl(x): one_hundred = 100 x_is_none_work = 4 if x is None: x_is_none_work = 100 else: dead = 7 # noqa: F841 # no effect if x_is_none_work == one_hundred: y = 10 else: y = -3 return y, x_is_none_work self.assert_prune(impl, (types.NoneType('none'),), [False, None], None) self.assert_prune(impl, (types.IntegerLiteral(10),), [True, None], 10) def test_double_if_else_non_literal_const(self): def impl(x): one_hundred = 100 if x == one_hundred: y = 3.14159 else: y = 1.61803 return y # no prune as compilation specialization on literal value not permitted self.assert_prune(impl, (types.IntegerLiteral(10),), [None], 10) self.assert_prune(impl, (types.IntegerLiteral(100),), [None], 100) def test_single_two_branches_same_cond(self): def impl(x): if x is None: y = 10 else: y = 40 if x is not None: z = 100 else: z = 400 return z, y self.assert_prune(impl, (types.NoneType('none'),), [False, True], None) self.assert_prune(impl, (types.IntegerLiteral(10),), [True, False], 10) def test_cond_is_kwarg_none(self): def impl(x=None): if x is None: y = 10 else: y = 40 if x is not None: z = 100 else: z = 400 return z, y self.assert_prune(impl, (types.Omitted(None),), [False, True], None) self.assert_prune(impl, (types.NoneType('none'),), [False, True], None) self.assert_prune(impl, (types.IntegerLiteral(10),), [True, False], 10) def test_cond_is_kwarg_value(self): def impl(x=1000): if x == 1000: y = 10 else: y = 40 if x != 1000: z = 100 else: z = 400 return z, y self.assert_prune(impl, (types.Omitted(1000),), [None, None], 1000) self.assert_prune(impl, (types.IntegerLiteral(1000),), [None, None], 1000) self.assert_prune(impl, (types.IntegerLiteral(0),), [None, None], 0) self.assert_prune(impl, (types.NoneType('none'),), [True, False], None) def test_cond_rewrite_is_correct(self): # this checks that when a condition is replaced, it is replace by a # true/false bit that correctly represents the evaluated condition def fn(x): if x is None: return 10 return 12 def check(func, arg_tys, bit_val): func_ir = compile_to_ir(func) # check there is 1 branch before_branches = self.find_branches(func_ir) self.assertEqual(len(before_branches), 1) # check the condition in the branch is a binop pred_var = before_branches[0].cond pred_defn = ir_utils.get_definition(func_ir, pred_var) self.assertEqual(pred_defn.op, 'call') condition_var = pred_defn.args[0] condition_op = ir_utils.get_definition(func_ir, condition_var) self.assertEqual(condition_op.op, 'binop') # do the prune, this should kill the dead branch and rewrite the #'condition to a true/false const bit if self._DEBUG: print("=" * 80) print("before prune") func_ir.dump() dead_branch_prune(func_ir, arg_tys) if self._DEBUG: print("=" * 80) print("after prune") func_ir.dump() # after mutation, the condition should be a const value `bit_val` new_condition_defn = ir_utils.get_definition(func_ir, condition_var) self.assertTrue(isinstance(new_condition_defn, ir.Const)) self.assertEqual(new_condition_defn.value, bit_val) check(fn, (types.NoneType('none'),), 1) check(fn, (types.IntegerLiteral(10),), 0) def test_global_bake_in(self): def impl(x): if _GLOBAL == 123: return x else: return x + 10 self.assert_prune(impl, (types.IntegerLiteral(1),), [False], 1) global _GLOBAL tmp = _GLOBAL try: _GLOBAL = 5 def impl(x): if _GLOBAL == 123: return x else: return x + 10 self.assert_prune(impl, (types.IntegerLiteral(1),), [True], 1) finally: _GLOBAL = tmp def test_freevar_bake_in(self): _FREEVAR = 123 def impl(x): if _FREEVAR == 123: return x else: return x + 10 self.assert_prune(impl, (types.IntegerLiteral(1),), [False], 1) _FREEVAR = 12 def impl(x): if _FREEVAR == 123: return x else: return x + 10 self.assert_prune(impl, (types.IntegerLiteral(1),), [True], 1) def test_redefined_variables_are_not_considered_in_prune(self): # see issue #4163, checks that if a variable that is an argument is # redefined in the user code it is not considered const def impl(array, a=None): if a is None: a = 0 if a < 0: return 10 return 30 self.assert_prune(impl, (types.Array(types.float64, 2, 'C'), types.NoneType('none'),), [None, None], np.zeros((2, 3)), None) def test_comparison_operators(self): # see issue #4163, checks that a variable that is an argument and has # value None survives TypeError from invalid comparison which should be # dead def impl(array, a=None): x = 0 if a is None: return 10 # dynamic exec would return here # static analysis requires that this is executed with a=None, # hence TypeError if a < 0: return 20 return x self.assert_prune(impl, (types.Array(types.float64, 2, 'C'), types.NoneType('none'),), [False, 'both'], np.zeros((2, 3)), None) self.assert_prune(impl, (types.Array(types.float64, 2, 'C'), types.float64,), [None, None], np.zeros((2, 3)), 12.) def test_redefinition_analysis_same_block(self): # checks that a redefinition in a block with prunable potential doesn't # break def impl(array, x, a=None): b = 2 if x < 4: b = 12 if a is None: # known true a = 7 # live else: b = 15 # dead if a < 0: # valid as a result of the redefinition of 'a' return 10 return 30 + b + a self.assert_prune(impl, (types.Array(types.float64, 2, 'C'), types.float64, types.NoneType('none'),), [None, False, None], np.zeros((2, 3)), 1., None) def test_redefinition_analysis_different_block_can_exec(self): # checks that a redefinition in a block that may be executed prevents # pruning def impl(array, x, a=None): b = 0 if x > 5: a = 11 # a redefined, cannot tell statically if this will exec if x < 4: b = 12 if a is None: # cannot prune, cannot determine if re-defn occurred b += 5 else: b += 7 if a < 0: return 10 return 30 + b self.assert_prune(impl, (types.Array(types.float64, 2, 'C'), types.float64, types.NoneType('none'),), [None, None, None, None], np.zeros((2, 3)), 1., None) def test_redefinition_analysis_different_block_cannot_exec(self): # checks that a redefinition in a block guarded by something that # has prune potential def impl(array, x=None, a=None): b = 0 if x is not None: a = 11 if a is None: b += 5 else: b += 7 return 30 + b self.assert_prune(impl, (types.Array(types.float64, 2, 'C'), types.NoneType('none'), types.NoneType('none')), [True, None], np.zeros((2, 3)), None, None) self.assert_prune(impl, (types.Array(types.float64, 2, 'C'), types.NoneType('none'), types.float64), [True, None], np.zeros((2, 3)), None, 1.2) self.assert_prune(impl, (types.Array(types.float64, 2, 'C'), types.float64, types.NoneType('none')), [None, None], np.zeros((2, 3)), 1.2, None) def test_closure_and_nonlocal_can_prune(self): # Closures must be inlined ahead of branch pruning in case nonlocal # is used. See issue #6585. def impl(): x = 1000 def closure(): nonlocal x x = 0 closure() if x == 0: return True else: return False self.assert_prune(impl, (), [False,],) def test_closure_and_nonlocal_cannot_prune(self): # Closures must be inlined ahead of branch pruning in case nonlocal # is used. See issue #6585. def impl(n): x = 1000 def closure(t): nonlocal x x = t closure(n) if x == 0: return True else: return False self.assert_prune(impl, (types.int64,), [None,], 1) class TestBranchPrunePredicates(TestBranchPruneBase, SerialMixin): # Really important thing to remember... the branch on predicates end up as # POP_JUMP_IF_ and the targets are backwards compared to normal, i.e. # the true condition is far jump and the false the near i.e. `if x` would # end up in Numba IR as e.g. `branch x 10, 6`. _TRUTHY = (1, "String", True, 7.4, 3j) _FALSEY = (0, "", False, 0.0, 0j, None) def _literal_const_sample_generator(self, pyfunc, consts): """ This takes a python function, pyfunc, and manipulates its co_const __code__ member to create a new function with different co_consts as supplied in argument consts. consts is a dict {index: value} of co_const tuple index to constant value used to update a pyfunc clone's co_const. """ pyfunc_code = pyfunc.__code__ # translate consts spec to update the constants co_consts = {k: v for k, v in enumerate(pyfunc_code.co_consts)} for k, v in consts.items(): co_consts[k] = v new_consts = tuple([v for _, v in sorted(co_consts.items())]) # create code object with mutation new_code = pyfunc_code.replace(co_consts=new_consts) # get function return pytypes.FunctionType(new_code, globals()) def test_literal_const_code_gen(self): def impl(x): _CONST1 = "PLACEHOLDER1" if _CONST1: return 3.14159 else: _CONST2 = "PLACEHOLDER2" return _CONST2 + 4 new = self._literal_const_sample_generator(impl, {1:0, 3:20}) iconst = impl.__code__.co_consts nconst = new.__code__.co_consts self.assertEqual(iconst, (None, "PLACEHOLDER1", 3.14159, "PLACEHOLDER2", 4)) self.assertEqual(nconst, (None, 0, 3.14159, 20, 4)) self.assertEqual(impl(None), 3.14159) self.assertEqual(new(None), 24) def test_single_if_const(self): def impl(x): _CONST1 = "PLACEHOLDER1" if _CONST1: return 3.14159 for c_inp, prune in (self._TRUTHY, False), (self._FALSEY, True): for const in c_inp: func = self._literal_const_sample_generator(impl, {1: const}) self.assert_prune(func, (types.NoneType('none'),), [prune], None) def test_single_if_negate_const(self): def impl(x): _CONST1 = "PLACEHOLDER1" if not _CONST1: return 3.14159 for c_inp, prune in (self._TRUTHY, False), (self._FALSEY, True): for const in c_inp: func = self._literal_const_sample_generator(impl, {1: const}) self.assert_prune(func, (types.NoneType('none'),), [prune], None) def test_single_if_else_const(self): def impl(x): _CONST1 = "PLACEHOLDER1" if _CONST1: return 3.14159 else: return 1.61803 for c_inp, prune in (self._TRUTHY, False), (self._FALSEY, True): for const in c_inp: func = self._literal_const_sample_generator(impl, {1: const}) self.assert_prune(func, (types.NoneType('none'),), [prune], None) def test_single_if_else_negate_const(self): def impl(x): _CONST1 = "PLACEHOLDER1" if not _CONST1: return 3.14159 else: return 1.61803 for c_inp, prune in (self._TRUTHY, False), (self._FALSEY, True): for const in c_inp: func = self._literal_const_sample_generator(impl, {1: const}) self.assert_prune(func, (types.NoneType('none'),), [prune], None) def test_single_if_freevar(self): for c_inp, prune in (self._TRUTHY, False), (self._FALSEY, True): for const in c_inp: def func(x): if const: return 3.14159, const self.assert_prune(func, (types.NoneType('none'),), [prune], None) def test_single_if_negate_freevar(self): for c_inp, prune in (self._TRUTHY, False), (self._FALSEY, True): for const in c_inp: def func(x): if not const: return 3.14159, const self.assert_prune(func, (types.NoneType('none'),), [prune], None) def test_single_if_else_freevar(self): for c_inp, prune in (self._TRUTHY, False), (self._FALSEY, True): for const in c_inp: def func(x): if const: return 3.14159, const else: return 1.61803, const self.assert_prune(func, (types.NoneType('none'),), [prune], None) def test_single_if_else_negate_freevar(self): for c_inp, prune in (self._TRUTHY, False), (self._FALSEY, True): for const in c_inp: def func(x): if not const: return 3.14159, const else: return 1.61803, const self.assert_prune(func, (types.NoneType('none'),), [prune], None) # globals in this section have absurd names after their test usecase names # so as to prevent collisions and permit tests to run in parallel def test_single_if_global(self): global c_test_single_if_global for c_inp, prune in (self._TRUTHY, False), (self._FALSEY, True): for c in c_inp: c_test_single_if_global = c def func(x): if c_test_single_if_global: return 3.14159, c_test_single_if_global self.assert_prune(func, (types.NoneType('none'),), [prune], None) def test_single_if_negate_global(self): global c_test_single_if_negate_global for c_inp, prune in (self._TRUTHY, False), (self._FALSEY, True): for c in c_inp: c_test_single_if_negate_global = c def func(x): if c_test_single_if_negate_global: return 3.14159, c_test_single_if_negate_global self.assert_prune(func, (types.NoneType('none'),), [prune], None) def test_single_if_else_global(self): global c_test_single_if_else_global for c_inp, prune in (self._TRUTHY, False), (self._FALSEY, True): for c in c_inp: c_test_single_if_else_global = c def func(x): if c_test_single_if_else_global: return 3.14159, c_test_single_if_else_global else: return 1.61803, c_test_single_if_else_global self.assert_prune(func, (types.NoneType('none'),), [prune], None) def test_single_if_else_negate_global(self): global c_test_single_if_else_negate_global for c_inp, prune in (self._TRUTHY, False), (self._FALSEY, True): for c in c_inp: c_test_single_if_else_negate_global = c def func(x): if not c_test_single_if_else_negate_global: return 3.14159, c_test_single_if_else_negate_global else: return 1.61803, c_test_single_if_else_negate_global self.assert_prune(func, (types.NoneType('none'),), [prune], None) def test_issue_5618(self): @njit def foo(): values = np.zeros(1) tmp = 666 if tmp: values[0] = tmp return values self.assertPreciseEqual(foo.py_func()[0], 666.) self.assertPreciseEqual(foo()[0], 666.) class TestBranchPruneSSA(MemoryLeakMixin, TestCase): # Tests SSA rewiring of phi nodes after branch pruning. class SSAPrunerCompiler(CompilerBase): def define_pipelines(self): # This is a simple pipeline that does branch pruning on IR in SSA # form, then types and lowers as per the standard nopython pipeline. pm = PassManager("testing pm") pm.add_pass(TranslateByteCode, "analyzing bytecode") pm.add_pass(IRProcessing, "processing IR") # SSA early pm.add_pass(ReconstructSSA, "ssa") pm.add_pass(DeadBranchPrune, "dead branch pruning") # type and then lower as usual pm.add_pass(PreserveIR, "preserves the IR as metadata") dpb = DefaultPassBuilder typed_passes = dpb.define_typed_pipeline(self.state) pm.passes.extend(typed_passes.passes) lowering_passes = dpb.define_nopython_lowering_pipeline(self.state) pm.passes.extend(lowering_passes.passes) pm.finalize() return [pm] def test_ssa_update_phi(self): # This checks that dead branch pruning is rewiring phi nodes correctly # after a block containing an incoming for a phi is removed. @njit(pipeline_class=self.SSAPrunerCompiler) def impl(p=None, q=None): z = 1 r = False if p is None: r = True # live if r and q is not None: z = 20 # dead # one of the incoming blocks for z is dead, the phi needs an update # were this not done, it would refer to variables that do not exist # and result in a lowering error. return z, r self.assertPreciseEqual(impl(), impl.py_func()) def test_ssa_replace_phi(self): # This checks that when a phi only has one incoming, because the other # has been pruned, that a direct assignment is used instead. @njit(pipeline_class=self.SSAPrunerCompiler) def impl(p=None): z = 0 if p is None: z = 10 else: z = 20 return z self.assertPreciseEqual(impl(), impl.py_func()) func_ir = impl.overloads[impl.signatures[0]].metadata['preserved_ir'] # check the func_ir, make sure there's no phi nodes for blk in func_ir.blocks.values(): self.assertFalse([*blk.find_exprs('phi')]) class TestBranchPrunePostSemanticConstRewrites(TestBranchPruneBase): # Tests that semantic constants rewriting works by virtue of branch pruning def test_array_ndim_attr(self): def impl(array): if array.ndim == 2: if array.shape[1] == 2: return 1 else: return 10 self.assert_prune(impl, (types.Array(types.float64, 2, 'C'),), [False, None], np.zeros((2, 3))) self.assert_prune(impl, (types.Array(types.float64, 1, 'C'),), [True, 'both'], np.zeros((2,))) def test_tuple_len(self): def impl(tup): if len(tup) == 3: if tup[2] == 2: return 1 else: return 0 self.assert_prune(impl, (types.UniTuple(types.int64, 3),), [False, None], tuple([1, 2, 3])) self.assert_prune(impl, (types.UniTuple(types.int64, 2),), [True, 'both'], tuple([1, 2])) def test_attr_not_len(self): # The purpose of this test is to make sure that the conditions guarding # the rewrite part do not themselves raise exceptions. # This produces an `ir.Expr` call node for `float.as_integer_ratio`, # which is a getattr() on `float`. @njit def test(): float.as_integer_ratio(1.23) # this should raise a TypingError with self.assertRaises(errors.TypingError) as e: test() self.assertIn("Unknown attribute 'as_integer_ratio'", str(e.exception)) def test_ndim_not_on_array(self): FakeArray = collections.namedtuple('FakeArray', ['ndim']) fa = FakeArray(ndim=2) def impl(fa): if fa.ndim == 2: return fa.ndim else: object() # check prune works for array ndim self.assert_prune(impl, (types.Array(types.float64, 2, 'C'),), [False], np.zeros((2, 3))) # check prune fails for something with `ndim` attr that is not array FakeArrayType = types.NamedUniTuple(types.int64, 1, FakeArray) self.assert_prune(impl, (FakeArrayType,), [None], fa, flags={'nopython':False, 'forceobj':True}) def test_semantic_const_propagates_before_static_rewrites(self): # see issue #5015, the ndim needs writing in as a const before # the rewrite passes run to make e.g. getitems static where possible @njit def impl(a, b): return a.shape[:b.ndim] args = (np.zeros((5, 4, 3, 2)), np.zeros((1, 1))) self.assertPreciseEqual(impl(*args), impl.py_func(*args)) def test_tuple_const_propagation(self): @njit(pipeline_class=IRPreservingTestPipeline) def impl(*args): s = 0 for arg in literal_unroll(args): s += len(arg) return s inp = ((), (1, 2, 3), ()) self.assertPreciseEqual(impl(*inp), impl.py_func(*inp)) ol = impl.overloads[impl.signatures[0]] func_ir = ol.metadata['preserved_ir'] # make sure one of the inplace binop args is a Const binop_consts = set() for blk in func_ir.blocks.values(): for expr in blk.find_exprs('inplace_binop'): inst = blk.find_variable_assignment(expr.rhs.name) self.assertIsInstance(inst.value, ir.Const) binop_consts.add(inst.value.value) self.assertEqual(binop_consts, {len(x) for x in inp})