ai-content-maker/.venv/Lib/site-packages/numba/tests/test_ssa.py

632 lines
18 KiB
Python
Raw Normal View History

2024-05-03 04:18:51 +03:00
"""
Tests for SSA reconstruction
"""
import sys
import copy
import logging
import numpy as np
from numba import njit, jit, types
from numba.core import errors, ir
from numba.core.compiler_machinery import FunctionPass, register_pass
from numba.core.compiler import DefaultPassBuilder, CompilerBase
from numba.core.untyped_passes import ReconstructSSA, PreserveIR
from numba.core.typed_passes import NativeLowering
from numba.extending import overload
from numba.tests.support import MemoryLeakMixin, TestCase, override_config
_DEBUG = False
if _DEBUG:
# Enable debug logger on SSA reconstruction
ssa_logger = logging.getLogger("numba.core.ssa")
ssa_logger.setLevel(level=logging.DEBUG)
ssa_logger.addHandler(logging.StreamHandler(sys.stderr))
class SSABaseTest(TestCase):
def check_func(self, func, *args):
got = func(*copy.deepcopy(args))
exp = func.py_func(*copy.deepcopy(args))
self.assertEqual(got, exp)
class TestSSA(SSABaseTest):
"""
Contains tests to help isolate problems in SSA
"""
def test_argument_name_reused(self):
@njit
def foo(x):
x += 1
return x
self.check_func(foo, 123)
def test_if_else_redefine(self):
@njit
def foo(x, y):
z = x * y
if x < y:
z = x
else:
z = y
return z
self.check_func(foo, 3, 2)
self.check_func(foo, 2, 3)
def test_sum_loop(self):
@njit
def foo(n):
c = 0
for i in range(n):
c += i
return c
self.check_func(foo, 0)
self.check_func(foo, 10)
def test_sum_loop_2vars(self):
@njit
def foo(n):
c = 0
d = n
for i in range(n):
c += i
d += n
return c, d
self.check_func(foo, 0)
self.check_func(foo, 10)
def test_sum_2d_loop(self):
@njit
def foo(n):
c = 0
for i in range(n):
for j in range(n):
c += j
c += i
return c
self.check_func(foo, 0)
self.check_func(foo, 10)
def check_undefined_var(self, should_warn):
@njit
def foo(n):
if n:
if n > 0:
c = 0
return c
else:
# variable c is not defined in this branch
c += 1
return c
if should_warn:
with self.assertWarns(errors.NumbaWarning) as warns:
# n=1 so we won't actually run the branch with the uninitialized
self.check_func(foo, 1)
self.assertIn("Detected uninitialized variable c",
str(warns.warning))
else:
self.check_func(foo, 1)
with self.assertRaises(UnboundLocalError):
foo.py_func(0)
def test_undefined_var(self):
with override_config('ALWAYS_WARN_UNINIT_VAR', 0):
self.check_undefined_var(should_warn=False)
with override_config('ALWAYS_WARN_UNINIT_VAR', 1):
self.check_undefined_var(should_warn=True)
def test_phi_propagation(self):
@njit
def foo(actions):
n = 1
i = 0
ct = 0
while n > 0 and i < len(actions):
n -= 1
while actions[i]:
if actions[i]:
if actions[i]:
n += 10
actions[i] -= 1
else:
if actions[i]:
n += 20
actions[i] += 1
ct += n
ct += n
return ct, n
self.check_func(foo, np.array([1, 2]))
def test_unhandled_undefined(self):
def function1(arg1, arg2, arg3, arg4, arg5):
# This function is auto-generated.
if arg1:
var1 = arg2
var2 = arg3
var3 = var2
var4 = arg1
return
else:
if arg2:
if arg4:
var5 = arg4 # noqa: F841
return
else:
var6 = var4
return
return var6
else:
if arg5:
if var1:
if arg5:
var1 = var6
return
else:
var7 = arg2 # noqa: F841
return arg2
return
else:
if var2:
arg5 = arg2
return arg1
else:
var6 = var3
return var4
return
return
else:
var8 = var1
return
return var8
var9 = var3 # noqa: F841
var10 = arg5 # noqa: F841
return var1
# The argument values is not critical for re-creating the bug
# because the bug is in compile-time.
expect = function1(2, 3, 6, 0, 7)
got = njit(function1)(2, 3, 6, 0, 7)
self.assertEqual(expect, got)
class TestReportedSSAIssues(SSABaseTest):
# Tests from issues
# https://github.com/numba/numba/issues?q=is%3Aopen+is%3Aissue+label%3ASSA
def test_issue2194(self):
@njit
def foo():
V = np.empty(1)
s = np.uint32(1)
for i in range(s):
V[i] = 1
for i in range(s, 1):
pass
self.check_func(foo, )
def test_issue3094(self):
@njit
def doit(x):
return x
@njit
def foo(pred):
if pred:
x = True
else:
x = False
# do something with x
return doit(x)
self.check_func(foo, False)
def test_issue3931(self):
@njit
def foo(arr):
for i in range(1):
arr = arr.reshape(3 * 2)
arr = arr.reshape(3, 2)
return (arr)
np.testing.assert_allclose(foo(np.zeros((3, 2))),
foo.py_func(np.zeros((3, 2))))
def test_issue3976(self):
def overload_this(a):
return 'dummy'
@njit
def foo(a):
if a:
s = 5
s = overload_this(s)
else:
s = 'b'
return s
@overload(overload_this)
def ol(a):
return overload_this
self.check_func(foo, True)
def test_issue3979(self):
@njit
def foo(A, B):
x = A[0]
y = B[0]
for i in A:
x = i
for i in B:
y = i
return x, y
self.check_func(foo, (1, 2), ('A', 'B'))
def test_issue5219(self):
def overload_this(a, b=None):
if isinstance(b, tuple):
b = b[0]
return b
@overload(overload_this)
def ol(a, b=None):
b_is_tuple = isinstance(b, (types.Tuple, types.UniTuple))
def impl(a, b=None):
if b_is_tuple is True:
b = b[0]
return b
return impl
@njit
def test_tuple(a, b):
overload_this(a, b)
self.check_func(test_tuple, 1, (2, ))
def test_issue5223(self):
@njit
def bar(x):
if len(x) == 5:
return x
x = x.copy()
for i in range(len(x)):
x[i] += 1
return x
a = np.ones(5)
a.flags.writeable = False
np.testing.assert_allclose(bar(a), bar.py_func(a))
def test_issue5243(self):
@njit
def foo(q):
lin = np.array((0.1, 0.6, 0.3))
stencil = np.zeros((3, 3))
stencil[0, 0] = q[0, 0]
return lin[0]
self.check_func(foo, np.zeros((2, 2)))
def test_issue5482_missing_variable_init(self):
# Test error that lowering fails because variable is missing
# a definition before use.
@njit("(intp, intp, intp)")
def foo(x, v, n):
for i in range(n):
if i == 0:
if i == x:
pass
else:
problematic = v
else:
if i == x:
pass
else:
problematic = problematic + v
return problematic
def test_issue5482_objmode_expr_null_lowering(self):
# Existing pipelines will not have the Expr.null in objmode.
# We have to create a custom pipeline to force a SSA reconstruction
# and stripping.
from numba.core.compiler import CompilerBase, DefaultPassBuilder
from numba.core.untyped_passes import ReconstructSSA, IRProcessing
from numba.core.typed_passes import PreLowerStripPhis
class CustomPipeline(CompilerBase):
def define_pipelines(self):
pm = DefaultPassBuilder.define_objectmode_pipeline(self.state)
# Force SSA reconstruction and stripping
pm.add_pass_after(ReconstructSSA, IRProcessing)
pm.add_pass_after(PreLowerStripPhis, ReconstructSSA)
pm.finalize()
return [pm]
@jit("(intp, intp, intp)", looplift=False,
pipeline_class=CustomPipeline)
def foo(x, v, n):
for i in range(n):
if i == n:
if i == x:
pass
else:
problematic = v
else:
if i == x:
pass
else:
problematic = problematic + v
return problematic
def test_issue5493_unneeded_phi(self):
# Test error that unneeded phi is inserted because variable does not
# have a dominance definition.
data = (np.ones(2), np.ones(2))
A = np.ones(1)
B = np.ones((1,1))
def foo(m, n, data):
if len(data) == 1:
v0 = data[0]
else:
v0 = data[0]
# Unneeded PHI node for `problematic` would be placed here
for _ in range(1, len(data)):
v0 += A
for t in range(1, m):
for idx in range(n):
t = B
if idx == 0:
if idx == n - 1:
pass
else:
problematic = t
else:
if idx == n - 1:
pass
else:
problematic = problematic + t
return problematic
expect = foo(10, 10, data)
res1 = njit(foo)(10, 10, data)
res2 = jit(forceobj=True, looplift=False)(foo)(10, 10, data)
np.testing.assert_array_equal(expect, res1)
np.testing.assert_array_equal(expect, res2)
def test_issue5623_equal_statements_in_same_bb(self):
def foo(pred, stack):
i = 0
c = 1
if pred is True:
stack[i] = c
i += 1
stack[i] = c
i += 1
python = np.array([0, 666])
foo(True, python)
nb = np.array([0, 666])
njit(foo)(True, nb)
expect = np.array([1, 1])
np.testing.assert_array_equal(python, expect)
np.testing.assert_array_equal(nb, expect)
def test_issue5678_non_minimal_phi(self):
# There should be only one phi for variable "i"
from numba.core.compiler import CompilerBase, DefaultPassBuilder
from numba.core.untyped_passes import (
ReconstructSSA, FunctionPass, register_pass,
)
phi_counter = []
@register_pass(mutates_CFG=False, analysis_only=True)
class CheckSSAMinimal(FunctionPass):
# A custom pass to count the number of phis
_name = self.__class__.__qualname__ + ".CheckSSAMinimal"
def __init__(self):
super().__init__(self)
def run_pass(self, state):
ct = 0
for blk in state.func_ir.blocks.values():
ct += len(list(blk.find_exprs('phi')))
phi_counter.append(ct)
return True
class CustomPipeline(CompilerBase):
def define_pipelines(self):
pm = DefaultPassBuilder.define_nopython_pipeline(self.state)
pm.add_pass_after(CheckSSAMinimal, ReconstructSSA)
pm.finalize()
return [pm]
@njit(pipeline_class=CustomPipeline)
def while_for(n, max_iter=1):
a = np.empty((n,n))
i = 0
while i <= max_iter:
for j in range(len(a)):
for k in range(len(a)):
a[j,k] = j + k
i += 1
return a
# Runs fine?
self.assertPreciseEqual(while_for(10), while_for.py_func(10))
# One phi?
self.assertEqual(phi_counter, [1])
def test_issue9242_use_not_dom_def(self):
from numba.core.ir import FunctionIR
from numba.core.compiler_machinery import (
AnalysisPass,
register_pass,
)
def check(fir: FunctionIR):
[blk, *_] = fir.blocks.values()
var = blk.scope.get("d")
defn = fir.get_definition(var)
self.assertEqual(defn.op, "phi")
self.assertIn(ir.UNDEFINED, defn.incoming_values)
@register_pass(mutates_CFG=False, analysis_only=True)
class SSACheck(AnalysisPass):
"""
Check SSA on variable `d`
"""
_name = "SSA_Check"
def __init__(self):
AnalysisPass.__init__(self)
def run_pass(self, state):
check(state.func_ir)
return False
class SSACheckPipeline(CompilerBase):
"""Inject SSACheck pass into the default pipeline following the SSA
pass
"""
def define_pipelines(self):
pipeline = DefaultPassBuilder.define_nopython_pipeline(
self.state, "ssa_check_custom_pipeline")
pipeline._finalized = False
pipeline.add_pass_after(SSACheck, ReconstructSSA)
pipeline.finalize()
return [pipeline]
@njit(pipeline_class=SSACheckPipeline)
def py_func(a):
c = a > 0
if c:
d = a + 5 # d is only defined here; undef in the else branch
return c and d > 0
py_func(10)
class TestSROAIssues(MemoryLeakMixin, TestCase):
# This tests issues related to the SROA optimization done in lowering, which
# reduces time spent in the LLVM SROA pass. The optimization is related to
# SSA and tries to reduce the number of alloca statements for variables with
# only a single assignment.
def test_issue7258_multiple_assignment_post_SSA(self):
# This test adds a pass that will duplicate assignment statements to
# variables named "foobar".
# In the reported issue, the bug will cause a memory leak.
cloned = []
@register_pass(analysis_only=False, mutates_CFG=True)
class CloneFoobarAssignments(FunctionPass):
# A pass that clones variable assignments into "foobar"
_name = "clone_foobar_assignments_pass"
def __init__(self):
FunctionPass.__init__(self)
def run_pass(self, state):
mutated = False
for blk in state.func_ir.blocks.values():
to_clone = []
# find assignments to "foobar"
for assign in blk.find_insts(ir.Assign):
if assign.target.name == "foobar":
to_clone.append(assign)
# clone
for assign in to_clone:
clone = copy.deepcopy(assign)
blk.insert_after(clone, assign)
mutated = True
# keep track of cloned statements
cloned.append(clone)
return mutated
class CustomCompiler(CompilerBase):
def define_pipelines(self):
pm = DefaultPassBuilder.define_nopython_pipeline(
self.state, "custom_pipeline",
)
pm._finalized = False
# Insert the cloning pass after SSA
pm.add_pass_after(CloneFoobarAssignments, ReconstructSSA)
# Capture IR post lowering
pm.add_pass_after(PreserveIR, NativeLowering)
pm.finalize()
return [pm]
@njit(pipeline_class=CustomCompiler)
def udt(arr):
foobar = arr + 1 # this assignment will be cloned
return foobar
arr = np.arange(10)
# Verify that the function works as expected
self.assertPreciseEqual(udt(arr), arr + 1)
# Verify that the expected statement is cloned
self.assertEqual(len(cloned), 1)
self.assertEqual(cloned[0].target.name, "foobar")
# Verify in the Numba IR that the expected statement is cloned
nir = udt.overloads[udt.signatures[0]].metadata['preserved_ir']
self.assertEqual(len(nir.blocks), 1,
"only one block")
[blk] = nir.blocks.values()
assigns = blk.find_insts(ir.Assign)
foobar_assigns = [stmt for stmt in assigns
if stmt.target.name == "foobar"]
self.assertEqual(
len(foobar_assigns), 2,
"expected two assignment statements into 'foobar'",
)
self.assertEqual(
foobar_assigns[0], foobar_assigns[1],
"expected the two assignment statements to be the same",
)