1194 lines
38 KiB
Python
1194 lines
38 KiB
Python
|
import multiprocessing
|
||
|
import platform
|
||
|
import threading
|
||
|
import pickle
|
||
|
import weakref
|
||
|
from itertools import chain
|
||
|
from io import StringIO
|
||
|
|
||
|
import numpy as np
|
||
|
|
||
|
from numba import njit, jit, typeof, vectorize
|
||
|
from numba.core import types, errors
|
||
|
from numba import _dispatcher
|
||
|
from numba.tests.support import TestCase, captured_stdout
|
||
|
from numba.np.numpy_support import as_dtype
|
||
|
from numba.core.dispatcher import Dispatcher
|
||
|
from numba.extending import overload
|
||
|
from numba.tests.support import needs_lapack, SerialMixin
|
||
|
from numba.testing.main import _TIMEOUT as _RUNNER_TIMEOUT
|
||
|
import unittest
|
||
|
|
||
|
|
||
|
_TEST_TIMEOUT = _RUNNER_TIMEOUT - 60.
|
||
|
|
||
|
|
||
|
try:
|
||
|
import jinja2
|
||
|
except ImportError:
|
||
|
jinja2 = None
|
||
|
|
||
|
try:
|
||
|
import pygments
|
||
|
except ImportError:
|
||
|
pygments = None
|
||
|
|
||
|
_is_armv7l = platform.machine() == 'armv7l'
|
||
|
|
||
|
|
||
|
def dummy(x):
|
||
|
return x
|
||
|
|
||
|
|
||
|
def add(x, y):
|
||
|
return x + y
|
||
|
|
||
|
|
||
|
def addsub(x, y, z):
|
||
|
return x - y + z
|
||
|
|
||
|
|
||
|
def addsub_defaults(x, y=2, z=3):
|
||
|
return x - y + z
|
||
|
|
||
|
|
||
|
def star_defaults(x, y=2, *z):
|
||
|
return x, y, z
|
||
|
|
||
|
|
||
|
def generated_usecase(x, y=5):
|
||
|
if isinstance(x, types.Complex):
|
||
|
def impl(x, y):
|
||
|
return x + y
|
||
|
else:
|
||
|
def impl(x, y):
|
||
|
return x - y
|
||
|
return impl
|
||
|
|
||
|
|
||
|
def bad_generated_usecase(x, y=5):
|
||
|
if isinstance(x, types.Complex):
|
||
|
def impl(x):
|
||
|
return x
|
||
|
else:
|
||
|
def impl(x, y=6):
|
||
|
return x - y
|
||
|
return impl
|
||
|
|
||
|
|
||
|
def dtype_generated_usecase(a, b, dtype=None):
|
||
|
if isinstance(dtype, (types.misc.NoneType, types.misc.Omitted)):
|
||
|
out_dtype = np.result_type(*(np.dtype(ary.dtype.name)
|
||
|
for ary in (a, b)))
|
||
|
elif isinstance(dtype, (types.DType, types.NumberClass)):
|
||
|
out_dtype = as_dtype(dtype)
|
||
|
else:
|
||
|
raise TypeError("Unhandled Type %s" % type(dtype))
|
||
|
|
||
|
def _fn(a, b, dtype=None):
|
||
|
return np.ones(a.shape, dtype=out_dtype)
|
||
|
|
||
|
return _fn
|
||
|
|
||
|
|
||
|
class BaseTest(TestCase):
|
||
|
|
||
|
jit_args = dict(nopython=True)
|
||
|
|
||
|
def compile_func(self, pyfunc):
|
||
|
def check(*args, **kwargs):
|
||
|
expected = pyfunc(*args, **kwargs)
|
||
|
result = f(*args, **kwargs)
|
||
|
self.assertPreciseEqual(result, expected)
|
||
|
f = jit(**self.jit_args)(pyfunc)
|
||
|
return f, check
|
||
|
|
||
|
|
||
|
class TestDispatcher(BaseTest):
|
||
|
|
||
|
def test_equality(self):
|
||
|
@jit
|
||
|
def foo(x):
|
||
|
return x
|
||
|
|
||
|
@jit
|
||
|
def bar(x):
|
||
|
return x
|
||
|
|
||
|
# Written this way to verify `==` returns a bool (gh-5838). Using
|
||
|
# `assertTrue(foo == foo)` or `assertEqual(foo, foo)` would defeat the
|
||
|
# purpose of this test.
|
||
|
self.assertEqual(foo == foo, True)
|
||
|
self.assertEqual(foo == bar, False)
|
||
|
self.assertEqual(foo == None, False) # noqa: E711
|
||
|
|
||
|
def test_dyn_pyfunc(self):
|
||
|
@jit
|
||
|
def foo(x):
|
||
|
return x
|
||
|
|
||
|
foo(1)
|
||
|
[cr] = foo.overloads.values()
|
||
|
# __module__ must be match that of foo
|
||
|
self.assertEqual(cr.entry_point.__module__, foo.py_func.__module__)
|
||
|
|
||
|
def test_no_argument(self):
|
||
|
@jit
|
||
|
def foo():
|
||
|
return 1
|
||
|
|
||
|
# Just make sure this doesn't crash
|
||
|
foo()
|
||
|
|
||
|
def test_coerce_input_types(self):
|
||
|
# Issue #486: do not allow unsafe conversions if we can still
|
||
|
# compile other specializations.
|
||
|
c_add = jit(nopython=True)(add)
|
||
|
self.assertPreciseEqual(c_add(123, 456), add(123, 456))
|
||
|
self.assertPreciseEqual(c_add(12.3, 45.6), add(12.3, 45.6))
|
||
|
self.assertPreciseEqual(c_add(12.3, 45.6j), add(12.3, 45.6j))
|
||
|
self.assertPreciseEqual(c_add(12300000000, 456), add(12300000000, 456))
|
||
|
|
||
|
# Now force compilation of only a single specialization
|
||
|
c_add = jit('(i4, i4)', nopython=True)(add)
|
||
|
self.assertPreciseEqual(c_add(123, 456), add(123, 456))
|
||
|
# Implicit (unsafe) conversion of float to int
|
||
|
self.assertPreciseEqual(c_add(12.3, 45.6), add(12, 45))
|
||
|
with self.assertRaises(TypeError):
|
||
|
# Implicit conversion of complex to int disallowed
|
||
|
c_add(12.3, 45.6j)
|
||
|
|
||
|
def test_ambiguous_new_version(self):
|
||
|
"""Test compiling new version in an ambiguous case
|
||
|
"""
|
||
|
@jit
|
||
|
def foo(a, b):
|
||
|
return a + b
|
||
|
|
||
|
INT = 1
|
||
|
FLT = 1.5
|
||
|
self.assertAlmostEqual(foo(INT, FLT), INT + FLT)
|
||
|
self.assertEqual(len(foo.overloads), 1)
|
||
|
self.assertAlmostEqual(foo(FLT, INT), FLT + INT)
|
||
|
self.assertEqual(len(foo.overloads), 2)
|
||
|
self.assertAlmostEqual(foo(FLT, FLT), FLT + FLT)
|
||
|
self.assertEqual(len(foo.overloads), 3)
|
||
|
# The following call is ambiguous because (int, int) can resolve
|
||
|
# to (float, int) or (int, float) with equal weight.
|
||
|
self.assertAlmostEqual(foo(1, 1), INT + INT)
|
||
|
self.assertEqual(len(foo.overloads), 4, "didn't compile a new "
|
||
|
"version")
|
||
|
|
||
|
def test_lock(self):
|
||
|
"""
|
||
|
Test that (lazy) compiling from several threads at once doesn't
|
||
|
produce errors (see issue #908).
|
||
|
"""
|
||
|
errors = []
|
||
|
|
||
|
@jit
|
||
|
def foo(x):
|
||
|
return x + 1
|
||
|
|
||
|
def wrapper():
|
||
|
try:
|
||
|
self.assertEqual(foo(1), 2)
|
||
|
except Exception as e:
|
||
|
errors.append(e)
|
||
|
|
||
|
threads = [threading.Thread(target=wrapper) for i in range(16)]
|
||
|
for t in threads:
|
||
|
t.start()
|
||
|
for t in threads:
|
||
|
t.join()
|
||
|
self.assertFalse(errors)
|
||
|
|
||
|
def test_explicit_signatures(self):
|
||
|
f = jit("(int64,int64)")(add)
|
||
|
# Approximate match (unsafe conversion)
|
||
|
self.assertPreciseEqual(f(1.5, 2.5), 3)
|
||
|
self.assertEqual(len(f.overloads), 1, f.overloads)
|
||
|
f = jit(["(int64,int64)", "(float64,float64)"])(add)
|
||
|
# Exact signature matches
|
||
|
self.assertPreciseEqual(f(1, 2), 3)
|
||
|
self.assertPreciseEqual(f(1.5, 2.5), 4.0)
|
||
|
# Approximate match (int32 -> float64 is a safe conversion)
|
||
|
self.assertPreciseEqual(f(np.int32(1), 2.5), 3.5)
|
||
|
# No conversion
|
||
|
with self.assertRaises(TypeError) as cm:
|
||
|
f(1j, 1j)
|
||
|
self.assertIn("No matching definition", str(cm.exception))
|
||
|
self.assertEqual(len(f.overloads), 2, f.overloads)
|
||
|
# A more interesting one...
|
||
|
f = jit(["(float32,float32)", "(float64,float64)"])(add)
|
||
|
self.assertPreciseEqual(f(np.float32(1), np.float32(2**-25)), 1.0)
|
||
|
self.assertPreciseEqual(f(1, 2**-25), 1.0000000298023224)
|
||
|
# Fail to resolve ambiguity between the two best overloads
|
||
|
f = jit(["(float32,float64)",
|
||
|
"(float64,float32)",
|
||
|
"(int64,int64)"])(add)
|
||
|
with self.assertRaises(TypeError) as cm:
|
||
|
f(1.0, 2.0)
|
||
|
# The two best matches are output in the error message, as well
|
||
|
# as the actual argument types.
|
||
|
self.assertRegex(
|
||
|
str(cm.exception),
|
||
|
r"Ambiguous overloading for <function add [^>]*> "
|
||
|
r"\(float64, float64\):\n"
|
||
|
r"\(float32, float64\) -> float64\n"
|
||
|
r"\(float64, float32\) -> float64"
|
||
|
)
|
||
|
# The integer signature is not part of the best matches
|
||
|
self.assertNotIn("int64", str(cm.exception))
|
||
|
|
||
|
def test_signature_mismatch(self):
|
||
|
tmpl = ("Signature mismatch: %d argument types given, but function "
|
||
|
"takes 2 arguments")
|
||
|
with self.assertRaises(TypeError) as cm:
|
||
|
jit("()")(add)
|
||
|
self.assertIn(tmpl % 0, str(cm.exception))
|
||
|
with self.assertRaises(TypeError) as cm:
|
||
|
jit("(intc,)")(add)
|
||
|
self.assertIn(tmpl % 1, str(cm.exception))
|
||
|
with self.assertRaises(TypeError) as cm:
|
||
|
jit("(intc,intc,intc)")(add)
|
||
|
self.assertIn(tmpl % 3, str(cm.exception))
|
||
|
# With forceobj=True, an empty tuple is accepted
|
||
|
jit("()", forceobj=True)(add)
|
||
|
with self.assertRaises(TypeError) as cm:
|
||
|
jit("(intc,)", forceobj=True)(add)
|
||
|
self.assertIn(tmpl % 1, str(cm.exception))
|
||
|
|
||
|
def test_matching_error_message(self):
|
||
|
f = jit("(intc,intc)")(add)
|
||
|
with self.assertRaises(TypeError) as cm:
|
||
|
f(1j, 1j)
|
||
|
self.assertEqual(str(cm.exception),
|
||
|
"No matching definition for argument type(s) "
|
||
|
"complex128, complex128")
|
||
|
|
||
|
def test_disabled_compilation(self):
|
||
|
@jit
|
||
|
def foo(a):
|
||
|
return a
|
||
|
|
||
|
foo.compile("(float32,)")
|
||
|
foo.disable_compile()
|
||
|
with self.assertRaises(RuntimeError) as raises:
|
||
|
foo.compile("(int32,)")
|
||
|
self.assertEqual(str(raises.exception), "compilation disabled")
|
||
|
self.assertEqual(len(foo.signatures), 1)
|
||
|
|
||
|
def test_disabled_compilation_through_list(self):
|
||
|
@jit(["(float32,)", "(int32,)"])
|
||
|
def foo(a):
|
||
|
return a
|
||
|
|
||
|
with self.assertRaises(RuntimeError) as raises:
|
||
|
foo.compile("(complex64,)")
|
||
|
self.assertEqual(str(raises.exception), "compilation disabled")
|
||
|
self.assertEqual(len(foo.signatures), 2)
|
||
|
|
||
|
def test_disabled_compilation_nested_call(self):
|
||
|
@jit(["(intp,)"])
|
||
|
def foo(a):
|
||
|
return a
|
||
|
|
||
|
@jit
|
||
|
def bar():
|
||
|
foo(1)
|
||
|
foo(np.ones(1)) # no matching definition
|
||
|
|
||
|
with self.assertRaises(errors.TypingError) as raises:
|
||
|
bar()
|
||
|
|
||
|
m = r".*Invalid use of.*with parameters \(array\(float64, 1d, C\)\).*"
|
||
|
self.assertRegex(str(raises.exception), m)
|
||
|
|
||
|
def test_fingerprint_failure(self):
|
||
|
"""
|
||
|
Failure in computing the fingerprint cannot affect a nopython=False
|
||
|
function. On the other hand, with nopython=True, a ValueError should
|
||
|
be raised to report the failure with fingerprint.
|
||
|
"""
|
||
|
def foo(x):
|
||
|
return x
|
||
|
|
||
|
# Empty list will trigger failure in compile_fingerprint
|
||
|
errmsg = 'cannot compute fingerprint of empty list'
|
||
|
with self.assertRaises(ValueError) as raises:
|
||
|
_dispatcher.compute_fingerprint([])
|
||
|
self.assertIn(errmsg, str(raises.exception))
|
||
|
# It should work in objmode
|
||
|
objmode_foo = jit(forceobj=True)(foo)
|
||
|
self.assertEqual(objmode_foo([]), [])
|
||
|
# But, not in nopython=True
|
||
|
strict_foo = jit(nopython=True)(foo)
|
||
|
with self.assertRaises(ValueError) as raises:
|
||
|
strict_foo([])
|
||
|
self.assertIn(errmsg, str(raises.exception))
|
||
|
|
||
|
# Test in loop lifting context
|
||
|
@jit(forceobj=True)
|
||
|
def bar():
|
||
|
object() # force looplifting
|
||
|
x = []
|
||
|
for i in range(10):
|
||
|
x = objmode_foo(x)
|
||
|
return x
|
||
|
|
||
|
self.assertEqual(bar(), [])
|
||
|
# Make sure it was looplifted
|
||
|
[cr] = bar.overloads.values()
|
||
|
self.assertEqual(len(cr.lifted), 1)
|
||
|
|
||
|
def test_serialization(self):
|
||
|
"""
|
||
|
Test serialization of Dispatcher objects
|
||
|
"""
|
||
|
@jit(nopython=True)
|
||
|
def foo(x):
|
||
|
return x + 1
|
||
|
|
||
|
self.assertEqual(foo(1), 2)
|
||
|
|
||
|
# get serialization memo
|
||
|
memo = Dispatcher._memo
|
||
|
Dispatcher._recent.clear()
|
||
|
memo_size = len(memo)
|
||
|
|
||
|
# pickle foo and check memo size
|
||
|
serialized_foo = pickle.dumps(foo)
|
||
|
# increases the memo size
|
||
|
self.assertEqual(memo_size + 1, len(memo))
|
||
|
|
||
|
# unpickle
|
||
|
foo_rebuilt = pickle.loads(serialized_foo)
|
||
|
self.assertEqual(memo_size + 1, len(memo))
|
||
|
|
||
|
self.assertIs(foo, foo_rebuilt)
|
||
|
|
||
|
# do we get the same object even if we delete all the explicit
|
||
|
# references?
|
||
|
id_orig = id(foo_rebuilt)
|
||
|
del foo
|
||
|
del foo_rebuilt
|
||
|
self.assertEqual(memo_size + 1, len(memo))
|
||
|
new_foo = pickle.loads(serialized_foo)
|
||
|
self.assertEqual(id_orig, id(new_foo))
|
||
|
|
||
|
# now clear the recent cache
|
||
|
ref = weakref.ref(new_foo)
|
||
|
del new_foo
|
||
|
Dispatcher._recent.clear()
|
||
|
self.assertEqual(memo_size, len(memo))
|
||
|
|
||
|
# show that deserializing creates a new object
|
||
|
pickle.loads(serialized_foo)
|
||
|
self.assertIs(ref(), None)
|
||
|
|
||
|
@needs_lapack
|
||
|
@unittest.skipIf(_is_armv7l, "Unaligned loads unsupported")
|
||
|
def test_misaligned_array_dispatch(self):
|
||
|
# for context see issue #2937
|
||
|
def foo(a):
|
||
|
return np.linalg.matrix_power(a, 1)
|
||
|
|
||
|
jitfoo = jit(nopython=True)(foo)
|
||
|
|
||
|
n = 64
|
||
|
r = int(np.sqrt(n))
|
||
|
dt = np.int8
|
||
|
count = np.complex128().itemsize // dt().itemsize
|
||
|
|
||
|
tmp = np.arange(n * count + 1, dtype=dt)
|
||
|
|
||
|
# create some arrays as Cartesian production of:
|
||
|
# [F/C] x [aligned/misaligned]
|
||
|
C_contig_aligned = tmp[:-1].view(np.complex128).reshape(r, r)
|
||
|
C_contig_misaligned = tmp[1:].view(np.complex128).reshape(r, r)
|
||
|
F_contig_aligned = C_contig_aligned.T
|
||
|
F_contig_misaligned = C_contig_misaligned.T
|
||
|
|
||
|
# checking routine
|
||
|
def check(name, a):
|
||
|
a[:, :] = np.arange(n, dtype=np.complex128).reshape(r, r)
|
||
|
expected = foo(a)
|
||
|
got = jitfoo(a)
|
||
|
np.testing.assert_allclose(expected, got)
|
||
|
|
||
|
# The checks must be run in this order to create the dispatch key
|
||
|
# sequence that causes invalid dispatch noted in #2937.
|
||
|
# The first two should hit the cache as they are aligned, supported
|
||
|
# order and under 5 dimensions. The second two should end up in the
|
||
|
# fallback path as they are misaligned.
|
||
|
check("C_contig_aligned", C_contig_aligned)
|
||
|
check("F_contig_aligned", F_contig_aligned)
|
||
|
check("C_contig_misaligned", C_contig_misaligned)
|
||
|
check("F_contig_misaligned", F_contig_misaligned)
|
||
|
|
||
|
@unittest.skipIf(_is_armv7l, "Unaligned loads unsupported")
|
||
|
def test_immutability_in_array_dispatch(self):
|
||
|
|
||
|
# RO operation in function
|
||
|
def foo(a):
|
||
|
return np.sum(a)
|
||
|
|
||
|
jitfoo = jit(nopython=True)(foo)
|
||
|
|
||
|
n = 64
|
||
|
r = int(np.sqrt(n))
|
||
|
dt = np.int8
|
||
|
count = np.complex128().itemsize // dt().itemsize
|
||
|
|
||
|
tmp = np.arange(n * count + 1, dtype=dt)
|
||
|
|
||
|
# create some arrays as Cartesian production of:
|
||
|
# [F/C] x [aligned/misaligned]
|
||
|
C_contig_aligned = tmp[:-1].view(np.complex128).reshape(r, r)
|
||
|
C_contig_misaligned = tmp[1:].view(np.complex128).reshape(r, r)
|
||
|
F_contig_aligned = C_contig_aligned.T
|
||
|
F_contig_misaligned = C_contig_misaligned.T
|
||
|
|
||
|
# checking routine
|
||
|
def check(name, a, disable_write_bit=False):
|
||
|
a[:, :] = np.arange(n, dtype=np.complex128).reshape(r, r)
|
||
|
if disable_write_bit:
|
||
|
a.flags.writeable = False
|
||
|
expected = foo(a)
|
||
|
got = jitfoo(a)
|
||
|
np.testing.assert_allclose(expected, got)
|
||
|
|
||
|
# all of these should end up in the fallback path as they have no write
|
||
|
# bit set
|
||
|
check("C_contig_aligned", C_contig_aligned, disable_write_bit=True)
|
||
|
check("F_contig_aligned", F_contig_aligned, disable_write_bit=True)
|
||
|
check("C_contig_misaligned", C_contig_misaligned,
|
||
|
disable_write_bit=True)
|
||
|
check("F_contig_misaligned", F_contig_misaligned,
|
||
|
disable_write_bit=True)
|
||
|
|
||
|
@needs_lapack
|
||
|
@unittest.skipIf(_is_armv7l, "Unaligned loads unsupported")
|
||
|
def test_misaligned_high_dimension_array_dispatch(self):
|
||
|
|
||
|
def foo(a):
|
||
|
return np.linalg.matrix_power(a[0, 0, 0, 0, :, :], 1)
|
||
|
|
||
|
jitfoo = jit(nopython=True)(foo)
|
||
|
|
||
|
def check_properties(arr, layout, aligned):
|
||
|
self.assertEqual(arr.flags.aligned, aligned)
|
||
|
if layout == "C":
|
||
|
self.assertEqual(arr.flags.c_contiguous, True)
|
||
|
if layout == "F":
|
||
|
self.assertEqual(arr.flags.f_contiguous, True)
|
||
|
|
||
|
n = 729
|
||
|
r = 3
|
||
|
dt = np.int8
|
||
|
count = np.complex128().itemsize // dt().itemsize
|
||
|
|
||
|
tmp = np.arange(n * count + 1, dtype=dt)
|
||
|
|
||
|
# create some arrays as Cartesian production of:
|
||
|
# [F/C] x [aligned/misaligned]
|
||
|
C_contig_aligned = tmp[:-1].view(np.complex128).\
|
||
|
reshape(r, r, r, r, r, r)
|
||
|
check_properties(C_contig_aligned, 'C', True)
|
||
|
C_contig_misaligned = tmp[1:].view(np.complex128).\
|
||
|
reshape(r, r, r, r, r, r)
|
||
|
check_properties(C_contig_misaligned, 'C', False)
|
||
|
F_contig_aligned = C_contig_aligned.T
|
||
|
check_properties(F_contig_aligned, 'F', True)
|
||
|
F_contig_misaligned = C_contig_misaligned.T
|
||
|
check_properties(F_contig_misaligned, 'F', False)
|
||
|
|
||
|
# checking routine
|
||
|
def check(name, a):
|
||
|
a[:, :] = np.arange(n, dtype=np.complex128).\
|
||
|
reshape(r, r, r, r, r, r)
|
||
|
expected = foo(a)
|
||
|
got = jitfoo(a)
|
||
|
np.testing.assert_allclose(expected, got)
|
||
|
|
||
|
# these should all hit the fallback path as the cache is only for up to
|
||
|
# 5 dimensions
|
||
|
check("F_contig_misaligned", F_contig_misaligned)
|
||
|
check("C_contig_aligned", C_contig_aligned)
|
||
|
check("F_contig_aligned", F_contig_aligned)
|
||
|
check("C_contig_misaligned", C_contig_misaligned)
|
||
|
|
||
|
def test_dispatch_recompiles_for_scalars(self):
|
||
|
# for context #3612, essentially, compiling a lambda x:x for a
|
||
|
# numerically wide type (everything can be converted to a complex128)
|
||
|
# and then calling again with e.g. an int32 would lead to the int32
|
||
|
# being converted to a complex128 whereas it ought to compile an int32
|
||
|
# specialization.
|
||
|
def foo(x):
|
||
|
return x
|
||
|
|
||
|
# jit and compile on dispatch for 3 scalar types, expect 3 signatures
|
||
|
jitfoo = jit(nopython=True)(foo)
|
||
|
jitfoo(np.complex128(1 + 2j))
|
||
|
jitfoo(np.int32(10))
|
||
|
jitfoo(np.bool_(False))
|
||
|
self.assertEqual(len(jitfoo.signatures), 3)
|
||
|
expected_sigs = [(types.complex128,), (types.int32,), (types.bool_,)]
|
||
|
self.assertEqual(jitfoo.signatures, expected_sigs)
|
||
|
|
||
|
# now jit with signatures so recompilation is forbidden
|
||
|
# expect 1 signature and type conversion
|
||
|
jitfoo = jit([(types.complex128,)], nopython=True)(foo)
|
||
|
jitfoo(np.complex128(1 + 2j))
|
||
|
jitfoo(np.int32(10))
|
||
|
jitfoo(np.bool_(False))
|
||
|
self.assertEqual(len(jitfoo.signatures), 1)
|
||
|
expected_sigs = [(types.complex128,)]
|
||
|
self.assertEqual(jitfoo.signatures, expected_sigs)
|
||
|
|
||
|
def test_dispatcher_raises_for_invalid_decoration(self):
|
||
|
# For context see https://github.com/numba/numba/issues/4750.
|
||
|
|
||
|
@jit(nopython=True)
|
||
|
def foo(x):
|
||
|
return x
|
||
|
|
||
|
with self.assertRaises(TypeError) as raises:
|
||
|
jit(foo)
|
||
|
err_msg = str(raises.exception)
|
||
|
self.assertIn(
|
||
|
"A jit decorator was called on an already jitted function", err_msg)
|
||
|
self.assertIn("foo", err_msg)
|
||
|
self.assertIn(".py_func", err_msg)
|
||
|
|
||
|
with self.assertRaises(TypeError) as raises:
|
||
|
jit(BaseTest)
|
||
|
err_msg = str(raises.exception)
|
||
|
self.assertIn("The decorated object is not a function", err_msg)
|
||
|
self.assertIn(f"{type(BaseTest)}", err_msg)
|
||
|
|
||
|
|
||
|
class TestSignatureHandling(BaseTest):
|
||
|
"""
|
||
|
Test support for various parameter passing styles.
|
||
|
"""
|
||
|
|
||
|
def test_named_args(self):
|
||
|
"""
|
||
|
Test passing named arguments to a dispatcher.
|
||
|
"""
|
||
|
f, check = self.compile_func(addsub)
|
||
|
check(3, z=10, y=4)
|
||
|
check(3, 4, 10)
|
||
|
check(x=3, y=4, z=10)
|
||
|
# All calls above fall under the same specialization
|
||
|
self.assertEqual(len(f.overloads), 1)
|
||
|
# Errors
|
||
|
with self.assertRaises(TypeError) as cm:
|
||
|
f(3, 4, y=6, z=7)
|
||
|
self.assertIn("too many arguments: expected 3, got 4",
|
||
|
str(cm.exception))
|
||
|
with self.assertRaises(TypeError) as cm:
|
||
|
f()
|
||
|
self.assertIn("not enough arguments: expected 3, got 0",
|
||
|
str(cm.exception))
|
||
|
with self.assertRaises(TypeError) as cm:
|
||
|
f(3, 4, y=6)
|
||
|
self.assertIn("missing argument 'z'", str(cm.exception))
|
||
|
|
||
|
def test_default_args(self):
|
||
|
"""
|
||
|
Test omitting arguments with a default value.
|
||
|
"""
|
||
|
f, check = self.compile_func(addsub_defaults)
|
||
|
check(3, z=10, y=4)
|
||
|
check(3, 4, 10)
|
||
|
check(x=3, y=4, z=10)
|
||
|
# Now omitting some values
|
||
|
check(3, z=10)
|
||
|
check(3, 4)
|
||
|
check(x=3, y=4)
|
||
|
check(3)
|
||
|
check(x=3)
|
||
|
# Errors
|
||
|
with self.assertRaises(TypeError) as cm:
|
||
|
f(3, 4, y=6, z=7)
|
||
|
self.assertIn("too many arguments: expected 3, got 4",
|
||
|
str(cm.exception))
|
||
|
with self.assertRaises(TypeError) as cm:
|
||
|
f()
|
||
|
self.assertIn("not enough arguments: expected at least 1, got 0",
|
||
|
str(cm.exception))
|
||
|
with self.assertRaises(TypeError) as cm:
|
||
|
f(y=6, z=7)
|
||
|
self.assertIn("missing argument 'x'", str(cm.exception))
|
||
|
|
||
|
def test_star_args(self):
|
||
|
"""
|
||
|
Test a compiled function with starargs in the signature.
|
||
|
"""
|
||
|
f, check = self.compile_func(star_defaults)
|
||
|
check(4)
|
||
|
check(4, 5)
|
||
|
check(4, 5, 6)
|
||
|
check(4, 5, 6, 7)
|
||
|
check(4, 5, 6, 7, 8)
|
||
|
check(x=4)
|
||
|
check(x=4, y=5)
|
||
|
check(4, y=5)
|
||
|
with self.assertRaises(TypeError) as cm:
|
||
|
f(4, 5, y=6)
|
||
|
self.assertIn("some keyword arguments unexpected", str(cm.exception))
|
||
|
with self.assertRaises(TypeError) as cm:
|
||
|
f(4, 5, z=6)
|
||
|
self.assertIn("some keyword arguments unexpected", str(cm.exception))
|
||
|
with self.assertRaises(TypeError) as cm:
|
||
|
f(4, x=6)
|
||
|
self.assertIn("some keyword arguments unexpected", str(cm.exception))
|
||
|
|
||
|
|
||
|
class TestSignatureHandlingObjectMode(TestSignatureHandling):
|
||
|
"""
|
||
|
Sams as TestSignatureHandling, but in object mode.
|
||
|
"""
|
||
|
|
||
|
jit_args = dict(forceobj=True)
|
||
|
|
||
|
|
||
|
class TestDispatcherMethods(TestCase):
|
||
|
|
||
|
def test_recompile(self):
|
||
|
closure = 1
|
||
|
|
||
|
@jit
|
||
|
def foo(x):
|
||
|
return x + closure
|
||
|
self.assertPreciseEqual(foo(1), 2)
|
||
|
self.assertPreciseEqual(foo(1.5), 2.5)
|
||
|
self.assertEqual(len(foo.signatures), 2)
|
||
|
closure = 2
|
||
|
self.assertPreciseEqual(foo(1), 2)
|
||
|
# Recompiling takes the new closure into account.
|
||
|
foo.recompile()
|
||
|
# Everything was recompiled
|
||
|
self.assertEqual(len(foo.signatures), 2)
|
||
|
self.assertPreciseEqual(foo(1), 3)
|
||
|
self.assertPreciseEqual(foo(1.5), 3.5)
|
||
|
|
||
|
def test_recompile_signatures(self):
|
||
|
# Same as above, but with an explicit signature on @jit.
|
||
|
closure = 1
|
||
|
|
||
|
@jit("int32(int32)")
|
||
|
def foo(x):
|
||
|
return x + closure
|
||
|
self.assertPreciseEqual(foo(1), 2)
|
||
|
self.assertPreciseEqual(foo(1.5), 2)
|
||
|
closure = 2
|
||
|
self.assertPreciseEqual(foo(1), 2)
|
||
|
# Recompiling takes the new closure into account.
|
||
|
foo.recompile()
|
||
|
self.assertPreciseEqual(foo(1), 3)
|
||
|
self.assertPreciseEqual(foo(1.5), 3)
|
||
|
|
||
|
def test_inspect_llvm(self):
|
||
|
# Create a jited function
|
||
|
@jit
|
||
|
def foo(explicit_arg1, explicit_arg2):
|
||
|
return explicit_arg1 + explicit_arg2
|
||
|
|
||
|
# Call it in a way to create 3 signatures
|
||
|
foo(1, 1)
|
||
|
foo(1.0, 1)
|
||
|
foo(1.0, 1.0)
|
||
|
|
||
|
# base call to get all llvm in a dict
|
||
|
llvms = foo.inspect_llvm()
|
||
|
self.assertEqual(len(llvms), 3)
|
||
|
|
||
|
# make sure the function name shows up in the llvm
|
||
|
for llvm_bc in llvms.values():
|
||
|
# Look for the function name
|
||
|
self.assertIn("foo", llvm_bc)
|
||
|
|
||
|
# Look for the argument names
|
||
|
self.assertIn("explicit_arg1", llvm_bc)
|
||
|
self.assertIn("explicit_arg2", llvm_bc)
|
||
|
|
||
|
def test_inspect_asm(self):
|
||
|
# Create a jited function
|
||
|
@jit
|
||
|
def foo(explicit_arg1, explicit_arg2):
|
||
|
return explicit_arg1 + explicit_arg2
|
||
|
|
||
|
# Call it in a way to create 3 signatures
|
||
|
foo(1, 1)
|
||
|
foo(1.0, 1)
|
||
|
foo(1.0, 1.0)
|
||
|
|
||
|
# base call to get all llvm in a dict
|
||
|
asms = foo.inspect_asm()
|
||
|
self.assertEqual(len(asms), 3)
|
||
|
|
||
|
# make sure the function name shows up in the llvm
|
||
|
for asm in asms.values():
|
||
|
# Look for the function name
|
||
|
self.assertTrue("foo" in asm)
|
||
|
|
||
|
def _check_cfg_display(self, cfg, wrapper=''):
|
||
|
# simple stringify test
|
||
|
if wrapper:
|
||
|
wrapper = "{}{}".format(len(wrapper), wrapper)
|
||
|
module_name = __name__.split('.', 1)[0]
|
||
|
module_len = len(module_name)
|
||
|
prefix = r'^digraph "CFG for \'_ZN{}{}{}'.format(wrapper,
|
||
|
module_len,
|
||
|
module_name)
|
||
|
self.assertRegex(str(cfg), prefix)
|
||
|
# .display() requires an optional dependency on `graphviz`.
|
||
|
# just test for the attribute without running it.
|
||
|
self.assertTrue(callable(cfg.display))
|
||
|
|
||
|
def test_inspect_cfg(self):
|
||
|
# Exercise the .inspect_cfg(). These are minimal tests and do not fully
|
||
|
# check the correctness of the function.
|
||
|
@jit
|
||
|
def foo(the_array):
|
||
|
return the_array.sum()
|
||
|
|
||
|
# Generate 3 overloads
|
||
|
a1 = np.ones(1)
|
||
|
a2 = np.ones((1, 1))
|
||
|
a3 = np.ones((1, 1, 1))
|
||
|
foo(a1)
|
||
|
foo(a2)
|
||
|
foo(a3)
|
||
|
|
||
|
# Call inspect_cfg() without arguments
|
||
|
cfgs = foo.inspect_cfg()
|
||
|
|
||
|
# Correct count of overloads
|
||
|
self.assertEqual(len(cfgs), 3)
|
||
|
|
||
|
# Makes sure all the signatures are correct
|
||
|
[s1, s2, s3] = cfgs.keys()
|
||
|
self.assertEqual(set([s1, s2, s3]),
|
||
|
set(map(lambda x: (typeof(x),), [a1, a2, a3])))
|
||
|
|
||
|
for cfg in cfgs.values():
|
||
|
self._check_cfg_display(cfg)
|
||
|
self.assertEqual(len(list(cfgs.values())), 3)
|
||
|
|
||
|
# Call inspect_cfg(signature)
|
||
|
cfg = foo.inspect_cfg(signature=foo.signatures[0])
|
||
|
self._check_cfg_display(cfg)
|
||
|
|
||
|
def test_inspect_cfg_with_python_wrapper(self):
|
||
|
# Exercise the .inspect_cfg() including the python wrapper.
|
||
|
# These are minimal tests and do not fully check the correctness of
|
||
|
# the function.
|
||
|
@jit
|
||
|
def foo(the_array):
|
||
|
return the_array.sum()
|
||
|
|
||
|
# Generate 3 overloads
|
||
|
a1 = np.ones(1)
|
||
|
a2 = np.ones((1, 1))
|
||
|
a3 = np.ones((1, 1, 1))
|
||
|
foo(a1)
|
||
|
foo(a2)
|
||
|
foo(a3)
|
||
|
|
||
|
# Call inspect_cfg(signature, show_wrapper="python")
|
||
|
cfg = foo.inspect_cfg(signature=foo.signatures[0],
|
||
|
show_wrapper="python")
|
||
|
self._check_cfg_display(cfg, wrapper='cpython')
|
||
|
|
||
|
def test_inspect_types(self):
|
||
|
@jit
|
||
|
def foo(a, b):
|
||
|
return a + b
|
||
|
|
||
|
foo(1, 2)
|
||
|
# Exercise the method
|
||
|
foo.inspect_types(StringIO())
|
||
|
|
||
|
# Test output
|
||
|
expected = str(foo.overloads[foo.signatures[0]].type_annotation)
|
||
|
with captured_stdout() as out:
|
||
|
foo.inspect_types()
|
||
|
assert expected in out.getvalue()
|
||
|
|
||
|
def test_inspect_types_with_signature(self):
|
||
|
@jit
|
||
|
def foo(a):
|
||
|
return a + 1
|
||
|
|
||
|
foo(1)
|
||
|
foo(1.0)
|
||
|
# Inspect all signatures
|
||
|
with captured_stdout() as total:
|
||
|
foo.inspect_types()
|
||
|
# Inspect first signature
|
||
|
with captured_stdout() as first:
|
||
|
foo.inspect_types(signature=foo.signatures[0])
|
||
|
# Inspect second signature
|
||
|
with captured_stdout() as second:
|
||
|
foo.inspect_types(signature=foo.signatures[1])
|
||
|
|
||
|
self.assertEqual(total.getvalue(), first.getvalue() + second.getvalue())
|
||
|
|
||
|
@unittest.skipIf(jinja2 is None, "please install the 'jinja2' package")
|
||
|
@unittest.skipIf(pygments is None, "please install the 'pygments' package")
|
||
|
def test_inspect_types_pretty(self):
|
||
|
@jit
|
||
|
def foo(a, b):
|
||
|
return a + b
|
||
|
|
||
|
foo(1, 2)
|
||
|
|
||
|
# Exercise the method, dump the output
|
||
|
with captured_stdout():
|
||
|
ann = foo.inspect_types(pretty=True)
|
||
|
|
||
|
# ensure HTML <span> is found in the annotation output
|
||
|
for k, v in ann.ann.items():
|
||
|
span_found = False
|
||
|
for line in v['pygments_lines']:
|
||
|
if 'span' in line[2]:
|
||
|
span_found = True
|
||
|
self.assertTrue(span_found)
|
||
|
|
||
|
# check that file+pretty kwarg combo raises
|
||
|
with self.assertRaises(ValueError) as raises:
|
||
|
foo.inspect_types(file=StringIO(), pretty=True)
|
||
|
|
||
|
self.assertIn("`file` must be None if `pretty=True`",
|
||
|
str(raises.exception))
|
||
|
|
||
|
def test_get_annotation_info(self):
|
||
|
@jit
|
||
|
def foo(a):
|
||
|
return a + 1
|
||
|
|
||
|
foo(1)
|
||
|
foo(1.3)
|
||
|
|
||
|
expected = dict(chain.from_iterable(foo.get_annotation_info(i).items()
|
||
|
for i in foo.signatures))
|
||
|
result = foo.get_annotation_info()
|
||
|
self.assertEqual(expected, result)
|
||
|
|
||
|
def test_issue_with_array_layout_conflict(self):
|
||
|
"""
|
||
|
This test an issue with the dispatcher when an array that is both
|
||
|
C and F contiguous is supplied as the first signature.
|
||
|
The dispatcher checks for F contiguous first but the compiler checks
|
||
|
for C contiguous first. This results in an C contiguous code inserted
|
||
|
as F contiguous function.
|
||
|
"""
|
||
|
def pyfunc(A, i, j):
|
||
|
return A[i, j]
|
||
|
|
||
|
cfunc = jit(pyfunc)
|
||
|
|
||
|
ary_c_and_f = np.array([[1.]])
|
||
|
ary_c = np.array([[0., 1.], [2., 3.]], order='C')
|
||
|
ary_f = np.array([[0., 1.], [2., 3.]], order='F')
|
||
|
|
||
|
exp_c = pyfunc(ary_c, 1, 0)
|
||
|
exp_f = pyfunc(ary_f, 1, 0)
|
||
|
|
||
|
self.assertEqual(1., cfunc(ary_c_and_f, 0, 0))
|
||
|
got_c = cfunc(ary_c, 1, 0)
|
||
|
got_f = cfunc(ary_f, 1, 0)
|
||
|
|
||
|
self.assertEqual(exp_c, got_c)
|
||
|
self.assertEqual(exp_f, got_f)
|
||
|
|
||
|
|
||
|
class TestDispatcherFunctionBoundaries(TestCase):
|
||
|
def test_pass_dispatcher_as_arg(self):
|
||
|
# Test that a Dispatcher object can be pass as argument
|
||
|
@jit(nopython=True)
|
||
|
def add1(x):
|
||
|
return x + 1
|
||
|
|
||
|
@jit(nopython=True)
|
||
|
def bar(fn, x):
|
||
|
return fn(x)
|
||
|
|
||
|
@jit(nopython=True)
|
||
|
def foo(x):
|
||
|
return bar(add1, x)
|
||
|
|
||
|
# Check dispatcher as argument inside NPM
|
||
|
inputs = [1, 11.1, np.arange(10)]
|
||
|
expected_results = [x + 1 for x in inputs]
|
||
|
|
||
|
for arg, expect in zip(inputs, expected_results):
|
||
|
self.assertPreciseEqual(foo(arg), expect)
|
||
|
|
||
|
# Check dispatcher as argument from python
|
||
|
for arg, expect in zip(inputs, expected_results):
|
||
|
self.assertPreciseEqual(bar(add1, arg), expect)
|
||
|
|
||
|
def test_dispatcher_as_arg_usecase(self):
|
||
|
@jit(nopython=True)
|
||
|
def maximum(seq, cmpfn):
|
||
|
tmp = seq[0]
|
||
|
for each in seq[1:]:
|
||
|
cmpval = cmpfn(tmp, each)
|
||
|
if cmpval < 0:
|
||
|
tmp = each
|
||
|
return tmp
|
||
|
|
||
|
got = maximum([1, 2, 3, 4], cmpfn=jit(lambda x, y: x - y))
|
||
|
self.assertEqual(got, 4)
|
||
|
got = maximum(list(zip(range(5), range(5)[::-1])),
|
||
|
cmpfn=jit(lambda x, y: x[0] - y[0]))
|
||
|
self.assertEqual(got, (4, 0))
|
||
|
got = maximum(list(zip(range(5), range(5)[::-1])),
|
||
|
cmpfn=jit(lambda x, y: x[1] - y[1]))
|
||
|
self.assertEqual(got, (0, 4))
|
||
|
|
||
|
def test_dispatcher_can_return_to_python(self):
|
||
|
@jit(nopython=True)
|
||
|
def foo(fn):
|
||
|
return fn
|
||
|
|
||
|
fn = jit(lambda x: x)
|
||
|
|
||
|
self.assertEqual(foo(fn), fn)
|
||
|
|
||
|
def test_dispatcher_in_sequence_arg(self):
|
||
|
@jit(nopython=True)
|
||
|
def one(x):
|
||
|
return x + 1
|
||
|
|
||
|
@jit(nopython=True)
|
||
|
def two(x):
|
||
|
return one(one(x))
|
||
|
|
||
|
@jit(nopython=True)
|
||
|
def three(x):
|
||
|
return one(one(one(x)))
|
||
|
|
||
|
@jit(nopython=True)
|
||
|
def choose(fns, x):
|
||
|
return fns[0](x), fns[1](x), fns[2](x)
|
||
|
|
||
|
# Tuple case
|
||
|
self.assertEqual(choose((one, two, three), 1), (2, 3, 4))
|
||
|
# List case
|
||
|
self.assertEqual(choose([one, one, one], 1), (2, 2, 2))
|
||
|
|
||
|
|
||
|
class TestBoxingDefaultError(unittest.TestCase):
|
||
|
# Testing default error at boxing/unboxing
|
||
|
def test_unbox_runtime_error(self):
|
||
|
# Dummy type has no unbox support
|
||
|
def foo(x):
|
||
|
pass
|
||
|
argtys = (types.Dummy("dummy_type"),)
|
||
|
# This needs `compile_isolated`-like behaviour so as to bypass
|
||
|
# dispatcher type checking logic
|
||
|
cres = njit(argtys)(foo).overloads[argtys]
|
||
|
with self.assertRaises(TypeError) as raises:
|
||
|
# Can pass in whatever and the unbox logic will always raise
|
||
|
# without checking the input value.
|
||
|
cres.entry_point(None)
|
||
|
self.assertEqual(str(raises.exception), "can't unbox dummy_type type")
|
||
|
|
||
|
def test_box_runtime_error(self):
|
||
|
@njit
|
||
|
def foo():
|
||
|
return unittest # Module type has no boxing logic
|
||
|
with self.assertRaises(TypeError) as raises:
|
||
|
foo()
|
||
|
pat = "cannot convert native Module.* to Python object"
|
||
|
self.assertRegex(str(raises.exception), pat)
|
||
|
|
||
|
|
||
|
class TestNoRetryFailedSignature(unittest.TestCase):
|
||
|
"""Test that failed-to-compile signatures are not recompiled.
|
||
|
"""
|
||
|
|
||
|
def run_test(self, func):
|
||
|
fcom = func._compiler
|
||
|
self.assertEqual(len(fcom._failed_cache), 0)
|
||
|
# expected failure because `int` has no `__getitem__`
|
||
|
with self.assertRaises(errors.TypingError):
|
||
|
func(1)
|
||
|
self.assertEqual(len(fcom._failed_cache), 1)
|
||
|
# retry
|
||
|
with self.assertRaises(errors.TypingError):
|
||
|
func(1)
|
||
|
self.assertEqual(len(fcom._failed_cache), 1)
|
||
|
# retry with double
|
||
|
with self.assertRaises(errors.TypingError):
|
||
|
func(1.0)
|
||
|
self.assertEqual(len(fcom._failed_cache), 2)
|
||
|
|
||
|
def test_direct_call(self):
|
||
|
@jit(nopython=True)
|
||
|
def foo(x):
|
||
|
return x[0]
|
||
|
|
||
|
self.run_test(foo)
|
||
|
|
||
|
def test_nested_call(self):
|
||
|
@jit(nopython=True)
|
||
|
def bar(x):
|
||
|
return x[0]
|
||
|
|
||
|
@jit(nopython=True)
|
||
|
def foobar(x):
|
||
|
bar(x)
|
||
|
|
||
|
@jit(nopython=True)
|
||
|
def foo(x):
|
||
|
return bar(x) + foobar(x)
|
||
|
|
||
|
self.run_test(foo)
|
||
|
|
||
|
@unittest.expectedFailure
|
||
|
# NOTE: @overload does not have an error cache. See PR #9259 for this
|
||
|
# feature and remove the xfail once this is merged.
|
||
|
def test_error_count(self):
|
||
|
def check(field, would_fail):
|
||
|
# Slightly modified from the reproducer in issue #4117.
|
||
|
# Before the patch, the compilation time of the failing case is
|
||
|
# much longer than of the successful case. This can be detected
|
||
|
# by the number of times `trigger()` is visited.
|
||
|
k = 10
|
||
|
counter = {'c': 0}
|
||
|
|
||
|
def trigger(x):
|
||
|
assert 0, "unreachable"
|
||
|
|
||
|
@overload(trigger)
|
||
|
def ol_trigger(x):
|
||
|
# Keep track of every visit
|
||
|
counter['c'] += 1
|
||
|
if would_fail:
|
||
|
raise errors.TypingError("invoke_failed")
|
||
|
return lambda x: x
|
||
|
|
||
|
@jit(nopython=True)
|
||
|
def ident(out, x):
|
||
|
pass
|
||
|
|
||
|
def chain_assign(fs, inner=ident):
|
||
|
tab_head, tab_tail = fs[-1], fs[:-1]
|
||
|
|
||
|
@jit(nopython=True)
|
||
|
def assign(out, x):
|
||
|
inner(out, x)
|
||
|
out[0] += tab_head(x)
|
||
|
|
||
|
if tab_tail:
|
||
|
return chain_assign(tab_tail, assign)
|
||
|
else:
|
||
|
return assign
|
||
|
|
||
|
chain = chain_assign((trigger,) * k)
|
||
|
out = np.ones(2)
|
||
|
if would_fail:
|
||
|
with self.assertRaises(errors.TypingError) as raises:
|
||
|
chain(out, 1)
|
||
|
self.assertIn('invoke_failed', str(raises.exception))
|
||
|
else:
|
||
|
chain(out, 1)
|
||
|
|
||
|
# Returns the visit counts
|
||
|
return counter['c']
|
||
|
|
||
|
ct_ok = check('a', False)
|
||
|
ct_bad = check('c', True)
|
||
|
# `trigger()` is visited exactly once for both successful and failed
|
||
|
# compilation.
|
||
|
self.assertEqual(ct_ok, 1)
|
||
|
self.assertEqual(ct_bad, 1)
|
||
|
|
||
|
|
||
|
@njit
|
||
|
def add_y1(x, y=1):
|
||
|
return x + y
|
||
|
|
||
|
|
||
|
@njit
|
||
|
def add_ynone(x, y=None):
|
||
|
return x + (1 if y else 2)
|
||
|
|
||
|
|
||
|
@njit
|
||
|
def mult(x, y):
|
||
|
return x * y
|
||
|
|
||
|
|
||
|
@njit
|
||
|
def add_func(x, func=mult):
|
||
|
return x + func(x, x)
|
||
|
|
||
|
|
||
|
def _checker(f1, arg):
|
||
|
assert f1(arg) == f1.py_func(arg)
|
||
|
|
||
|
|
||
|
class TestMultiprocessingDefaultParameters(SerialMixin, unittest.TestCase):
|
||
|
def run_fc_multiproc(self, fc):
|
||
|
try:
|
||
|
ctx = multiprocessing.get_context('spawn')
|
||
|
except AttributeError:
|
||
|
ctx = multiprocessing
|
||
|
|
||
|
# RE: issue #5973, this doesn't use multiprocessing.Pool.map as doing so
|
||
|
# causes the TBB library to segfault under certain conditions. It's not
|
||
|
# clear whether the cause is something in the complexity of the Pool
|
||
|
# itself, e.g. watcher threads etc, or if it's a problem synonymous with
|
||
|
# a "timing attack".
|
||
|
for a in [1, 2, 3]:
|
||
|
p = ctx.Process(target=_checker, args=(fc, a,))
|
||
|
p.start()
|
||
|
p.join(_TEST_TIMEOUT)
|
||
|
self.assertEqual(p.exitcode, 0)
|
||
|
|
||
|
def test_int_def_param(self):
|
||
|
""" Tests issue #4888"""
|
||
|
|
||
|
self.run_fc_multiproc(add_y1)
|
||
|
|
||
|
def test_none_def_param(self):
|
||
|
""" Tests None as a default parameter"""
|
||
|
|
||
|
self.run_fc_multiproc(add_func)
|
||
|
|
||
|
def test_function_def_param(self):
|
||
|
""" Tests a function as a default parameter"""
|
||
|
|
||
|
self.run_fc_multiproc(add_func)
|
||
|
|
||
|
|
||
|
class TestVectorizeDifferentTargets(unittest.TestCase):
|
||
|
"""Test that vectorize can be reapplied if the target is different
|
||
|
"""
|
||
|
|
||
|
def test_cpu_vs_parallel(self):
|
||
|
@jit
|
||
|
def add(x, y):
|
||
|
return x + y
|
||
|
|
||
|
custom_vectorize = vectorize([], identity=None, target='cpu')
|
||
|
|
||
|
custom_vectorize(add)
|
||
|
|
||
|
custom_vectorize_2 = vectorize([], identity=None, target='parallel')
|
||
|
|
||
|
custom_vectorize_2(add)
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
unittest.main()
|