import os import platform import re import textwrap import warnings import numpy as np from numba.tests.support import (TestCase, override_config, override_env_config, captured_stdout, forbid_codegen, skip_parfors_unsupported, needs_blas) from numba import jit, njit from numba.core import types, compiler, utils from numba.core.errors import NumbaPerformanceWarning from numba import prange from numba.experimental import jitclass import unittest def simple_nopython(somearg): retval = somearg + 1 return retval def simple_gen(x, y): yield x yield y class SimpleClass(object): def __init__(self): self.h = 5 simple_class_spec = [('h', types.int32)] def simple_class_user(obj): return obj.h def unsupported_parfor(a, b): return np.dot(a, b) # dot as gemm unsupported def supported_parfor(n): a = np.ones(n) for i in prange(n): a[i] = a[i] + np.sin(i) return a def unsupported_prange(n): a = np.ones(n) for i in prange(n): a[i] = a[i] + np.sin(i) assert i + 13 < 100000 return a class DebugTestBase(TestCase): all_dumps = set(['bytecode', 'cfg', 'ir', 'typeinfer', 'llvm', 'func_opt_llvm', 'optimized_llvm', 'assembly']) def assert_fails(self, *args, **kwargs): self.assertRaises(AssertionError, *args, **kwargs) def check_debug_output(self, out, dump_names): enabled_dumps = dict.fromkeys(self.all_dumps, False) for name in dump_names: assert name in enabled_dumps enabled_dumps[name] = True for name, enabled in sorted(enabled_dumps.items()): check_meth = getattr(self, '_check_dump_%s' % name) if enabled: check_meth(out) else: self.assert_fails(check_meth, out) def _check_dump_bytecode(self, out): if utils.PYVERSION in ((3, 11), (3, 12)): self.assertIn('BINARY_OP', out) elif utils.PYVERSION in ((3, 9), (3, 10)): self.assertIn('BINARY_ADD', out) else: raise NotImplementedError(utils.PYVERSION) def _check_dump_cfg(self, out): self.assertIn('CFG dominators', out) def _check_dump_ir(self, out): self.assertIn('--IR DUMP: %s--' % self.func_name, out) def _check_dump_typeinfer(self, out): self.assertIn('--propagate--', out) def _check_dump_llvm(self, out): self.assertIn('--LLVM DUMP', out) if compiler.Flags.options["auto_parallel"].default.enabled == False: self.assertRegex(out, r'store i64 %\"\.\d", i64\* %"retptr"', out) def _check_dump_func_opt_llvm(self, out): self.assertIn('--FUNCTION OPTIMIZED DUMP %s' % self.func_name, out) # allocas have been optimized away self.assertIn('add nsw i64 %arg.somearg, 1', out) def _check_dump_optimized_llvm(self, out): self.assertIn('--OPTIMIZED DUMP %s' % self.func_name, out) self.assertIn('add nsw i64 %arg.somearg, 1', out) def _check_dump_assembly(self, out): self.assertIn('--ASSEMBLY %s' % self.func_name, out) if platform.machine() in ('x86_64', 'AMD64', 'i386', 'i686'): self.assertIn('xorl', out) class FunctionDebugTestBase(DebugTestBase): func_name = 'simple_nopython' def compile_simple_nopython(self): with captured_stdout() as out: cfunc = njit((types.int64,))(simple_nopython) # Sanity check compiled function self.assertPreciseEqual(cfunc(2), 3) return out.getvalue() class TestFunctionDebugOutput(FunctionDebugTestBase): def test_dump_bytecode(self): with override_config('DUMP_BYTECODE', True): out = self.compile_simple_nopython() self.check_debug_output(out, ['bytecode']) def test_dump_ir(self): with override_config('DUMP_IR', True): out = self.compile_simple_nopython() self.check_debug_output(out, ['ir']) def test_dump_cfg(self): with override_config('DUMP_CFG', True): out = self.compile_simple_nopython() self.check_debug_output(out, ['cfg']) def test_dump_llvm(self): with override_config('DUMP_LLVM', True): out = self.compile_simple_nopython() self.check_debug_output(out, ['llvm']) def test_dump_func_opt_llvm(self): with override_config('DUMP_FUNC_OPT', True): out = self.compile_simple_nopython() self.check_debug_output(out, ['func_opt_llvm']) def test_dump_optimized_llvm(self): with override_config('DUMP_OPTIMIZED', True): out = self.compile_simple_nopython() self.check_debug_output(out, ['optimized_llvm']) def test_dump_assembly(self): with override_config('DUMP_ASSEMBLY', True): out = self.compile_simple_nopython() self.check_debug_output(out, ['assembly']) class TestGeneratorDebugOutput(DebugTestBase): func_name = 'simple_gen' def compile_simple_gen(self): with captured_stdout() as out: cfunc = njit((types.int64, types.int64))(simple_gen) # Sanity check compiled function self.assertPreciseEqual(list(cfunc(2, 5)), [2, 5]) return out.getvalue() def test_dump_ir_generator(self): with override_config('DUMP_IR', True): out = self.compile_simple_gen() self.check_debug_output(out, ['ir']) self.assertIn('--GENERATOR INFO: %s' % self.func_name, out) expected_gen_info = textwrap.dedent(""" generator state variables: ['x', 'y'] yield point #1: live variables = ['y'], weak live variables = ['x'] yield point #2: live variables = [], weak live variables = ['y'] """) self.assertIn(expected_gen_info, out) class TestDisableJIT(DebugTestBase): """ Test the NUMBA_DISABLE_JIT environment variable. """ def test_jit(self): with override_config('DISABLE_JIT', True): with forbid_codegen(): cfunc = jit(nopython=True)(simple_nopython) self.assertPreciseEqual(cfunc(2), 3) def test_jitclass(self): with override_config('DISABLE_JIT', True): with forbid_codegen(): SimpleJITClass = jitclass(simple_class_spec)(SimpleClass) obj = SimpleJITClass() self.assertPreciseEqual(obj.h, 5) cfunc = jit(nopython=True)(simple_class_user) self.assertPreciseEqual(cfunc(obj), 5) class TestEnvironmentOverride(FunctionDebugTestBase): """ Test that environment variables are reloaded by Numba when modified. """ # mutates env with os.environ so must be run serially _numba_parallel_test_ = False def test_debug(self): out = self.compile_simple_nopython() self.assertFalse(out) with override_env_config('NUMBA_DEBUG', '1'): out = self.compile_simple_nopython() # Note that all variables dependent on NUMBA_DEBUG are # updated too. self.check_debug_output(out, ['ir', 'typeinfer', 'llvm', 'func_opt_llvm', 'optimized_llvm', 'assembly']) out = self.compile_simple_nopython() self.assertFalse(out) class TestParforsDebug(TestCase): """ Tests debug options associated with parfors """ # mutates env with os.environ so must be run serially _numba_parallel_test_ = False def check_parfors_warning(self, warn_list): msg = ("'parallel=True' was specified but no transformation for " "parallel execution was possible.") warning_found = False for w in warn_list: if msg in str(w.message): warning_found = True break self.assertTrue(warning_found, "Warning message should be found.") def check_parfors_unsupported_prange_warning(self, warn_list): msg = ("prange or pndindex loop will not be executed in parallel " "due to there being more than one entry to or exit from the " "loop (e.g., an assertion).") warning_found = False for w in warn_list: if msg in str(w.message): warning_found = True break self.assertTrue(warning_found, "Warning message should be found.") @needs_blas @skip_parfors_unsupported def test_warns(self): """ Test that using parallel=True on a function that does not have parallel semantics warns. """ arr_ty = types.Array(types.float64, 2, "C") with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always", NumbaPerformanceWarning) njit((arr_ty, arr_ty), parallel=True)(unsupported_parfor) self.check_parfors_warning(w) @needs_blas @skip_parfors_unsupported def test_unsupported_prange_warns(self): """ Test that prange with multiple exits issues a warning """ with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always", NumbaPerformanceWarning) njit((types.int64,), parallel=True)(unsupported_prange) self.check_parfors_unsupported_prange_warning(w) @skip_parfors_unsupported def test_array_debug_opt_stats(self): """ Test that NUMBA_DEBUG_ARRAY_OPT_STATS produces valid output """ # deliberately trigger a compilation loop to increment the # Parfor class state, this is to ensure the test works based # on indices computed based on this state and not hard coded # indices. njit((types.int64,), parallel=True)(supported_parfor) with override_env_config('NUMBA_DEBUG_ARRAY_OPT_STATS', '1'): with captured_stdout() as out: njit((types.int64,), parallel=True)(supported_parfor) # grab the various parts out the output output = out.getvalue().split('\n') parallel_loop_output = \ [x for x in output if 'is produced from pattern' in x] fuse_output = \ [x for x in output if 'is fused into' in x] after_fusion_output = \ [x for x in output if 'After fusion, function' in x] # Parfor's have a shared state index, grab the current value # as it will be used as an offset for all loop messages parfor_state = int(re.compile(r'#([0-9]+)').search( parallel_loop_output[0]).group(1)) bounds = range(parfor_state, parfor_state + len(parallel_loop_output)) # Check the Parallel for-loop is produced from # works first pattern = ("('ones function', 'NumPy mapping')", ('prange', 'user', '')) fmt = 'Parallel for-loop #{} is produced from pattern \'{}\' at' for i, trials, lpattern in zip(bounds, parallel_loop_output, pattern): to_match = fmt.format(i, lpattern) self.assertIn(to_match, trials) # Check the fusion statements are correct pattern = (parfor_state + 1, parfor_state + 0) fmt = 'Parallel for-loop #{} is fused into for-loop #{}.' for trials in fuse_output: to_match = fmt.format(*pattern) self.assertIn(to_match, trials) # Check the post fusion statements are correct pattern = (supported_parfor.__name__, 1, set([parfor_state])) fmt = 'After fusion, function {} has {} parallel for-loop(s) #{}.' for trials in after_fusion_output: to_match = fmt.format(*pattern) self.assertIn(to_match, trials) if __name__ == '__main__': unittest.main()