""" Tests for SSA reconstruction """ import sys import copy import logging import numpy as np from numba import njit, jit, types from numba.core import errors, ir from numba.core.compiler_machinery import FunctionPass, register_pass from numba.core.compiler import DefaultPassBuilder, CompilerBase from numba.core.untyped_passes import ReconstructSSA, PreserveIR from numba.core.typed_passes import NativeLowering from numba.extending import overload from numba.tests.support import MemoryLeakMixin, TestCase, override_config _DEBUG = False if _DEBUG: # Enable debug logger on SSA reconstruction ssa_logger = logging.getLogger("numba.core.ssa") ssa_logger.setLevel(level=logging.DEBUG) ssa_logger.addHandler(logging.StreamHandler(sys.stderr)) class SSABaseTest(TestCase): def check_func(self, func, *args): got = func(*copy.deepcopy(args)) exp = func.py_func(*copy.deepcopy(args)) self.assertEqual(got, exp) class TestSSA(SSABaseTest): """ Contains tests to help isolate problems in SSA """ def test_argument_name_reused(self): @njit def foo(x): x += 1 return x self.check_func(foo, 123) def test_if_else_redefine(self): @njit def foo(x, y): z = x * y if x < y: z = x else: z = y return z self.check_func(foo, 3, 2) self.check_func(foo, 2, 3) def test_sum_loop(self): @njit def foo(n): c = 0 for i in range(n): c += i return c self.check_func(foo, 0) self.check_func(foo, 10) def test_sum_loop_2vars(self): @njit def foo(n): c = 0 d = n for i in range(n): c += i d += n return c, d self.check_func(foo, 0) self.check_func(foo, 10) def test_sum_2d_loop(self): @njit def foo(n): c = 0 for i in range(n): for j in range(n): c += j c += i return c self.check_func(foo, 0) self.check_func(foo, 10) def check_undefined_var(self, should_warn): @njit def foo(n): if n: if n > 0: c = 0 return c else: # variable c is not defined in this branch c += 1 return c if should_warn: with self.assertWarns(errors.NumbaWarning) as warns: # n=1 so we won't actually run the branch with the uninitialized self.check_func(foo, 1) self.assertIn("Detected uninitialized variable c", str(warns.warning)) else: self.check_func(foo, 1) with self.assertRaises(UnboundLocalError): foo.py_func(0) def test_undefined_var(self): with override_config('ALWAYS_WARN_UNINIT_VAR', 0): self.check_undefined_var(should_warn=False) with override_config('ALWAYS_WARN_UNINIT_VAR', 1): self.check_undefined_var(should_warn=True) def test_phi_propagation(self): @njit def foo(actions): n = 1 i = 0 ct = 0 while n > 0 and i < len(actions): n -= 1 while actions[i]: if actions[i]: if actions[i]: n += 10 actions[i] -= 1 else: if actions[i]: n += 20 actions[i] += 1 ct += n ct += n return ct, n self.check_func(foo, np.array([1, 2])) def test_unhandled_undefined(self): def function1(arg1, arg2, arg3, arg4, arg5): # This function is auto-generated. if arg1: var1 = arg2 var2 = arg3 var3 = var2 var4 = arg1 return else: if arg2: if arg4: var5 = arg4 # noqa: F841 return else: var6 = var4 return return var6 else: if arg5: if var1: if arg5: var1 = var6 return else: var7 = arg2 # noqa: F841 return arg2 return else: if var2: arg5 = arg2 return arg1 else: var6 = var3 return var4 return return else: var8 = var1 return return var8 var9 = var3 # noqa: F841 var10 = arg5 # noqa: F841 return var1 # The argument values is not critical for re-creating the bug # because the bug is in compile-time. expect = function1(2, 3, 6, 0, 7) got = njit(function1)(2, 3, 6, 0, 7) self.assertEqual(expect, got) class TestReportedSSAIssues(SSABaseTest): # Tests from issues # https://github.com/numba/numba/issues?q=is%3Aopen+is%3Aissue+label%3ASSA def test_issue2194(self): @njit def foo(): V = np.empty(1) s = np.uint32(1) for i in range(s): V[i] = 1 for i in range(s, 1): pass self.check_func(foo, ) def test_issue3094(self): @njit def doit(x): return x @njit def foo(pred): if pred: x = True else: x = False # do something with x return doit(x) self.check_func(foo, False) def test_issue3931(self): @njit def foo(arr): for i in range(1): arr = arr.reshape(3 * 2) arr = arr.reshape(3, 2) return (arr) np.testing.assert_allclose(foo(np.zeros((3, 2))), foo.py_func(np.zeros((3, 2)))) def test_issue3976(self): def overload_this(a): return 'dummy' @njit def foo(a): if a: s = 5 s = overload_this(s) else: s = 'b' return s @overload(overload_this) def ol(a): return overload_this self.check_func(foo, True) def test_issue3979(self): @njit def foo(A, B): x = A[0] y = B[0] for i in A: x = i for i in B: y = i return x, y self.check_func(foo, (1, 2), ('A', 'B')) def test_issue5219(self): def overload_this(a, b=None): if isinstance(b, tuple): b = b[0] return b @overload(overload_this) def ol(a, b=None): b_is_tuple = isinstance(b, (types.Tuple, types.UniTuple)) def impl(a, b=None): if b_is_tuple is True: b = b[0] return b return impl @njit def test_tuple(a, b): overload_this(a, b) self.check_func(test_tuple, 1, (2, )) def test_issue5223(self): @njit def bar(x): if len(x) == 5: return x x = x.copy() for i in range(len(x)): x[i] += 1 return x a = np.ones(5) a.flags.writeable = False np.testing.assert_allclose(bar(a), bar.py_func(a)) def test_issue5243(self): @njit def foo(q): lin = np.array((0.1, 0.6, 0.3)) stencil = np.zeros((3, 3)) stencil[0, 0] = q[0, 0] return lin[0] self.check_func(foo, np.zeros((2, 2))) def test_issue5482_missing_variable_init(self): # Test error that lowering fails because variable is missing # a definition before use. @njit("(intp, intp, intp)") def foo(x, v, n): for i in range(n): if i == 0: if i == x: pass else: problematic = v else: if i == x: pass else: problematic = problematic + v return problematic def test_issue5482_objmode_expr_null_lowering(self): # Existing pipelines will not have the Expr.null in objmode. # We have to create a custom pipeline to force a SSA reconstruction # and stripping. from numba.core.compiler import CompilerBase, DefaultPassBuilder from numba.core.untyped_passes import ReconstructSSA, IRProcessing from numba.core.typed_passes import PreLowerStripPhis class CustomPipeline(CompilerBase): def define_pipelines(self): pm = DefaultPassBuilder.define_objectmode_pipeline(self.state) # Force SSA reconstruction and stripping pm.add_pass_after(ReconstructSSA, IRProcessing) pm.add_pass_after(PreLowerStripPhis, ReconstructSSA) pm.finalize() return [pm] @jit("(intp, intp, intp)", looplift=False, pipeline_class=CustomPipeline) def foo(x, v, n): for i in range(n): if i == n: if i == x: pass else: problematic = v else: if i == x: pass else: problematic = problematic + v return problematic def test_issue5493_unneeded_phi(self): # Test error that unneeded phi is inserted because variable does not # have a dominance definition. data = (np.ones(2), np.ones(2)) A = np.ones(1) B = np.ones((1,1)) def foo(m, n, data): if len(data) == 1: v0 = data[0] else: v0 = data[0] # Unneeded PHI node for `problematic` would be placed here for _ in range(1, len(data)): v0 += A for t in range(1, m): for idx in range(n): t = B if idx == 0: if idx == n - 1: pass else: problematic = t else: if idx == n - 1: pass else: problematic = problematic + t return problematic expect = foo(10, 10, data) res1 = njit(foo)(10, 10, data) res2 = jit(forceobj=True, looplift=False)(foo)(10, 10, data) np.testing.assert_array_equal(expect, res1) np.testing.assert_array_equal(expect, res2) def test_issue5623_equal_statements_in_same_bb(self): def foo(pred, stack): i = 0 c = 1 if pred is True: stack[i] = c i += 1 stack[i] = c i += 1 python = np.array([0, 666]) foo(True, python) nb = np.array([0, 666]) njit(foo)(True, nb) expect = np.array([1, 1]) np.testing.assert_array_equal(python, expect) np.testing.assert_array_equal(nb, expect) def test_issue5678_non_minimal_phi(self): # There should be only one phi for variable "i" from numba.core.compiler import CompilerBase, DefaultPassBuilder from numba.core.untyped_passes import ( ReconstructSSA, FunctionPass, register_pass, ) phi_counter = [] @register_pass(mutates_CFG=False, analysis_only=True) class CheckSSAMinimal(FunctionPass): # A custom pass to count the number of phis _name = self.__class__.__qualname__ + ".CheckSSAMinimal" def __init__(self): super().__init__(self) def run_pass(self, state): ct = 0 for blk in state.func_ir.blocks.values(): ct += len(list(blk.find_exprs('phi'))) phi_counter.append(ct) return True class CustomPipeline(CompilerBase): def define_pipelines(self): pm = DefaultPassBuilder.define_nopython_pipeline(self.state) pm.add_pass_after(CheckSSAMinimal, ReconstructSSA) pm.finalize() return [pm] @njit(pipeline_class=CustomPipeline) def while_for(n, max_iter=1): a = np.empty((n,n)) i = 0 while i <= max_iter: for j in range(len(a)): for k in range(len(a)): a[j,k] = j + k i += 1 return a # Runs fine? self.assertPreciseEqual(while_for(10), while_for.py_func(10)) # One phi? self.assertEqual(phi_counter, [1]) def test_issue9242_use_not_dom_def(self): from numba.core.ir import FunctionIR from numba.core.compiler_machinery import ( AnalysisPass, register_pass, ) def check(fir: FunctionIR): [blk, *_] = fir.blocks.values() var = blk.scope.get("d") defn = fir.get_definition(var) self.assertEqual(defn.op, "phi") self.assertIn(ir.UNDEFINED, defn.incoming_values) @register_pass(mutates_CFG=False, analysis_only=True) class SSACheck(AnalysisPass): """ Check SSA on variable `d` """ _name = "SSA_Check" def __init__(self): AnalysisPass.__init__(self) def run_pass(self, state): check(state.func_ir) return False class SSACheckPipeline(CompilerBase): """Inject SSACheck pass into the default pipeline following the SSA pass """ def define_pipelines(self): pipeline = DefaultPassBuilder.define_nopython_pipeline( self.state, "ssa_check_custom_pipeline") pipeline._finalized = False pipeline.add_pass_after(SSACheck, ReconstructSSA) pipeline.finalize() return [pipeline] @njit(pipeline_class=SSACheckPipeline) def py_func(a): c = a > 0 if c: d = a + 5 # d is only defined here; undef in the else branch return c and d > 0 py_func(10) class TestSROAIssues(MemoryLeakMixin, TestCase): # This tests issues related to the SROA optimization done in lowering, which # reduces time spent in the LLVM SROA pass. The optimization is related to # SSA and tries to reduce the number of alloca statements for variables with # only a single assignment. def test_issue7258_multiple_assignment_post_SSA(self): # This test adds a pass that will duplicate assignment statements to # variables named "foobar". # In the reported issue, the bug will cause a memory leak. cloned = [] @register_pass(analysis_only=False, mutates_CFG=True) class CloneFoobarAssignments(FunctionPass): # A pass that clones variable assignments into "foobar" _name = "clone_foobar_assignments_pass" def __init__(self): FunctionPass.__init__(self) def run_pass(self, state): mutated = False for blk in state.func_ir.blocks.values(): to_clone = [] # find assignments to "foobar" for assign in blk.find_insts(ir.Assign): if assign.target.name == "foobar": to_clone.append(assign) # clone for assign in to_clone: clone = copy.deepcopy(assign) blk.insert_after(clone, assign) mutated = True # keep track of cloned statements cloned.append(clone) return mutated class CustomCompiler(CompilerBase): def define_pipelines(self): pm = DefaultPassBuilder.define_nopython_pipeline( self.state, "custom_pipeline", ) pm._finalized = False # Insert the cloning pass after SSA pm.add_pass_after(CloneFoobarAssignments, ReconstructSSA) # Capture IR post lowering pm.add_pass_after(PreserveIR, NativeLowering) pm.finalize() return [pm] @njit(pipeline_class=CustomCompiler) def udt(arr): foobar = arr + 1 # this assignment will be cloned return foobar arr = np.arange(10) # Verify that the function works as expected self.assertPreciseEqual(udt(arr), arr + 1) # Verify that the expected statement is cloned self.assertEqual(len(cloned), 1) self.assertEqual(cloned[0].target.name, "foobar") # Verify in the Numba IR that the expected statement is cloned nir = udt.overloads[udt.signatures[0]].metadata['preserved_ir'] self.assertEqual(len(nir.blocks), 1, "only one block") [blk] = nir.blocks.values() assigns = blk.find_insts(ir.Assign) foobar_assigns = [stmt for stmt in assigns if stmt.target.name == "foobar"] self.assertEqual( len(foobar_assigns), 2, "expected two assignment statements into 'foobar'", ) self.assertEqual( foobar_assigns[0], foobar_assigns[1], "expected the two assignment statements to be the same", )