ai-content-maker/.venv/Lib/site-packages/numba/tests/test_analysis.py

1012 lines
34 KiB
Python
Raw Normal View History

2024-05-03 04:18:51 +03:00
# Tests numba.analysis functions
import collections
import types as pytypes
import numpy as np
from numba.core.compiler import run_frontend, Flags, StateDict
from numba import jit, njit, literal_unroll
from numba.core import types, errors, ir, rewrites, ir_utils, utils, cpu
from numba.core import postproc
from numba.core.inline_closurecall import InlineClosureCallPass
from numba.tests.support import (TestCase, MemoryLeakMixin, SerialMixin,
IRPreservingTestPipeline)
from numba.core.analysis import dead_branch_prune, rewrite_semantic_constants
from numba.core.untyped_passes import (ReconstructSSA, TranslateByteCode,
IRProcessing, DeadBranchPrune,
PreserveIR)
from numba.core.compiler import DefaultPassBuilder, CompilerBase, PassManager
_GLOBAL = 123
enable_pyobj_flags = Flags()
enable_pyobj_flags.enable_pyobject = True
def compile_to_ir(func):
func_ir = run_frontend(func)
state = StateDict()
state.func_ir = func_ir
state.typemap = None
state.calltypes = None
# Transform to SSA
ReconstructSSA().run_pass(state)
# call this to get print etc rewrites
rewrites.rewrite_registry.apply('before-inference', state)
return func_ir
class TestBranchPruneBase(MemoryLeakMixin, TestCase):
"""
Tests branch pruning
"""
_DEBUG = False
# find *all* branches
def find_branches(self, the_ir):
branches = []
for blk in the_ir.blocks.values():
tmp = [_ for _ in blk.find_insts(cls=ir.Branch)]
branches.extend(tmp)
return branches
def assert_prune(self, func, args_tys, prune, *args, **kwargs):
# This checks that the expected pruned branches have indeed been pruned.
# func is a python function to assess
# args_tys is the numba types arguments tuple
# prune arg is a list, one entry per branch. The value in the entry is
# encoded as follows:
# True: using constant inference only, the True branch will be pruned
# False: using constant inference only, the False branch will be pruned
# None: under no circumstances should this branch be pruned
# *args: the argument instances to pass to the function to check
# execution is still valid post transform
# **kwargs:
# - flags: args to pass to `jit` default is `nopython=True`,
# e.g. permits use of e.g. object mode.
func_ir = compile_to_ir(func)
before = func_ir.copy()
if self._DEBUG:
print("=" * 80)
print("before inline")
func_ir.dump()
# run closure inlining to ensure that nonlocals in closures are visible
inline_pass = InlineClosureCallPass(func_ir,
cpu.ParallelOptions(False),)
inline_pass.run()
# Remove all Dels, and re-run postproc
post_proc = postproc.PostProcessor(func_ir)
post_proc.run()
rewrite_semantic_constants(func_ir, args_tys)
if self._DEBUG:
print("=" * 80)
print("before prune")
func_ir.dump()
dead_branch_prune(func_ir, args_tys)
after = func_ir
if self._DEBUG:
print("after prune")
func_ir.dump()
before_branches = self.find_branches(before)
self.assertEqual(len(before_branches), len(prune))
# what is expected to be pruned
expect_removed = []
for idx, prune in enumerate(prune):
branch = before_branches[idx]
if prune is True:
expect_removed.append(branch.truebr)
elif prune is False:
expect_removed.append(branch.falsebr)
elif prune is None:
pass # nothing should be removed!
elif prune == 'both':
expect_removed.append(branch.falsebr)
expect_removed.append(branch.truebr)
else:
assert 0, "unreachable"
# compare labels
original_labels = set([_ for _ in before.blocks.keys()])
new_labels = set([_ for _ in after.blocks.keys()])
# assert that the new labels are precisely the original less the
# expected pruned labels
try:
self.assertEqual(new_labels, original_labels - set(expect_removed))
except AssertionError as e:
print("new_labels", sorted(new_labels))
print("original_labels", sorted(original_labels))
print("expect_removed", sorted(expect_removed))
raise e
supplied_flags = kwargs.pop('flags', {'nopython': True})
# NOTE: original testing used `compile_isolated` hence use of `cres`.
cres = jit(args_tys, **supplied_flags)(func).overloads[args_tys]
if args is None:
res = cres.entry_point()
expected = func()
else:
res = cres.entry_point(*args)
expected = func(*args)
self.assertEqual(res, expected)
class TestBranchPrune(TestBranchPruneBase, SerialMixin):
def test_single_if(self):
def impl(x):
if 1 == 0:
return 3.14159
self.assert_prune(impl, (types.NoneType('none'),), [True], None)
def impl(x):
if 1 == 1:
return 3.14159
self.assert_prune(impl, (types.NoneType('none'),), [False], None)
def impl(x):
if x is None:
return 3.14159
self.assert_prune(impl, (types.NoneType('none'),), [False], None)
self.assert_prune(impl, (types.IntegerLiteral(10),), [True], 10)
def impl(x):
if x == 10:
return 3.14159
self.assert_prune(impl, (types.NoneType('none'),), [True], None)
self.assert_prune(impl, (types.IntegerLiteral(10),), [None], 10)
def impl(x):
if x == 10:
z = 3.14159 # noqa: F841 # no effect
self.assert_prune(impl, (types.NoneType('none'),), [True], None)
self.assert_prune(impl, (types.IntegerLiteral(10),), [None], 10)
def impl(x):
z = None
y = z
if x == y:
return 100
self.assert_prune(impl, (types.NoneType('none'),), [False], None)
self.assert_prune(impl, (types.IntegerLiteral(10),), [True], 10)
def test_single_if_else(self):
def impl(x):
if x is None:
return 3.14159
else:
return 1.61803
self.assert_prune(impl, (types.NoneType('none'),), [False], None)
self.assert_prune(impl, (types.IntegerLiteral(10),), [True], 10)
def test_single_if_const_val(self):
def impl(x):
if x == 100:
return 3.14159
self.assert_prune(impl, (types.NoneType('none'),), [True], None)
self.assert_prune(impl, (types.IntegerLiteral(100),), [None], 100)
def impl(x):
# switch the condition order
if 100 == x:
return 3.14159
self.assert_prune(impl, (types.NoneType('none'),), [True], None)
self.assert_prune(impl, (types.IntegerLiteral(100),), [None], 100)
def test_single_if_else_two_const_val(self):
def impl(x, y):
if x == y:
return 3.14159
else:
return 1.61803
self.assert_prune(impl, (types.IntegerLiteral(100),) * 2, [None], 100,
100)
self.assert_prune(impl, (types.NoneType('none'),) * 2, [False], None,
None)
self.assert_prune(impl, (types.IntegerLiteral(100),
types.NoneType('none'),), [True], 100, None)
self.assert_prune(impl, (types.IntegerLiteral(100),
types.IntegerLiteral(1000)), [None], 100, 1000)
def test_single_if_else_w_following_undetermined(self):
def impl(x):
x_is_none_work = False
if x is None:
x_is_none_work = True
else:
dead = 7 # noqa: F841 # no effect
if x_is_none_work:
y = 10
else:
y = -3
return y
self.assert_prune(impl, (types.NoneType('none'),), [False, None], None)
self.assert_prune(impl, (types.IntegerLiteral(10),), [True, None], 10)
def impl(x):
x_is_none_work = False
if x is None:
x_is_none_work = True
else:
pass
if x_is_none_work:
y = 10
else:
y = -3
return y
if utils.PYVERSION >= (3, 10):
# Python 3.10 creates a block with a NOP in it for the `pass` which
# means it gets pruned.
self.assert_prune(impl, (types.NoneType('none'),), [False, None],
None)
else:
self.assert_prune(impl, (types.NoneType('none'),), [None, None],
None)
self.assert_prune(impl, (types.IntegerLiteral(10),), [True, None], 10)
def test_double_if_else_rt_const(self):
def impl(x):
one_hundred = 100
x_is_none_work = 4
if x is None:
x_is_none_work = 100
else:
dead = 7 # noqa: F841 # no effect
if x_is_none_work == one_hundred:
y = 10
else:
y = -3
return y, x_is_none_work
self.assert_prune(impl, (types.NoneType('none'),), [False, None], None)
self.assert_prune(impl, (types.IntegerLiteral(10),), [True, None], 10)
def test_double_if_else_non_literal_const(self):
def impl(x):
one_hundred = 100
if x == one_hundred:
y = 3.14159
else:
y = 1.61803
return y
# no prune as compilation specialization on literal value not permitted
self.assert_prune(impl, (types.IntegerLiteral(10),), [None], 10)
self.assert_prune(impl, (types.IntegerLiteral(100),), [None], 100)
def test_single_two_branches_same_cond(self):
def impl(x):
if x is None:
y = 10
else:
y = 40
if x is not None:
z = 100
else:
z = 400
return z, y
self.assert_prune(impl, (types.NoneType('none'),), [False, True], None)
self.assert_prune(impl, (types.IntegerLiteral(10),), [True, False], 10)
def test_cond_is_kwarg_none(self):
def impl(x=None):
if x is None:
y = 10
else:
y = 40
if x is not None:
z = 100
else:
z = 400
return z, y
self.assert_prune(impl, (types.Omitted(None),),
[False, True], None)
self.assert_prune(impl, (types.NoneType('none'),), [False, True], None)
self.assert_prune(impl, (types.IntegerLiteral(10),), [True, False], 10)
def test_cond_is_kwarg_value(self):
def impl(x=1000):
if x == 1000:
y = 10
else:
y = 40
if x != 1000:
z = 100
else:
z = 400
return z, y
self.assert_prune(impl, (types.Omitted(1000),), [None, None], 1000)
self.assert_prune(impl, (types.IntegerLiteral(1000),), [None, None],
1000)
self.assert_prune(impl, (types.IntegerLiteral(0),), [None, None], 0)
self.assert_prune(impl, (types.NoneType('none'),), [True, False], None)
def test_cond_rewrite_is_correct(self):
# this checks that when a condition is replaced, it is replace by a
# true/false bit that correctly represents the evaluated condition
def fn(x):
if x is None:
return 10
return 12
def check(func, arg_tys, bit_val):
func_ir = compile_to_ir(func)
# check there is 1 branch
before_branches = self.find_branches(func_ir)
self.assertEqual(len(before_branches), 1)
# check the condition in the branch is a binop
pred_var = before_branches[0].cond
pred_defn = ir_utils.get_definition(func_ir, pred_var)
self.assertEqual(pred_defn.op, 'call')
condition_var = pred_defn.args[0]
condition_op = ir_utils.get_definition(func_ir, condition_var)
self.assertEqual(condition_op.op, 'binop')
# do the prune, this should kill the dead branch and rewrite the
#'condition to a true/false const bit
if self._DEBUG:
print("=" * 80)
print("before prune")
func_ir.dump()
dead_branch_prune(func_ir, arg_tys)
if self._DEBUG:
print("=" * 80)
print("after prune")
func_ir.dump()
# after mutation, the condition should be a const value `bit_val`
new_condition_defn = ir_utils.get_definition(func_ir, condition_var)
self.assertTrue(isinstance(new_condition_defn, ir.Const))
self.assertEqual(new_condition_defn.value, bit_val)
check(fn, (types.NoneType('none'),), 1)
check(fn, (types.IntegerLiteral(10),), 0)
def test_global_bake_in(self):
def impl(x):
if _GLOBAL == 123:
return x
else:
return x + 10
self.assert_prune(impl, (types.IntegerLiteral(1),), [False], 1)
global _GLOBAL
tmp = _GLOBAL
try:
_GLOBAL = 5
def impl(x):
if _GLOBAL == 123:
return x
else:
return x + 10
self.assert_prune(impl, (types.IntegerLiteral(1),), [True], 1)
finally:
_GLOBAL = tmp
def test_freevar_bake_in(self):
_FREEVAR = 123
def impl(x):
if _FREEVAR == 123:
return x
else:
return x + 10
self.assert_prune(impl, (types.IntegerLiteral(1),), [False], 1)
_FREEVAR = 12
def impl(x):
if _FREEVAR == 123:
return x
else:
return x + 10
self.assert_prune(impl, (types.IntegerLiteral(1),), [True], 1)
def test_redefined_variables_are_not_considered_in_prune(self):
# see issue #4163, checks that if a variable that is an argument is
# redefined in the user code it is not considered const
def impl(array, a=None):
if a is None:
a = 0
if a < 0:
return 10
return 30
self.assert_prune(impl,
(types.Array(types.float64, 2, 'C'),
types.NoneType('none'),),
[None, None],
np.zeros((2, 3)), None)
def test_comparison_operators(self):
# see issue #4163, checks that a variable that is an argument and has
# value None survives TypeError from invalid comparison which should be
# dead
def impl(array, a=None):
x = 0
if a is None:
return 10 # dynamic exec would return here
# static analysis requires that this is executed with a=None,
# hence TypeError
if a < 0:
return 20
return x
self.assert_prune(impl,
(types.Array(types.float64, 2, 'C'),
types.NoneType('none'),),
[False, 'both'],
np.zeros((2, 3)), None)
self.assert_prune(impl,
(types.Array(types.float64, 2, 'C'),
types.float64,),
[None, None],
np.zeros((2, 3)), 12.)
def test_redefinition_analysis_same_block(self):
# checks that a redefinition in a block with prunable potential doesn't
# break
def impl(array, x, a=None):
b = 2
if x < 4:
b = 12
if a is None: # known true
a = 7 # live
else:
b = 15 # dead
if a < 0: # valid as a result of the redefinition of 'a'
return 10
return 30 + b + a
self.assert_prune(impl,
(types.Array(types.float64, 2, 'C'),
types.float64, types.NoneType('none'),),
[None, False, None],
np.zeros((2, 3)), 1., None)
def test_redefinition_analysis_different_block_can_exec(self):
# checks that a redefinition in a block that may be executed prevents
# pruning
def impl(array, x, a=None):
b = 0
if x > 5:
a = 11 # a redefined, cannot tell statically if this will exec
if x < 4:
b = 12
if a is None: # cannot prune, cannot determine if re-defn occurred
b += 5
else:
b += 7
if a < 0:
return 10
return 30 + b
self.assert_prune(impl,
(types.Array(types.float64, 2, 'C'),
types.float64, types.NoneType('none'),),
[None, None, None, None],
np.zeros((2, 3)), 1., None)
def test_redefinition_analysis_different_block_cannot_exec(self):
# checks that a redefinition in a block guarded by something that
# has prune potential
def impl(array, x=None, a=None):
b = 0
if x is not None:
a = 11
if a is None:
b += 5
else:
b += 7
return 30 + b
self.assert_prune(impl,
(types.Array(types.float64, 2, 'C'),
types.NoneType('none'), types.NoneType('none')),
[True, None],
np.zeros((2, 3)), None, None)
self.assert_prune(impl,
(types.Array(types.float64, 2, 'C'),
types.NoneType('none'), types.float64),
[True, None],
np.zeros((2, 3)), None, 1.2)
self.assert_prune(impl,
(types.Array(types.float64, 2, 'C'),
types.float64, types.NoneType('none')),
[None, None],
np.zeros((2, 3)), 1.2, None)
def test_closure_and_nonlocal_can_prune(self):
# Closures must be inlined ahead of branch pruning in case nonlocal
# is used. See issue #6585.
def impl():
x = 1000
def closure():
nonlocal x
x = 0
closure()
if x == 0:
return True
else:
return False
self.assert_prune(impl, (), [False,],)
def test_closure_and_nonlocal_cannot_prune(self):
# Closures must be inlined ahead of branch pruning in case nonlocal
# is used. See issue #6585.
def impl(n):
x = 1000
def closure(t):
nonlocal x
x = t
closure(n)
if x == 0:
return True
else:
return False
self.assert_prune(impl, (types.int64,), [None,], 1)
class TestBranchPrunePredicates(TestBranchPruneBase, SerialMixin):
# Really important thing to remember... the branch on predicates end up as
# POP_JUMP_IF_<bool> and the targets are backwards compared to normal, i.e.
# the true condition is far jump and the false the near i.e. `if x` would
# end up in Numba IR as e.g. `branch x 10, 6`.
_TRUTHY = (1, "String", True, 7.4, 3j)
_FALSEY = (0, "", False, 0.0, 0j, None)
def _literal_const_sample_generator(self, pyfunc, consts):
"""
This takes a python function, pyfunc, and manipulates its co_const
__code__ member to create a new function with different co_consts as
supplied in argument consts.
consts is a dict {index: value} of co_const tuple index to constant
value used to update a pyfunc clone's co_const.
"""
pyfunc_code = pyfunc.__code__
# translate consts spec to update the constants
co_consts = {k: v for k, v in enumerate(pyfunc_code.co_consts)}
for k, v in consts.items():
co_consts[k] = v
new_consts = tuple([v for _, v in sorted(co_consts.items())])
# create code object with mutation
new_code = pyfunc_code.replace(co_consts=new_consts)
# get function
return pytypes.FunctionType(new_code, globals())
def test_literal_const_code_gen(self):
def impl(x):
_CONST1 = "PLACEHOLDER1"
if _CONST1:
return 3.14159
else:
_CONST2 = "PLACEHOLDER2"
return _CONST2 + 4
new = self._literal_const_sample_generator(impl, {1:0, 3:20})
iconst = impl.__code__.co_consts
nconst = new.__code__.co_consts
self.assertEqual(iconst, (None, "PLACEHOLDER1", 3.14159,
"PLACEHOLDER2", 4))
self.assertEqual(nconst, (None, 0, 3.14159, 20, 4))
self.assertEqual(impl(None), 3.14159)
self.assertEqual(new(None), 24)
def test_single_if_const(self):
def impl(x):
_CONST1 = "PLACEHOLDER1"
if _CONST1:
return 3.14159
for c_inp, prune in (self._TRUTHY, False), (self._FALSEY, True):
for const in c_inp:
func = self._literal_const_sample_generator(impl, {1: const})
self.assert_prune(func, (types.NoneType('none'),), [prune],
None)
def test_single_if_negate_const(self):
def impl(x):
_CONST1 = "PLACEHOLDER1"
if not _CONST1:
return 3.14159
for c_inp, prune in (self._TRUTHY, False), (self._FALSEY, True):
for const in c_inp:
func = self._literal_const_sample_generator(impl, {1: const})
self.assert_prune(func, (types.NoneType('none'),), [prune],
None)
def test_single_if_else_const(self):
def impl(x):
_CONST1 = "PLACEHOLDER1"
if _CONST1:
return 3.14159
else:
return 1.61803
for c_inp, prune in (self._TRUTHY, False), (self._FALSEY, True):
for const in c_inp:
func = self._literal_const_sample_generator(impl, {1: const})
self.assert_prune(func, (types.NoneType('none'),), [prune],
None)
def test_single_if_else_negate_const(self):
def impl(x):
_CONST1 = "PLACEHOLDER1"
if not _CONST1:
return 3.14159
else:
return 1.61803
for c_inp, prune in (self._TRUTHY, False), (self._FALSEY, True):
for const in c_inp:
func = self._literal_const_sample_generator(impl, {1: const})
self.assert_prune(func, (types.NoneType('none'),), [prune],
None)
def test_single_if_freevar(self):
for c_inp, prune in (self._TRUTHY, False), (self._FALSEY, True):
for const in c_inp:
def func(x):
if const:
return 3.14159, const
self.assert_prune(func, (types.NoneType('none'),), [prune],
None)
def test_single_if_negate_freevar(self):
for c_inp, prune in (self._TRUTHY, False), (self._FALSEY, True):
for const in c_inp:
def func(x):
if not const:
return 3.14159, const
self.assert_prune(func, (types.NoneType('none'),), [prune],
None)
def test_single_if_else_freevar(self):
for c_inp, prune in (self._TRUTHY, False), (self._FALSEY, True):
for const in c_inp:
def func(x):
if const:
return 3.14159, const
else:
return 1.61803, const
self.assert_prune(func, (types.NoneType('none'),), [prune],
None)
def test_single_if_else_negate_freevar(self):
for c_inp, prune in (self._TRUTHY, False), (self._FALSEY, True):
for const in c_inp:
def func(x):
if not const:
return 3.14159, const
else:
return 1.61803, const
self.assert_prune(func, (types.NoneType('none'),), [prune],
None)
# globals in this section have absurd names after their test usecase names
# so as to prevent collisions and permit tests to run in parallel
def test_single_if_global(self):
global c_test_single_if_global
for c_inp, prune in (self._TRUTHY, False), (self._FALSEY, True):
for c in c_inp:
c_test_single_if_global = c
def func(x):
if c_test_single_if_global:
return 3.14159, c_test_single_if_global
self.assert_prune(func, (types.NoneType('none'),), [prune],
None)
def test_single_if_negate_global(self):
global c_test_single_if_negate_global
for c_inp, prune in (self._TRUTHY, False), (self._FALSEY, True):
for c in c_inp:
c_test_single_if_negate_global = c
def func(x):
if c_test_single_if_negate_global:
return 3.14159, c_test_single_if_negate_global
self.assert_prune(func, (types.NoneType('none'),), [prune],
None)
def test_single_if_else_global(self):
global c_test_single_if_else_global
for c_inp, prune in (self._TRUTHY, False), (self._FALSEY, True):
for c in c_inp:
c_test_single_if_else_global = c
def func(x):
if c_test_single_if_else_global:
return 3.14159, c_test_single_if_else_global
else:
return 1.61803, c_test_single_if_else_global
self.assert_prune(func, (types.NoneType('none'),), [prune],
None)
def test_single_if_else_negate_global(self):
global c_test_single_if_else_negate_global
for c_inp, prune in (self._TRUTHY, False), (self._FALSEY, True):
for c in c_inp:
c_test_single_if_else_negate_global = c
def func(x):
if not c_test_single_if_else_negate_global:
return 3.14159, c_test_single_if_else_negate_global
else:
return 1.61803, c_test_single_if_else_negate_global
self.assert_prune(func, (types.NoneType('none'),), [prune],
None)
def test_issue_5618(self):
@njit
def foo():
values = np.zeros(1)
tmp = 666
if tmp:
values[0] = tmp
return values
self.assertPreciseEqual(foo.py_func()[0], 666.)
self.assertPreciseEqual(foo()[0], 666.)
class TestBranchPruneSSA(MemoryLeakMixin, TestCase):
# Tests SSA rewiring of phi nodes after branch pruning.
class SSAPrunerCompiler(CompilerBase):
def define_pipelines(self):
# This is a simple pipeline that does branch pruning on IR in SSA
# form, then types and lowers as per the standard nopython pipeline.
pm = PassManager("testing pm")
pm.add_pass(TranslateByteCode, "analyzing bytecode")
pm.add_pass(IRProcessing, "processing IR")
# SSA early
pm.add_pass(ReconstructSSA, "ssa")
pm.add_pass(DeadBranchPrune, "dead branch pruning")
# type and then lower as usual
pm.add_pass(PreserveIR, "preserves the IR as metadata")
dpb = DefaultPassBuilder
typed_passes = dpb.define_typed_pipeline(self.state)
pm.passes.extend(typed_passes.passes)
lowering_passes = dpb.define_nopython_lowering_pipeline(self.state)
pm.passes.extend(lowering_passes.passes)
pm.finalize()
return [pm]
def test_ssa_update_phi(self):
# This checks that dead branch pruning is rewiring phi nodes correctly
# after a block containing an incoming for a phi is removed.
@njit(pipeline_class=self.SSAPrunerCompiler)
def impl(p=None, q=None):
z = 1
r = False
if p is None:
r = True # live
if r and q is not None:
z = 20 # dead
# one of the incoming blocks for z is dead, the phi needs an update
# were this not done, it would refer to variables that do not exist
# and result in a lowering error.
return z, r
self.assertPreciseEqual(impl(), impl.py_func())
def test_ssa_replace_phi(self):
# This checks that when a phi only has one incoming, because the other
# has been pruned, that a direct assignment is used instead.
@njit(pipeline_class=self.SSAPrunerCompiler)
def impl(p=None):
z = 0
if p is None:
z = 10
else:
z = 20
return z
self.assertPreciseEqual(impl(), impl.py_func())
func_ir = impl.overloads[impl.signatures[0]].metadata['preserved_ir']
# check the func_ir, make sure there's no phi nodes
for blk in func_ir.blocks.values():
self.assertFalse([*blk.find_exprs('phi')])
class TestBranchPrunePostSemanticConstRewrites(TestBranchPruneBase):
# Tests that semantic constants rewriting works by virtue of branch pruning
def test_array_ndim_attr(self):
def impl(array):
if array.ndim == 2:
if array.shape[1] == 2:
return 1
else:
return 10
self.assert_prune(impl, (types.Array(types.float64, 2, 'C'),), [False,
None],
np.zeros((2, 3)))
self.assert_prune(impl, (types.Array(types.float64, 1, 'C'),), [True,
'both'],
np.zeros((2,)))
def test_tuple_len(self):
def impl(tup):
if len(tup) == 3:
if tup[2] == 2:
return 1
else:
return 0
self.assert_prune(impl, (types.UniTuple(types.int64, 3),), [False,
None],
tuple([1, 2, 3]))
self.assert_prune(impl, (types.UniTuple(types.int64, 2),), [True,
'both'],
tuple([1, 2]))
def test_attr_not_len(self):
# The purpose of this test is to make sure that the conditions guarding
# the rewrite part do not themselves raise exceptions.
# This produces an `ir.Expr` call node for `float.as_integer_ratio`,
# which is a getattr() on `float`.
@njit
def test():
float.as_integer_ratio(1.23)
# this should raise a TypingError
with self.assertRaises(errors.TypingError) as e:
test()
self.assertIn("Unknown attribute 'as_integer_ratio'", str(e.exception))
def test_ndim_not_on_array(self):
FakeArray = collections.namedtuple('FakeArray', ['ndim'])
fa = FakeArray(ndim=2)
def impl(fa):
if fa.ndim == 2:
return fa.ndim
else:
object()
# check prune works for array ndim
self.assert_prune(impl, (types.Array(types.float64, 2, 'C'),), [False],
np.zeros((2, 3)))
# check prune fails for something with `ndim` attr that is not array
FakeArrayType = types.NamedUniTuple(types.int64, 1, FakeArray)
self.assert_prune(impl, (FakeArrayType,), [None], fa,
flags={'nopython':False, 'forceobj':True})
def test_semantic_const_propagates_before_static_rewrites(self):
# see issue #5015, the ndim needs writing in as a const before
# the rewrite passes run to make e.g. getitems static where possible
@njit
def impl(a, b):
return a.shape[:b.ndim]
args = (np.zeros((5, 4, 3, 2)), np.zeros((1, 1)))
self.assertPreciseEqual(impl(*args), impl.py_func(*args))
def test_tuple_const_propagation(self):
@njit(pipeline_class=IRPreservingTestPipeline)
def impl(*args):
s = 0
for arg in literal_unroll(args):
s += len(arg)
return s
inp = ((), (1, 2, 3), ())
self.assertPreciseEqual(impl(*inp), impl.py_func(*inp))
ol = impl.overloads[impl.signatures[0]]
func_ir = ol.metadata['preserved_ir']
# make sure one of the inplace binop args is a Const
binop_consts = set()
for blk in func_ir.blocks.values():
for expr in blk.find_exprs('inplace_binop'):
inst = blk.find_variable_assignment(expr.rhs.name)
self.assertIsInstance(inst.value, ir.Const)
binop_consts.add(inst.value.value)
self.assertEqual(binop_consts, {len(x) for x in inp})