ai-content-maker/.venv/Lib/site-packages/numba/tests/test_debug.py

341 lines
12 KiB
Python

import os
import platform
import re
import textwrap
import warnings
import numpy as np
from numba.tests.support import (TestCase, override_config, override_env_config,
captured_stdout, forbid_codegen, skip_parfors_unsupported,
needs_blas)
from numba import jit, njit
from numba.core import types, compiler, utils
from numba.core.errors import NumbaPerformanceWarning
from numba import prange
from numba.experimental import jitclass
import unittest
def simple_nopython(somearg):
retval = somearg + 1
return retval
def simple_gen(x, y):
yield x
yield y
class SimpleClass(object):
def __init__(self):
self.h = 5
simple_class_spec = [('h', types.int32)]
def simple_class_user(obj):
return obj.h
def unsupported_parfor(a, b):
return np.dot(a, b) # dot as gemm unsupported
def supported_parfor(n):
a = np.ones(n)
for i in prange(n):
a[i] = a[i] + np.sin(i)
return a
def unsupported_prange(n):
a = np.ones(n)
for i in prange(n):
a[i] = a[i] + np.sin(i)
assert i + 13 < 100000
return a
class DebugTestBase(TestCase):
all_dumps = set(['bytecode', 'cfg', 'ir', 'typeinfer', 'llvm',
'func_opt_llvm', 'optimized_llvm', 'assembly'])
def assert_fails(self, *args, **kwargs):
self.assertRaises(AssertionError, *args, **kwargs)
def check_debug_output(self, out, dump_names):
enabled_dumps = dict.fromkeys(self.all_dumps, False)
for name in dump_names:
assert name in enabled_dumps
enabled_dumps[name] = True
for name, enabled in sorted(enabled_dumps.items()):
check_meth = getattr(self, '_check_dump_%s' % name)
if enabled:
check_meth(out)
else:
self.assert_fails(check_meth, out)
def _check_dump_bytecode(self, out):
if utils.PYVERSION in ((3, 11), (3, 12)):
self.assertIn('BINARY_OP', out)
elif utils.PYVERSION in ((3, 9), (3, 10)):
self.assertIn('BINARY_ADD', out)
else:
raise NotImplementedError(utils.PYVERSION)
def _check_dump_cfg(self, out):
self.assertIn('CFG dominators', out)
def _check_dump_ir(self, out):
self.assertIn('--IR DUMP: %s--' % self.func_name, out)
def _check_dump_typeinfer(self, out):
self.assertIn('--propagate--', out)
def _check_dump_llvm(self, out):
self.assertIn('--LLVM DUMP', out)
if compiler.Flags.options["auto_parallel"].default.enabled == False:
self.assertRegex(out, r'store i64 %\"\.\d", i64\* %"retptr"', out)
def _check_dump_func_opt_llvm(self, out):
self.assertIn('--FUNCTION OPTIMIZED DUMP %s' % self.func_name, out)
# allocas have been optimized away
self.assertIn('add nsw i64 %arg.somearg, 1', out)
def _check_dump_optimized_llvm(self, out):
self.assertIn('--OPTIMIZED DUMP %s' % self.func_name, out)
self.assertIn('add nsw i64 %arg.somearg, 1', out)
def _check_dump_assembly(self, out):
self.assertIn('--ASSEMBLY %s' % self.func_name, out)
if platform.machine() in ('x86_64', 'AMD64', 'i386', 'i686'):
self.assertIn('xorl', out)
class FunctionDebugTestBase(DebugTestBase):
func_name = 'simple_nopython'
def compile_simple_nopython(self):
with captured_stdout() as out:
cfunc = njit((types.int64,))(simple_nopython)
# Sanity check compiled function
self.assertPreciseEqual(cfunc(2), 3)
return out.getvalue()
class TestFunctionDebugOutput(FunctionDebugTestBase):
def test_dump_bytecode(self):
with override_config('DUMP_BYTECODE', True):
out = self.compile_simple_nopython()
self.check_debug_output(out, ['bytecode'])
def test_dump_ir(self):
with override_config('DUMP_IR', True):
out = self.compile_simple_nopython()
self.check_debug_output(out, ['ir'])
def test_dump_cfg(self):
with override_config('DUMP_CFG', True):
out = self.compile_simple_nopython()
self.check_debug_output(out, ['cfg'])
def test_dump_llvm(self):
with override_config('DUMP_LLVM', True):
out = self.compile_simple_nopython()
self.check_debug_output(out, ['llvm'])
def test_dump_func_opt_llvm(self):
with override_config('DUMP_FUNC_OPT', True):
out = self.compile_simple_nopython()
self.check_debug_output(out, ['func_opt_llvm'])
def test_dump_optimized_llvm(self):
with override_config('DUMP_OPTIMIZED', True):
out = self.compile_simple_nopython()
self.check_debug_output(out, ['optimized_llvm'])
def test_dump_assembly(self):
with override_config('DUMP_ASSEMBLY', True):
out = self.compile_simple_nopython()
self.check_debug_output(out, ['assembly'])
class TestGeneratorDebugOutput(DebugTestBase):
func_name = 'simple_gen'
def compile_simple_gen(self):
with captured_stdout() as out:
cfunc = njit((types.int64, types.int64))(simple_gen)
# Sanity check compiled function
self.assertPreciseEqual(list(cfunc(2, 5)), [2, 5])
return out.getvalue()
def test_dump_ir_generator(self):
with override_config('DUMP_IR', True):
out = self.compile_simple_gen()
self.check_debug_output(out, ['ir'])
self.assertIn('--GENERATOR INFO: %s' % self.func_name, out)
expected_gen_info = textwrap.dedent("""
generator state variables: ['x', 'y']
yield point #1: live variables = ['y'], weak live variables = ['x']
yield point #2: live variables = [], weak live variables = ['y']
""")
self.assertIn(expected_gen_info, out)
class TestDisableJIT(DebugTestBase):
"""
Test the NUMBA_DISABLE_JIT environment variable.
"""
def test_jit(self):
with override_config('DISABLE_JIT', True):
with forbid_codegen():
cfunc = jit(nopython=True)(simple_nopython)
self.assertPreciseEqual(cfunc(2), 3)
def test_jitclass(self):
with override_config('DISABLE_JIT', True):
with forbid_codegen():
SimpleJITClass = jitclass(simple_class_spec)(SimpleClass)
obj = SimpleJITClass()
self.assertPreciseEqual(obj.h, 5)
cfunc = jit(nopython=True)(simple_class_user)
self.assertPreciseEqual(cfunc(obj), 5)
class TestEnvironmentOverride(FunctionDebugTestBase):
"""
Test that environment variables are reloaded by Numba when modified.
"""
# mutates env with os.environ so must be run serially
_numba_parallel_test_ = False
def test_debug(self):
out = self.compile_simple_nopython()
self.assertFalse(out)
with override_env_config('NUMBA_DEBUG', '1'):
out = self.compile_simple_nopython()
# Note that all variables dependent on NUMBA_DEBUG are
# updated too.
self.check_debug_output(out, ['ir', 'typeinfer',
'llvm', 'func_opt_llvm',
'optimized_llvm', 'assembly'])
out = self.compile_simple_nopython()
self.assertFalse(out)
class TestParforsDebug(TestCase):
"""
Tests debug options associated with parfors
"""
# mutates env with os.environ so must be run serially
_numba_parallel_test_ = False
def check_parfors_warning(self, warn_list):
msg = ("'parallel=True' was specified but no transformation for "
"parallel execution was possible.")
warning_found = False
for w in warn_list:
if msg in str(w.message):
warning_found = True
break
self.assertTrue(warning_found, "Warning message should be found.")
def check_parfors_unsupported_prange_warning(self, warn_list):
msg = ("prange or pndindex loop will not be executed in parallel "
"due to there being more than one entry to or exit from the "
"loop (e.g., an assertion).")
warning_found = False
for w in warn_list:
if msg in str(w.message):
warning_found = True
break
self.assertTrue(warning_found, "Warning message should be found.")
@needs_blas
@skip_parfors_unsupported
def test_warns(self):
"""
Test that using parallel=True on a function that does not have parallel
semantics warns.
"""
arr_ty = types.Array(types.float64, 2, "C")
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always", NumbaPerformanceWarning)
njit((arr_ty, arr_ty), parallel=True)(unsupported_parfor)
self.check_parfors_warning(w)
@needs_blas
@skip_parfors_unsupported
def test_unsupported_prange_warns(self):
"""
Test that prange with multiple exits issues a warning
"""
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always", NumbaPerformanceWarning)
njit((types.int64,), parallel=True)(unsupported_prange)
self.check_parfors_unsupported_prange_warning(w)
@skip_parfors_unsupported
def test_array_debug_opt_stats(self):
"""
Test that NUMBA_DEBUG_ARRAY_OPT_STATS produces valid output
"""
# deliberately trigger a compilation loop to increment the
# Parfor class state, this is to ensure the test works based
# on indices computed based on this state and not hard coded
# indices.
njit((types.int64,), parallel=True)(supported_parfor)
with override_env_config('NUMBA_DEBUG_ARRAY_OPT_STATS', '1'):
with captured_stdout() as out:
njit((types.int64,), parallel=True)(supported_parfor)
# grab the various parts out the output
output = out.getvalue().split('\n')
parallel_loop_output = \
[x for x in output if 'is produced from pattern' in x]
fuse_output = \
[x for x in output if 'is fused into' in x]
after_fusion_output = \
[x for x in output if 'After fusion, function' in x]
# Parfor's have a shared state index, grab the current value
# as it will be used as an offset for all loop messages
parfor_state = int(re.compile(r'#([0-9]+)').search(
parallel_loop_output[0]).group(1))
bounds = range(parfor_state,
parfor_state + len(parallel_loop_output))
# Check the Parallel for-loop <index> is produced from <pattern>
# works first
pattern = ("('ones function', 'NumPy mapping')",
('prange', 'user', ''))
fmt = 'Parallel for-loop #{} is produced from pattern \'{}\' at'
for i, trials, lpattern in zip(bounds, parallel_loop_output,
pattern):
to_match = fmt.format(i, lpattern)
self.assertIn(to_match, trials)
# Check the fusion statements are correct
pattern = (parfor_state + 1, parfor_state + 0)
fmt = 'Parallel for-loop #{} is fused into for-loop #{}.'
for trials in fuse_output:
to_match = fmt.format(*pattern)
self.assertIn(to_match, trials)
# Check the post fusion statements are correct
pattern = (supported_parfor.__name__, 1, set([parfor_state]))
fmt = 'After fusion, function {} has {} parallel for-loop(s) #{}.'
for trials in after_fusion_output:
to_match = fmt.format(*pattern)
self.assertIn(to_match, trials)
if __name__ == '__main__':
unittest.main()