ai-content-maker/.venv/Lib/site-packages/numba/tests/test_refop_pruning.py

185 lines
5.7 KiB
Python
Raw Normal View History

2024-05-03 04:18:51 +03:00
import unittest
import warnings
from contextlib import contextmanager
import numpy as np
import llvmlite.binding as llvm
from numba import njit, types
from numba.core.errors import NumbaInvalidConfigWarning
from numba.core.codegen import _parse_refprune_flags
from numba.tests.support import override_config, TestCase
@contextmanager
def set_refprune_flags(flags):
with override_config('LLVM_REFPRUNE_FLAGS', flags):
yield
class TestRefOpPruning(TestCase):
_numba_parallel_test_ = False
def check(self, func, *argtys, **prune_types):
"""
Asserts the the func compiled with argument types "argtys" reports
refop pruning statistics. The **prune_types** kwargs list each kind
of pruning and whether the stat should be zero (False) or >0 (True).
Note: The exact statistic varies across platform.
NOTE: Tests using this `check` method need to run in subprocesses as
`njit` sets up the module pass manager etc once and the overrides have
no effect else.
"""
with override_config('LLVM_REFPRUNE_PASS', '1'):
cres = njit((*argtys,))(func).overloads[(*argtys,)]
pstats = cres.metadata.get('prune_stats', None)
self.assertIsNotNone(pstats)
for k, v in prune_types.items():
stat = getattr(pstats, k, None)
self.assertIsNotNone(stat)
msg = f'failed checking {k}'
if v:
self.assertGreater(stat, 0, msg=msg)
else:
self.assertEqual(stat, 0, msg=msg)
@TestCase.run_test_in_subprocess
def test_basic_block_1(self):
# some nominally involved control flow and ops, there's only basic_block
# opportunities present here.
def func(n):
a = np.zeros(n)
acc = 0
if n > 4:
b = a[1:]
acc += b[1]
else:
c = a[:-1]
acc += c[0]
return acc
self.check(func, (types.intp), basicblock=True)
@TestCase.run_test_in_subprocess
def test_diamond_1(self):
# most basic?! diamond
def func(n):
a = np.ones(n)
x = 0
if n > 2:
x = a.sum()
return x + 1
# disable fanout pruning
with set_refprune_flags('per_bb,diamond'):
self.check(func, (types.intp), basicblock=True, diamond=True,
fanout=False, fanout_raise=False)
@TestCase.run_test_in_subprocess
def test_diamond_2(self):
# more complex diamonds
def func(n):
con = []
for i in range(n):
con.append(np.arange(i))
c = 0.0
for arr in con:
c += arr.sum() / (1 + arr.size)
return c
# disable fanout pruning
with set_refprune_flags('per_bb,diamond'):
self.check(func, (types.intp), basicblock=True, diamond=True,
fanout=False, fanout_raise=False)
@TestCase.run_test_in_subprocess
def test_fanout_1(self):
# most basic?! fan-out
def func(n):
a = np.zeros(n)
b = np.zeros(n)
x = (a, b)
acc = 0.
for i in x:
acc += i[0]
return acc
self.check(func, (types.intp), basicblock=True, fanout=True)
@TestCase.run_test_in_subprocess
def test_fanout_2(self):
# fanout with raise
def func(n):
a = np.zeros(n)
b = np.zeros(n)
x = (a, b)
for i in x:
if n:
raise ValueError
return x
with set_refprune_flags('per_bb,fanout'):
self.check(func, (types.intp), basicblock=True, diamond=False,
fanout=True, fanout_raise=False)
@TestCase.run_test_in_subprocess
def test_fanout_3(self):
# fanout with raise
def func(n):
ary = np.arange(n)
# basically an impl of array.sum
c = 0
# The raise is from StopIteration of next(iterator) implicit in
# the for loop
for v in np.nditer(ary):
c += v.item()
return 1
with set_refprune_flags('per_bb,fanout_raise'):
self.check(func, (types.intp), basicblock=True, diamond=False,
fanout=False, fanout_raise=True)
class TestRefPruneFlags(TestCase):
def setUp(self):
warnings.simplefilter('error', NumbaInvalidConfigWarning)
def tearDown(self):
warnings.resetwarnings()
def test_warn_invalid_flags(self):
with set_refprune_flags('abc,per_bb,cde'):
with self.assertWarns(NumbaInvalidConfigWarning) as cm:
optval = _parse_refprune_flags()
self.assertEqual(len(cm.warnings), 2)
self.assertIn('abc', str(cm.warnings[0].message))
self.assertIn('cde', str(cm.warnings[1].message))
self.assertEqual(optval, llvm.RefPruneSubpasses.PER_BB)
def test_valid_flag(self):
with set_refprune_flags('per_bb, diamond, fanout,fanout_raise'):
optval = _parse_refprune_flags()
self.assertEqual(optval, llvm.RefPruneSubpasses.ALL)
def test_the_all_flag(self):
with set_refprune_flags('all'):
optval = _parse_refprune_flags()
self.assertEqual(optval, llvm.RefPruneSubpasses.ALL)
def test_some_flags(self):
with set_refprune_flags('per_bb, fanout'):
optval = _parse_refprune_flags()
enumcls = llvm.RefPruneSubpasses
self.assertEqual(optval, enumcls.PER_BB | enumcls.FANOUT)
if __name__ == "__main__":
unittest.main()