341 lines
12 KiB
Python
341 lines
12 KiB
Python
|
import os
|
||
|
import platform
|
||
|
import re
|
||
|
import textwrap
|
||
|
import warnings
|
||
|
|
||
|
import numpy as np
|
||
|
|
||
|
from numba.tests.support import (TestCase, override_config, override_env_config,
|
||
|
captured_stdout, forbid_codegen, skip_parfors_unsupported,
|
||
|
needs_blas)
|
||
|
from numba import jit, njit
|
||
|
from numba.core import types, compiler, utils
|
||
|
from numba.core.errors import NumbaPerformanceWarning
|
||
|
from numba import prange
|
||
|
from numba.experimental import jitclass
|
||
|
import unittest
|
||
|
|
||
|
|
||
|
def simple_nopython(somearg):
|
||
|
retval = somearg + 1
|
||
|
return retval
|
||
|
|
||
|
def simple_gen(x, y):
|
||
|
yield x
|
||
|
yield y
|
||
|
|
||
|
|
||
|
class SimpleClass(object):
|
||
|
def __init__(self):
|
||
|
self.h = 5
|
||
|
|
||
|
simple_class_spec = [('h', types.int32)]
|
||
|
|
||
|
def simple_class_user(obj):
|
||
|
return obj.h
|
||
|
|
||
|
def unsupported_parfor(a, b):
|
||
|
return np.dot(a, b) # dot as gemm unsupported
|
||
|
|
||
|
def supported_parfor(n):
|
||
|
a = np.ones(n)
|
||
|
for i in prange(n):
|
||
|
a[i] = a[i] + np.sin(i)
|
||
|
return a
|
||
|
|
||
|
def unsupported_prange(n):
|
||
|
a = np.ones(n)
|
||
|
for i in prange(n):
|
||
|
a[i] = a[i] + np.sin(i)
|
||
|
assert i + 13 < 100000
|
||
|
return a
|
||
|
|
||
|
|
||
|
class DebugTestBase(TestCase):
|
||
|
|
||
|
all_dumps = set(['bytecode', 'cfg', 'ir', 'typeinfer', 'llvm',
|
||
|
'func_opt_llvm', 'optimized_llvm', 'assembly'])
|
||
|
|
||
|
def assert_fails(self, *args, **kwargs):
|
||
|
self.assertRaises(AssertionError, *args, **kwargs)
|
||
|
|
||
|
def check_debug_output(self, out, dump_names):
|
||
|
enabled_dumps = dict.fromkeys(self.all_dumps, False)
|
||
|
for name in dump_names:
|
||
|
assert name in enabled_dumps
|
||
|
enabled_dumps[name] = True
|
||
|
for name, enabled in sorted(enabled_dumps.items()):
|
||
|
check_meth = getattr(self, '_check_dump_%s' % name)
|
||
|
if enabled:
|
||
|
check_meth(out)
|
||
|
else:
|
||
|
self.assert_fails(check_meth, out)
|
||
|
|
||
|
def _check_dump_bytecode(self, out):
|
||
|
if utils.PYVERSION in ((3, 11), (3, 12)):
|
||
|
self.assertIn('BINARY_OP', out)
|
||
|
elif utils.PYVERSION in ((3, 9), (3, 10)):
|
||
|
self.assertIn('BINARY_ADD', out)
|
||
|
else:
|
||
|
raise NotImplementedError(utils.PYVERSION)
|
||
|
|
||
|
def _check_dump_cfg(self, out):
|
||
|
self.assertIn('CFG dominators', out)
|
||
|
|
||
|
def _check_dump_ir(self, out):
|
||
|
self.assertIn('--IR DUMP: %s--' % self.func_name, out)
|
||
|
|
||
|
def _check_dump_typeinfer(self, out):
|
||
|
self.assertIn('--propagate--', out)
|
||
|
|
||
|
def _check_dump_llvm(self, out):
|
||
|
self.assertIn('--LLVM DUMP', out)
|
||
|
if compiler.Flags.options["auto_parallel"].default.enabled == False:
|
||
|
self.assertRegex(out, r'store i64 %\"\.\d", i64\* %"retptr"', out)
|
||
|
|
||
|
def _check_dump_func_opt_llvm(self, out):
|
||
|
self.assertIn('--FUNCTION OPTIMIZED DUMP %s' % self.func_name, out)
|
||
|
# allocas have been optimized away
|
||
|
self.assertIn('add nsw i64 %arg.somearg, 1', out)
|
||
|
|
||
|
def _check_dump_optimized_llvm(self, out):
|
||
|
self.assertIn('--OPTIMIZED DUMP %s' % self.func_name, out)
|
||
|
self.assertIn('add nsw i64 %arg.somearg, 1', out)
|
||
|
|
||
|
def _check_dump_assembly(self, out):
|
||
|
self.assertIn('--ASSEMBLY %s' % self.func_name, out)
|
||
|
if platform.machine() in ('x86_64', 'AMD64', 'i386', 'i686'):
|
||
|
self.assertIn('xorl', out)
|
||
|
|
||
|
|
||
|
class FunctionDebugTestBase(DebugTestBase):
|
||
|
|
||
|
func_name = 'simple_nopython'
|
||
|
|
||
|
def compile_simple_nopython(self):
|
||
|
with captured_stdout() as out:
|
||
|
cfunc = njit((types.int64,))(simple_nopython)
|
||
|
# Sanity check compiled function
|
||
|
self.assertPreciseEqual(cfunc(2), 3)
|
||
|
return out.getvalue()
|
||
|
|
||
|
|
||
|
class TestFunctionDebugOutput(FunctionDebugTestBase):
|
||
|
|
||
|
def test_dump_bytecode(self):
|
||
|
with override_config('DUMP_BYTECODE', True):
|
||
|
out = self.compile_simple_nopython()
|
||
|
self.check_debug_output(out, ['bytecode'])
|
||
|
|
||
|
def test_dump_ir(self):
|
||
|
with override_config('DUMP_IR', True):
|
||
|
out = self.compile_simple_nopython()
|
||
|
self.check_debug_output(out, ['ir'])
|
||
|
|
||
|
def test_dump_cfg(self):
|
||
|
with override_config('DUMP_CFG', True):
|
||
|
out = self.compile_simple_nopython()
|
||
|
self.check_debug_output(out, ['cfg'])
|
||
|
|
||
|
def test_dump_llvm(self):
|
||
|
with override_config('DUMP_LLVM', True):
|
||
|
out = self.compile_simple_nopython()
|
||
|
self.check_debug_output(out, ['llvm'])
|
||
|
|
||
|
def test_dump_func_opt_llvm(self):
|
||
|
with override_config('DUMP_FUNC_OPT', True):
|
||
|
out = self.compile_simple_nopython()
|
||
|
self.check_debug_output(out, ['func_opt_llvm'])
|
||
|
|
||
|
def test_dump_optimized_llvm(self):
|
||
|
with override_config('DUMP_OPTIMIZED', True):
|
||
|
out = self.compile_simple_nopython()
|
||
|
self.check_debug_output(out, ['optimized_llvm'])
|
||
|
|
||
|
def test_dump_assembly(self):
|
||
|
with override_config('DUMP_ASSEMBLY', True):
|
||
|
out = self.compile_simple_nopython()
|
||
|
self.check_debug_output(out, ['assembly'])
|
||
|
|
||
|
|
||
|
class TestGeneratorDebugOutput(DebugTestBase):
|
||
|
|
||
|
func_name = 'simple_gen'
|
||
|
|
||
|
def compile_simple_gen(self):
|
||
|
with captured_stdout() as out:
|
||
|
cfunc = njit((types.int64, types.int64))(simple_gen)
|
||
|
# Sanity check compiled function
|
||
|
self.assertPreciseEqual(list(cfunc(2, 5)), [2, 5])
|
||
|
return out.getvalue()
|
||
|
|
||
|
def test_dump_ir_generator(self):
|
||
|
with override_config('DUMP_IR', True):
|
||
|
out = self.compile_simple_gen()
|
||
|
self.check_debug_output(out, ['ir'])
|
||
|
self.assertIn('--GENERATOR INFO: %s' % self.func_name, out)
|
||
|
expected_gen_info = textwrap.dedent("""
|
||
|
generator state variables: ['x', 'y']
|
||
|
yield point #1: live variables = ['y'], weak live variables = ['x']
|
||
|
yield point #2: live variables = [], weak live variables = ['y']
|
||
|
""")
|
||
|
self.assertIn(expected_gen_info, out)
|
||
|
|
||
|
|
||
|
class TestDisableJIT(DebugTestBase):
|
||
|
"""
|
||
|
Test the NUMBA_DISABLE_JIT environment variable.
|
||
|
"""
|
||
|
|
||
|
def test_jit(self):
|
||
|
with override_config('DISABLE_JIT', True):
|
||
|
with forbid_codegen():
|
||
|
cfunc = jit(nopython=True)(simple_nopython)
|
||
|
self.assertPreciseEqual(cfunc(2), 3)
|
||
|
|
||
|
def test_jitclass(self):
|
||
|
with override_config('DISABLE_JIT', True):
|
||
|
with forbid_codegen():
|
||
|
SimpleJITClass = jitclass(simple_class_spec)(SimpleClass)
|
||
|
|
||
|
obj = SimpleJITClass()
|
||
|
self.assertPreciseEqual(obj.h, 5)
|
||
|
|
||
|
cfunc = jit(nopython=True)(simple_class_user)
|
||
|
self.assertPreciseEqual(cfunc(obj), 5)
|
||
|
|
||
|
|
||
|
class TestEnvironmentOverride(FunctionDebugTestBase):
|
||
|
"""
|
||
|
Test that environment variables are reloaded by Numba when modified.
|
||
|
"""
|
||
|
|
||
|
# mutates env with os.environ so must be run serially
|
||
|
_numba_parallel_test_ = False
|
||
|
|
||
|
def test_debug(self):
|
||
|
out = self.compile_simple_nopython()
|
||
|
self.assertFalse(out)
|
||
|
with override_env_config('NUMBA_DEBUG', '1'):
|
||
|
out = self.compile_simple_nopython()
|
||
|
# Note that all variables dependent on NUMBA_DEBUG are
|
||
|
# updated too.
|
||
|
self.check_debug_output(out, ['ir', 'typeinfer',
|
||
|
'llvm', 'func_opt_llvm',
|
||
|
'optimized_llvm', 'assembly'])
|
||
|
out = self.compile_simple_nopython()
|
||
|
self.assertFalse(out)
|
||
|
|
||
|
class TestParforsDebug(TestCase):
|
||
|
"""
|
||
|
Tests debug options associated with parfors
|
||
|
"""
|
||
|
|
||
|
# mutates env with os.environ so must be run serially
|
||
|
_numba_parallel_test_ = False
|
||
|
|
||
|
def check_parfors_warning(self, warn_list):
|
||
|
msg = ("'parallel=True' was specified but no transformation for "
|
||
|
"parallel execution was possible.")
|
||
|
warning_found = False
|
||
|
for w in warn_list:
|
||
|
if msg in str(w.message):
|
||
|
warning_found = True
|
||
|
break
|
||
|
self.assertTrue(warning_found, "Warning message should be found.")
|
||
|
|
||
|
def check_parfors_unsupported_prange_warning(self, warn_list):
|
||
|
msg = ("prange or pndindex loop will not be executed in parallel "
|
||
|
"due to there being more than one entry to or exit from the "
|
||
|
"loop (e.g., an assertion).")
|
||
|
warning_found = False
|
||
|
for w in warn_list:
|
||
|
if msg in str(w.message):
|
||
|
warning_found = True
|
||
|
break
|
||
|
self.assertTrue(warning_found, "Warning message should be found.")
|
||
|
|
||
|
@needs_blas
|
||
|
@skip_parfors_unsupported
|
||
|
def test_warns(self):
|
||
|
"""
|
||
|
Test that using parallel=True on a function that does not have parallel
|
||
|
semantics warns.
|
||
|
"""
|
||
|
arr_ty = types.Array(types.float64, 2, "C")
|
||
|
with warnings.catch_warnings(record=True) as w:
|
||
|
warnings.simplefilter("always", NumbaPerformanceWarning)
|
||
|
njit((arr_ty, arr_ty), parallel=True)(unsupported_parfor)
|
||
|
self.check_parfors_warning(w)
|
||
|
|
||
|
@needs_blas
|
||
|
@skip_parfors_unsupported
|
||
|
def test_unsupported_prange_warns(self):
|
||
|
"""
|
||
|
Test that prange with multiple exits issues a warning
|
||
|
"""
|
||
|
with warnings.catch_warnings(record=True) as w:
|
||
|
warnings.simplefilter("always", NumbaPerformanceWarning)
|
||
|
njit((types.int64,), parallel=True)(unsupported_prange)
|
||
|
self.check_parfors_unsupported_prange_warning(w)
|
||
|
|
||
|
@skip_parfors_unsupported
|
||
|
def test_array_debug_opt_stats(self):
|
||
|
"""
|
||
|
Test that NUMBA_DEBUG_ARRAY_OPT_STATS produces valid output
|
||
|
"""
|
||
|
# deliberately trigger a compilation loop to increment the
|
||
|
# Parfor class state, this is to ensure the test works based
|
||
|
# on indices computed based on this state and not hard coded
|
||
|
# indices.
|
||
|
njit((types.int64,), parallel=True)(supported_parfor)
|
||
|
|
||
|
with override_env_config('NUMBA_DEBUG_ARRAY_OPT_STATS', '1'):
|
||
|
with captured_stdout() as out:
|
||
|
njit((types.int64,), parallel=True)(supported_parfor)
|
||
|
|
||
|
# grab the various parts out the output
|
||
|
output = out.getvalue().split('\n')
|
||
|
parallel_loop_output = \
|
||
|
[x for x in output if 'is produced from pattern' in x]
|
||
|
fuse_output = \
|
||
|
[x for x in output if 'is fused into' in x]
|
||
|
after_fusion_output = \
|
||
|
[x for x in output if 'After fusion, function' in x]
|
||
|
|
||
|
# Parfor's have a shared state index, grab the current value
|
||
|
# as it will be used as an offset for all loop messages
|
||
|
parfor_state = int(re.compile(r'#([0-9]+)').search(
|
||
|
parallel_loop_output[0]).group(1))
|
||
|
bounds = range(parfor_state,
|
||
|
parfor_state + len(parallel_loop_output))
|
||
|
|
||
|
# Check the Parallel for-loop <index> is produced from <pattern>
|
||
|
# works first
|
||
|
pattern = ("('ones function', 'NumPy mapping')",
|
||
|
('prange', 'user', ''))
|
||
|
fmt = 'Parallel for-loop #{} is produced from pattern \'{}\' at'
|
||
|
for i, trials, lpattern in zip(bounds, parallel_loop_output,
|
||
|
pattern):
|
||
|
to_match = fmt.format(i, lpattern)
|
||
|
self.assertIn(to_match, trials)
|
||
|
|
||
|
# Check the fusion statements are correct
|
||
|
pattern = (parfor_state + 1, parfor_state + 0)
|
||
|
fmt = 'Parallel for-loop #{} is fused into for-loop #{}.'
|
||
|
for trials in fuse_output:
|
||
|
to_match = fmt.format(*pattern)
|
||
|
self.assertIn(to_match, trials)
|
||
|
|
||
|
# Check the post fusion statements are correct
|
||
|
pattern = (supported_parfor.__name__, 1, set([parfor_state]))
|
||
|
fmt = 'After fusion, function {} has {} parallel for-loop(s) #{}.'
|
||
|
for trials in after_fusion_output:
|
||
|
to_match = fmt.format(*pattern)
|
||
|
self.assertIn(to_match, trials)
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
unittest.main()
|