632 lines
18 KiB
Python
632 lines
18 KiB
Python
|
"""
|
||
|
Tests for SSA reconstruction
|
||
|
"""
|
||
|
import sys
|
||
|
import copy
|
||
|
import logging
|
||
|
|
||
|
import numpy as np
|
||
|
|
||
|
from numba import njit, jit, types
|
||
|
from numba.core import errors, ir
|
||
|
from numba.core.compiler_machinery import FunctionPass, register_pass
|
||
|
from numba.core.compiler import DefaultPassBuilder, CompilerBase
|
||
|
from numba.core.untyped_passes import ReconstructSSA, PreserveIR
|
||
|
from numba.core.typed_passes import NativeLowering
|
||
|
from numba.extending import overload
|
||
|
from numba.tests.support import MemoryLeakMixin, TestCase, override_config
|
||
|
|
||
|
|
||
|
_DEBUG = False
|
||
|
|
||
|
if _DEBUG:
|
||
|
# Enable debug logger on SSA reconstruction
|
||
|
ssa_logger = logging.getLogger("numba.core.ssa")
|
||
|
ssa_logger.setLevel(level=logging.DEBUG)
|
||
|
ssa_logger.addHandler(logging.StreamHandler(sys.stderr))
|
||
|
|
||
|
|
||
|
class SSABaseTest(TestCase):
|
||
|
|
||
|
def check_func(self, func, *args):
|
||
|
got = func(*copy.deepcopy(args))
|
||
|
exp = func.py_func(*copy.deepcopy(args))
|
||
|
self.assertEqual(got, exp)
|
||
|
|
||
|
|
||
|
class TestSSA(SSABaseTest):
|
||
|
"""
|
||
|
Contains tests to help isolate problems in SSA
|
||
|
"""
|
||
|
|
||
|
def test_argument_name_reused(self):
|
||
|
@njit
|
||
|
def foo(x):
|
||
|
x += 1
|
||
|
return x
|
||
|
|
||
|
self.check_func(foo, 123)
|
||
|
|
||
|
def test_if_else_redefine(self):
|
||
|
@njit
|
||
|
def foo(x, y):
|
||
|
z = x * y
|
||
|
if x < y:
|
||
|
z = x
|
||
|
else:
|
||
|
z = y
|
||
|
return z
|
||
|
|
||
|
self.check_func(foo, 3, 2)
|
||
|
self.check_func(foo, 2, 3)
|
||
|
|
||
|
def test_sum_loop(self):
|
||
|
@njit
|
||
|
def foo(n):
|
||
|
c = 0
|
||
|
for i in range(n):
|
||
|
c += i
|
||
|
return c
|
||
|
|
||
|
self.check_func(foo, 0)
|
||
|
self.check_func(foo, 10)
|
||
|
|
||
|
def test_sum_loop_2vars(self):
|
||
|
@njit
|
||
|
def foo(n):
|
||
|
c = 0
|
||
|
d = n
|
||
|
for i in range(n):
|
||
|
c += i
|
||
|
d += n
|
||
|
return c, d
|
||
|
|
||
|
self.check_func(foo, 0)
|
||
|
self.check_func(foo, 10)
|
||
|
|
||
|
def test_sum_2d_loop(self):
|
||
|
@njit
|
||
|
def foo(n):
|
||
|
c = 0
|
||
|
for i in range(n):
|
||
|
for j in range(n):
|
||
|
c += j
|
||
|
c += i
|
||
|
return c
|
||
|
|
||
|
self.check_func(foo, 0)
|
||
|
self.check_func(foo, 10)
|
||
|
|
||
|
def check_undefined_var(self, should_warn):
|
||
|
@njit
|
||
|
def foo(n):
|
||
|
if n:
|
||
|
if n > 0:
|
||
|
c = 0
|
||
|
return c
|
||
|
else:
|
||
|
# variable c is not defined in this branch
|
||
|
c += 1
|
||
|
return c
|
||
|
|
||
|
if should_warn:
|
||
|
with self.assertWarns(errors.NumbaWarning) as warns:
|
||
|
# n=1 so we won't actually run the branch with the uninitialized
|
||
|
self.check_func(foo, 1)
|
||
|
self.assertIn("Detected uninitialized variable c",
|
||
|
str(warns.warning))
|
||
|
else:
|
||
|
self.check_func(foo, 1)
|
||
|
|
||
|
with self.assertRaises(UnboundLocalError):
|
||
|
foo.py_func(0)
|
||
|
|
||
|
def test_undefined_var(self):
|
||
|
with override_config('ALWAYS_WARN_UNINIT_VAR', 0):
|
||
|
self.check_undefined_var(should_warn=False)
|
||
|
with override_config('ALWAYS_WARN_UNINIT_VAR', 1):
|
||
|
self.check_undefined_var(should_warn=True)
|
||
|
|
||
|
def test_phi_propagation(self):
|
||
|
@njit
|
||
|
def foo(actions):
|
||
|
n = 1
|
||
|
|
||
|
i = 0
|
||
|
ct = 0
|
||
|
while n > 0 and i < len(actions):
|
||
|
n -= 1
|
||
|
|
||
|
while actions[i]:
|
||
|
if actions[i]:
|
||
|
if actions[i]:
|
||
|
n += 10
|
||
|
actions[i] -= 1
|
||
|
else:
|
||
|
if actions[i]:
|
||
|
n += 20
|
||
|
actions[i] += 1
|
||
|
|
||
|
ct += n
|
||
|
ct += n
|
||
|
return ct, n
|
||
|
|
||
|
self.check_func(foo, np.array([1, 2]))
|
||
|
|
||
|
def test_unhandled_undefined(self):
|
||
|
def function1(arg1, arg2, arg3, arg4, arg5):
|
||
|
# This function is auto-generated.
|
||
|
if arg1:
|
||
|
var1 = arg2
|
||
|
var2 = arg3
|
||
|
var3 = var2
|
||
|
var4 = arg1
|
||
|
return
|
||
|
else:
|
||
|
if arg2:
|
||
|
if arg4:
|
||
|
var5 = arg4 # noqa: F841
|
||
|
return
|
||
|
else:
|
||
|
var6 = var4
|
||
|
return
|
||
|
return var6
|
||
|
else:
|
||
|
if arg5:
|
||
|
if var1:
|
||
|
if arg5:
|
||
|
var1 = var6
|
||
|
return
|
||
|
else:
|
||
|
var7 = arg2 # noqa: F841
|
||
|
return arg2
|
||
|
return
|
||
|
else:
|
||
|
if var2:
|
||
|
arg5 = arg2
|
||
|
return arg1
|
||
|
else:
|
||
|
var6 = var3
|
||
|
return var4
|
||
|
return
|
||
|
return
|
||
|
else:
|
||
|
var8 = var1
|
||
|
return
|
||
|
return var8
|
||
|
var9 = var3 # noqa: F841
|
||
|
var10 = arg5 # noqa: F841
|
||
|
return var1
|
||
|
|
||
|
# The argument values is not critical for re-creating the bug
|
||
|
# because the bug is in compile-time.
|
||
|
expect = function1(2, 3, 6, 0, 7)
|
||
|
got = njit(function1)(2, 3, 6, 0, 7)
|
||
|
self.assertEqual(expect, got)
|
||
|
|
||
|
|
||
|
class TestReportedSSAIssues(SSABaseTest):
|
||
|
# Tests from issues
|
||
|
# https://github.com/numba/numba/issues?q=is%3Aopen+is%3Aissue+label%3ASSA
|
||
|
|
||
|
def test_issue2194(self):
|
||
|
|
||
|
@njit
|
||
|
def foo():
|
||
|
V = np.empty(1)
|
||
|
s = np.uint32(1)
|
||
|
|
||
|
for i in range(s):
|
||
|
V[i] = 1
|
||
|
for i in range(s, 1):
|
||
|
pass
|
||
|
|
||
|
self.check_func(foo, )
|
||
|
|
||
|
def test_issue3094(self):
|
||
|
|
||
|
@njit
|
||
|
def doit(x):
|
||
|
return x
|
||
|
|
||
|
@njit
|
||
|
def foo(pred):
|
||
|
if pred:
|
||
|
x = True
|
||
|
else:
|
||
|
x = False
|
||
|
# do something with x
|
||
|
return doit(x)
|
||
|
|
||
|
self.check_func(foo, False)
|
||
|
|
||
|
def test_issue3931(self):
|
||
|
|
||
|
@njit
|
||
|
def foo(arr):
|
||
|
for i in range(1):
|
||
|
arr = arr.reshape(3 * 2)
|
||
|
arr = arr.reshape(3, 2)
|
||
|
return (arr)
|
||
|
|
||
|
np.testing.assert_allclose(foo(np.zeros((3, 2))),
|
||
|
foo.py_func(np.zeros((3, 2))))
|
||
|
|
||
|
def test_issue3976(self):
|
||
|
|
||
|
def overload_this(a):
|
||
|
return 'dummy'
|
||
|
|
||
|
@njit
|
||
|
def foo(a):
|
||
|
if a:
|
||
|
s = 5
|
||
|
s = overload_this(s)
|
||
|
else:
|
||
|
s = 'b'
|
||
|
|
||
|
return s
|
||
|
|
||
|
@overload(overload_this)
|
||
|
def ol(a):
|
||
|
return overload_this
|
||
|
|
||
|
self.check_func(foo, True)
|
||
|
|
||
|
def test_issue3979(self):
|
||
|
|
||
|
@njit
|
||
|
def foo(A, B):
|
||
|
x = A[0]
|
||
|
y = B[0]
|
||
|
for i in A:
|
||
|
x = i
|
||
|
for i in B:
|
||
|
y = i
|
||
|
return x, y
|
||
|
|
||
|
self.check_func(foo, (1, 2), ('A', 'B'))
|
||
|
|
||
|
def test_issue5219(self):
|
||
|
|
||
|
def overload_this(a, b=None):
|
||
|
if isinstance(b, tuple):
|
||
|
b = b[0]
|
||
|
return b
|
||
|
|
||
|
@overload(overload_this)
|
||
|
def ol(a, b=None):
|
||
|
b_is_tuple = isinstance(b, (types.Tuple, types.UniTuple))
|
||
|
|
||
|
def impl(a, b=None):
|
||
|
if b_is_tuple is True:
|
||
|
b = b[0]
|
||
|
return b
|
||
|
return impl
|
||
|
|
||
|
@njit
|
||
|
def test_tuple(a, b):
|
||
|
overload_this(a, b)
|
||
|
|
||
|
self.check_func(test_tuple, 1, (2, ))
|
||
|
|
||
|
def test_issue5223(self):
|
||
|
|
||
|
@njit
|
||
|
def bar(x):
|
||
|
if len(x) == 5:
|
||
|
return x
|
||
|
x = x.copy()
|
||
|
for i in range(len(x)):
|
||
|
x[i] += 1
|
||
|
return x
|
||
|
|
||
|
a = np.ones(5)
|
||
|
a.flags.writeable = False
|
||
|
|
||
|
np.testing.assert_allclose(bar(a), bar.py_func(a))
|
||
|
|
||
|
def test_issue5243(self):
|
||
|
|
||
|
@njit
|
||
|
def foo(q):
|
||
|
lin = np.array((0.1, 0.6, 0.3))
|
||
|
stencil = np.zeros((3, 3))
|
||
|
stencil[0, 0] = q[0, 0]
|
||
|
return lin[0]
|
||
|
|
||
|
self.check_func(foo, np.zeros((2, 2)))
|
||
|
|
||
|
def test_issue5482_missing_variable_init(self):
|
||
|
# Test error that lowering fails because variable is missing
|
||
|
# a definition before use.
|
||
|
@njit("(intp, intp, intp)")
|
||
|
def foo(x, v, n):
|
||
|
for i in range(n):
|
||
|
if i == 0:
|
||
|
if i == x:
|
||
|
pass
|
||
|
else:
|
||
|
problematic = v
|
||
|
else:
|
||
|
if i == x:
|
||
|
pass
|
||
|
else:
|
||
|
problematic = problematic + v
|
||
|
return problematic
|
||
|
|
||
|
def test_issue5482_objmode_expr_null_lowering(self):
|
||
|
# Existing pipelines will not have the Expr.null in objmode.
|
||
|
# We have to create a custom pipeline to force a SSA reconstruction
|
||
|
# and stripping.
|
||
|
from numba.core.compiler import CompilerBase, DefaultPassBuilder
|
||
|
from numba.core.untyped_passes import ReconstructSSA, IRProcessing
|
||
|
from numba.core.typed_passes import PreLowerStripPhis
|
||
|
|
||
|
class CustomPipeline(CompilerBase):
|
||
|
def define_pipelines(self):
|
||
|
pm = DefaultPassBuilder.define_objectmode_pipeline(self.state)
|
||
|
# Force SSA reconstruction and stripping
|
||
|
pm.add_pass_after(ReconstructSSA, IRProcessing)
|
||
|
pm.add_pass_after(PreLowerStripPhis, ReconstructSSA)
|
||
|
pm.finalize()
|
||
|
return [pm]
|
||
|
|
||
|
@jit("(intp, intp, intp)", looplift=False,
|
||
|
pipeline_class=CustomPipeline)
|
||
|
def foo(x, v, n):
|
||
|
for i in range(n):
|
||
|
if i == n:
|
||
|
if i == x:
|
||
|
pass
|
||
|
else:
|
||
|
problematic = v
|
||
|
else:
|
||
|
if i == x:
|
||
|
pass
|
||
|
else:
|
||
|
problematic = problematic + v
|
||
|
return problematic
|
||
|
|
||
|
def test_issue5493_unneeded_phi(self):
|
||
|
# Test error that unneeded phi is inserted because variable does not
|
||
|
# have a dominance definition.
|
||
|
data = (np.ones(2), np.ones(2))
|
||
|
A = np.ones(1)
|
||
|
B = np.ones((1,1))
|
||
|
|
||
|
def foo(m, n, data):
|
||
|
if len(data) == 1:
|
||
|
v0 = data[0]
|
||
|
else:
|
||
|
v0 = data[0]
|
||
|
# Unneeded PHI node for `problematic` would be placed here
|
||
|
for _ in range(1, len(data)):
|
||
|
v0 += A
|
||
|
|
||
|
for t in range(1, m):
|
||
|
for idx in range(n):
|
||
|
t = B
|
||
|
|
||
|
if idx == 0:
|
||
|
if idx == n - 1:
|
||
|
pass
|
||
|
else:
|
||
|
problematic = t
|
||
|
else:
|
||
|
if idx == n - 1:
|
||
|
pass
|
||
|
else:
|
||
|
problematic = problematic + t
|
||
|
return problematic
|
||
|
|
||
|
expect = foo(10, 10, data)
|
||
|
res1 = njit(foo)(10, 10, data)
|
||
|
res2 = jit(forceobj=True, looplift=False)(foo)(10, 10, data)
|
||
|
np.testing.assert_array_equal(expect, res1)
|
||
|
np.testing.assert_array_equal(expect, res2)
|
||
|
|
||
|
def test_issue5623_equal_statements_in_same_bb(self):
|
||
|
|
||
|
def foo(pred, stack):
|
||
|
i = 0
|
||
|
c = 1
|
||
|
|
||
|
if pred is True:
|
||
|
stack[i] = c
|
||
|
i += 1
|
||
|
stack[i] = c
|
||
|
i += 1
|
||
|
|
||
|
python = np.array([0, 666])
|
||
|
foo(True, python)
|
||
|
|
||
|
nb = np.array([0, 666])
|
||
|
njit(foo)(True, nb)
|
||
|
|
||
|
expect = np.array([1, 1])
|
||
|
|
||
|
np.testing.assert_array_equal(python, expect)
|
||
|
np.testing.assert_array_equal(nb, expect)
|
||
|
|
||
|
def test_issue5678_non_minimal_phi(self):
|
||
|
# There should be only one phi for variable "i"
|
||
|
|
||
|
from numba.core.compiler import CompilerBase, DefaultPassBuilder
|
||
|
from numba.core.untyped_passes import (
|
||
|
ReconstructSSA, FunctionPass, register_pass,
|
||
|
)
|
||
|
|
||
|
phi_counter = []
|
||
|
|
||
|
@register_pass(mutates_CFG=False, analysis_only=True)
|
||
|
class CheckSSAMinimal(FunctionPass):
|
||
|
# A custom pass to count the number of phis
|
||
|
|
||
|
_name = self.__class__.__qualname__ + ".CheckSSAMinimal"
|
||
|
|
||
|
def __init__(self):
|
||
|
super().__init__(self)
|
||
|
|
||
|
def run_pass(self, state):
|
||
|
ct = 0
|
||
|
for blk in state.func_ir.blocks.values():
|
||
|
ct += len(list(blk.find_exprs('phi')))
|
||
|
phi_counter.append(ct)
|
||
|
return True
|
||
|
|
||
|
class CustomPipeline(CompilerBase):
|
||
|
def define_pipelines(self):
|
||
|
pm = DefaultPassBuilder.define_nopython_pipeline(self.state)
|
||
|
pm.add_pass_after(CheckSSAMinimal, ReconstructSSA)
|
||
|
pm.finalize()
|
||
|
return [pm]
|
||
|
|
||
|
@njit(pipeline_class=CustomPipeline)
|
||
|
def while_for(n, max_iter=1):
|
||
|
a = np.empty((n,n))
|
||
|
i = 0
|
||
|
while i <= max_iter:
|
||
|
for j in range(len(a)):
|
||
|
for k in range(len(a)):
|
||
|
a[j,k] = j + k
|
||
|
i += 1
|
||
|
return a
|
||
|
|
||
|
# Runs fine?
|
||
|
self.assertPreciseEqual(while_for(10), while_for.py_func(10))
|
||
|
# One phi?
|
||
|
self.assertEqual(phi_counter, [1])
|
||
|
|
||
|
def test_issue9242_use_not_dom_def(self):
|
||
|
from numba.core.ir import FunctionIR
|
||
|
from numba.core.compiler_machinery import (
|
||
|
AnalysisPass,
|
||
|
register_pass,
|
||
|
)
|
||
|
|
||
|
def check(fir: FunctionIR):
|
||
|
[blk, *_] = fir.blocks.values()
|
||
|
var = blk.scope.get("d")
|
||
|
defn = fir.get_definition(var)
|
||
|
self.assertEqual(defn.op, "phi")
|
||
|
self.assertIn(ir.UNDEFINED, defn.incoming_values)
|
||
|
|
||
|
@register_pass(mutates_CFG=False, analysis_only=True)
|
||
|
class SSACheck(AnalysisPass):
|
||
|
"""
|
||
|
Check SSA on variable `d`
|
||
|
"""
|
||
|
|
||
|
_name = "SSA_Check"
|
||
|
|
||
|
def __init__(self):
|
||
|
AnalysisPass.__init__(self)
|
||
|
|
||
|
def run_pass(self, state):
|
||
|
check(state.func_ir)
|
||
|
return False
|
||
|
|
||
|
class SSACheckPipeline(CompilerBase):
|
||
|
"""Inject SSACheck pass into the default pipeline following the SSA
|
||
|
pass
|
||
|
"""
|
||
|
|
||
|
def define_pipelines(self):
|
||
|
pipeline = DefaultPassBuilder.define_nopython_pipeline(
|
||
|
self.state, "ssa_check_custom_pipeline")
|
||
|
|
||
|
pipeline._finalized = False
|
||
|
pipeline.add_pass_after(SSACheck, ReconstructSSA)
|
||
|
|
||
|
pipeline.finalize()
|
||
|
return [pipeline]
|
||
|
|
||
|
@njit(pipeline_class=SSACheckPipeline)
|
||
|
def py_func(a):
|
||
|
c = a > 0
|
||
|
if c:
|
||
|
d = a + 5 # d is only defined here; undef in the else branch
|
||
|
|
||
|
return c and d > 0
|
||
|
|
||
|
py_func(10)
|
||
|
|
||
|
|
||
|
class TestSROAIssues(MemoryLeakMixin, TestCase):
|
||
|
# This tests issues related to the SROA optimization done in lowering, which
|
||
|
# reduces time spent in the LLVM SROA pass. The optimization is related to
|
||
|
# SSA and tries to reduce the number of alloca statements for variables with
|
||
|
# only a single assignment.
|
||
|
def test_issue7258_multiple_assignment_post_SSA(self):
|
||
|
# This test adds a pass that will duplicate assignment statements to
|
||
|
# variables named "foobar".
|
||
|
# In the reported issue, the bug will cause a memory leak.
|
||
|
cloned = []
|
||
|
|
||
|
@register_pass(analysis_only=False, mutates_CFG=True)
|
||
|
class CloneFoobarAssignments(FunctionPass):
|
||
|
# A pass that clones variable assignments into "foobar"
|
||
|
_name = "clone_foobar_assignments_pass"
|
||
|
|
||
|
def __init__(self):
|
||
|
FunctionPass.__init__(self)
|
||
|
|
||
|
def run_pass(self, state):
|
||
|
mutated = False
|
||
|
for blk in state.func_ir.blocks.values():
|
||
|
to_clone = []
|
||
|
# find assignments to "foobar"
|
||
|
for assign in blk.find_insts(ir.Assign):
|
||
|
if assign.target.name == "foobar":
|
||
|
to_clone.append(assign)
|
||
|
# clone
|
||
|
for assign in to_clone:
|
||
|
clone = copy.deepcopy(assign)
|
||
|
blk.insert_after(clone, assign)
|
||
|
mutated = True
|
||
|
# keep track of cloned statements
|
||
|
cloned.append(clone)
|
||
|
return mutated
|
||
|
|
||
|
class CustomCompiler(CompilerBase):
|
||
|
def define_pipelines(self):
|
||
|
pm = DefaultPassBuilder.define_nopython_pipeline(
|
||
|
self.state, "custom_pipeline",
|
||
|
)
|
||
|
pm._finalized = False
|
||
|
# Insert the cloning pass after SSA
|
||
|
pm.add_pass_after(CloneFoobarAssignments, ReconstructSSA)
|
||
|
# Capture IR post lowering
|
||
|
pm.add_pass_after(PreserveIR, NativeLowering)
|
||
|
pm.finalize()
|
||
|
return [pm]
|
||
|
|
||
|
@njit(pipeline_class=CustomCompiler)
|
||
|
def udt(arr):
|
||
|
foobar = arr + 1 # this assignment will be cloned
|
||
|
return foobar
|
||
|
|
||
|
arr = np.arange(10)
|
||
|
# Verify that the function works as expected
|
||
|
self.assertPreciseEqual(udt(arr), arr + 1)
|
||
|
# Verify that the expected statement is cloned
|
||
|
self.assertEqual(len(cloned), 1)
|
||
|
self.assertEqual(cloned[0].target.name, "foobar")
|
||
|
# Verify in the Numba IR that the expected statement is cloned
|
||
|
nir = udt.overloads[udt.signatures[0]].metadata['preserved_ir']
|
||
|
self.assertEqual(len(nir.blocks), 1,
|
||
|
"only one block")
|
||
|
[blk] = nir.blocks.values()
|
||
|
assigns = blk.find_insts(ir.Assign)
|
||
|
foobar_assigns = [stmt for stmt in assigns
|
||
|
if stmt.target.name == "foobar"]
|
||
|
self.assertEqual(
|
||
|
len(foobar_assigns), 2,
|
||
|
"expected two assignment statements into 'foobar'",
|
||
|
)
|
||
|
self.assertEqual(
|
||
|
foobar_assigns[0], foobar_assigns[1],
|
||
|
"expected the two assignment statements to be the same",
|
||
|
)
|