185 lines
5.7 KiB
Python
185 lines
5.7 KiB
Python
import unittest
|
|
import warnings
|
|
from contextlib import contextmanager
|
|
|
|
import numpy as np
|
|
|
|
import llvmlite.binding as llvm
|
|
|
|
from numba import njit, types
|
|
from numba.core.errors import NumbaInvalidConfigWarning
|
|
from numba.core.codegen import _parse_refprune_flags
|
|
from numba.tests.support import override_config, TestCase
|
|
|
|
|
|
@contextmanager
|
|
def set_refprune_flags(flags):
|
|
with override_config('LLVM_REFPRUNE_FLAGS', flags):
|
|
yield
|
|
|
|
|
|
class TestRefOpPruning(TestCase):
|
|
|
|
_numba_parallel_test_ = False
|
|
|
|
def check(self, func, *argtys, **prune_types):
|
|
"""
|
|
Asserts the the func compiled with argument types "argtys" reports
|
|
refop pruning statistics. The **prune_types** kwargs list each kind
|
|
of pruning and whether the stat should be zero (False) or >0 (True).
|
|
|
|
Note: The exact statistic varies across platform.
|
|
|
|
NOTE: Tests using this `check` method need to run in subprocesses as
|
|
`njit` sets up the module pass manager etc once and the overrides have
|
|
no effect else.
|
|
"""
|
|
|
|
with override_config('LLVM_REFPRUNE_PASS', '1'):
|
|
cres = njit((*argtys,))(func).overloads[(*argtys,)]
|
|
|
|
pstats = cres.metadata.get('prune_stats', None)
|
|
self.assertIsNotNone(pstats)
|
|
|
|
for k, v in prune_types.items():
|
|
stat = getattr(pstats, k, None)
|
|
self.assertIsNotNone(stat)
|
|
msg = f'failed checking {k}'
|
|
if v:
|
|
self.assertGreater(stat, 0, msg=msg)
|
|
else:
|
|
self.assertEqual(stat, 0, msg=msg)
|
|
|
|
@TestCase.run_test_in_subprocess
|
|
def test_basic_block_1(self):
|
|
# some nominally involved control flow and ops, there's only basic_block
|
|
# opportunities present here.
|
|
def func(n):
|
|
a = np.zeros(n)
|
|
acc = 0
|
|
if n > 4:
|
|
b = a[1:]
|
|
acc += b[1]
|
|
else:
|
|
c = a[:-1]
|
|
acc += c[0]
|
|
return acc
|
|
|
|
self.check(func, (types.intp), basicblock=True)
|
|
|
|
@TestCase.run_test_in_subprocess
|
|
def test_diamond_1(self):
|
|
# most basic?! diamond
|
|
def func(n):
|
|
a = np.ones(n)
|
|
x = 0
|
|
if n > 2:
|
|
x = a.sum()
|
|
return x + 1
|
|
|
|
# disable fanout pruning
|
|
with set_refprune_flags('per_bb,diamond'):
|
|
self.check(func, (types.intp), basicblock=True, diamond=True,
|
|
fanout=False, fanout_raise=False)
|
|
|
|
@TestCase.run_test_in_subprocess
|
|
def test_diamond_2(self):
|
|
# more complex diamonds
|
|
def func(n):
|
|
con = []
|
|
for i in range(n):
|
|
con.append(np.arange(i))
|
|
c = 0.0
|
|
for arr in con:
|
|
c += arr.sum() / (1 + arr.size)
|
|
return c
|
|
|
|
# disable fanout pruning
|
|
with set_refprune_flags('per_bb,diamond'):
|
|
self.check(func, (types.intp), basicblock=True, diamond=True,
|
|
fanout=False, fanout_raise=False)
|
|
|
|
@TestCase.run_test_in_subprocess
|
|
def test_fanout_1(self):
|
|
# most basic?! fan-out
|
|
def func(n):
|
|
a = np.zeros(n)
|
|
b = np.zeros(n)
|
|
x = (a, b)
|
|
acc = 0.
|
|
for i in x:
|
|
acc += i[0]
|
|
return acc
|
|
|
|
self.check(func, (types.intp), basicblock=True, fanout=True)
|
|
|
|
@TestCase.run_test_in_subprocess
|
|
def test_fanout_2(self):
|
|
# fanout with raise
|
|
def func(n):
|
|
a = np.zeros(n)
|
|
b = np.zeros(n)
|
|
x = (a, b)
|
|
for i in x:
|
|
if n:
|
|
raise ValueError
|
|
return x
|
|
|
|
with set_refprune_flags('per_bb,fanout'):
|
|
self.check(func, (types.intp), basicblock=True, diamond=False,
|
|
fanout=True, fanout_raise=False)
|
|
|
|
@TestCase.run_test_in_subprocess
|
|
def test_fanout_3(self):
|
|
# fanout with raise
|
|
def func(n):
|
|
ary = np.arange(n)
|
|
# basically an impl of array.sum
|
|
c = 0
|
|
# The raise is from StopIteration of next(iterator) implicit in
|
|
# the for loop
|
|
for v in np.nditer(ary):
|
|
c += v.item()
|
|
return 1
|
|
|
|
with set_refprune_flags('per_bb,fanout_raise'):
|
|
self.check(func, (types.intp), basicblock=True, diamond=False,
|
|
fanout=False, fanout_raise=True)
|
|
|
|
|
|
class TestRefPruneFlags(TestCase):
|
|
def setUp(self):
|
|
warnings.simplefilter('error', NumbaInvalidConfigWarning)
|
|
|
|
def tearDown(self):
|
|
warnings.resetwarnings()
|
|
|
|
def test_warn_invalid_flags(self):
|
|
with set_refprune_flags('abc,per_bb,cde'):
|
|
with self.assertWarns(NumbaInvalidConfigWarning) as cm:
|
|
optval = _parse_refprune_flags()
|
|
self.assertEqual(len(cm.warnings), 2)
|
|
self.assertIn('abc', str(cm.warnings[0].message))
|
|
self.assertIn('cde', str(cm.warnings[1].message))
|
|
self.assertEqual(optval, llvm.RefPruneSubpasses.PER_BB)
|
|
|
|
def test_valid_flag(self):
|
|
with set_refprune_flags('per_bb, diamond, fanout,fanout_raise'):
|
|
optval = _parse_refprune_flags()
|
|
self.assertEqual(optval, llvm.RefPruneSubpasses.ALL)
|
|
|
|
def test_the_all_flag(self):
|
|
with set_refprune_flags('all'):
|
|
optval = _parse_refprune_flags()
|
|
self.assertEqual(optval, llvm.RefPruneSubpasses.ALL)
|
|
|
|
def test_some_flags(self):
|
|
with set_refprune_flags('per_bb, fanout'):
|
|
optval = _parse_refprune_flags()
|
|
enumcls = llvm.RefPruneSubpasses
|
|
self.assertEqual(optval, enumcls.PER_BB | enumcls.FANOUT)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|