ai-content-maker/.venv/Lib/site-packages/numba/tests/test_extending.py

2249 lines
67 KiB
Python
Raw Normal View History

2024-05-03 04:18:51 +03:00
import inspect
import math
import operator
import sys
import pickle
import multiprocessing
import ctypes
import warnings
import re
import numpy as np
from llvmlite import ir
import numba
from numba import njit, jit, vectorize, guvectorize, objmode
from numba.core import types, errors, typing, compiler, cgutils
from numba.core.typed_passes import type_inference_stage
from numba.core.registry import cpu_target
from numba.core.imputils import lower_constant
from numba.tests.support import (
TestCase,
captured_stdout,
temp_directory,
override_config,
run_in_new_process_in_cache_dir,
skip_if_typeguard,
)
from numba.core.errors import LoweringError
import unittest
from numba.extending import (
typeof_impl,
type_callable,
lower_builtin,
lower_cast,
overload,
overload_attribute,
overload_method,
models,
register_model,
box,
unbox,
NativeValue,
intrinsic,
_Intrinsic,
register_jitable,
get_cython_function_address,
is_jitted,
overload_classmethod,
)
from numba.core.typing.templates import (
ConcreteTemplate,
signature,
infer,
infer_global,
AbstractTemplate,
)
# Pandas-like API implementation
from .pdlike_usecase import Index, Series
try:
import scipy.special.cython_special as sc
except ImportError:
sc = None
# -----------------------------------------------------------------------
# Define a custom type and an implicit cast on it
class MyDummy(object):
pass
class MyDummyType(types.Opaque):
def can_convert_to(self, context, toty):
if isinstance(toty, types.Number):
from numba.core.typeconv import Conversion
return Conversion.safe
mydummy_type = MyDummyType("mydummy")
mydummy = MyDummy()
@typeof_impl.register(MyDummy)
def typeof_mydummy(val, c):
return mydummy_type
@lower_cast(MyDummyType, types.Number)
def mydummy_to_number(context, builder, fromty, toty, val):
"""
Implicit conversion from MyDummy to int.
"""
return context.get_constant(toty, 42)
def get_dummy():
return mydummy
register_model(MyDummyType)(models.OpaqueModel)
@unbox(MyDummyType)
def unbox_index(typ, obj, c):
return NativeValue(c.context.get_dummy_value())
# -----------------------------------------------------------------------
# Define a second custom type but w/o implicit cast to Number
def base_dummy_type_factory(name):
class DynType(object):
pass
class DynTypeType(types.Opaque):
pass
dyn_type_type = DynTypeType(name)
@typeof_impl.register(DynType)
def typeof_mydummy(val, c):
return dyn_type_type
register_model(DynTypeType)(models.OpaqueModel)
return DynTypeType, DynType, dyn_type_type
MyDummyType2, MyDummy2, mydummy_type_2 = base_dummy_type_factory("mydummy2")
@unbox(MyDummyType2)
def unbox_index2(typ, obj, c):
return NativeValue(c.context.get_dummy_value())
# -----------------------------------------------------------------------
# Define a function's typing and implementation using the classical
# two-step API
def func1(x=None):
raise NotImplementedError
def type_func1_(context):
def typer(x=None):
if x in (None, types.none):
# 0-arg or 1-arg with None
return types.int32
elif isinstance(x, types.Float):
# 1-arg with float
return x
return typer
type_func1 = type_callable(func1)(type_func1_)
@lower_builtin(func1)
@lower_builtin(func1, types.none)
def func1_nullary(context, builder, sig, args):
return context.get_constant(sig.return_type, 42)
@lower_builtin(func1, types.Float)
def func1_unary(context, builder, sig, args):
def func1_impl(x):
return math.sqrt(2 * x)
return context.compile_internal(builder, func1_impl, sig, args)
# We can do the same for a known internal operation, here "print_item"
# which we extend to support MyDummyType.
@infer
class PrintDummy(ConcreteTemplate):
key = "print_item"
cases = [signature(types.none, mydummy_type)]
@lower_builtin("print_item", MyDummyType)
def print_dummy(context, builder, sig, args):
[x] = args
pyapi = context.get_python_api(builder)
strobj = pyapi.unserialize(pyapi.serialize_object("hello!"))
pyapi.print_object(strobj)
pyapi.decref(strobj)
return context.get_dummy_value()
# -----------------------------------------------------------------------
# Define an overloaded function (combined API)
def where(cond, x, y):
raise NotImplementedError
def np_where(cond, x, y):
"""
Wrap np.where() to allow for keyword arguments
"""
return np.where(cond, x, y)
def call_where(cond, x, y):
return where(cond, y=y, x=x)
@overload(where)
def overload_where_arrays(cond, x, y):
"""
Implement where() for arrays.
"""
# Choose implementation based on argument types.
if isinstance(cond, types.Array):
if x.dtype != y.dtype:
raise errors.TypingError("x and y should have the same dtype")
# Array where() => return an array of the same shape
if all(ty.layout == "C" for ty in (cond, x, y)):
def where_impl(cond, x, y):
"""
Fast implementation for C-contiguous arrays
"""
shape = cond.shape
if x.shape != shape or y.shape != shape:
raise ValueError("all inputs should have the same shape")
res = np.empty_like(x)
cf = cond.flat
xf = x.flat
yf = y.flat
rf = res.flat
for i in range(cond.size):
rf[i] = xf[i] if cf[i] else yf[i]
return res
else:
def where_impl(cond, x, y):
"""
Generic implementation for other arrays
"""
shape = cond.shape
if x.shape != shape or y.shape != shape:
raise ValueError("all inputs should have the same shape")
res = np.empty_like(x)
for idx, c in np.ndenumerate(cond):
res[idx] = x[idx] if c else y[idx]
return res
return where_impl
# We can define another overload function for the same function, they
# will be tried in turn until one succeeds.
@overload(where)
def overload_where_scalars(cond, x, y):
"""
Implement where() for scalars.
"""
if not isinstance(cond, types.Array):
if x != y:
raise errors.TypingError("x and y should have the same type")
def where_impl(cond, x, y):
"""
Scalar where() => return a 0-dim array
"""
scal = x if cond else y
# Can't use full_like() on Numpy < 1.8
arr = np.empty_like(scal)
arr[()] = scal
return arr
return where_impl
# -----------------------------------------------------------------------
# Overload an already defined built-in function, extending it for new types.
@overload(len)
def overload_len_dummy(arg):
if isinstance(arg, MyDummyType):
def len_impl(arg):
return 13
return len_impl
@overload(operator.add)
def overload_add_dummy(arg1, arg2):
if isinstance(arg1, (MyDummyType, MyDummyType2)) and isinstance(
arg2, (MyDummyType, MyDummyType2)
):
def dummy_add_impl(arg1, arg2):
return 42
return dummy_add_impl
@overload(operator.delitem)
def overload_dummy_delitem(obj, idx):
if isinstance(obj, MyDummyType) and isinstance(idx, types.Integer):
def dummy_delitem_impl(obj, idx):
print("del", obj, idx)
return dummy_delitem_impl
@overload(operator.getitem)
def overload_dummy_getitem(obj, idx):
if isinstance(obj, MyDummyType) and isinstance(idx, types.Integer):
def dummy_getitem_impl(obj, idx):
return idx + 123
return dummy_getitem_impl
@overload(operator.setitem)
def overload_dummy_setitem(obj, idx, val):
if all(
[
isinstance(obj, MyDummyType),
isinstance(idx, types.Integer),
isinstance(val, types.Integer),
]
):
def dummy_setitem_impl(obj, idx, val):
print(idx, val)
return dummy_setitem_impl
def call_add_operator(arg1, arg2):
return operator.add(arg1, arg2)
def call_add_binop(arg1, arg2):
return arg1 + arg2
@overload(operator.iadd)
def overload_iadd_dummy(arg1, arg2):
if isinstance(arg1, (MyDummyType, MyDummyType2)) and isinstance(
arg2, (MyDummyType, MyDummyType2)
):
def dummy_iadd_impl(arg1, arg2):
return 42
return dummy_iadd_impl
def call_iadd_operator(arg1, arg2):
return operator.add(arg1, arg2)
def call_iadd_binop(arg1, arg2):
arg1 += arg2
return arg1
def call_delitem(obj, idx):
del obj[idx]
def call_getitem(obj, idx):
return obj[idx]
def call_setitem(obj, idx, val):
obj[idx] = val
@overload_method(MyDummyType, "length")
def overload_method_length(arg):
def imp(arg):
return len(arg)
return imp
def cache_overload_method_usecase(x):
return x.length()
def call_func1_nullary():
return func1()
def call_func1_unary(x):
return func1(x)
def len_usecase(x):
return len(x)
def print_usecase(x):
print(x)
def getitem_usecase(x, key):
return x[key]
def npyufunc_usecase(x):
return np.cos(np.sin(x))
def get_data_usecase(x):
return x._data
def get_index_usecase(x):
return x._index
def is_monotonic_usecase(x):
return x.is_monotonic_increasing
def make_series_usecase(data, index):
return Series(data, index)
def clip_usecase(x, lo, hi):
return x.clip(lo, hi)
# -----------------------------------------------------------------------
def return_non_boxable():
return np
@overload(return_non_boxable)
def overload_return_non_boxable():
def imp():
return np
return imp
def non_boxable_ok_usecase(sz):
mod = return_non_boxable()
return mod.arange(sz)
def non_boxable_bad_usecase():
return return_non_boxable()
def mk_func_input(f):
pass
@infer_global(mk_func_input)
class MkFuncTyping(AbstractTemplate):
def generic(self, args, kws):
assert isinstance(args[0], types.MakeFunctionLiteral)
return signature(types.none, *args)
def mk_func_test_impl():
mk_func_input(lambda a: a)
# -----------------------------------------------------------------------
# Define a types derived from types.Callable and overloads for them
class MyClass(object):
pass
class CallableTypeRef(types.Callable):
def __init__(self, instance_type):
self.instance_type = instance_type
self.sig_to_impl_key = {}
self.compiled_templates = []
super(CallableTypeRef, self).__init__('callable_type_ref'
'[{}]'.format(self.instance_type))
def get_call_type(self, context, args, kws):
res_sig = None
for template in context._functions[type(self)]:
try:
res_sig = template.apply(args, kws)
except Exception:
pass # for simplicity assume args must match exactly
else:
compiled_ovlds = getattr(template, '_compiled_overloads', {})
if args in compiled_ovlds:
self.sig_to_impl_key[res_sig] = compiled_ovlds[args]
self.compiled_templates.append(template)
break
return res_sig
def get_call_signatures(self):
sigs = list(self.sig_to_impl_key.keys())
return sigs, True
def get_impl_key(self, sig):
return self.sig_to_impl_key[sig]
@register_model(CallableTypeRef)
class CallableTypeModel(models.OpaqueModel):
def __init__(self, dmm, fe_type):
models.OpaqueModel.__init__(self, dmm, fe_type)
infer_global(MyClass, CallableTypeRef(MyClass))
@lower_constant(CallableTypeRef)
def constant_callable_typeref(context, builder, ty, pyval):
return context.get_dummy_value()
# -----------------------------------------------------------------------
@overload(np.exp)
def overload_np_exp(obj):
if isinstance(obj, MyDummyType):
def imp(obj):
# Returns a constant if a MyDummyType is seen
return 0xDEADBEEF
return imp
class TestLowLevelExtending(TestCase):
"""
Test the low-level two-tier extension API.
"""
# Check with `@jit` from within the test process and also in a new test
# process so as to check the registration mechanism.
def test_func1(self):
pyfunc = call_func1_nullary
cfunc = jit(nopython=True)(pyfunc)
self.assertPreciseEqual(cfunc(), 42)
pyfunc = call_func1_unary
cfunc = jit(nopython=True)(pyfunc)
self.assertPreciseEqual(cfunc(None), 42)
self.assertPreciseEqual(cfunc(18.0), 6.0)
@TestCase.run_test_in_subprocess
def test_func1_isolated(self):
self.test_func1()
def test_type_callable_keeps_function(self):
self.assertIs(type_func1, type_func1_)
self.assertIsNotNone(type_func1)
@TestCase.run_test_in_subprocess
def test_cast_mydummy(self):
pyfunc = get_dummy
cfunc = njit(types.float64(),)(pyfunc)
self.assertPreciseEqual(cfunc(), 42.0)
def test_mk_func_literal(self):
"""make sure make_function is passed to typer class as a literal
"""
test_ir = compiler.run_frontend(mk_func_test_impl)
typingctx = cpu_target.typing_context
targetctx = cpu_target.target_context
typingctx.refresh()
targetctx.refresh()
typing_res = type_inference_stage(typingctx, targetctx, test_ir, (),
None)
self.assertTrue(
any(
isinstance(a, types.MakeFunctionLiteral)
for a in typing_res.typemap.values()
)
)
class TestPandasLike(TestCase):
"""
Test implementing a pandas-like Index object.
Also stresses most of the high-level API.
"""
def test_index_len(self):
i = Index(np.arange(3))
cfunc = jit(nopython=True)(len_usecase)
self.assertPreciseEqual(cfunc(i), 3)
def test_index_getitem(self):
i = Index(np.int32([42, 8, -5]))
cfunc = jit(nopython=True)(getitem_usecase)
self.assertPreciseEqual(cfunc(i, 1), 8)
ii = cfunc(i, slice(1, None))
self.assertIsInstance(ii, Index)
self.assertEqual(list(ii), [8, -5])
def test_index_ufunc(self):
"""
Check Numpy ufunc on an Index object.
"""
i = Index(np.int32([42, 8, -5]))
cfunc = jit(nopython=True)(npyufunc_usecase)
ii = cfunc(i)
self.assertIsInstance(ii, Index)
self.assertPreciseEqual(ii._data, np.cos(np.sin(i._data)))
def test_index_get_data(self):
# The _data attribute is exposed with make_attribute_wrapper()
i = Index(np.int32([42, 8, -5]))
cfunc = jit(nopython=True)(get_data_usecase)
data = cfunc(i)
self.assertIs(data, i._data)
def test_index_is_monotonic(self):
# The is_monotonic_increasing attribute is exposed with
# overload_attribute()
cfunc = jit(nopython=True)(is_monotonic_usecase)
for values, expected in [
([8, 42, 5], False),
([5, 8, 42], True),
([], True),
]:
i = Index(np.int32(values))
got = cfunc(i)
self.assertEqual(got, expected)
def test_series_len(self):
i = Index(np.int32([2, 4, 3]))
s = Series(np.float64([1.5, 4.0, 2.5]), i)
cfunc = jit(nopython=True)(len_usecase)
self.assertPreciseEqual(cfunc(s), 3)
def test_series_get_index(self):
i = Index(np.int32([2, 4, 3]))
s = Series(np.float64([1.5, 4.0, 2.5]), i)
cfunc = jit(nopython=True)(get_index_usecase)
got = cfunc(s)
self.assertIsInstance(got, Index)
self.assertIs(got._data, i._data)
def test_series_ufunc(self):
"""
Check Numpy ufunc on an Series object.
"""
i = Index(np.int32([42, 8, -5]))
s = Series(np.int64([1, 2, 3]), i)
cfunc = jit(nopython=True)(npyufunc_usecase)
ss = cfunc(s)
self.assertIsInstance(ss, Series)
self.assertIsInstance(ss._index, Index)
self.assertIs(ss._index._data, i._data)
self.assertPreciseEqual(ss._values, np.cos(np.sin(s._values)))
def test_series_constructor(self):
i = Index(np.int32([42, 8, -5]))
d = np.float64([1.5, 4.0, 2.5])
cfunc = jit(nopython=True)(make_series_usecase)
got = cfunc(d, i)
self.assertIsInstance(got, Series)
self.assertIsInstance(got._index, Index)
self.assertIs(got._index._data, i._data)
self.assertIs(got._values, d)
def test_series_clip(self):
i = Index(np.int32([42, 8, -5]))
s = Series(np.float64([1.5, 4.0, 2.5]), i)
cfunc = jit(nopython=True)(clip_usecase)
ss = cfunc(s, 1.6, 3.0)
self.assertIsInstance(ss, Series)
self.assertIsInstance(ss._index, Index)
self.assertIs(ss._index._data, i._data)
self.assertPreciseEqual(ss._values, np.float64([1.6, 3.0, 2.5]))
class TestHighLevelExtending(TestCase):
"""
Test the high-level combined API.
"""
def test_where(self):
"""
Test implementing a function with @overload.
"""
pyfunc = call_where
cfunc = jit(nopython=True)(pyfunc)
def check(*args, **kwargs):
expected = np_where(*args, **kwargs)
got = cfunc(*args, **kwargs)
self.assertPreciseEqual(expected, got)
check(x=3, cond=True, y=8)
check(True, 3, 8)
check(
np.bool_([True, False, True]),
np.int32([1, 2, 3]),
np.int32([4, 5, 5]),
)
# The typing error is propagated
with self.assertRaises(errors.TypingError) as raises:
cfunc(np.bool_([]), np.int32([]), np.int64([]))
self.assertIn(
"x and y should have the same dtype", str(raises.exception)
)
def test_len(self):
"""
Test re-implementing len() for a custom type with @overload.
"""
cfunc = jit(nopython=True)(len_usecase)
self.assertPreciseEqual(cfunc(MyDummy()), 13)
self.assertPreciseEqual(cfunc([4, 5]), 2)
def test_print(self):
"""
Test re-implementing print() for a custom type with @overload.
"""
cfunc = jit(nopython=True)(print_usecase)
with captured_stdout():
cfunc(MyDummy())
self.assertEqual(sys.stdout.getvalue(), "hello!\n")
def test_add_operator(self):
"""
Test re-implementing operator.add() for a custom type with @overload.
"""
pyfunc = call_add_operator
cfunc = jit(nopython=True)(pyfunc)
self.assertPreciseEqual(cfunc(1, 2), 3)
self.assertPreciseEqual(cfunc(MyDummy2(), MyDummy2()), 42)
# this will call add(Number, Number) as MyDummy implicitly casts to
# Number
self.assertPreciseEqual(cfunc(MyDummy(), MyDummy()), 84)
def test_add_binop(self):
"""
Test re-implementing '+' for a custom type via @overload(operator.add).
"""
pyfunc = call_add_binop
cfunc = jit(nopython=True)(pyfunc)
self.assertPreciseEqual(cfunc(1, 2), 3)
self.assertPreciseEqual(cfunc(MyDummy2(), MyDummy2()), 42)
# this will call add(Number, Number) as MyDummy implicitly casts to
# Number
self.assertPreciseEqual(cfunc(MyDummy(), MyDummy()), 84)
def test_iadd_operator(self):
"""
Test re-implementing operator.add() for a custom type with @overload.
"""
pyfunc = call_iadd_operator
cfunc = jit(nopython=True)(pyfunc)
self.assertPreciseEqual(cfunc(1, 2), 3)
self.assertPreciseEqual(cfunc(MyDummy2(), MyDummy2()), 42)
# this will call add(Number, Number) as MyDummy implicitly casts to
# Number
self.assertPreciseEqual(cfunc(MyDummy(), MyDummy()), 84)
def test_iadd_binop(self):
"""
Test re-implementing '+' for a custom type via @overload(operator.add).
"""
pyfunc = call_iadd_binop
cfunc = jit(nopython=True)(pyfunc)
self.assertPreciseEqual(cfunc(1, 2), 3)
self.assertPreciseEqual(cfunc(MyDummy2(), MyDummy2()), 42)
# this will call add(Number, Number) as MyDummy implicitly casts to
# Number
self.assertPreciseEqual(cfunc(MyDummy(), MyDummy()), 84)
def test_delitem(self):
pyfunc = call_delitem
cfunc = jit(nopython=True)(pyfunc)
obj = MyDummy()
e = None
with captured_stdout() as out:
try:
cfunc(obj, 321)
except Exception as exc:
e = exc
if e is not None:
raise e
self.assertEqual(out.getvalue(), "del hello! 321\n")
def test_getitem(self):
pyfunc = call_getitem
cfunc = jit(nopython=True)(pyfunc)
self.assertPreciseEqual(cfunc(MyDummy(), 321), 321 + 123)
def test_setitem(self):
pyfunc = call_setitem
cfunc = jit(nopython=True)(pyfunc)
obj = MyDummy()
e = None
with captured_stdout() as out:
try:
cfunc(obj, 321, 123)
except Exception as exc:
e = exc
if e is not None:
raise e
self.assertEqual(out.getvalue(), "321 123\n")
def test_no_cpython_wrapper(self):
"""
Test overloading whose return value cannot be represented in CPython.
"""
# Test passing Module type from a @overload implementation to ensure
# that the *no_cpython_wrapper* flag works
ok_cfunc = jit(nopython=True)(non_boxable_ok_usecase)
n = 10
got = ok_cfunc(n)
expect = non_boxable_ok_usecase(n)
np.testing.assert_equal(expect, got)
# Verify that the Module type cannot be returned to CPython
bad_cfunc = jit(nopython=True)(non_boxable_bad_usecase)
with self.assertRaises(TypeError) as raises:
bad_cfunc()
errmsg = str(raises.exception)
expectmsg = "cannot convert native Module"
self.assertIn(expectmsg, errmsg)
def test_typing_vs_impl_signature_mismatch_handling(self):
"""
Tests that an overload which has a differing typing and implementing
signature raises an exception.
"""
def gen_ol(impl=None):
def myoverload(a, b, c, kw=None):
pass
@overload(myoverload)
def _myoverload_impl(a, b, c, kw=None):
return impl
@jit(nopython=True)
def foo(a, b, c, d):
myoverload(a, b, c, kw=d)
return foo
sentinel = "Typing and implementation arguments differ in"
# kwarg value is different
def impl1(a, b, c, kw=12):
if a > 10:
return 1
else:
return -1
with self.assertRaises(errors.TypingError) as e:
gen_ol(impl1)(1, 2, 3, 4)
msg = str(e.exception)
self.assertIn(sentinel, msg)
self.assertIn("keyword argument default values", msg)
self.assertIn('<Parameter "kw=12">', msg)
self.assertIn('<Parameter "kw=None">', msg)
# kwarg name is different
def impl2(a, b, c, kwarg=None):
if a > 10:
return 1
else:
return -1
with self.assertRaises(errors.TypingError) as e:
gen_ol(impl2)(1, 2, 3, 4)
msg = str(e.exception)
self.assertIn(sentinel, msg)
self.assertIn("keyword argument names", msg)
self.assertIn('<Parameter "kwarg=None">', msg)
self.assertIn('<Parameter "kw=None">', msg)
# arg name is different
def impl3(z, b, c, kw=None):
if a > 10: # noqa: F821
return 1
else:
return -1
with self.assertRaises(errors.TypingError) as e:
gen_ol(impl3)(1, 2, 3, 4)
msg = str(e.exception)
self.assertIn(sentinel, msg)
self.assertIn("argument names", msg)
self.assertFalse("keyword" in msg)
self.assertIn('<Parameter "a">', msg)
self.assertIn('<Parameter "z">', msg)
from .overload_usecases import impl4, impl5
with self.assertRaises(errors.TypingError) as e:
gen_ol(impl4)(1, 2, 3, 4)
msg = str(e.exception)
self.assertIn(sentinel, msg)
self.assertIn("argument names", msg)
self.assertFalse("keyword" in msg)
self.assertIn("First difference: 'z'", msg)
with self.assertRaises(errors.TypingError) as e:
gen_ol(impl5)(1, 2, 3, 4)
msg = str(e.exception)
self.assertIn(sentinel, msg)
self.assertIn("argument names", msg)
self.assertFalse("keyword" in msg)
self.assertIn('<Parameter "a">', msg)
self.assertIn('<Parameter "z">', msg)
# too many args
def impl6(a, b, c, d, e, kw=None):
if a > 10:
return 1
else:
return -1
with self.assertRaises(errors.TypingError) as e:
gen_ol(impl6)(1, 2, 3, 4)
msg = str(e.exception)
self.assertIn(sentinel, msg)
self.assertIn("argument names", msg)
self.assertFalse("keyword" in msg)
self.assertIn('<Parameter "d">', msg)
self.assertIn('<Parameter "e">', msg)
# too few args
def impl7(a, b, kw=None):
if a > 10:
return 1
else:
return -1
with self.assertRaises(errors.TypingError) as e:
gen_ol(impl7)(1, 2, 3, 4)
msg = str(e.exception)
self.assertIn(sentinel, msg)
self.assertIn("argument names", msg)
self.assertFalse("keyword" in msg)
self.assertIn('<Parameter "c">', msg)
# too many kwargs
def impl8(a, b, c, kw=None, extra_kwarg=None):
if a > 10:
return 1
else:
return -1
with self.assertRaises(errors.TypingError) as e:
gen_ol(impl8)(1, 2, 3, 4)
msg = str(e.exception)
self.assertIn(sentinel, msg)
self.assertIn("keyword argument names", msg)
self.assertIn('<Parameter "extra_kwarg=None">', msg)
# too few kwargs
def impl9(a, b, c):
if a > 10:
return 1
else:
return -1
with self.assertRaises(errors.TypingError) as e:
gen_ol(impl9)(1, 2, 3, 4)
msg = str(e.exception)
self.assertIn(sentinel, msg)
self.assertIn("keyword argument names", msg)
self.assertIn('<Parameter "kw=None">', msg)
def test_typing_vs_impl_signature_mismatch_handling_var_positional(self):
"""
Tests that an overload which has a differing typing and implementing
signature raises an exception and uses VAR_POSITIONAL (*args) in typing
"""
def myoverload(a, kw=None):
pass
from .overload_usecases import var_positional_impl
overload(myoverload)(var_positional_impl)
@jit(nopython=True)
def foo(a, b):
return myoverload(a, b, 9, kw=11)
with self.assertRaises(errors.TypingError) as e:
foo(1, 5)
msg = str(e.exception)
self.assertIn("VAR_POSITIONAL (e.g. *args) argument kind", msg)
self.assertIn("offending argument name is '*star_args_token'", msg)
def test_typing_vs_impl_signature_mismatch_handling_var_keyword(self):
"""
Tests that an overload which uses **kwargs (VAR_KEYWORD)
"""
def gen_ol(impl, strict=True):
def myoverload(a, kw=None):
pass
overload(myoverload, strict=strict)(impl)
@jit(nopython=True)
def foo(a, b):
return myoverload(a, kw=11)
return foo
# **kwargs in typing
def ol1(a, **kws):
def impl(a, kw=10):
return a
return impl
gen_ol(ol1, False)(1, 2) # no error if strictness not enforced
with self.assertRaises(errors.TypingError) as e:
gen_ol(ol1)(1, 2)
msg = str(e.exception)
self.assertIn("use of VAR_KEYWORD (e.g. **kwargs) is unsupported", msg)
self.assertIn("offending argument name is '**kws'", msg)
# **kwargs in implementation
def ol2(a, kw=0):
def impl(a, **kws):
return a
return impl
with self.assertRaises(errors.TypingError) as e:
gen_ol(ol2)(1, 2)
msg = str(e.exception)
self.assertIn("use of VAR_KEYWORD (e.g. **kwargs) is unsupported", msg)
self.assertIn("offending argument name is '**kws'", msg)
def test_overload_method_kwargs(self):
# Issue #3489
@overload_method(types.Array, "foo")
def fooimpl(arr, a_kwarg=10):
def impl(arr, a_kwarg=10):
return a_kwarg
return impl
@njit
def bar(A):
return A.foo(), A.foo(20), A.foo(a_kwarg=30)
Z = np.arange(5)
self.assertEqual(bar(Z), (10, 20, 30))
def test_overload_method_literal_unpack(self):
# Issue #3683
@overload_method(types.Array, "litfoo")
def litfoo(arr, val):
# Must be an integer
if isinstance(val, types.Integer):
# Must not be literal
if not isinstance(val, types.Literal):
def impl(arr, val):
return val
return impl
@njit
def bar(A):
return A.litfoo(0xCAFE)
A = np.zeros(1)
bar(A)
self.assertEqual(bar(A), 0xCAFE)
def test_overload_ufunc(self):
# Issue #4133.
# Use an extended type (MyDummyType) to use with a customized
# ufunc (np.exp).
@njit
def test():
return np.exp(mydummy)
self.assertEqual(test(), 0xDEADBEEF)
def test_overload_method_stararg(self):
@overload_method(MyDummyType, "method_stararg")
def _ov_method_stararg(obj, val, val2, *args):
def get(obj, val, val2, *args):
return (val, val2, args)
return get
@njit
def foo(obj, *args):
# Test with expanding stararg
return obj.method_stararg(*args)
obj = MyDummy()
self.assertEqual(foo(obj, 1, 2), (1, 2, ()))
self.assertEqual(foo(obj, 1, 2, 3), (1, 2, (3,)))
self.assertEqual(foo(obj, 1, 2, 3, 4), (1, 2, (3, 4)))
@njit
def bar(obj):
# Test with explicit argument
return (
obj.method_stararg(1, 2),
obj.method_stararg(1, 2, 3),
obj.method_stararg(1, 2, 3, 4),
)
self.assertEqual(
bar(obj), ((1, 2, ()), (1, 2, (3,)), (1, 2, (3, 4))),
)
# Check cases that put tuple type into stararg
# NOTE: the expected result has an extra tuple because of stararg.
self.assertEqual(
foo(obj, 1, 2, (3,)), (1, 2, ((3,),)),
)
self.assertEqual(
foo(obj, 1, 2, (3, 4)), (1, 2, ((3, 4),)),
)
self.assertEqual(
foo(obj, 1, 2, (3, (4, 5))), (1, 2, ((3, (4, 5)),)),
)
def test_overload_classmethod(self):
# Add classmethod to a subclass of Array
class MyArray(types.Array):
pass
@overload_classmethod(MyArray, "array_alloc")
def ol_array_alloc(cls, nitems):
def impl(cls, nitems):
arr = np.arange(nitems)
return arr
return impl
@njit
def foo(nitems):
return MyArray.array_alloc(nitems)
nitems = 13
self.assertPreciseEqual(foo(nitems), np.arange(nitems))
# Check that the base type doesn't get the classmethod
@njit
def no_classmethod_in_base(nitems):
return types.Array.array_alloc(nitems)
with self.assertRaises(errors.TypingError) as raises:
no_classmethod_in_base(nitems)
self.assertIn(
"Unknown attribute 'array_alloc' of",
str(raises.exception),
)
def test_overload_callable_typeref(self):
@overload(CallableTypeRef)
def callable_type_call_ovld1(x):
if isinstance(x, types.Integer):
def impl(x):
return 42.5 + x
return impl
@overload(CallableTypeRef)
def callable_type_call_ovld2(x):
if isinstance(x, types.UnicodeType):
def impl(x):
return '42.5' + x
return impl
@njit
def foo(a, b):
return MyClass(a), MyClass(b)
args = (4, '4')
expected = (42.5 + args[0], '42.5' + args[1])
self.assertPreciseEqual(foo(*args), expected)
def _assert_cache_stats(cfunc, expect_hit, expect_misses):
hit = cfunc._cache_hits[cfunc.signatures[0]]
if hit != expect_hit:
raise AssertionError("cache not used")
miss = cfunc._cache_misses[cfunc.signatures[0]]
if miss != expect_misses:
raise AssertionError("cache not used")
@skip_if_typeguard
class TestOverloadMethodCaching(TestCase):
# Nested multiprocessing.Pool raises AssertionError:
# "daemonic processes are not allowed to have children"
_numba_parallel_test_ = False
def test_caching_overload_method(self):
self._cache_dir = temp_directory(self.__class__.__name__)
with override_config("CACHE_DIR", self._cache_dir):
self.run_caching_overload_method()
def run_caching_overload_method(self):
cfunc = jit(nopython=True, cache=True)(cache_overload_method_usecase)
self.assertPreciseEqual(cfunc(MyDummy()), 13)
_assert_cache_stats(cfunc, 0, 1)
llvmir = cfunc.inspect_llvm((mydummy_type,))
# Ensure the inner method is not a declaration
decls = [
ln
for ln in llvmir.splitlines()
if ln.startswith("declare") and "overload_method_length" in ln
]
self.assertEqual(len(decls), 0)
# Test in a separate process
try:
ctx = multiprocessing.get_context("spawn")
except AttributeError:
ctx = multiprocessing
q = ctx.Queue()
p = ctx.Process(
target=run_caching_overload_method, args=(q, self._cache_dir)
)
p.start()
q.put(MyDummy())
p.join()
# Ensure subprocess exited normally
self.assertEqual(p.exitcode, 0)
res = q.get(timeout=1)
self.assertEqual(res, 13)
def run_caching_overload_method(q, cache_dir):
"""
Used by TestOverloadMethodCaching.test_caching_overload_method
"""
with override_config("CACHE_DIR", cache_dir):
arg = q.get()
cfunc = jit(nopython=True, cache=True)(cache_overload_method_usecase)
res = cfunc(arg)
q.put(res)
# Check cache stat
_assert_cache_stats(cfunc, 1, 0)
class TestIntrinsic(TestCase):
def test_void_return(self):
"""
Verify that returning a None from codegen function is handled
automatically for void functions, otherwise raise exception.
"""
@intrinsic
def void_func(typingctx, a):
sig = types.void(types.int32)
def codegen(context, builder, signature, args):
pass # do nothing, return None, should be turned into
# dummy value
return sig, codegen
@intrinsic
def non_void_func(typingctx, a):
sig = types.int32(types.int32)
def codegen(context, builder, signature, args):
pass # oops, should be returning a value here, raise exception
return sig, codegen
@jit(nopython=True)
def call_void_func():
void_func(1)
return 0
@jit(nopython=True)
def call_non_void_func():
non_void_func(1)
return 0
# void func should work
self.assertEqual(call_void_func(), 0)
# not void function should raise exception
with self.assertRaises(LoweringError) as e:
call_non_void_func()
self.assertIn("non-void function returns None", e.exception.msg)
def test_ll_pointer_cast(self):
"""
Usecase test: custom reinterpret cast to turn int values to pointers
"""
from ctypes import CFUNCTYPE, POINTER, c_float, c_int
# Use intrinsic to make a reinterpret_cast operation
def unsafe_caster(result_type):
assert isinstance(result_type, types.CPointer)
@intrinsic
def unsafe_cast(typingctx, src):
self.assertIsInstance(typingctx, typing.Context)
if isinstance(src, types.Integer):
sig = result_type(types.uintp)
# defines the custom code generation
def codegen(context, builder, signature, args):
[src] = args
rtype = signature.return_type
llrtype = context.get_value_type(rtype)
return builder.inttoptr(src, llrtype)
return sig, codegen
return unsafe_cast
# make a nopython function to use our cast op.
# this is not usable from cpython due to the returning of a pointer.
def unsafe_get_ctypes_pointer(src):
raise NotImplementedError("not callable from python")
@overload(unsafe_get_ctypes_pointer, strict=False)
def array_impl_unsafe_get_ctypes_pointer(arrtype):
if isinstance(arrtype, types.Array):
unsafe_cast = unsafe_caster(types.CPointer(arrtype.dtype))
def array_impl(arr):
return unsafe_cast(src=arr.ctypes.data)
return array_impl
# the ctype wrapped function for use in nopython mode
def my_c_fun_raw(ptr, n):
for i in range(n):
print(ptr[i])
prototype = CFUNCTYPE(None, POINTER(c_float), c_int)
my_c_fun = prototype(my_c_fun_raw)
# Call our pointer-cast in a @jit compiled function and use
# the pointer in a ctypes function
@jit(nopython=True)
def foo(arr):
ptr = unsafe_get_ctypes_pointer(arr)
my_c_fun(ptr, arr.size)
# Test
arr = np.arange(10, dtype=np.float32)
with captured_stdout() as buf:
foo(arr)
got = buf.getvalue().splitlines()
buf.close()
expect = list(map(str, arr))
self.assertEqual(expect, got)
def test_serialization(self):
"""
Test serialization of intrinsic objects
"""
# define a intrinsic
@intrinsic
def identity(context, x):
def codegen(context, builder, signature, args):
return args[0]
sig = x(x)
return sig, codegen
# use in a jit function
@jit(nopython=True)
def foo(x):
return identity(x)
self.assertEqual(foo(1), 1)
# get serialization memo
memo = _Intrinsic._memo
memo_size = len(memo)
# pickle foo and check memo size
serialized_foo = pickle.dumps(foo)
# increases the memo size
memo_size += 1
self.assertEqual(memo_size, len(memo))
# unpickle
foo_rebuilt = pickle.loads(serialized_foo)
self.assertEqual(memo_size, len(memo))
# check rebuilt foo
self.assertEqual(foo(1), foo_rebuilt(1))
# pickle identity directly
serialized_identity = pickle.dumps(identity)
# memo size unchanged
self.assertEqual(memo_size, len(memo))
# unpickle
identity_rebuilt = pickle.loads(serialized_identity)
# must be the same object
self.assertIs(identity, identity_rebuilt)
# memo size unchanged
self.assertEqual(memo_size, len(memo))
def test_deserialization(self):
"""
Test deserialization of intrinsic
"""
def defn(context, x):
def codegen(context, builder, signature, args):
return args[0]
return x(x), codegen
memo = _Intrinsic._memo
memo_size = len(memo)
# invoke _Intrinsic indirectly to avoid registration which keeps an
# internal reference inside the compiler
original = _Intrinsic("foo", defn)
self.assertIs(original._defn, defn)
pickled = pickle.dumps(original)
# by pickling, a new memo entry is created
memo_size += 1
self.assertEqual(memo_size, len(memo))
del original # remove original before unpickling
# by deleting, the memo entry is NOT removed due to recent
# function queue
self.assertEqual(memo_size, len(memo))
# Manually force clear of _recent queue
_Intrinsic._recent.clear()
memo_size -= 1
self.assertEqual(memo_size, len(memo))
rebuilt = pickle.loads(pickled)
# verify that the rebuilt object is different
self.assertIsNot(rebuilt._defn, defn)
# the second rebuilt object is the same as the first
second = pickle.loads(pickled)
self.assertIs(rebuilt._defn, second._defn)
def test_docstring(self):
@intrinsic
def void_func(typingctx, a: int):
"""void_func docstring"""
sig = types.void(types.int32)
def codegen(context, builder, signature, args):
pass # do nothing, return None, should be turned into
# dummy value
return sig, codegen
self.assertEqual("numba.tests.test_extending", void_func.__module__)
self.assertEqual("void_func", void_func.__name__)
self.assertEqual("TestIntrinsic.test_docstring.<locals>.void_func",
void_func.__qualname__)
self.assertDictEqual({'a': int}, void_func.__annotations__)
self.assertEqual("void_func docstring", void_func.__doc__)
class TestRegisterJitable(unittest.TestCase):
def test_no_flags(self):
@register_jitable
def foo(x, y):
return x + y
def bar(x, y):
return foo(x, y)
cbar = jit(nopython=True)(bar)
expect = bar(1, 2)
got = cbar(1, 2)
self.assertEqual(expect, got)
def test_flags_no_nrt(self):
@register_jitable(_nrt=False)
def foo(n):
return np.arange(n)
def bar(n):
return foo(n)
self.assertEqual(bar(3).tolist(), [0, 1, 2])
cbar = jit(nopython=True)(bar)
with self.assertRaises(errors.TypingError) as raises:
cbar(2)
msg = (
"Only accept returning of array passed into the function as "
"argument"
)
self.assertIn(msg, str(raises.exception))
class TestImportCythonFunction(unittest.TestCase):
@unittest.skipIf(sc is None, "Only run if SciPy >= 0.19 is installed")
def test_getting_function(self):
addr = get_cython_function_address(
"scipy.special.cython_special", "j0"
)
functype = ctypes.CFUNCTYPE(ctypes.c_double, ctypes.c_double)
_j0 = functype(addr)
j0 = jit(nopython=True)(lambda x: _j0(x))
self.assertEqual(j0(0), 1)
def test_missing_module(self):
with self.assertRaises(ImportError) as raises:
get_cython_function_address("fakemodule", "fakefunction")
# The quotes are not there in Python 2
msg = "No module named '?fakemodule'?"
match = re.match(msg, str(raises.exception))
self.assertIsNotNone(match)
@unittest.skipIf(sc is None, "Only run if SciPy >= 0.19 is installed")
def test_missing_function(self):
with self.assertRaises(ValueError) as raises:
get_cython_function_address(
"scipy.special.cython_special", "foo"
)
msg = (
"No function 'foo' found in __pyx_capi__ of "
"'scipy.special.cython_special'"
)
self.assertEqual(msg, str(raises.exception))
@overload_method(
MyDummyType, "method_jit_option_check_nrt", jit_options={"_nrt": True}
)
def ov_method_jit_option_check_nrt(obj):
def imp(obj):
return np.arange(10)
return imp
@overload_method(
MyDummyType, "method_jit_option_check_no_nrt", jit_options={"_nrt": False}
)
def ov_method_jit_option_check_no_nrt(obj):
def imp(obj):
return np.arange(10)
return imp
@overload_attribute(
MyDummyType, "attr_jit_option_check_nrt", jit_options={"_nrt": True}
)
def ov_attr_jit_option_check_nrt(obj):
def imp(obj):
return np.arange(10)
return imp
@overload_attribute(
MyDummyType, "attr_jit_option_check_no_nrt", jit_options={"_nrt": False}
)
def ov_attr_jit_option_check_no_nrt(obj):
def imp(obj):
return np.arange(10)
return imp
class TestJitOptionsNoNRT(TestCase):
# Test overload*(jit_options={...}) by turning off _nrt
def check_error_no_nrt(self, func, *args, **kwargs):
# Check that the compilation fails with a complaint about dynamic array
msg = (
"Only accept returning of array passed into "
"the function as argument"
)
with self.assertRaises(errors.TypingError) as raises:
func(*args, **kwargs)
self.assertIn(msg, str(raises.exception))
def no_nrt_overload_check(self, flag):
def dummy():
return np.arange(10)
@overload(dummy, jit_options={"_nrt": flag})
def ov_dummy():
def dummy():
return np.arange(10)
return dummy
@njit
def foo():
return dummy()
if flag:
self.assertPreciseEqual(foo(), np.arange(10))
else:
self.check_error_no_nrt(foo)
def test_overload_no_nrt(self):
self.no_nrt_overload_check(True)
self.no_nrt_overload_check(False)
def test_overload_method_no_nrt(self):
@njit
def udt(x):
return x.method_jit_option_check_nrt()
self.assertPreciseEqual(udt(mydummy), np.arange(10))
@njit
def udt(x):
return x.method_jit_option_check_no_nrt()
self.check_error_no_nrt(udt, mydummy)
def test_overload_attribute_no_nrt(self):
@njit
def udt(x):
return x.attr_jit_option_check_nrt
self.assertPreciseEqual(udt(mydummy), np.arange(10))
@njit
def udt(x):
return x.attr_jit_option_check_no_nrt
self.check_error_no_nrt(udt, mydummy)
class TestBoxingCallingJIT(TestCase):
def setUp(self):
super().setUp()
many = base_dummy_type_factory("mydummy2")
self.DynTypeType, self.DynType, self.dyn_type_type = many
self.dyn_type = self.DynType()
def test_unboxer_basic(self):
# Implements an unboxer on DynType that calls an intrinsic into the
# unboxer code.
magic_token = 0xCAFE
magic_offset = 123
@intrinsic
def my_intrinsic(typingctx, val):
# An intrinsic that returns `val + magic_offset`
def impl(context, builder, sig, args):
[val] = args
return builder.add(val, val.type(magic_offset))
sig = signature(val, val)
return sig, impl
@unbox(self.DynTypeType)
def unboxer(typ, obj, c):
# The unboxer that calls some jitcode
def bridge(x):
# proof that this is a jit'ed context by calling jit only
# intrinsic
return my_intrinsic(x)
args = [c.context.get_constant(types.intp, magic_token)]
sig = signature(types.voidptr, types.intp)
is_error, res = c.pyapi.call_jit_code(bridge, sig, args)
return NativeValue(res, is_error=is_error)
@box(self.DynTypeType)
def boxer(typ, val, c):
# The boxer that returns an integer representation
res = c.builder.ptrtoint(val, cgutils.intp_t)
return c.pyapi.long_from_ssize_t(res)
@njit
def passthru(x):
return x
out = passthru(self.dyn_type)
self.assertEqual(out, magic_token + magic_offset)
def test_unboxer_raise(self):
# Testing exception raising in jitcode called from unboxing.
@unbox(self.DynTypeType)
def unboxer(typ, obj, c):
# The unboxer that calls some jitcode
def bridge(x):
if x > 0:
raise ValueError("cannot be x > 0")
return x
args = [c.context.get_constant(types.intp, 1)]
sig = signature(types.voidptr, types.intp)
is_error, res = c.pyapi.call_jit_code(bridge, sig, args)
return NativeValue(res, is_error=is_error)
@box(self.DynTypeType)
def boxer(typ, val, c):
# The boxer that returns an integer representation
res = c.builder.ptrtoint(val, cgutils.intp_t)
return c.pyapi.long_from_ssize_t(res)
@njit
def passthru(x):
return x
with self.assertRaises(ValueError) as raises:
passthru(self.dyn_type)
self.assertIn(
"cannot be x > 0", str(raises.exception),
)
def test_boxer(self):
# Call jitcode inside the boxer
magic_token = 0xCAFE
magic_offset = 312
@intrinsic
def my_intrinsic(typingctx, val):
# An intrinsic that returns `val + magic_offset`
def impl(context, builder, sig, args):
[val] = args
return builder.add(val, val.type(magic_offset))
sig = signature(val, val)
return sig, impl
@unbox(self.DynTypeType)
def unboxer(typ, obj, c):
return NativeValue(c.context.get_dummy_value())
@box(self.DynTypeType)
def boxer(typ, val, c):
# Note: this doesn't do proper error handling
def bridge(x):
return my_intrinsic(x)
args = [c.context.get_constant(types.intp, magic_token)]
sig = signature(types.intp, types.intp)
is_error, res = c.pyapi.call_jit_code(bridge, sig, args)
return c.pyapi.long_from_ssize_t(res)
@njit
def passthru(x):
return x
r = passthru(self.dyn_type)
self.assertEqual(r, magic_token + magic_offset)
def test_boxer_raise(self):
# Call jitcode inside the boxer
@unbox(self.DynTypeType)
def unboxer(typ, obj, c):
return NativeValue(c.context.get_dummy_value())
@box(self.DynTypeType)
def boxer(typ, val, c):
def bridge(x):
if x > 0:
raise ValueError("cannot do x > 0")
return x
args = [c.context.get_constant(types.intp, 1)]
sig = signature(types.intp, types.intp)
is_error, res = c.pyapi.call_jit_code(bridge, sig, args)
# The error handling
retval = cgutils.alloca_once(c.builder, c.pyapi.pyobj, zfill=True)
with c.builder.if_then(c.builder.not_(is_error)):
obj = c.pyapi.long_from_ssize_t(res)
c.builder.store(obj, retval)
return c.builder.load(retval)
@njit
def passthru(x):
return x
with self.assertRaises(ValueError) as raises:
passthru(self.dyn_type)
self.assertIn(
"cannot do x > 0", str(raises.exception),
)
def with_objmode_cache_ov_example(x):
# This is the function stub for overloading inside
# TestCachingOverloadObjmode.test_caching_overload_objmode
pass
@skip_if_typeguard
class TestCachingOverloadObjmode(TestCase):
"""Test caching of the use of overload implementations that use
`with objmode`
"""
_numba_parallel_test_ = False
def setUp(self):
warnings.simplefilter("error", errors.NumbaWarning)
def tearDown(self):
warnings.resetwarnings()
def test_caching_overload_objmode(self):
cache_dir = temp_directory(self.__class__.__name__)
with override_config("CACHE_DIR", cache_dir):
def realwork(x):
# uses numpy code
arr = np.arange(x) / x
return np.linalg.norm(arr)
def python_code(x):
# create indirections
return realwork(x)
@overload(with_objmode_cache_ov_example)
def _ov_with_objmode_cache_ov_example(x):
def impl(x):
with objmode(y="float64"):
y = python_code(x)
return y
return impl
@njit(cache=True)
def testcase(x):
return with_objmode_cache_ov_example(x)
expect = realwork(123)
got = testcase(123)
self.assertEqual(got, expect)
testcase_cached = njit(cache=True)(testcase.py_func)
got = testcase_cached(123)
self.assertEqual(got, expect)
@classmethod
def check_objmode_cache_ndarray(cls):
def do_this(a, b):
return np.sum(a + b)
def do_something(a, b):
return np.sum(a + b)
@overload(do_something)
def overload_do_something(a, b):
def _do_something_impl(a, b):
with objmode(y='float64'):
y = do_this(a, b)
return y
return _do_something_impl
@njit(cache=True)
def test_caching():
a = np.arange(20)
b = np.arange(20)
return do_something(a, b)
got = test_caching()
expect = test_caching.py_func()
# Check result
if got != expect:
raise AssertionError("incorrect result")
return test_caching
@classmethod
def populate_objmode_cache_ndarray_check_cache(cls):
cls.check_objmode_cache_ndarray()
@classmethod
def check_objmode_cache_ndarray_check_cache(cls):
disp = cls.check_objmode_cache_ndarray()
if len(disp.stats.cache_misses) != 0:
raise AssertionError('unexpected cache miss')
if len(disp.stats.cache_hits) <= 0:
raise AssertionError("unexpected missing cache hit")
def test_check_objmode_cache_ndarray(self):
# See issue #6130.
# Env is missing after cache load.
cache_dir = temp_directory(self.__class__.__name__)
with override_config("CACHE_DIR", cache_dir):
# Run in new process to populate the cache
run_in_new_process_in_cache_dir(
self.populate_objmode_cache_ndarray_check_cache, cache_dir
)
# Run in new process to use the cache in a fresh process.
res = run_in_new_process_in_cache_dir(
self.check_objmode_cache_ndarray_check_cache, cache_dir
)
self.assertEqual(res['exitcode'], 0)
class TestMisc(TestCase):
def test_is_jitted(self):
def foo(x):
pass
self.assertFalse(is_jitted(foo))
self.assertTrue(is_jitted(njit(foo)))
self.assertFalse(is_jitted(vectorize(foo)))
self.assertFalse(is_jitted(vectorize(parallel=True)(foo)))
self.assertFalse(
is_jitted(guvectorize("void(float64[:])", "(m)")(foo))
)
def test_overload_arg_binding(self):
# See issue #7982, checks that calling a function with named args works
# correctly irrespective of the order in which the names are supplied.
@njit
def standard_order():
return np.full(shape=123, fill_value=456).shape
@njit
def reversed_order():
return np.full(fill_value=456, shape=123).shape
self.assertPreciseEqual(standard_order(), standard_order.py_func())
self.assertPreciseEqual(reversed_order(), reversed_order.py_func())
class TestOverloadPreferLiteral(TestCase):
def test_overload(self):
def prefer_lit(x):
pass
def non_lit(x):
pass
def ov(x):
if isinstance(x, types.IntegerLiteral):
# With prefer_literal=False, this branch will not be reached.
if x.literal_value == 1:
def impl(x):
return 0xcafe
return impl
else:
raise errors.TypingError('literal value')
else:
def impl(x):
return x * 100
return impl
overload(prefer_lit, prefer_literal=True)(ov)
overload(non_lit)(ov)
@njit
def check_prefer_lit(x):
return prefer_lit(1), prefer_lit(2), prefer_lit(x)
a, b, c = check_prefer_lit(3)
self.assertEqual(a, 0xcafe)
self.assertEqual(b, 200)
self.assertEqual(c, 300)
@njit
def check_non_lit(x):
return non_lit(1), non_lit(2), non_lit(x)
a, b, c = check_non_lit(3)
self.assertEqual(a, 100)
self.assertEqual(b, 200)
self.assertEqual(c, 300)
def test_overload_method(self):
def ov(self, x):
if isinstance(x, types.IntegerLiteral):
# With prefer_literal=False, this branch will not be reached.
if x.literal_value == 1:
def impl(self, x):
return 0xcafe
return impl
else:
raise errors.TypingError('literal value')
else:
def impl(self, x):
return x * 100
return impl
overload_method(
MyDummyType, "method_prefer_literal",
prefer_literal=True,
)(ov)
overload_method(
MyDummyType, "method_non_literal",
prefer_literal=False,
)(ov)
@njit
def check_prefer_lit(dummy, x):
return (
dummy.method_prefer_literal(1),
dummy.method_prefer_literal(2),
dummy.method_prefer_literal(x),
)
a, b, c = check_prefer_lit(MyDummy(), 3)
self.assertEqual(a, 0xcafe)
self.assertEqual(b, 200)
self.assertEqual(c, 300)
@njit
def check_non_lit(dummy, x):
return (
dummy.method_non_literal(1),
dummy.method_non_literal(2),
dummy.method_non_literal(x),
)
a, b, c = check_non_lit(MyDummy(), 3)
self.assertEqual(a, 100)
self.assertEqual(b, 200)
self.assertEqual(c, 300)
class TestIntrinsicPreferLiteral(TestCase):
def test_intrinsic(self):
def intrin(context, x):
# This intrinsic will return 0xcafe if `x` is a literal `1`.
sig = signature(types.intp, x)
if isinstance(x, types.IntegerLiteral):
# With prefer_literal=False, this branch will not be reached
if x.literal_value == 1:
def codegen(context, builder, signature, args):
atype = signature.args[0]
llrtype = context.get_value_type(atype)
return ir.Constant(llrtype, 0xcafe)
return sig, codegen
else:
raise errors.TypingError('literal value')
else:
def codegen(context, builder, signature, args):
atype = signature.return_type
llrtype = context.get_value_type(atype)
int_100 = ir.Constant(llrtype, 100)
return builder.mul(args[0], int_100)
return sig, codegen
prefer_lit = intrinsic(prefer_literal=True)(intrin)
non_lit = intrinsic(prefer_literal=False)(intrin)
@njit
def check_prefer_lit(x):
return prefer_lit(1), prefer_lit(2), prefer_lit(x)
a, b, c = check_prefer_lit(3)
self.assertEqual(a, 0xcafe)
self.assertEqual(b, 200)
self.assertEqual(c, 300)
@njit
def check_non_lit(x):
return non_lit(1), non_lit(2), non_lit(x)
a, b, c = check_non_lit(3)
self.assertEqual(a, 100)
self.assertEqual(b, 200)
self.assertEqual(c, 300)
class TestNumbaInternalOverloads(TestCase):
def test_signatures_match_overloaded_api(self):
# This is a "best-effort" test to try and ensure that Numba's internal
# overload declarations have signatures with argument names that match
# the API they are overloading. The purpose of ensuring there is a
# match is so that users can use call-by-name for positional arguments.
# Set this to:
# 0 to make violations raise a ValueError (default).
# 1 to get violations reported to STDOUT
# 2 to get a verbose output of everything that was checked and its state
# reported to STDOUT.
DEBUG = 0
# np.random.* does not have a signature exposed to `inspect`... so
# custom parse the docstrings.
def sig_from_np_random(x):
if not x.startswith('_'):
thing = getattr(np.random, x)
if inspect.isbuiltin(thing):
docstr = thing.__doc__.splitlines()
for l in docstr:
if l:
sl = l.strip()
if sl.startswith(x): # its the signature
# special case np.random.seed, it has `self` in
# the signature whereas all the other functions
# do not!?
if x == "seed":
sl = "seed(seed)"
fake_impl = f"def {sl}:\n\tpass"
l = {}
try:
exec(fake_impl, {}, l)
except SyntaxError:
# likely elipsis, e.g. rand(d0, d1, ..., dn)
if DEBUG == 2:
print("... skipped as cannot parse "
"signature")
return None
else:
fn = l.get(x)
return inspect.signature(fn)
def checker(func, overload_func):
if DEBUG == 2:
print(f"Checking: {func}")
def create_message(func, overload_func, func_sig, ol_sig):
msg = []
s = (f"{func} from module '{getattr(func, '__module__')}' "
"has mismatched sig.")
msg.append(s)
msg.append(f" - expected: {func_sig}")
msg.append(f" - got: {ol_sig}")
lineno = inspect.getsourcelines(overload_func)[1]
tmpsrcfile = inspect.getfile(overload_func)
srcfile = tmpsrcfile.replace(numba.__path__[0], '')
msg.append(f"from {srcfile}:{lineno}")
msgstr = '\n' + '\n'.join(msg)
return msgstr
func_sig = None
try:
func_sig = inspect.signature(func)
except ValueError:
# probably a built-in/C code, see if it's a np.random function
if fname := getattr(func, '__name__', False):
if maybe_func := getattr(np.random, fname, False):
if maybe_func == func:
# it's a built-in from np.random
func_sig = sig_from_np_random(fname)
if func_sig is not None:
ol_sig = inspect.signature(overload_func)
x = list(func_sig.parameters.keys())
y = list(ol_sig.parameters.keys())
for a, b in zip(x[:len(y)], y):
if a != b:
p = func_sig.parameters[a]
if p.kind == p.POSITIONAL_ONLY:
# probably a built-in/C code
if DEBUG == 2:
print("... skipped as positional only "
"arguments found")
break
elif '*' in str(p): # probably *args or similar
if DEBUG == 2:
print("... skipped as contains *args")
break
else:
# Only error/report on functions that have a module
# or are from somewhere other than Numba.
if (not func.__module__ or
not func.__module__.startswith("numba")):
msgstr = create_message(func, overload_func,
func_sig, ol_sig)
if DEBUG != 0:
if DEBUG == 2:
print("... INVALID")
if msgstr:
print(msgstr)
break
else:
raise ValueError(msgstr)
else:
if DEBUG == 2:
if not func.__module__:
print("... skipped as no __module__ "
"present")
else:
print("... skipped as Numba internal")
break
else:
if DEBUG == 2:
print("... OK")
# Compile something to make sure that the typing context registries
# are populated with everything from the CPU target.
njit(lambda : None).compile(())
tyctx = numba.core.typing.context.Context()
tyctx.refresh()
# Walk the registries and check each function that is an overload
regs = tyctx._registries
for k, v in regs.items():
for item in k.functions:
if getattr(item, '_overload_func', False):
checker(item.key, item._overload_func)
if __name__ == "__main__":
unittest.main()