import inspect import math import operator import sys import pickle import multiprocessing import ctypes import warnings import re import numpy as np from llvmlite import ir import numba from numba import njit, jit, vectorize, guvectorize, objmode from numba.core import types, errors, typing, compiler, cgutils from numba.core.typed_passes import type_inference_stage from numba.core.registry import cpu_target from numba.core.imputils import lower_constant from numba.tests.support import ( TestCase, captured_stdout, temp_directory, override_config, run_in_new_process_in_cache_dir, skip_if_typeguard, ) from numba.core.errors import LoweringError import unittest from numba.extending import ( typeof_impl, type_callable, lower_builtin, lower_cast, overload, overload_attribute, overload_method, models, register_model, box, unbox, NativeValue, intrinsic, _Intrinsic, register_jitable, get_cython_function_address, is_jitted, overload_classmethod, ) from numba.core.typing.templates import ( ConcreteTemplate, signature, infer, infer_global, AbstractTemplate, ) # Pandas-like API implementation from .pdlike_usecase import Index, Series try: import scipy.special.cython_special as sc except ImportError: sc = None # ----------------------------------------------------------------------- # Define a custom type and an implicit cast on it class MyDummy(object): pass class MyDummyType(types.Opaque): def can_convert_to(self, context, toty): if isinstance(toty, types.Number): from numba.core.typeconv import Conversion return Conversion.safe mydummy_type = MyDummyType("mydummy") mydummy = MyDummy() @typeof_impl.register(MyDummy) def typeof_mydummy(val, c): return mydummy_type @lower_cast(MyDummyType, types.Number) def mydummy_to_number(context, builder, fromty, toty, val): """ Implicit conversion from MyDummy to int. """ return context.get_constant(toty, 42) def get_dummy(): return mydummy register_model(MyDummyType)(models.OpaqueModel) @unbox(MyDummyType) def unbox_index(typ, obj, c): return NativeValue(c.context.get_dummy_value()) # ----------------------------------------------------------------------- # Define a second custom type but w/o implicit cast to Number def base_dummy_type_factory(name): class DynType(object): pass class DynTypeType(types.Opaque): pass dyn_type_type = DynTypeType(name) @typeof_impl.register(DynType) def typeof_mydummy(val, c): return dyn_type_type register_model(DynTypeType)(models.OpaqueModel) return DynTypeType, DynType, dyn_type_type MyDummyType2, MyDummy2, mydummy_type_2 = base_dummy_type_factory("mydummy2") @unbox(MyDummyType2) def unbox_index2(typ, obj, c): return NativeValue(c.context.get_dummy_value()) # ----------------------------------------------------------------------- # Define a function's typing and implementation using the classical # two-step API def func1(x=None): raise NotImplementedError def type_func1_(context): def typer(x=None): if x in (None, types.none): # 0-arg or 1-arg with None return types.int32 elif isinstance(x, types.Float): # 1-arg with float return x return typer type_func1 = type_callable(func1)(type_func1_) @lower_builtin(func1) @lower_builtin(func1, types.none) def func1_nullary(context, builder, sig, args): return context.get_constant(sig.return_type, 42) @lower_builtin(func1, types.Float) def func1_unary(context, builder, sig, args): def func1_impl(x): return math.sqrt(2 * x) return context.compile_internal(builder, func1_impl, sig, args) # We can do the same for a known internal operation, here "print_item" # which we extend to support MyDummyType. @infer class PrintDummy(ConcreteTemplate): key = "print_item" cases = [signature(types.none, mydummy_type)] @lower_builtin("print_item", MyDummyType) def print_dummy(context, builder, sig, args): [x] = args pyapi = context.get_python_api(builder) strobj = pyapi.unserialize(pyapi.serialize_object("hello!")) pyapi.print_object(strobj) pyapi.decref(strobj) return context.get_dummy_value() # ----------------------------------------------------------------------- # Define an overloaded function (combined API) def where(cond, x, y): raise NotImplementedError def np_where(cond, x, y): """ Wrap np.where() to allow for keyword arguments """ return np.where(cond, x, y) def call_where(cond, x, y): return where(cond, y=y, x=x) @overload(where) def overload_where_arrays(cond, x, y): """ Implement where() for arrays. """ # Choose implementation based on argument types. if isinstance(cond, types.Array): if x.dtype != y.dtype: raise errors.TypingError("x and y should have the same dtype") # Array where() => return an array of the same shape if all(ty.layout == "C" for ty in (cond, x, y)): def where_impl(cond, x, y): """ Fast implementation for C-contiguous arrays """ shape = cond.shape if x.shape != shape or y.shape != shape: raise ValueError("all inputs should have the same shape") res = np.empty_like(x) cf = cond.flat xf = x.flat yf = y.flat rf = res.flat for i in range(cond.size): rf[i] = xf[i] if cf[i] else yf[i] return res else: def where_impl(cond, x, y): """ Generic implementation for other arrays """ shape = cond.shape if x.shape != shape or y.shape != shape: raise ValueError("all inputs should have the same shape") res = np.empty_like(x) for idx, c in np.ndenumerate(cond): res[idx] = x[idx] if c else y[idx] return res return where_impl # We can define another overload function for the same function, they # will be tried in turn until one succeeds. @overload(where) def overload_where_scalars(cond, x, y): """ Implement where() for scalars. """ if not isinstance(cond, types.Array): if x != y: raise errors.TypingError("x and y should have the same type") def where_impl(cond, x, y): """ Scalar where() => return a 0-dim array """ scal = x if cond else y # Can't use full_like() on Numpy < 1.8 arr = np.empty_like(scal) arr[()] = scal return arr return where_impl # ----------------------------------------------------------------------- # Overload an already defined built-in function, extending it for new types. @overload(len) def overload_len_dummy(arg): if isinstance(arg, MyDummyType): def len_impl(arg): return 13 return len_impl @overload(operator.add) def overload_add_dummy(arg1, arg2): if isinstance(arg1, (MyDummyType, MyDummyType2)) and isinstance( arg2, (MyDummyType, MyDummyType2) ): def dummy_add_impl(arg1, arg2): return 42 return dummy_add_impl @overload(operator.delitem) def overload_dummy_delitem(obj, idx): if isinstance(obj, MyDummyType) and isinstance(idx, types.Integer): def dummy_delitem_impl(obj, idx): print("del", obj, idx) return dummy_delitem_impl @overload(operator.getitem) def overload_dummy_getitem(obj, idx): if isinstance(obj, MyDummyType) and isinstance(idx, types.Integer): def dummy_getitem_impl(obj, idx): return idx + 123 return dummy_getitem_impl @overload(operator.setitem) def overload_dummy_setitem(obj, idx, val): if all( [ isinstance(obj, MyDummyType), isinstance(idx, types.Integer), isinstance(val, types.Integer), ] ): def dummy_setitem_impl(obj, idx, val): print(idx, val) return dummy_setitem_impl def call_add_operator(arg1, arg2): return operator.add(arg1, arg2) def call_add_binop(arg1, arg2): return arg1 + arg2 @overload(operator.iadd) def overload_iadd_dummy(arg1, arg2): if isinstance(arg1, (MyDummyType, MyDummyType2)) and isinstance( arg2, (MyDummyType, MyDummyType2) ): def dummy_iadd_impl(arg1, arg2): return 42 return dummy_iadd_impl def call_iadd_operator(arg1, arg2): return operator.add(arg1, arg2) def call_iadd_binop(arg1, arg2): arg1 += arg2 return arg1 def call_delitem(obj, idx): del obj[idx] def call_getitem(obj, idx): return obj[idx] def call_setitem(obj, idx, val): obj[idx] = val @overload_method(MyDummyType, "length") def overload_method_length(arg): def imp(arg): return len(arg) return imp def cache_overload_method_usecase(x): return x.length() def call_func1_nullary(): return func1() def call_func1_unary(x): return func1(x) def len_usecase(x): return len(x) def print_usecase(x): print(x) def getitem_usecase(x, key): return x[key] def npyufunc_usecase(x): return np.cos(np.sin(x)) def get_data_usecase(x): return x._data def get_index_usecase(x): return x._index def is_monotonic_usecase(x): return x.is_monotonic_increasing def make_series_usecase(data, index): return Series(data, index) def clip_usecase(x, lo, hi): return x.clip(lo, hi) # ----------------------------------------------------------------------- def return_non_boxable(): return np @overload(return_non_boxable) def overload_return_non_boxable(): def imp(): return np return imp def non_boxable_ok_usecase(sz): mod = return_non_boxable() return mod.arange(sz) def non_boxable_bad_usecase(): return return_non_boxable() def mk_func_input(f): pass @infer_global(mk_func_input) class MkFuncTyping(AbstractTemplate): def generic(self, args, kws): assert isinstance(args[0], types.MakeFunctionLiteral) return signature(types.none, *args) def mk_func_test_impl(): mk_func_input(lambda a: a) # ----------------------------------------------------------------------- # Define a types derived from types.Callable and overloads for them class MyClass(object): pass class CallableTypeRef(types.Callable): def __init__(self, instance_type): self.instance_type = instance_type self.sig_to_impl_key = {} self.compiled_templates = [] super(CallableTypeRef, self).__init__('callable_type_ref' '[{}]'.format(self.instance_type)) def get_call_type(self, context, args, kws): res_sig = None for template in context._functions[type(self)]: try: res_sig = template.apply(args, kws) except Exception: pass # for simplicity assume args must match exactly else: compiled_ovlds = getattr(template, '_compiled_overloads', {}) if args in compiled_ovlds: self.sig_to_impl_key[res_sig] = compiled_ovlds[args] self.compiled_templates.append(template) break return res_sig def get_call_signatures(self): sigs = list(self.sig_to_impl_key.keys()) return sigs, True def get_impl_key(self, sig): return self.sig_to_impl_key[sig] @register_model(CallableTypeRef) class CallableTypeModel(models.OpaqueModel): def __init__(self, dmm, fe_type): models.OpaqueModel.__init__(self, dmm, fe_type) infer_global(MyClass, CallableTypeRef(MyClass)) @lower_constant(CallableTypeRef) def constant_callable_typeref(context, builder, ty, pyval): return context.get_dummy_value() # ----------------------------------------------------------------------- @overload(np.exp) def overload_np_exp(obj): if isinstance(obj, MyDummyType): def imp(obj): # Returns a constant if a MyDummyType is seen return 0xDEADBEEF return imp class TestLowLevelExtending(TestCase): """ Test the low-level two-tier extension API. """ # Check with `@jit` from within the test process and also in a new test # process so as to check the registration mechanism. def test_func1(self): pyfunc = call_func1_nullary cfunc = jit(nopython=True)(pyfunc) self.assertPreciseEqual(cfunc(), 42) pyfunc = call_func1_unary cfunc = jit(nopython=True)(pyfunc) self.assertPreciseEqual(cfunc(None), 42) self.assertPreciseEqual(cfunc(18.0), 6.0) @TestCase.run_test_in_subprocess def test_func1_isolated(self): self.test_func1() def test_type_callable_keeps_function(self): self.assertIs(type_func1, type_func1_) self.assertIsNotNone(type_func1) @TestCase.run_test_in_subprocess def test_cast_mydummy(self): pyfunc = get_dummy cfunc = njit(types.float64(),)(pyfunc) self.assertPreciseEqual(cfunc(), 42.0) def test_mk_func_literal(self): """make sure make_function is passed to typer class as a literal """ test_ir = compiler.run_frontend(mk_func_test_impl) typingctx = cpu_target.typing_context targetctx = cpu_target.target_context typingctx.refresh() targetctx.refresh() typing_res = type_inference_stage(typingctx, targetctx, test_ir, (), None) self.assertTrue( any( isinstance(a, types.MakeFunctionLiteral) for a in typing_res.typemap.values() ) ) class TestPandasLike(TestCase): """ Test implementing a pandas-like Index object. Also stresses most of the high-level API. """ def test_index_len(self): i = Index(np.arange(3)) cfunc = jit(nopython=True)(len_usecase) self.assertPreciseEqual(cfunc(i), 3) def test_index_getitem(self): i = Index(np.int32([42, 8, -5])) cfunc = jit(nopython=True)(getitem_usecase) self.assertPreciseEqual(cfunc(i, 1), 8) ii = cfunc(i, slice(1, None)) self.assertIsInstance(ii, Index) self.assertEqual(list(ii), [8, -5]) def test_index_ufunc(self): """ Check Numpy ufunc on an Index object. """ i = Index(np.int32([42, 8, -5])) cfunc = jit(nopython=True)(npyufunc_usecase) ii = cfunc(i) self.assertIsInstance(ii, Index) self.assertPreciseEqual(ii._data, np.cos(np.sin(i._data))) def test_index_get_data(self): # The _data attribute is exposed with make_attribute_wrapper() i = Index(np.int32([42, 8, -5])) cfunc = jit(nopython=True)(get_data_usecase) data = cfunc(i) self.assertIs(data, i._data) def test_index_is_monotonic(self): # The is_monotonic_increasing attribute is exposed with # overload_attribute() cfunc = jit(nopython=True)(is_monotonic_usecase) for values, expected in [ ([8, 42, 5], False), ([5, 8, 42], True), ([], True), ]: i = Index(np.int32(values)) got = cfunc(i) self.assertEqual(got, expected) def test_series_len(self): i = Index(np.int32([2, 4, 3])) s = Series(np.float64([1.5, 4.0, 2.5]), i) cfunc = jit(nopython=True)(len_usecase) self.assertPreciseEqual(cfunc(s), 3) def test_series_get_index(self): i = Index(np.int32([2, 4, 3])) s = Series(np.float64([1.5, 4.0, 2.5]), i) cfunc = jit(nopython=True)(get_index_usecase) got = cfunc(s) self.assertIsInstance(got, Index) self.assertIs(got._data, i._data) def test_series_ufunc(self): """ Check Numpy ufunc on an Series object. """ i = Index(np.int32([42, 8, -5])) s = Series(np.int64([1, 2, 3]), i) cfunc = jit(nopython=True)(npyufunc_usecase) ss = cfunc(s) self.assertIsInstance(ss, Series) self.assertIsInstance(ss._index, Index) self.assertIs(ss._index._data, i._data) self.assertPreciseEqual(ss._values, np.cos(np.sin(s._values))) def test_series_constructor(self): i = Index(np.int32([42, 8, -5])) d = np.float64([1.5, 4.0, 2.5]) cfunc = jit(nopython=True)(make_series_usecase) got = cfunc(d, i) self.assertIsInstance(got, Series) self.assertIsInstance(got._index, Index) self.assertIs(got._index._data, i._data) self.assertIs(got._values, d) def test_series_clip(self): i = Index(np.int32([42, 8, -5])) s = Series(np.float64([1.5, 4.0, 2.5]), i) cfunc = jit(nopython=True)(clip_usecase) ss = cfunc(s, 1.6, 3.0) self.assertIsInstance(ss, Series) self.assertIsInstance(ss._index, Index) self.assertIs(ss._index._data, i._data) self.assertPreciseEqual(ss._values, np.float64([1.6, 3.0, 2.5])) class TestHighLevelExtending(TestCase): """ Test the high-level combined API. """ def test_where(self): """ Test implementing a function with @overload. """ pyfunc = call_where cfunc = jit(nopython=True)(pyfunc) def check(*args, **kwargs): expected = np_where(*args, **kwargs) got = cfunc(*args, **kwargs) self.assertPreciseEqual(expected, got) check(x=3, cond=True, y=8) check(True, 3, 8) check( np.bool_([True, False, True]), np.int32([1, 2, 3]), np.int32([4, 5, 5]), ) # The typing error is propagated with self.assertRaises(errors.TypingError) as raises: cfunc(np.bool_([]), np.int32([]), np.int64([])) self.assertIn( "x and y should have the same dtype", str(raises.exception) ) def test_len(self): """ Test re-implementing len() for a custom type with @overload. """ cfunc = jit(nopython=True)(len_usecase) self.assertPreciseEqual(cfunc(MyDummy()), 13) self.assertPreciseEqual(cfunc([4, 5]), 2) def test_print(self): """ Test re-implementing print() for a custom type with @overload. """ cfunc = jit(nopython=True)(print_usecase) with captured_stdout(): cfunc(MyDummy()) self.assertEqual(sys.stdout.getvalue(), "hello!\n") def test_add_operator(self): """ Test re-implementing operator.add() for a custom type with @overload. """ pyfunc = call_add_operator cfunc = jit(nopython=True)(pyfunc) self.assertPreciseEqual(cfunc(1, 2), 3) self.assertPreciseEqual(cfunc(MyDummy2(), MyDummy2()), 42) # this will call add(Number, Number) as MyDummy implicitly casts to # Number self.assertPreciseEqual(cfunc(MyDummy(), MyDummy()), 84) def test_add_binop(self): """ Test re-implementing '+' for a custom type via @overload(operator.add). """ pyfunc = call_add_binop cfunc = jit(nopython=True)(pyfunc) self.assertPreciseEqual(cfunc(1, 2), 3) self.assertPreciseEqual(cfunc(MyDummy2(), MyDummy2()), 42) # this will call add(Number, Number) as MyDummy implicitly casts to # Number self.assertPreciseEqual(cfunc(MyDummy(), MyDummy()), 84) def test_iadd_operator(self): """ Test re-implementing operator.add() for a custom type with @overload. """ pyfunc = call_iadd_operator cfunc = jit(nopython=True)(pyfunc) self.assertPreciseEqual(cfunc(1, 2), 3) self.assertPreciseEqual(cfunc(MyDummy2(), MyDummy2()), 42) # this will call add(Number, Number) as MyDummy implicitly casts to # Number self.assertPreciseEqual(cfunc(MyDummy(), MyDummy()), 84) def test_iadd_binop(self): """ Test re-implementing '+' for a custom type via @overload(operator.add). """ pyfunc = call_iadd_binop cfunc = jit(nopython=True)(pyfunc) self.assertPreciseEqual(cfunc(1, 2), 3) self.assertPreciseEqual(cfunc(MyDummy2(), MyDummy2()), 42) # this will call add(Number, Number) as MyDummy implicitly casts to # Number self.assertPreciseEqual(cfunc(MyDummy(), MyDummy()), 84) def test_delitem(self): pyfunc = call_delitem cfunc = jit(nopython=True)(pyfunc) obj = MyDummy() e = None with captured_stdout() as out: try: cfunc(obj, 321) except Exception as exc: e = exc if e is not None: raise e self.assertEqual(out.getvalue(), "del hello! 321\n") def test_getitem(self): pyfunc = call_getitem cfunc = jit(nopython=True)(pyfunc) self.assertPreciseEqual(cfunc(MyDummy(), 321), 321 + 123) def test_setitem(self): pyfunc = call_setitem cfunc = jit(nopython=True)(pyfunc) obj = MyDummy() e = None with captured_stdout() as out: try: cfunc(obj, 321, 123) except Exception as exc: e = exc if e is not None: raise e self.assertEqual(out.getvalue(), "321 123\n") def test_no_cpython_wrapper(self): """ Test overloading whose return value cannot be represented in CPython. """ # Test passing Module type from a @overload implementation to ensure # that the *no_cpython_wrapper* flag works ok_cfunc = jit(nopython=True)(non_boxable_ok_usecase) n = 10 got = ok_cfunc(n) expect = non_boxable_ok_usecase(n) np.testing.assert_equal(expect, got) # Verify that the Module type cannot be returned to CPython bad_cfunc = jit(nopython=True)(non_boxable_bad_usecase) with self.assertRaises(TypeError) as raises: bad_cfunc() errmsg = str(raises.exception) expectmsg = "cannot convert native Module" self.assertIn(expectmsg, errmsg) def test_typing_vs_impl_signature_mismatch_handling(self): """ Tests that an overload which has a differing typing and implementing signature raises an exception. """ def gen_ol(impl=None): def myoverload(a, b, c, kw=None): pass @overload(myoverload) def _myoverload_impl(a, b, c, kw=None): return impl @jit(nopython=True) def foo(a, b, c, d): myoverload(a, b, c, kw=d) return foo sentinel = "Typing and implementation arguments differ in" # kwarg value is different def impl1(a, b, c, kw=12): if a > 10: return 1 else: return -1 with self.assertRaises(errors.TypingError) as e: gen_ol(impl1)(1, 2, 3, 4) msg = str(e.exception) self.assertIn(sentinel, msg) self.assertIn("keyword argument default values", msg) self.assertIn('', msg) self.assertIn('', msg) # kwarg name is different def impl2(a, b, c, kwarg=None): if a > 10: return 1 else: return -1 with self.assertRaises(errors.TypingError) as e: gen_ol(impl2)(1, 2, 3, 4) msg = str(e.exception) self.assertIn(sentinel, msg) self.assertIn("keyword argument names", msg) self.assertIn('', msg) self.assertIn('', msg) # arg name is different def impl3(z, b, c, kw=None): if a > 10: # noqa: F821 return 1 else: return -1 with self.assertRaises(errors.TypingError) as e: gen_ol(impl3)(1, 2, 3, 4) msg = str(e.exception) self.assertIn(sentinel, msg) self.assertIn("argument names", msg) self.assertFalse("keyword" in msg) self.assertIn('', msg) self.assertIn('', msg) from .overload_usecases import impl4, impl5 with self.assertRaises(errors.TypingError) as e: gen_ol(impl4)(1, 2, 3, 4) msg = str(e.exception) self.assertIn(sentinel, msg) self.assertIn("argument names", msg) self.assertFalse("keyword" in msg) self.assertIn("First difference: 'z'", msg) with self.assertRaises(errors.TypingError) as e: gen_ol(impl5)(1, 2, 3, 4) msg = str(e.exception) self.assertIn(sentinel, msg) self.assertIn("argument names", msg) self.assertFalse("keyword" in msg) self.assertIn('', msg) self.assertIn('', msg) # too many args def impl6(a, b, c, d, e, kw=None): if a > 10: return 1 else: return -1 with self.assertRaises(errors.TypingError) as e: gen_ol(impl6)(1, 2, 3, 4) msg = str(e.exception) self.assertIn(sentinel, msg) self.assertIn("argument names", msg) self.assertFalse("keyword" in msg) self.assertIn('', msg) self.assertIn('', msg) # too few args def impl7(a, b, kw=None): if a > 10: return 1 else: return -1 with self.assertRaises(errors.TypingError) as e: gen_ol(impl7)(1, 2, 3, 4) msg = str(e.exception) self.assertIn(sentinel, msg) self.assertIn("argument names", msg) self.assertFalse("keyword" in msg) self.assertIn('', msg) # too many kwargs def impl8(a, b, c, kw=None, extra_kwarg=None): if a > 10: return 1 else: return -1 with self.assertRaises(errors.TypingError) as e: gen_ol(impl8)(1, 2, 3, 4) msg = str(e.exception) self.assertIn(sentinel, msg) self.assertIn("keyword argument names", msg) self.assertIn('', msg) # too few kwargs def impl9(a, b, c): if a > 10: return 1 else: return -1 with self.assertRaises(errors.TypingError) as e: gen_ol(impl9)(1, 2, 3, 4) msg = str(e.exception) self.assertIn(sentinel, msg) self.assertIn("keyword argument names", msg) self.assertIn('', msg) def test_typing_vs_impl_signature_mismatch_handling_var_positional(self): """ Tests that an overload which has a differing typing and implementing signature raises an exception and uses VAR_POSITIONAL (*args) in typing """ def myoverload(a, kw=None): pass from .overload_usecases import var_positional_impl overload(myoverload)(var_positional_impl) @jit(nopython=True) def foo(a, b): return myoverload(a, b, 9, kw=11) with self.assertRaises(errors.TypingError) as e: foo(1, 5) msg = str(e.exception) self.assertIn("VAR_POSITIONAL (e.g. *args) argument kind", msg) self.assertIn("offending argument name is '*star_args_token'", msg) def test_typing_vs_impl_signature_mismatch_handling_var_keyword(self): """ Tests that an overload which uses **kwargs (VAR_KEYWORD) """ def gen_ol(impl, strict=True): def myoverload(a, kw=None): pass overload(myoverload, strict=strict)(impl) @jit(nopython=True) def foo(a, b): return myoverload(a, kw=11) return foo # **kwargs in typing def ol1(a, **kws): def impl(a, kw=10): return a return impl gen_ol(ol1, False)(1, 2) # no error if strictness not enforced with self.assertRaises(errors.TypingError) as e: gen_ol(ol1)(1, 2) msg = str(e.exception) self.assertIn("use of VAR_KEYWORD (e.g. **kwargs) is unsupported", msg) self.assertIn("offending argument name is '**kws'", msg) # **kwargs in implementation def ol2(a, kw=0): def impl(a, **kws): return a return impl with self.assertRaises(errors.TypingError) as e: gen_ol(ol2)(1, 2) msg = str(e.exception) self.assertIn("use of VAR_KEYWORD (e.g. **kwargs) is unsupported", msg) self.assertIn("offending argument name is '**kws'", msg) def test_overload_method_kwargs(self): # Issue #3489 @overload_method(types.Array, "foo") def fooimpl(arr, a_kwarg=10): def impl(arr, a_kwarg=10): return a_kwarg return impl @njit def bar(A): return A.foo(), A.foo(20), A.foo(a_kwarg=30) Z = np.arange(5) self.assertEqual(bar(Z), (10, 20, 30)) def test_overload_method_literal_unpack(self): # Issue #3683 @overload_method(types.Array, "litfoo") def litfoo(arr, val): # Must be an integer if isinstance(val, types.Integer): # Must not be literal if not isinstance(val, types.Literal): def impl(arr, val): return val return impl @njit def bar(A): return A.litfoo(0xCAFE) A = np.zeros(1) bar(A) self.assertEqual(bar(A), 0xCAFE) def test_overload_ufunc(self): # Issue #4133. # Use an extended type (MyDummyType) to use with a customized # ufunc (np.exp). @njit def test(): return np.exp(mydummy) self.assertEqual(test(), 0xDEADBEEF) def test_overload_method_stararg(self): @overload_method(MyDummyType, "method_stararg") def _ov_method_stararg(obj, val, val2, *args): def get(obj, val, val2, *args): return (val, val2, args) return get @njit def foo(obj, *args): # Test with expanding stararg return obj.method_stararg(*args) obj = MyDummy() self.assertEqual(foo(obj, 1, 2), (1, 2, ())) self.assertEqual(foo(obj, 1, 2, 3), (1, 2, (3,))) self.assertEqual(foo(obj, 1, 2, 3, 4), (1, 2, (3, 4))) @njit def bar(obj): # Test with explicit argument return ( obj.method_stararg(1, 2), obj.method_stararg(1, 2, 3), obj.method_stararg(1, 2, 3, 4), ) self.assertEqual( bar(obj), ((1, 2, ()), (1, 2, (3,)), (1, 2, (3, 4))), ) # Check cases that put tuple type into stararg # NOTE: the expected result has an extra tuple because of stararg. self.assertEqual( foo(obj, 1, 2, (3,)), (1, 2, ((3,),)), ) self.assertEqual( foo(obj, 1, 2, (3, 4)), (1, 2, ((3, 4),)), ) self.assertEqual( foo(obj, 1, 2, (3, (4, 5))), (1, 2, ((3, (4, 5)),)), ) def test_overload_classmethod(self): # Add classmethod to a subclass of Array class MyArray(types.Array): pass @overload_classmethod(MyArray, "array_alloc") def ol_array_alloc(cls, nitems): def impl(cls, nitems): arr = np.arange(nitems) return arr return impl @njit def foo(nitems): return MyArray.array_alloc(nitems) nitems = 13 self.assertPreciseEqual(foo(nitems), np.arange(nitems)) # Check that the base type doesn't get the classmethod @njit def no_classmethod_in_base(nitems): return types.Array.array_alloc(nitems) with self.assertRaises(errors.TypingError) as raises: no_classmethod_in_base(nitems) self.assertIn( "Unknown attribute 'array_alloc' of", str(raises.exception), ) def test_overload_callable_typeref(self): @overload(CallableTypeRef) def callable_type_call_ovld1(x): if isinstance(x, types.Integer): def impl(x): return 42.5 + x return impl @overload(CallableTypeRef) def callable_type_call_ovld2(x): if isinstance(x, types.UnicodeType): def impl(x): return '42.5' + x return impl @njit def foo(a, b): return MyClass(a), MyClass(b) args = (4, '4') expected = (42.5 + args[0], '42.5' + args[1]) self.assertPreciseEqual(foo(*args), expected) def _assert_cache_stats(cfunc, expect_hit, expect_misses): hit = cfunc._cache_hits[cfunc.signatures[0]] if hit != expect_hit: raise AssertionError("cache not used") miss = cfunc._cache_misses[cfunc.signatures[0]] if miss != expect_misses: raise AssertionError("cache not used") @skip_if_typeguard class TestOverloadMethodCaching(TestCase): # Nested multiprocessing.Pool raises AssertionError: # "daemonic processes are not allowed to have children" _numba_parallel_test_ = False def test_caching_overload_method(self): self._cache_dir = temp_directory(self.__class__.__name__) with override_config("CACHE_DIR", self._cache_dir): self.run_caching_overload_method() def run_caching_overload_method(self): cfunc = jit(nopython=True, cache=True)(cache_overload_method_usecase) self.assertPreciseEqual(cfunc(MyDummy()), 13) _assert_cache_stats(cfunc, 0, 1) llvmir = cfunc.inspect_llvm((mydummy_type,)) # Ensure the inner method is not a declaration decls = [ ln for ln in llvmir.splitlines() if ln.startswith("declare") and "overload_method_length" in ln ] self.assertEqual(len(decls), 0) # Test in a separate process try: ctx = multiprocessing.get_context("spawn") except AttributeError: ctx = multiprocessing q = ctx.Queue() p = ctx.Process( target=run_caching_overload_method, args=(q, self._cache_dir) ) p.start() q.put(MyDummy()) p.join() # Ensure subprocess exited normally self.assertEqual(p.exitcode, 0) res = q.get(timeout=1) self.assertEqual(res, 13) def run_caching_overload_method(q, cache_dir): """ Used by TestOverloadMethodCaching.test_caching_overload_method """ with override_config("CACHE_DIR", cache_dir): arg = q.get() cfunc = jit(nopython=True, cache=True)(cache_overload_method_usecase) res = cfunc(arg) q.put(res) # Check cache stat _assert_cache_stats(cfunc, 1, 0) class TestIntrinsic(TestCase): def test_void_return(self): """ Verify that returning a None from codegen function is handled automatically for void functions, otherwise raise exception. """ @intrinsic def void_func(typingctx, a): sig = types.void(types.int32) def codegen(context, builder, signature, args): pass # do nothing, return None, should be turned into # dummy value return sig, codegen @intrinsic def non_void_func(typingctx, a): sig = types.int32(types.int32) def codegen(context, builder, signature, args): pass # oops, should be returning a value here, raise exception return sig, codegen @jit(nopython=True) def call_void_func(): void_func(1) return 0 @jit(nopython=True) def call_non_void_func(): non_void_func(1) return 0 # void func should work self.assertEqual(call_void_func(), 0) # not void function should raise exception with self.assertRaises(LoweringError) as e: call_non_void_func() self.assertIn("non-void function returns None", e.exception.msg) def test_ll_pointer_cast(self): """ Usecase test: custom reinterpret cast to turn int values to pointers """ from ctypes import CFUNCTYPE, POINTER, c_float, c_int # Use intrinsic to make a reinterpret_cast operation def unsafe_caster(result_type): assert isinstance(result_type, types.CPointer) @intrinsic def unsafe_cast(typingctx, src): self.assertIsInstance(typingctx, typing.Context) if isinstance(src, types.Integer): sig = result_type(types.uintp) # defines the custom code generation def codegen(context, builder, signature, args): [src] = args rtype = signature.return_type llrtype = context.get_value_type(rtype) return builder.inttoptr(src, llrtype) return sig, codegen return unsafe_cast # make a nopython function to use our cast op. # this is not usable from cpython due to the returning of a pointer. def unsafe_get_ctypes_pointer(src): raise NotImplementedError("not callable from python") @overload(unsafe_get_ctypes_pointer, strict=False) def array_impl_unsafe_get_ctypes_pointer(arrtype): if isinstance(arrtype, types.Array): unsafe_cast = unsafe_caster(types.CPointer(arrtype.dtype)) def array_impl(arr): return unsafe_cast(src=arr.ctypes.data) return array_impl # the ctype wrapped function for use in nopython mode def my_c_fun_raw(ptr, n): for i in range(n): print(ptr[i]) prototype = CFUNCTYPE(None, POINTER(c_float), c_int) my_c_fun = prototype(my_c_fun_raw) # Call our pointer-cast in a @jit compiled function and use # the pointer in a ctypes function @jit(nopython=True) def foo(arr): ptr = unsafe_get_ctypes_pointer(arr) my_c_fun(ptr, arr.size) # Test arr = np.arange(10, dtype=np.float32) with captured_stdout() as buf: foo(arr) got = buf.getvalue().splitlines() buf.close() expect = list(map(str, arr)) self.assertEqual(expect, got) def test_serialization(self): """ Test serialization of intrinsic objects """ # define a intrinsic @intrinsic def identity(context, x): def codegen(context, builder, signature, args): return args[0] sig = x(x) return sig, codegen # use in a jit function @jit(nopython=True) def foo(x): return identity(x) self.assertEqual(foo(1), 1) # get serialization memo memo = _Intrinsic._memo memo_size = len(memo) # pickle foo and check memo size serialized_foo = pickle.dumps(foo) # increases the memo size memo_size += 1 self.assertEqual(memo_size, len(memo)) # unpickle foo_rebuilt = pickle.loads(serialized_foo) self.assertEqual(memo_size, len(memo)) # check rebuilt foo self.assertEqual(foo(1), foo_rebuilt(1)) # pickle identity directly serialized_identity = pickle.dumps(identity) # memo size unchanged self.assertEqual(memo_size, len(memo)) # unpickle identity_rebuilt = pickle.loads(serialized_identity) # must be the same object self.assertIs(identity, identity_rebuilt) # memo size unchanged self.assertEqual(memo_size, len(memo)) def test_deserialization(self): """ Test deserialization of intrinsic """ def defn(context, x): def codegen(context, builder, signature, args): return args[0] return x(x), codegen memo = _Intrinsic._memo memo_size = len(memo) # invoke _Intrinsic indirectly to avoid registration which keeps an # internal reference inside the compiler original = _Intrinsic("foo", defn) self.assertIs(original._defn, defn) pickled = pickle.dumps(original) # by pickling, a new memo entry is created memo_size += 1 self.assertEqual(memo_size, len(memo)) del original # remove original before unpickling # by deleting, the memo entry is NOT removed due to recent # function queue self.assertEqual(memo_size, len(memo)) # Manually force clear of _recent queue _Intrinsic._recent.clear() memo_size -= 1 self.assertEqual(memo_size, len(memo)) rebuilt = pickle.loads(pickled) # verify that the rebuilt object is different self.assertIsNot(rebuilt._defn, defn) # the second rebuilt object is the same as the first second = pickle.loads(pickled) self.assertIs(rebuilt._defn, second._defn) def test_docstring(self): @intrinsic def void_func(typingctx, a: int): """void_func docstring""" sig = types.void(types.int32) def codegen(context, builder, signature, args): pass # do nothing, return None, should be turned into # dummy value return sig, codegen self.assertEqual("numba.tests.test_extending", void_func.__module__) self.assertEqual("void_func", void_func.__name__) self.assertEqual("TestIntrinsic.test_docstring..void_func", void_func.__qualname__) self.assertDictEqual({'a': int}, void_func.__annotations__) self.assertEqual("void_func docstring", void_func.__doc__) class TestRegisterJitable(unittest.TestCase): def test_no_flags(self): @register_jitable def foo(x, y): return x + y def bar(x, y): return foo(x, y) cbar = jit(nopython=True)(bar) expect = bar(1, 2) got = cbar(1, 2) self.assertEqual(expect, got) def test_flags_no_nrt(self): @register_jitable(_nrt=False) def foo(n): return np.arange(n) def bar(n): return foo(n) self.assertEqual(bar(3).tolist(), [0, 1, 2]) cbar = jit(nopython=True)(bar) with self.assertRaises(errors.TypingError) as raises: cbar(2) msg = ( "Only accept returning of array passed into the function as " "argument" ) self.assertIn(msg, str(raises.exception)) class TestImportCythonFunction(unittest.TestCase): @unittest.skipIf(sc is None, "Only run if SciPy >= 0.19 is installed") def test_getting_function(self): addr = get_cython_function_address( "scipy.special.cython_special", "j0" ) functype = ctypes.CFUNCTYPE(ctypes.c_double, ctypes.c_double) _j0 = functype(addr) j0 = jit(nopython=True)(lambda x: _j0(x)) self.assertEqual(j0(0), 1) def test_missing_module(self): with self.assertRaises(ImportError) as raises: get_cython_function_address("fakemodule", "fakefunction") # The quotes are not there in Python 2 msg = "No module named '?fakemodule'?" match = re.match(msg, str(raises.exception)) self.assertIsNotNone(match) @unittest.skipIf(sc is None, "Only run if SciPy >= 0.19 is installed") def test_missing_function(self): with self.assertRaises(ValueError) as raises: get_cython_function_address( "scipy.special.cython_special", "foo" ) msg = ( "No function 'foo' found in __pyx_capi__ of " "'scipy.special.cython_special'" ) self.assertEqual(msg, str(raises.exception)) @overload_method( MyDummyType, "method_jit_option_check_nrt", jit_options={"_nrt": True} ) def ov_method_jit_option_check_nrt(obj): def imp(obj): return np.arange(10) return imp @overload_method( MyDummyType, "method_jit_option_check_no_nrt", jit_options={"_nrt": False} ) def ov_method_jit_option_check_no_nrt(obj): def imp(obj): return np.arange(10) return imp @overload_attribute( MyDummyType, "attr_jit_option_check_nrt", jit_options={"_nrt": True} ) def ov_attr_jit_option_check_nrt(obj): def imp(obj): return np.arange(10) return imp @overload_attribute( MyDummyType, "attr_jit_option_check_no_nrt", jit_options={"_nrt": False} ) def ov_attr_jit_option_check_no_nrt(obj): def imp(obj): return np.arange(10) return imp class TestJitOptionsNoNRT(TestCase): # Test overload*(jit_options={...}) by turning off _nrt def check_error_no_nrt(self, func, *args, **kwargs): # Check that the compilation fails with a complaint about dynamic array msg = ( "Only accept returning of array passed into " "the function as argument" ) with self.assertRaises(errors.TypingError) as raises: func(*args, **kwargs) self.assertIn(msg, str(raises.exception)) def no_nrt_overload_check(self, flag): def dummy(): return np.arange(10) @overload(dummy, jit_options={"_nrt": flag}) def ov_dummy(): def dummy(): return np.arange(10) return dummy @njit def foo(): return dummy() if flag: self.assertPreciseEqual(foo(), np.arange(10)) else: self.check_error_no_nrt(foo) def test_overload_no_nrt(self): self.no_nrt_overload_check(True) self.no_nrt_overload_check(False) def test_overload_method_no_nrt(self): @njit def udt(x): return x.method_jit_option_check_nrt() self.assertPreciseEqual(udt(mydummy), np.arange(10)) @njit def udt(x): return x.method_jit_option_check_no_nrt() self.check_error_no_nrt(udt, mydummy) def test_overload_attribute_no_nrt(self): @njit def udt(x): return x.attr_jit_option_check_nrt self.assertPreciseEqual(udt(mydummy), np.arange(10)) @njit def udt(x): return x.attr_jit_option_check_no_nrt self.check_error_no_nrt(udt, mydummy) class TestBoxingCallingJIT(TestCase): def setUp(self): super().setUp() many = base_dummy_type_factory("mydummy2") self.DynTypeType, self.DynType, self.dyn_type_type = many self.dyn_type = self.DynType() def test_unboxer_basic(self): # Implements an unboxer on DynType that calls an intrinsic into the # unboxer code. magic_token = 0xCAFE magic_offset = 123 @intrinsic def my_intrinsic(typingctx, val): # An intrinsic that returns `val + magic_offset` def impl(context, builder, sig, args): [val] = args return builder.add(val, val.type(magic_offset)) sig = signature(val, val) return sig, impl @unbox(self.DynTypeType) def unboxer(typ, obj, c): # The unboxer that calls some jitcode def bridge(x): # proof that this is a jit'ed context by calling jit only # intrinsic return my_intrinsic(x) args = [c.context.get_constant(types.intp, magic_token)] sig = signature(types.voidptr, types.intp) is_error, res = c.pyapi.call_jit_code(bridge, sig, args) return NativeValue(res, is_error=is_error) @box(self.DynTypeType) def boxer(typ, val, c): # The boxer that returns an integer representation res = c.builder.ptrtoint(val, cgutils.intp_t) return c.pyapi.long_from_ssize_t(res) @njit def passthru(x): return x out = passthru(self.dyn_type) self.assertEqual(out, magic_token + magic_offset) def test_unboxer_raise(self): # Testing exception raising in jitcode called from unboxing. @unbox(self.DynTypeType) def unboxer(typ, obj, c): # The unboxer that calls some jitcode def bridge(x): if x > 0: raise ValueError("cannot be x > 0") return x args = [c.context.get_constant(types.intp, 1)] sig = signature(types.voidptr, types.intp) is_error, res = c.pyapi.call_jit_code(bridge, sig, args) return NativeValue(res, is_error=is_error) @box(self.DynTypeType) def boxer(typ, val, c): # The boxer that returns an integer representation res = c.builder.ptrtoint(val, cgutils.intp_t) return c.pyapi.long_from_ssize_t(res) @njit def passthru(x): return x with self.assertRaises(ValueError) as raises: passthru(self.dyn_type) self.assertIn( "cannot be x > 0", str(raises.exception), ) def test_boxer(self): # Call jitcode inside the boxer magic_token = 0xCAFE magic_offset = 312 @intrinsic def my_intrinsic(typingctx, val): # An intrinsic that returns `val + magic_offset` def impl(context, builder, sig, args): [val] = args return builder.add(val, val.type(magic_offset)) sig = signature(val, val) return sig, impl @unbox(self.DynTypeType) def unboxer(typ, obj, c): return NativeValue(c.context.get_dummy_value()) @box(self.DynTypeType) def boxer(typ, val, c): # Note: this doesn't do proper error handling def bridge(x): return my_intrinsic(x) args = [c.context.get_constant(types.intp, magic_token)] sig = signature(types.intp, types.intp) is_error, res = c.pyapi.call_jit_code(bridge, sig, args) return c.pyapi.long_from_ssize_t(res) @njit def passthru(x): return x r = passthru(self.dyn_type) self.assertEqual(r, magic_token + magic_offset) def test_boxer_raise(self): # Call jitcode inside the boxer @unbox(self.DynTypeType) def unboxer(typ, obj, c): return NativeValue(c.context.get_dummy_value()) @box(self.DynTypeType) def boxer(typ, val, c): def bridge(x): if x > 0: raise ValueError("cannot do x > 0") return x args = [c.context.get_constant(types.intp, 1)] sig = signature(types.intp, types.intp) is_error, res = c.pyapi.call_jit_code(bridge, sig, args) # The error handling retval = cgutils.alloca_once(c.builder, c.pyapi.pyobj, zfill=True) with c.builder.if_then(c.builder.not_(is_error)): obj = c.pyapi.long_from_ssize_t(res) c.builder.store(obj, retval) return c.builder.load(retval) @njit def passthru(x): return x with self.assertRaises(ValueError) as raises: passthru(self.dyn_type) self.assertIn( "cannot do x > 0", str(raises.exception), ) def with_objmode_cache_ov_example(x): # This is the function stub for overloading inside # TestCachingOverloadObjmode.test_caching_overload_objmode pass @skip_if_typeguard class TestCachingOverloadObjmode(TestCase): """Test caching of the use of overload implementations that use `with objmode` """ _numba_parallel_test_ = False def setUp(self): warnings.simplefilter("error", errors.NumbaWarning) def tearDown(self): warnings.resetwarnings() def test_caching_overload_objmode(self): cache_dir = temp_directory(self.__class__.__name__) with override_config("CACHE_DIR", cache_dir): def realwork(x): # uses numpy code arr = np.arange(x) / x return np.linalg.norm(arr) def python_code(x): # create indirections return realwork(x) @overload(with_objmode_cache_ov_example) def _ov_with_objmode_cache_ov_example(x): def impl(x): with objmode(y="float64"): y = python_code(x) return y return impl @njit(cache=True) def testcase(x): return with_objmode_cache_ov_example(x) expect = realwork(123) got = testcase(123) self.assertEqual(got, expect) testcase_cached = njit(cache=True)(testcase.py_func) got = testcase_cached(123) self.assertEqual(got, expect) @classmethod def check_objmode_cache_ndarray(cls): def do_this(a, b): return np.sum(a + b) def do_something(a, b): return np.sum(a + b) @overload(do_something) def overload_do_something(a, b): def _do_something_impl(a, b): with objmode(y='float64'): y = do_this(a, b) return y return _do_something_impl @njit(cache=True) def test_caching(): a = np.arange(20) b = np.arange(20) return do_something(a, b) got = test_caching() expect = test_caching.py_func() # Check result if got != expect: raise AssertionError("incorrect result") return test_caching @classmethod def populate_objmode_cache_ndarray_check_cache(cls): cls.check_objmode_cache_ndarray() @classmethod def check_objmode_cache_ndarray_check_cache(cls): disp = cls.check_objmode_cache_ndarray() if len(disp.stats.cache_misses) != 0: raise AssertionError('unexpected cache miss') if len(disp.stats.cache_hits) <= 0: raise AssertionError("unexpected missing cache hit") def test_check_objmode_cache_ndarray(self): # See issue #6130. # Env is missing after cache load. cache_dir = temp_directory(self.__class__.__name__) with override_config("CACHE_DIR", cache_dir): # Run in new process to populate the cache run_in_new_process_in_cache_dir( self.populate_objmode_cache_ndarray_check_cache, cache_dir ) # Run in new process to use the cache in a fresh process. res = run_in_new_process_in_cache_dir( self.check_objmode_cache_ndarray_check_cache, cache_dir ) self.assertEqual(res['exitcode'], 0) class TestMisc(TestCase): def test_is_jitted(self): def foo(x): pass self.assertFalse(is_jitted(foo)) self.assertTrue(is_jitted(njit(foo))) self.assertFalse(is_jitted(vectorize(foo))) self.assertFalse(is_jitted(vectorize(parallel=True)(foo))) self.assertFalse( is_jitted(guvectorize("void(float64[:])", "(m)")(foo)) ) def test_overload_arg_binding(self): # See issue #7982, checks that calling a function with named args works # correctly irrespective of the order in which the names are supplied. @njit def standard_order(): return np.full(shape=123, fill_value=456).shape @njit def reversed_order(): return np.full(fill_value=456, shape=123).shape self.assertPreciseEqual(standard_order(), standard_order.py_func()) self.assertPreciseEqual(reversed_order(), reversed_order.py_func()) class TestOverloadPreferLiteral(TestCase): def test_overload(self): def prefer_lit(x): pass def non_lit(x): pass def ov(x): if isinstance(x, types.IntegerLiteral): # With prefer_literal=False, this branch will not be reached. if x.literal_value == 1: def impl(x): return 0xcafe return impl else: raise errors.TypingError('literal value') else: def impl(x): return x * 100 return impl overload(prefer_lit, prefer_literal=True)(ov) overload(non_lit)(ov) @njit def check_prefer_lit(x): return prefer_lit(1), prefer_lit(2), prefer_lit(x) a, b, c = check_prefer_lit(3) self.assertEqual(a, 0xcafe) self.assertEqual(b, 200) self.assertEqual(c, 300) @njit def check_non_lit(x): return non_lit(1), non_lit(2), non_lit(x) a, b, c = check_non_lit(3) self.assertEqual(a, 100) self.assertEqual(b, 200) self.assertEqual(c, 300) def test_overload_method(self): def ov(self, x): if isinstance(x, types.IntegerLiteral): # With prefer_literal=False, this branch will not be reached. if x.literal_value == 1: def impl(self, x): return 0xcafe return impl else: raise errors.TypingError('literal value') else: def impl(self, x): return x * 100 return impl overload_method( MyDummyType, "method_prefer_literal", prefer_literal=True, )(ov) overload_method( MyDummyType, "method_non_literal", prefer_literal=False, )(ov) @njit def check_prefer_lit(dummy, x): return ( dummy.method_prefer_literal(1), dummy.method_prefer_literal(2), dummy.method_prefer_literal(x), ) a, b, c = check_prefer_lit(MyDummy(), 3) self.assertEqual(a, 0xcafe) self.assertEqual(b, 200) self.assertEqual(c, 300) @njit def check_non_lit(dummy, x): return ( dummy.method_non_literal(1), dummy.method_non_literal(2), dummy.method_non_literal(x), ) a, b, c = check_non_lit(MyDummy(), 3) self.assertEqual(a, 100) self.assertEqual(b, 200) self.assertEqual(c, 300) class TestIntrinsicPreferLiteral(TestCase): def test_intrinsic(self): def intrin(context, x): # This intrinsic will return 0xcafe if `x` is a literal `1`. sig = signature(types.intp, x) if isinstance(x, types.IntegerLiteral): # With prefer_literal=False, this branch will not be reached if x.literal_value == 1: def codegen(context, builder, signature, args): atype = signature.args[0] llrtype = context.get_value_type(atype) return ir.Constant(llrtype, 0xcafe) return sig, codegen else: raise errors.TypingError('literal value') else: def codegen(context, builder, signature, args): atype = signature.return_type llrtype = context.get_value_type(atype) int_100 = ir.Constant(llrtype, 100) return builder.mul(args[0], int_100) return sig, codegen prefer_lit = intrinsic(prefer_literal=True)(intrin) non_lit = intrinsic(prefer_literal=False)(intrin) @njit def check_prefer_lit(x): return prefer_lit(1), prefer_lit(2), prefer_lit(x) a, b, c = check_prefer_lit(3) self.assertEqual(a, 0xcafe) self.assertEqual(b, 200) self.assertEqual(c, 300) @njit def check_non_lit(x): return non_lit(1), non_lit(2), non_lit(x) a, b, c = check_non_lit(3) self.assertEqual(a, 100) self.assertEqual(b, 200) self.assertEqual(c, 300) class TestNumbaInternalOverloads(TestCase): def test_signatures_match_overloaded_api(self): # This is a "best-effort" test to try and ensure that Numba's internal # overload declarations have signatures with argument names that match # the API they are overloading. The purpose of ensuring there is a # match is so that users can use call-by-name for positional arguments. # Set this to: # 0 to make violations raise a ValueError (default). # 1 to get violations reported to STDOUT # 2 to get a verbose output of everything that was checked and its state # reported to STDOUT. DEBUG = 0 # np.random.* does not have a signature exposed to `inspect`... so # custom parse the docstrings. def sig_from_np_random(x): if not x.startswith('_'): thing = getattr(np.random, x) if inspect.isbuiltin(thing): docstr = thing.__doc__.splitlines() for l in docstr: if l: sl = l.strip() if sl.startswith(x): # its the signature # special case np.random.seed, it has `self` in # the signature whereas all the other functions # do not!? if x == "seed": sl = "seed(seed)" fake_impl = f"def {sl}:\n\tpass" l = {} try: exec(fake_impl, {}, l) except SyntaxError: # likely elipsis, e.g. rand(d0, d1, ..., dn) if DEBUG == 2: print("... skipped as cannot parse " "signature") return None else: fn = l.get(x) return inspect.signature(fn) def checker(func, overload_func): if DEBUG == 2: print(f"Checking: {func}") def create_message(func, overload_func, func_sig, ol_sig): msg = [] s = (f"{func} from module '{getattr(func, '__module__')}' " "has mismatched sig.") msg.append(s) msg.append(f" - expected: {func_sig}") msg.append(f" - got: {ol_sig}") lineno = inspect.getsourcelines(overload_func)[1] tmpsrcfile = inspect.getfile(overload_func) srcfile = tmpsrcfile.replace(numba.__path__[0], '') msg.append(f"from {srcfile}:{lineno}") msgstr = '\n' + '\n'.join(msg) return msgstr func_sig = None try: func_sig = inspect.signature(func) except ValueError: # probably a built-in/C code, see if it's a np.random function if fname := getattr(func, '__name__', False): if maybe_func := getattr(np.random, fname, False): if maybe_func == func: # it's a built-in from np.random func_sig = sig_from_np_random(fname) if func_sig is not None: ol_sig = inspect.signature(overload_func) x = list(func_sig.parameters.keys()) y = list(ol_sig.parameters.keys()) for a, b in zip(x[:len(y)], y): if a != b: p = func_sig.parameters[a] if p.kind == p.POSITIONAL_ONLY: # probably a built-in/C code if DEBUG == 2: print("... skipped as positional only " "arguments found") break elif '*' in str(p): # probably *args or similar if DEBUG == 2: print("... skipped as contains *args") break else: # Only error/report on functions that have a module # or are from somewhere other than Numba. if (not func.__module__ or not func.__module__.startswith("numba")): msgstr = create_message(func, overload_func, func_sig, ol_sig) if DEBUG != 0: if DEBUG == 2: print("... INVALID") if msgstr: print(msgstr) break else: raise ValueError(msgstr) else: if DEBUG == 2: if not func.__module__: print("... skipped as no __module__ " "present") else: print("... skipped as Numba internal") break else: if DEBUG == 2: print("... OK") # Compile something to make sure that the typing context registries # are populated with everything from the CPU target. njit(lambda : None).compile(()) tyctx = numba.core.typing.context.Context() tyctx.refresh() # Walk the registries and check each function that is an overload regs = tyctx._registries for k, v in regs.items(): for item in k.functions: if getattr(item, '_overload_func', False): checker(item.key, item._overload_func) if __name__ == "__main__": unittest.main()