ai-content-maker/.venv/Lib/site-packages/numba/tests/test_ssa.py

632 lines
18 KiB
Python

"""
Tests for SSA reconstruction
"""
import sys
import copy
import logging
import numpy as np
from numba import njit, jit, types
from numba.core import errors, ir
from numba.core.compiler_machinery import FunctionPass, register_pass
from numba.core.compiler import DefaultPassBuilder, CompilerBase
from numba.core.untyped_passes import ReconstructSSA, PreserveIR
from numba.core.typed_passes import NativeLowering
from numba.extending import overload
from numba.tests.support import MemoryLeakMixin, TestCase, override_config
_DEBUG = False
if _DEBUG:
# Enable debug logger on SSA reconstruction
ssa_logger = logging.getLogger("numba.core.ssa")
ssa_logger.setLevel(level=logging.DEBUG)
ssa_logger.addHandler(logging.StreamHandler(sys.stderr))
class SSABaseTest(TestCase):
def check_func(self, func, *args):
got = func(*copy.deepcopy(args))
exp = func.py_func(*copy.deepcopy(args))
self.assertEqual(got, exp)
class TestSSA(SSABaseTest):
"""
Contains tests to help isolate problems in SSA
"""
def test_argument_name_reused(self):
@njit
def foo(x):
x += 1
return x
self.check_func(foo, 123)
def test_if_else_redefine(self):
@njit
def foo(x, y):
z = x * y
if x < y:
z = x
else:
z = y
return z
self.check_func(foo, 3, 2)
self.check_func(foo, 2, 3)
def test_sum_loop(self):
@njit
def foo(n):
c = 0
for i in range(n):
c += i
return c
self.check_func(foo, 0)
self.check_func(foo, 10)
def test_sum_loop_2vars(self):
@njit
def foo(n):
c = 0
d = n
for i in range(n):
c += i
d += n
return c, d
self.check_func(foo, 0)
self.check_func(foo, 10)
def test_sum_2d_loop(self):
@njit
def foo(n):
c = 0
for i in range(n):
for j in range(n):
c += j
c += i
return c
self.check_func(foo, 0)
self.check_func(foo, 10)
def check_undefined_var(self, should_warn):
@njit
def foo(n):
if n:
if n > 0:
c = 0
return c
else:
# variable c is not defined in this branch
c += 1
return c
if should_warn:
with self.assertWarns(errors.NumbaWarning) as warns:
# n=1 so we won't actually run the branch with the uninitialized
self.check_func(foo, 1)
self.assertIn("Detected uninitialized variable c",
str(warns.warning))
else:
self.check_func(foo, 1)
with self.assertRaises(UnboundLocalError):
foo.py_func(0)
def test_undefined_var(self):
with override_config('ALWAYS_WARN_UNINIT_VAR', 0):
self.check_undefined_var(should_warn=False)
with override_config('ALWAYS_WARN_UNINIT_VAR', 1):
self.check_undefined_var(should_warn=True)
def test_phi_propagation(self):
@njit
def foo(actions):
n = 1
i = 0
ct = 0
while n > 0 and i < len(actions):
n -= 1
while actions[i]:
if actions[i]:
if actions[i]:
n += 10
actions[i] -= 1
else:
if actions[i]:
n += 20
actions[i] += 1
ct += n
ct += n
return ct, n
self.check_func(foo, np.array([1, 2]))
def test_unhandled_undefined(self):
def function1(arg1, arg2, arg3, arg4, arg5):
# This function is auto-generated.
if arg1:
var1 = arg2
var2 = arg3
var3 = var2
var4 = arg1
return
else:
if arg2:
if arg4:
var5 = arg4 # noqa: F841
return
else:
var6 = var4
return
return var6
else:
if arg5:
if var1:
if arg5:
var1 = var6
return
else:
var7 = arg2 # noqa: F841
return arg2
return
else:
if var2:
arg5 = arg2
return arg1
else:
var6 = var3
return var4
return
return
else:
var8 = var1
return
return var8
var9 = var3 # noqa: F841
var10 = arg5 # noqa: F841
return var1
# The argument values is not critical for re-creating the bug
# because the bug is in compile-time.
expect = function1(2, 3, 6, 0, 7)
got = njit(function1)(2, 3, 6, 0, 7)
self.assertEqual(expect, got)
class TestReportedSSAIssues(SSABaseTest):
# Tests from issues
# https://github.com/numba/numba/issues?q=is%3Aopen+is%3Aissue+label%3ASSA
def test_issue2194(self):
@njit
def foo():
V = np.empty(1)
s = np.uint32(1)
for i in range(s):
V[i] = 1
for i in range(s, 1):
pass
self.check_func(foo, )
def test_issue3094(self):
@njit
def doit(x):
return x
@njit
def foo(pred):
if pred:
x = True
else:
x = False
# do something with x
return doit(x)
self.check_func(foo, False)
def test_issue3931(self):
@njit
def foo(arr):
for i in range(1):
arr = arr.reshape(3 * 2)
arr = arr.reshape(3, 2)
return (arr)
np.testing.assert_allclose(foo(np.zeros((3, 2))),
foo.py_func(np.zeros((3, 2))))
def test_issue3976(self):
def overload_this(a):
return 'dummy'
@njit
def foo(a):
if a:
s = 5
s = overload_this(s)
else:
s = 'b'
return s
@overload(overload_this)
def ol(a):
return overload_this
self.check_func(foo, True)
def test_issue3979(self):
@njit
def foo(A, B):
x = A[0]
y = B[0]
for i in A:
x = i
for i in B:
y = i
return x, y
self.check_func(foo, (1, 2), ('A', 'B'))
def test_issue5219(self):
def overload_this(a, b=None):
if isinstance(b, tuple):
b = b[0]
return b
@overload(overload_this)
def ol(a, b=None):
b_is_tuple = isinstance(b, (types.Tuple, types.UniTuple))
def impl(a, b=None):
if b_is_tuple is True:
b = b[0]
return b
return impl
@njit
def test_tuple(a, b):
overload_this(a, b)
self.check_func(test_tuple, 1, (2, ))
def test_issue5223(self):
@njit
def bar(x):
if len(x) == 5:
return x
x = x.copy()
for i in range(len(x)):
x[i] += 1
return x
a = np.ones(5)
a.flags.writeable = False
np.testing.assert_allclose(bar(a), bar.py_func(a))
def test_issue5243(self):
@njit
def foo(q):
lin = np.array((0.1, 0.6, 0.3))
stencil = np.zeros((3, 3))
stencil[0, 0] = q[0, 0]
return lin[0]
self.check_func(foo, np.zeros((2, 2)))
def test_issue5482_missing_variable_init(self):
# Test error that lowering fails because variable is missing
# a definition before use.
@njit("(intp, intp, intp)")
def foo(x, v, n):
for i in range(n):
if i == 0:
if i == x:
pass
else:
problematic = v
else:
if i == x:
pass
else:
problematic = problematic + v
return problematic
def test_issue5482_objmode_expr_null_lowering(self):
# Existing pipelines will not have the Expr.null in objmode.
# We have to create a custom pipeline to force a SSA reconstruction
# and stripping.
from numba.core.compiler import CompilerBase, DefaultPassBuilder
from numba.core.untyped_passes import ReconstructSSA, IRProcessing
from numba.core.typed_passes import PreLowerStripPhis
class CustomPipeline(CompilerBase):
def define_pipelines(self):
pm = DefaultPassBuilder.define_objectmode_pipeline(self.state)
# Force SSA reconstruction and stripping
pm.add_pass_after(ReconstructSSA, IRProcessing)
pm.add_pass_after(PreLowerStripPhis, ReconstructSSA)
pm.finalize()
return [pm]
@jit("(intp, intp, intp)", looplift=False,
pipeline_class=CustomPipeline)
def foo(x, v, n):
for i in range(n):
if i == n:
if i == x:
pass
else:
problematic = v
else:
if i == x:
pass
else:
problematic = problematic + v
return problematic
def test_issue5493_unneeded_phi(self):
# Test error that unneeded phi is inserted because variable does not
# have a dominance definition.
data = (np.ones(2), np.ones(2))
A = np.ones(1)
B = np.ones((1,1))
def foo(m, n, data):
if len(data) == 1:
v0 = data[0]
else:
v0 = data[0]
# Unneeded PHI node for `problematic` would be placed here
for _ in range(1, len(data)):
v0 += A
for t in range(1, m):
for idx in range(n):
t = B
if idx == 0:
if idx == n - 1:
pass
else:
problematic = t
else:
if idx == n - 1:
pass
else:
problematic = problematic + t
return problematic
expect = foo(10, 10, data)
res1 = njit(foo)(10, 10, data)
res2 = jit(forceobj=True, looplift=False)(foo)(10, 10, data)
np.testing.assert_array_equal(expect, res1)
np.testing.assert_array_equal(expect, res2)
def test_issue5623_equal_statements_in_same_bb(self):
def foo(pred, stack):
i = 0
c = 1
if pred is True:
stack[i] = c
i += 1
stack[i] = c
i += 1
python = np.array([0, 666])
foo(True, python)
nb = np.array([0, 666])
njit(foo)(True, nb)
expect = np.array([1, 1])
np.testing.assert_array_equal(python, expect)
np.testing.assert_array_equal(nb, expect)
def test_issue5678_non_minimal_phi(self):
# There should be only one phi for variable "i"
from numba.core.compiler import CompilerBase, DefaultPassBuilder
from numba.core.untyped_passes import (
ReconstructSSA, FunctionPass, register_pass,
)
phi_counter = []
@register_pass(mutates_CFG=False, analysis_only=True)
class CheckSSAMinimal(FunctionPass):
# A custom pass to count the number of phis
_name = self.__class__.__qualname__ + ".CheckSSAMinimal"
def __init__(self):
super().__init__(self)
def run_pass(self, state):
ct = 0
for blk in state.func_ir.blocks.values():
ct += len(list(blk.find_exprs('phi')))
phi_counter.append(ct)
return True
class CustomPipeline(CompilerBase):
def define_pipelines(self):
pm = DefaultPassBuilder.define_nopython_pipeline(self.state)
pm.add_pass_after(CheckSSAMinimal, ReconstructSSA)
pm.finalize()
return [pm]
@njit(pipeline_class=CustomPipeline)
def while_for(n, max_iter=1):
a = np.empty((n,n))
i = 0
while i <= max_iter:
for j in range(len(a)):
for k in range(len(a)):
a[j,k] = j + k
i += 1
return a
# Runs fine?
self.assertPreciseEqual(while_for(10), while_for.py_func(10))
# One phi?
self.assertEqual(phi_counter, [1])
def test_issue9242_use_not_dom_def(self):
from numba.core.ir import FunctionIR
from numba.core.compiler_machinery import (
AnalysisPass,
register_pass,
)
def check(fir: FunctionIR):
[blk, *_] = fir.blocks.values()
var = blk.scope.get("d")
defn = fir.get_definition(var)
self.assertEqual(defn.op, "phi")
self.assertIn(ir.UNDEFINED, defn.incoming_values)
@register_pass(mutates_CFG=False, analysis_only=True)
class SSACheck(AnalysisPass):
"""
Check SSA on variable `d`
"""
_name = "SSA_Check"
def __init__(self):
AnalysisPass.__init__(self)
def run_pass(self, state):
check(state.func_ir)
return False
class SSACheckPipeline(CompilerBase):
"""Inject SSACheck pass into the default pipeline following the SSA
pass
"""
def define_pipelines(self):
pipeline = DefaultPassBuilder.define_nopython_pipeline(
self.state, "ssa_check_custom_pipeline")
pipeline._finalized = False
pipeline.add_pass_after(SSACheck, ReconstructSSA)
pipeline.finalize()
return [pipeline]
@njit(pipeline_class=SSACheckPipeline)
def py_func(a):
c = a > 0
if c:
d = a + 5 # d is only defined here; undef in the else branch
return c and d > 0
py_func(10)
class TestSROAIssues(MemoryLeakMixin, TestCase):
# This tests issues related to the SROA optimization done in lowering, which
# reduces time spent in the LLVM SROA pass. The optimization is related to
# SSA and tries to reduce the number of alloca statements for variables with
# only a single assignment.
def test_issue7258_multiple_assignment_post_SSA(self):
# This test adds a pass that will duplicate assignment statements to
# variables named "foobar".
# In the reported issue, the bug will cause a memory leak.
cloned = []
@register_pass(analysis_only=False, mutates_CFG=True)
class CloneFoobarAssignments(FunctionPass):
# A pass that clones variable assignments into "foobar"
_name = "clone_foobar_assignments_pass"
def __init__(self):
FunctionPass.__init__(self)
def run_pass(self, state):
mutated = False
for blk in state.func_ir.blocks.values():
to_clone = []
# find assignments to "foobar"
for assign in blk.find_insts(ir.Assign):
if assign.target.name == "foobar":
to_clone.append(assign)
# clone
for assign in to_clone:
clone = copy.deepcopy(assign)
blk.insert_after(clone, assign)
mutated = True
# keep track of cloned statements
cloned.append(clone)
return mutated
class CustomCompiler(CompilerBase):
def define_pipelines(self):
pm = DefaultPassBuilder.define_nopython_pipeline(
self.state, "custom_pipeline",
)
pm._finalized = False
# Insert the cloning pass after SSA
pm.add_pass_after(CloneFoobarAssignments, ReconstructSSA)
# Capture IR post lowering
pm.add_pass_after(PreserveIR, NativeLowering)
pm.finalize()
return [pm]
@njit(pipeline_class=CustomCompiler)
def udt(arr):
foobar = arr + 1 # this assignment will be cloned
return foobar
arr = np.arange(10)
# Verify that the function works as expected
self.assertPreciseEqual(udt(arr), arr + 1)
# Verify that the expected statement is cloned
self.assertEqual(len(cloned), 1)
self.assertEqual(cloned[0].target.name, "foobar")
# Verify in the Numba IR that the expected statement is cloned
nir = udt.overloads[udt.signatures[0]].metadata['preserved_ir']
self.assertEqual(len(nir.blocks), 1,
"only one block")
[blk] = nir.blocks.values()
assigns = blk.find_insts(ir.Assign)
foobar_assigns = [stmt for stmt in assigns
if stmt.target.name == "foobar"]
self.assertEqual(
len(foobar_assigns), 2,
"expected two assignment statements into 'foobar'",
)
self.assertEqual(
foobar_assigns[0], foobar_assigns[1],
"expected the two assignment statements to be the same",
)