import multiprocessing import platform import threading import pickle import weakref from itertools import chain from io import StringIO import numpy as np from numba import njit, jit, typeof, vectorize from numba.core import types, errors from numba import _dispatcher from numba.tests.support import TestCase, captured_stdout from numba.np.numpy_support import as_dtype from numba.core.dispatcher import Dispatcher from numba.extending import overload from numba.tests.support import needs_lapack, SerialMixin from numba.testing.main import _TIMEOUT as _RUNNER_TIMEOUT import unittest _TEST_TIMEOUT = _RUNNER_TIMEOUT - 60. try: import jinja2 except ImportError: jinja2 = None try: import pygments except ImportError: pygments = None _is_armv7l = platform.machine() == 'armv7l' def dummy(x): return x def add(x, y): return x + y def addsub(x, y, z): return x - y + z def addsub_defaults(x, y=2, z=3): return x - y + z def star_defaults(x, y=2, *z): return x, y, z def generated_usecase(x, y=5): if isinstance(x, types.Complex): def impl(x, y): return x + y else: def impl(x, y): return x - y return impl def bad_generated_usecase(x, y=5): if isinstance(x, types.Complex): def impl(x): return x else: def impl(x, y=6): return x - y return impl def dtype_generated_usecase(a, b, dtype=None): if isinstance(dtype, (types.misc.NoneType, types.misc.Omitted)): out_dtype = np.result_type(*(np.dtype(ary.dtype.name) for ary in (a, b))) elif isinstance(dtype, (types.DType, types.NumberClass)): out_dtype = as_dtype(dtype) else: raise TypeError("Unhandled Type %s" % type(dtype)) def _fn(a, b, dtype=None): return np.ones(a.shape, dtype=out_dtype) return _fn class BaseTest(TestCase): jit_args = dict(nopython=True) def compile_func(self, pyfunc): def check(*args, **kwargs): expected = pyfunc(*args, **kwargs) result = f(*args, **kwargs) self.assertPreciseEqual(result, expected) f = jit(**self.jit_args)(pyfunc) return f, check class TestDispatcher(BaseTest): def test_equality(self): @jit def foo(x): return x @jit def bar(x): return x # Written this way to verify `==` returns a bool (gh-5838). Using # `assertTrue(foo == foo)` or `assertEqual(foo, foo)` would defeat the # purpose of this test. self.assertEqual(foo == foo, True) self.assertEqual(foo == bar, False) self.assertEqual(foo == None, False) # noqa: E711 def test_dyn_pyfunc(self): @jit def foo(x): return x foo(1) [cr] = foo.overloads.values() # __module__ must be match that of foo self.assertEqual(cr.entry_point.__module__, foo.py_func.__module__) def test_no_argument(self): @jit def foo(): return 1 # Just make sure this doesn't crash foo() def test_coerce_input_types(self): # Issue #486: do not allow unsafe conversions if we can still # compile other specializations. c_add = jit(nopython=True)(add) self.assertPreciseEqual(c_add(123, 456), add(123, 456)) self.assertPreciseEqual(c_add(12.3, 45.6), add(12.3, 45.6)) self.assertPreciseEqual(c_add(12.3, 45.6j), add(12.3, 45.6j)) self.assertPreciseEqual(c_add(12300000000, 456), add(12300000000, 456)) # Now force compilation of only a single specialization c_add = jit('(i4, i4)', nopython=True)(add) self.assertPreciseEqual(c_add(123, 456), add(123, 456)) # Implicit (unsafe) conversion of float to int self.assertPreciseEqual(c_add(12.3, 45.6), add(12, 45)) with self.assertRaises(TypeError): # Implicit conversion of complex to int disallowed c_add(12.3, 45.6j) def test_ambiguous_new_version(self): """Test compiling new version in an ambiguous case """ @jit def foo(a, b): return a + b INT = 1 FLT = 1.5 self.assertAlmostEqual(foo(INT, FLT), INT + FLT) self.assertEqual(len(foo.overloads), 1) self.assertAlmostEqual(foo(FLT, INT), FLT + INT) self.assertEqual(len(foo.overloads), 2) self.assertAlmostEqual(foo(FLT, FLT), FLT + FLT) self.assertEqual(len(foo.overloads), 3) # The following call is ambiguous because (int, int) can resolve # to (float, int) or (int, float) with equal weight. self.assertAlmostEqual(foo(1, 1), INT + INT) self.assertEqual(len(foo.overloads), 4, "didn't compile a new " "version") def test_lock(self): """ Test that (lazy) compiling from several threads at once doesn't produce errors (see issue #908). """ errors = [] @jit def foo(x): return x + 1 def wrapper(): try: self.assertEqual(foo(1), 2) except Exception as e: errors.append(e) threads = [threading.Thread(target=wrapper) for i in range(16)] for t in threads: t.start() for t in threads: t.join() self.assertFalse(errors) def test_explicit_signatures(self): f = jit("(int64,int64)")(add) # Approximate match (unsafe conversion) self.assertPreciseEqual(f(1.5, 2.5), 3) self.assertEqual(len(f.overloads), 1, f.overloads) f = jit(["(int64,int64)", "(float64,float64)"])(add) # Exact signature matches self.assertPreciseEqual(f(1, 2), 3) self.assertPreciseEqual(f(1.5, 2.5), 4.0) # Approximate match (int32 -> float64 is a safe conversion) self.assertPreciseEqual(f(np.int32(1), 2.5), 3.5) # No conversion with self.assertRaises(TypeError) as cm: f(1j, 1j) self.assertIn("No matching definition", str(cm.exception)) self.assertEqual(len(f.overloads), 2, f.overloads) # A more interesting one... f = jit(["(float32,float32)", "(float64,float64)"])(add) self.assertPreciseEqual(f(np.float32(1), np.float32(2**-25)), 1.0) self.assertPreciseEqual(f(1, 2**-25), 1.0000000298023224) # Fail to resolve ambiguity between the two best overloads f = jit(["(float32,float64)", "(float64,float32)", "(int64,int64)"])(add) with self.assertRaises(TypeError) as cm: f(1.0, 2.0) # The two best matches are output in the error message, as well # as the actual argument types. self.assertRegex( str(cm.exception), r"Ambiguous overloading for ]*> " r"\(float64, float64\):\n" r"\(float32, float64\) -> float64\n" r"\(float64, float32\) -> float64" ) # The integer signature is not part of the best matches self.assertNotIn("int64", str(cm.exception)) def test_signature_mismatch(self): tmpl = ("Signature mismatch: %d argument types given, but function " "takes 2 arguments") with self.assertRaises(TypeError) as cm: jit("()")(add) self.assertIn(tmpl % 0, str(cm.exception)) with self.assertRaises(TypeError) as cm: jit("(intc,)")(add) self.assertIn(tmpl % 1, str(cm.exception)) with self.assertRaises(TypeError) as cm: jit("(intc,intc,intc)")(add) self.assertIn(tmpl % 3, str(cm.exception)) # With forceobj=True, an empty tuple is accepted jit("()", forceobj=True)(add) with self.assertRaises(TypeError) as cm: jit("(intc,)", forceobj=True)(add) self.assertIn(tmpl % 1, str(cm.exception)) def test_matching_error_message(self): f = jit("(intc,intc)")(add) with self.assertRaises(TypeError) as cm: f(1j, 1j) self.assertEqual(str(cm.exception), "No matching definition for argument type(s) " "complex128, complex128") def test_disabled_compilation(self): @jit def foo(a): return a foo.compile("(float32,)") foo.disable_compile() with self.assertRaises(RuntimeError) as raises: foo.compile("(int32,)") self.assertEqual(str(raises.exception), "compilation disabled") self.assertEqual(len(foo.signatures), 1) def test_disabled_compilation_through_list(self): @jit(["(float32,)", "(int32,)"]) def foo(a): return a with self.assertRaises(RuntimeError) as raises: foo.compile("(complex64,)") self.assertEqual(str(raises.exception), "compilation disabled") self.assertEqual(len(foo.signatures), 2) def test_disabled_compilation_nested_call(self): @jit(["(intp,)"]) def foo(a): return a @jit def bar(): foo(1) foo(np.ones(1)) # no matching definition with self.assertRaises(errors.TypingError) as raises: bar() m = r".*Invalid use of.*with parameters \(array\(float64, 1d, C\)\).*" self.assertRegex(str(raises.exception), m) def test_fingerprint_failure(self): """ Failure in computing the fingerprint cannot affect a nopython=False function. On the other hand, with nopython=True, a ValueError should be raised to report the failure with fingerprint. """ def foo(x): return x # Empty list will trigger failure in compile_fingerprint errmsg = 'cannot compute fingerprint of empty list' with self.assertRaises(ValueError) as raises: _dispatcher.compute_fingerprint([]) self.assertIn(errmsg, str(raises.exception)) # It should work in objmode objmode_foo = jit(forceobj=True)(foo) self.assertEqual(objmode_foo([]), []) # But, not in nopython=True strict_foo = jit(nopython=True)(foo) with self.assertRaises(ValueError) as raises: strict_foo([]) self.assertIn(errmsg, str(raises.exception)) # Test in loop lifting context @jit(forceobj=True) def bar(): object() # force looplifting x = [] for i in range(10): x = objmode_foo(x) return x self.assertEqual(bar(), []) # Make sure it was looplifted [cr] = bar.overloads.values() self.assertEqual(len(cr.lifted), 1) def test_serialization(self): """ Test serialization of Dispatcher objects """ @jit(nopython=True) def foo(x): return x + 1 self.assertEqual(foo(1), 2) # get serialization memo memo = Dispatcher._memo Dispatcher._recent.clear() memo_size = len(memo) # pickle foo and check memo size serialized_foo = pickle.dumps(foo) # increases the memo size self.assertEqual(memo_size + 1, len(memo)) # unpickle foo_rebuilt = pickle.loads(serialized_foo) self.assertEqual(memo_size + 1, len(memo)) self.assertIs(foo, foo_rebuilt) # do we get the same object even if we delete all the explicit # references? id_orig = id(foo_rebuilt) del foo del foo_rebuilt self.assertEqual(memo_size + 1, len(memo)) new_foo = pickle.loads(serialized_foo) self.assertEqual(id_orig, id(new_foo)) # now clear the recent cache ref = weakref.ref(new_foo) del new_foo Dispatcher._recent.clear() self.assertEqual(memo_size, len(memo)) # show that deserializing creates a new object pickle.loads(serialized_foo) self.assertIs(ref(), None) @needs_lapack @unittest.skipIf(_is_armv7l, "Unaligned loads unsupported") def test_misaligned_array_dispatch(self): # for context see issue #2937 def foo(a): return np.linalg.matrix_power(a, 1) jitfoo = jit(nopython=True)(foo) n = 64 r = int(np.sqrt(n)) dt = np.int8 count = np.complex128().itemsize // dt().itemsize tmp = np.arange(n * count + 1, dtype=dt) # create some arrays as Cartesian production of: # [F/C] x [aligned/misaligned] C_contig_aligned = tmp[:-1].view(np.complex128).reshape(r, r) C_contig_misaligned = tmp[1:].view(np.complex128).reshape(r, r) F_contig_aligned = C_contig_aligned.T F_contig_misaligned = C_contig_misaligned.T # checking routine def check(name, a): a[:, :] = np.arange(n, dtype=np.complex128).reshape(r, r) expected = foo(a) got = jitfoo(a) np.testing.assert_allclose(expected, got) # The checks must be run in this order to create the dispatch key # sequence that causes invalid dispatch noted in #2937. # The first two should hit the cache as they are aligned, supported # order and under 5 dimensions. The second two should end up in the # fallback path as they are misaligned. check("C_contig_aligned", C_contig_aligned) check("F_contig_aligned", F_contig_aligned) check("C_contig_misaligned", C_contig_misaligned) check("F_contig_misaligned", F_contig_misaligned) @unittest.skipIf(_is_armv7l, "Unaligned loads unsupported") def test_immutability_in_array_dispatch(self): # RO operation in function def foo(a): return np.sum(a) jitfoo = jit(nopython=True)(foo) n = 64 r = int(np.sqrt(n)) dt = np.int8 count = np.complex128().itemsize // dt().itemsize tmp = np.arange(n * count + 1, dtype=dt) # create some arrays as Cartesian production of: # [F/C] x [aligned/misaligned] C_contig_aligned = tmp[:-1].view(np.complex128).reshape(r, r) C_contig_misaligned = tmp[1:].view(np.complex128).reshape(r, r) F_contig_aligned = C_contig_aligned.T F_contig_misaligned = C_contig_misaligned.T # checking routine def check(name, a, disable_write_bit=False): a[:, :] = np.arange(n, dtype=np.complex128).reshape(r, r) if disable_write_bit: a.flags.writeable = False expected = foo(a) got = jitfoo(a) np.testing.assert_allclose(expected, got) # all of these should end up in the fallback path as they have no write # bit set check("C_contig_aligned", C_contig_aligned, disable_write_bit=True) check("F_contig_aligned", F_contig_aligned, disable_write_bit=True) check("C_contig_misaligned", C_contig_misaligned, disable_write_bit=True) check("F_contig_misaligned", F_contig_misaligned, disable_write_bit=True) @needs_lapack @unittest.skipIf(_is_armv7l, "Unaligned loads unsupported") def test_misaligned_high_dimension_array_dispatch(self): def foo(a): return np.linalg.matrix_power(a[0, 0, 0, 0, :, :], 1) jitfoo = jit(nopython=True)(foo) def check_properties(arr, layout, aligned): self.assertEqual(arr.flags.aligned, aligned) if layout == "C": self.assertEqual(arr.flags.c_contiguous, True) if layout == "F": self.assertEqual(arr.flags.f_contiguous, True) n = 729 r = 3 dt = np.int8 count = np.complex128().itemsize // dt().itemsize tmp = np.arange(n * count + 1, dtype=dt) # create some arrays as Cartesian production of: # [F/C] x [aligned/misaligned] C_contig_aligned = tmp[:-1].view(np.complex128).\ reshape(r, r, r, r, r, r) check_properties(C_contig_aligned, 'C', True) C_contig_misaligned = tmp[1:].view(np.complex128).\ reshape(r, r, r, r, r, r) check_properties(C_contig_misaligned, 'C', False) F_contig_aligned = C_contig_aligned.T check_properties(F_contig_aligned, 'F', True) F_contig_misaligned = C_contig_misaligned.T check_properties(F_contig_misaligned, 'F', False) # checking routine def check(name, a): a[:, :] = np.arange(n, dtype=np.complex128).\ reshape(r, r, r, r, r, r) expected = foo(a) got = jitfoo(a) np.testing.assert_allclose(expected, got) # these should all hit the fallback path as the cache is only for up to # 5 dimensions check("F_contig_misaligned", F_contig_misaligned) check("C_contig_aligned", C_contig_aligned) check("F_contig_aligned", F_contig_aligned) check("C_contig_misaligned", C_contig_misaligned) def test_dispatch_recompiles_for_scalars(self): # for context #3612, essentially, compiling a lambda x:x for a # numerically wide type (everything can be converted to a complex128) # and then calling again with e.g. an int32 would lead to the int32 # being converted to a complex128 whereas it ought to compile an int32 # specialization. def foo(x): return x # jit and compile on dispatch for 3 scalar types, expect 3 signatures jitfoo = jit(nopython=True)(foo) jitfoo(np.complex128(1 + 2j)) jitfoo(np.int32(10)) jitfoo(np.bool_(False)) self.assertEqual(len(jitfoo.signatures), 3) expected_sigs = [(types.complex128,), (types.int32,), (types.bool_,)] self.assertEqual(jitfoo.signatures, expected_sigs) # now jit with signatures so recompilation is forbidden # expect 1 signature and type conversion jitfoo = jit([(types.complex128,)], nopython=True)(foo) jitfoo(np.complex128(1 + 2j)) jitfoo(np.int32(10)) jitfoo(np.bool_(False)) self.assertEqual(len(jitfoo.signatures), 1) expected_sigs = [(types.complex128,)] self.assertEqual(jitfoo.signatures, expected_sigs) def test_dispatcher_raises_for_invalid_decoration(self): # For context see https://github.com/numba/numba/issues/4750. @jit(nopython=True) def foo(x): return x with self.assertRaises(TypeError) as raises: jit(foo) err_msg = str(raises.exception) self.assertIn( "A jit decorator was called on an already jitted function", err_msg) self.assertIn("foo", err_msg) self.assertIn(".py_func", err_msg) with self.assertRaises(TypeError) as raises: jit(BaseTest) err_msg = str(raises.exception) self.assertIn("The decorated object is not a function", err_msg) self.assertIn(f"{type(BaseTest)}", err_msg) class TestSignatureHandling(BaseTest): """ Test support for various parameter passing styles. """ def test_named_args(self): """ Test passing named arguments to a dispatcher. """ f, check = self.compile_func(addsub) check(3, z=10, y=4) check(3, 4, 10) check(x=3, y=4, z=10) # All calls above fall under the same specialization self.assertEqual(len(f.overloads), 1) # Errors with self.assertRaises(TypeError) as cm: f(3, 4, y=6, z=7) self.assertIn("too many arguments: expected 3, got 4", str(cm.exception)) with self.assertRaises(TypeError) as cm: f() self.assertIn("not enough arguments: expected 3, got 0", str(cm.exception)) with self.assertRaises(TypeError) as cm: f(3, 4, y=6) self.assertIn("missing argument 'z'", str(cm.exception)) def test_default_args(self): """ Test omitting arguments with a default value. """ f, check = self.compile_func(addsub_defaults) check(3, z=10, y=4) check(3, 4, 10) check(x=3, y=4, z=10) # Now omitting some values check(3, z=10) check(3, 4) check(x=3, y=4) check(3) check(x=3) # Errors with self.assertRaises(TypeError) as cm: f(3, 4, y=6, z=7) self.assertIn("too many arguments: expected 3, got 4", str(cm.exception)) with self.assertRaises(TypeError) as cm: f() self.assertIn("not enough arguments: expected at least 1, got 0", str(cm.exception)) with self.assertRaises(TypeError) as cm: f(y=6, z=7) self.assertIn("missing argument 'x'", str(cm.exception)) def test_star_args(self): """ Test a compiled function with starargs in the signature. """ f, check = self.compile_func(star_defaults) check(4) check(4, 5) check(4, 5, 6) check(4, 5, 6, 7) check(4, 5, 6, 7, 8) check(x=4) check(x=4, y=5) check(4, y=5) with self.assertRaises(TypeError) as cm: f(4, 5, y=6) self.assertIn("some keyword arguments unexpected", str(cm.exception)) with self.assertRaises(TypeError) as cm: f(4, 5, z=6) self.assertIn("some keyword arguments unexpected", str(cm.exception)) with self.assertRaises(TypeError) as cm: f(4, x=6) self.assertIn("some keyword arguments unexpected", str(cm.exception)) class TestSignatureHandlingObjectMode(TestSignatureHandling): """ Sams as TestSignatureHandling, but in object mode. """ jit_args = dict(forceobj=True) class TestDispatcherMethods(TestCase): def test_recompile(self): closure = 1 @jit def foo(x): return x + closure self.assertPreciseEqual(foo(1), 2) self.assertPreciseEqual(foo(1.5), 2.5) self.assertEqual(len(foo.signatures), 2) closure = 2 self.assertPreciseEqual(foo(1), 2) # Recompiling takes the new closure into account. foo.recompile() # Everything was recompiled self.assertEqual(len(foo.signatures), 2) self.assertPreciseEqual(foo(1), 3) self.assertPreciseEqual(foo(1.5), 3.5) def test_recompile_signatures(self): # Same as above, but with an explicit signature on @jit. closure = 1 @jit("int32(int32)") def foo(x): return x + closure self.assertPreciseEqual(foo(1), 2) self.assertPreciseEqual(foo(1.5), 2) closure = 2 self.assertPreciseEqual(foo(1), 2) # Recompiling takes the new closure into account. foo.recompile() self.assertPreciseEqual(foo(1), 3) self.assertPreciseEqual(foo(1.5), 3) def test_inspect_llvm(self): # Create a jited function @jit def foo(explicit_arg1, explicit_arg2): return explicit_arg1 + explicit_arg2 # Call it in a way to create 3 signatures foo(1, 1) foo(1.0, 1) foo(1.0, 1.0) # base call to get all llvm in a dict llvms = foo.inspect_llvm() self.assertEqual(len(llvms), 3) # make sure the function name shows up in the llvm for llvm_bc in llvms.values(): # Look for the function name self.assertIn("foo", llvm_bc) # Look for the argument names self.assertIn("explicit_arg1", llvm_bc) self.assertIn("explicit_arg2", llvm_bc) def test_inspect_asm(self): # Create a jited function @jit def foo(explicit_arg1, explicit_arg2): return explicit_arg1 + explicit_arg2 # Call it in a way to create 3 signatures foo(1, 1) foo(1.0, 1) foo(1.0, 1.0) # base call to get all llvm in a dict asms = foo.inspect_asm() self.assertEqual(len(asms), 3) # make sure the function name shows up in the llvm for asm in asms.values(): # Look for the function name self.assertTrue("foo" in asm) def _check_cfg_display(self, cfg, wrapper=''): # simple stringify test if wrapper: wrapper = "{}{}".format(len(wrapper), wrapper) module_name = __name__.split('.', 1)[0] module_len = len(module_name) prefix = r'^digraph "CFG for \'_ZN{}{}{}'.format(wrapper, module_len, module_name) self.assertRegex(str(cfg), prefix) # .display() requires an optional dependency on `graphviz`. # just test for the attribute without running it. self.assertTrue(callable(cfg.display)) def test_inspect_cfg(self): # Exercise the .inspect_cfg(). These are minimal tests and do not fully # check the correctness of the function. @jit def foo(the_array): return the_array.sum() # Generate 3 overloads a1 = np.ones(1) a2 = np.ones((1, 1)) a3 = np.ones((1, 1, 1)) foo(a1) foo(a2) foo(a3) # Call inspect_cfg() without arguments cfgs = foo.inspect_cfg() # Correct count of overloads self.assertEqual(len(cfgs), 3) # Makes sure all the signatures are correct [s1, s2, s3] = cfgs.keys() self.assertEqual(set([s1, s2, s3]), set(map(lambda x: (typeof(x),), [a1, a2, a3]))) for cfg in cfgs.values(): self._check_cfg_display(cfg) self.assertEqual(len(list(cfgs.values())), 3) # Call inspect_cfg(signature) cfg = foo.inspect_cfg(signature=foo.signatures[0]) self._check_cfg_display(cfg) def test_inspect_cfg_with_python_wrapper(self): # Exercise the .inspect_cfg() including the python wrapper. # These are minimal tests and do not fully check the correctness of # the function. @jit def foo(the_array): return the_array.sum() # Generate 3 overloads a1 = np.ones(1) a2 = np.ones((1, 1)) a3 = np.ones((1, 1, 1)) foo(a1) foo(a2) foo(a3) # Call inspect_cfg(signature, show_wrapper="python") cfg = foo.inspect_cfg(signature=foo.signatures[0], show_wrapper="python") self._check_cfg_display(cfg, wrapper='cpython') def test_inspect_types(self): @jit def foo(a, b): return a + b foo(1, 2) # Exercise the method foo.inspect_types(StringIO()) # Test output expected = str(foo.overloads[foo.signatures[0]].type_annotation) with captured_stdout() as out: foo.inspect_types() assert expected in out.getvalue() def test_inspect_types_with_signature(self): @jit def foo(a): return a + 1 foo(1) foo(1.0) # Inspect all signatures with captured_stdout() as total: foo.inspect_types() # Inspect first signature with captured_stdout() as first: foo.inspect_types(signature=foo.signatures[0]) # Inspect second signature with captured_stdout() as second: foo.inspect_types(signature=foo.signatures[1]) self.assertEqual(total.getvalue(), first.getvalue() + second.getvalue()) @unittest.skipIf(jinja2 is None, "please install the 'jinja2' package") @unittest.skipIf(pygments is None, "please install the 'pygments' package") def test_inspect_types_pretty(self): @jit def foo(a, b): return a + b foo(1, 2) # Exercise the method, dump the output with captured_stdout(): ann = foo.inspect_types(pretty=True) # ensure HTML is found in the annotation output for k, v in ann.ann.items(): span_found = False for line in v['pygments_lines']: if 'span' in line[2]: span_found = True self.assertTrue(span_found) # check that file+pretty kwarg combo raises with self.assertRaises(ValueError) as raises: foo.inspect_types(file=StringIO(), pretty=True) self.assertIn("`file` must be None if `pretty=True`", str(raises.exception)) def test_get_annotation_info(self): @jit def foo(a): return a + 1 foo(1) foo(1.3) expected = dict(chain.from_iterable(foo.get_annotation_info(i).items() for i in foo.signatures)) result = foo.get_annotation_info() self.assertEqual(expected, result) def test_issue_with_array_layout_conflict(self): """ This test an issue with the dispatcher when an array that is both C and F contiguous is supplied as the first signature. The dispatcher checks for F contiguous first but the compiler checks for C contiguous first. This results in an C contiguous code inserted as F contiguous function. """ def pyfunc(A, i, j): return A[i, j] cfunc = jit(pyfunc) ary_c_and_f = np.array([[1.]]) ary_c = np.array([[0., 1.], [2., 3.]], order='C') ary_f = np.array([[0., 1.], [2., 3.]], order='F') exp_c = pyfunc(ary_c, 1, 0) exp_f = pyfunc(ary_f, 1, 0) self.assertEqual(1., cfunc(ary_c_and_f, 0, 0)) got_c = cfunc(ary_c, 1, 0) got_f = cfunc(ary_f, 1, 0) self.assertEqual(exp_c, got_c) self.assertEqual(exp_f, got_f) class TestDispatcherFunctionBoundaries(TestCase): def test_pass_dispatcher_as_arg(self): # Test that a Dispatcher object can be pass as argument @jit(nopython=True) def add1(x): return x + 1 @jit(nopython=True) def bar(fn, x): return fn(x) @jit(nopython=True) def foo(x): return bar(add1, x) # Check dispatcher as argument inside NPM inputs = [1, 11.1, np.arange(10)] expected_results = [x + 1 for x in inputs] for arg, expect in zip(inputs, expected_results): self.assertPreciseEqual(foo(arg), expect) # Check dispatcher as argument from python for arg, expect in zip(inputs, expected_results): self.assertPreciseEqual(bar(add1, arg), expect) def test_dispatcher_as_arg_usecase(self): @jit(nopython=True) def maximum(seq, cmpfn): tmp = seq[0] for each in seq[1:]: cmpval = cmpfn(tmp, each) if cmpval < 0: tmp = each return tmp got = maximum([1, 2, 3, 4], cmpfn=jit(lambda x, y: x - y)) self.assertEqual(got, 4) got = maximum(list(zip(range(5), range(5)[::-1])), cmpfn=jit(lambda x, y: x[0] - y[0])) self.assertEqual(got, (4, 0)) got = maximum(list(zip(range(5), range(5)[::-1])), cmpfn=jit(lambda x, y: x[1] - y[1])) self.assertEqual(got, (0, 4)) def test_dispatcher_can_return_to_python(self): @jit(nopython=True) def foo(fn): return fn fn = jit(lambda x: x) self.assertEqual(foo(fn), fn) def test_dispatcher_in_sequence_arg(self): @jit(nopython=True) def one(x): return x + 1 @jit(nopython=True) def two(x): return one(one(x)) @jit(nopython=True) def three(x): return one(one(one(x))) @jit(nopython=True) def choose(fns, x): return fns[0](x), fns[1](x), fns[2](x) # Tuple case self.assertEqual(choose((one, two, three), 1), (2, 3, 4)) # List case self.assertEqual(choose([one, one, one], 1), (2, 2, 2)) class TestBoxingDefaultError(unittest.TestCase): # Testing default error at boxing/unboxing def test_unbox_runtime_error(self): # Dummy type has no unbox support def foo(x): pass argtys = (types.Dummy("dummy_type"),) # This needs `compile_isolated`-like behaviour so as to bypass # dispatcher type checking logic cres = njit(argtys)(foo).overloads[argtys] with self.assertRaises(TypeError) as raises: # Can pass in whatever and the unbox logic will always raise # without checking the input value. cres.entry_point(None) self.assertEqual(str(raises.exception), "can't unbox dummy_type type") def test_box_runtime_error(self): @njit def foo(): return unittest # Module type has no boxing logic with self.assertRaises(TypeError) as raises: foo() pat = "cannot convert native Module.* to Python object" self.assertRegex(str(raises.exception), pat) class TestNoRetryFailedSignature(unittest.TestCase): """Test that failed-to-compile signatures are not recompiled. """ def run_test(self, func): fcom = func._compiler self.assertEqual(len(fcom._failed_cache), 0) # expected failure because `int` has no `__getitem__` with self.assertRaises(errors.TypingError): func(1) self.assertEqual(len(fcom._failed_cache), 1) # retry with self.assertRaises(errors.TypingError): func(1) self.assertEqual(len(fcom._failed_cache), 1) # retry with double with self.assertRaises(errors.TypingError): func(1.0) self.assertEqual(len(fcom._failed_cache), 2) def test_direct_call(self): @jit(nopython=True) def foo(x): return x[0] self.run_test(foo) def test_nested_call(self): @jit(nopython=True) def bar(x): return x[0] @jit(nopython=True) def foobar(x): bar(x) @jit(nopython=True) def foo(x): return bar(x) + foobar(x) self.run_test(foo) @unittest.expectedFailure # NOTE: @overload does not have an error cache. See PR #9259 for this # feature and remove the xfail once this is merged. def test_error_count(self): def check(field, would_fail): # Slightly modified from the reproducer in issue #4117. # Before the patch, the compilation time of the failing case is # much longer than of the successful case. This can be detected # by the number of times `trigger()` is visited. k = 10 counter = {'c': 0} def trigger(x): assert 0, "unreachable" @overload(trigger) def ol_trigger(x): # Keep track of every visit counter['c'] += 1 if would_fail: raise errors.TypingError("invoke_failed") return lambda x: x @jit(nopython=True) def ident(out, x): pass def chain_assign(fs, inner=ident): tab_head, tab_tail = fs[-1], fs[:-1] @jit(nopython=True) def assign(out, x): inner(out, x) out[0] += tab_head(x) if tab_tail: return chain_assign(tab_tail, assign) else: return assign chain = chain_assign((trigger,) * k) out = np.ones(2) if would_fail: with self.assertRaises(errors.TypingError) as raises: chain(out, 1) self.assertIn('invoke_failed', str(raises.exception)) else: chain(out, 1) # Returns the visit counts return counter['c'] ct_ok = check('a', False) ct_bad = check('c', True) # `trigger()` is visited exactly once for both successful and failed # compilation. self.assertEqual(ct_ok, 1) self.assertEqual(ct_bad, 1) @njit def add_y1(x, y=1): return x + y @njit def add_ynone(x, y=None): return x + (1 if y else 2) @njit def mult(x, y): return x * y @njit def add_func(x, func=mult): return x + func(x, x) def _checker(f1, arg): assert f1(arg) == f1.py_func(arg) class TestMultiprocessingDefaultParameters(SerialMixin, unittest.TestCase): def run_fc_multiproc(self, fc): try: ctx = multiprocessing.get_context('spawn') except AttributeError: ctx = multiprocessing # RE: issue #5973, this doesn't use multiprocessing.Pool.map as doing so # causes the TBB library to segfault under certain conditions. It's not # clear whether the cause is something in the complexity of the Pool # itself, e.g. watcher threads etc, or if it's a problem synonymous with # a "timing attack". for a in [1, 2, 3]: p = ctx.Process(target=_checker, args=(fc, a,)) p.start() p.join(_TEST_TIMEOUT) self.assertEqual(p.exitcode, 0) def test_int_def_param(self): """ Tests issue #4888""" self.run_fc_multiproc(add_y1) def test_none_def_param(self): """ Tests None as a default parameter""" self.run_fc_multiproc(add_func) def test_function_def_param(self): """ Tests a function as a default parameter""" self.run_fc_multiproc(add_func) class TestVectorizeDifferentTargets(unittest.TestCase): """Test that vectorize can be reapplied if the target is different """ def test_cpu_vs_parallel(self): @jit def add(x, y): return x + y custom_vectorize = vectorize([], identity=None, target='cpu') custom_vectorize(add) custom_vectorize_2 = vectorize([], identity=None, target='parallel') custom_vectorize_2(add) if __name__ == '__main__': unittest.main()