# -*- coding: utf-8 -*- import inspect import warnings from contextlib import contextmanager from numba.core import config, targetconfig from numba.core.decorators import jit from numba.core.descriptors import TargetDescriptor from numba.core.extending import is_jitted from numba.core.errors import NumbaDeprecationWarning from numba.core.options import TargetOptions, include_default_options from numba.core.registry import cpu_target from numba.core.target_extension import dispatcher_registry, target_registry from numba.core import utils, types, serialize, compiler, sigutils from numba.np.numpy_support import as_dtype from numba.np.ufunc import _internal from numba.np.ufunc.sigparse import parse_signature from numba.np.ufunc.wrappers import build_ufunc_wrapper, build_gufunc_wrapper from numba.core.caching import FunctionCache, NullCache from numba.core.compiler_lock import global_compiler_lock _options_mixin = include_default_options( "nopython", "forceobj", "boundscheck", "fastmath", "target_backend", "writable_args" ) class UFuncTargetOptions(_options_mixin, TargetOptions): def finalize(self, flags, options): if not flags.is_set("enable_pyobject"): flags.enable_pyobject = True if not flags.is_set("enable_looplift"): flags.enable_looplift = True flags.inherit_if_not_set("nrt", default=True) if not flags.is_set("debuginfo"): flags.debuginfo = config.DEBUGINFO_DEFAULT if not flags.is_set("boundscheck"): flags.boundscheck = flags.debuginfo flags.enable_pyobject_looplift = True flags.inherit_if_not_set("fastmath") class UFuncTarget(TargetDescriptor): options = UFuncTargetOptions def __init__(self): super().__init__('ufunc') @property def typing_context(self): return cpu_target.typing_context @property def target_context(self): return cpu_target.target_context ufunc_target = UFuncTarget() class UFuncDispatcher(serialize.ReduceMixin): """ An object handling compilation of various signatures for a ufunc. """ targetdescr = ufunc_target def __init__(self, py_func, locals={}, targetoptions={}): self.py_func = py_func self.overloads = utils.UniqueDict() self.targetoptions = targetoptions self.locals = locals self.cache = NullCache() def _reduce_states(self): """ NOTE: part of ReduceMixin protocol """ return dict( pyfunc=self.py_func, locals=self.locals, targetoptions=self.targetoptions, ) @classmethod def _rebuild(cls, pyfunc, locals, targetoptions): """ NOTE: part of ReduceMixin protocol """ return cls(py_func=pyfunc, locals=locals, targetoptions=targetoptions) def enable_caching(self): self.cache = FunctionCache(self.py_func) def compile(self, sig, locals={}, **targetoptions): locs = self.locals.copy() locs.update(locals) topt = self.targetoptions.copy() topt.update(targetoptions) flags = compiler.Flags() self.targetdescr.options.parse_as_flags(flags, topt) flags.no_cpython_wrapper = True flags.error_model = "numpy" # Disable loop lifting # The feature requires a real # python function flags.enable_looplift = False return self._compile_core(sig, flags, locals) def _compile_core(self, sig, flags, locals): """ Trigger the compiler on the core function or load a previously compiled version from the cache. Returns the CompileResult. """ typingctx = self.targetdescr.typing_context targetctx = self.targetdescr.target_context @contextmanager def store_overloads_on_success(): # use to ensure overloads are stored on success try: yield except Exception: raise else: exists = self.overloads.get(cres.signature) if exists is None: self.overloads[cres.signature] = cres # Use cache and compiler in a critical section with global_compiler_lock: with targetconfig.ConfigStack().enter(flags.copy()): with store_overloads_on_success(): # attempt look up of existing cres = self.cache.load_overload(sig, targetctx) if cres is not None: return cres # Compile args, return_type = sigutils.normalize_signature(sig) cres = compiler.compile_extra(typingctx, targetctx, self.py_func, args=args, return_type=return_type, flags=flags, locals=locals) # cache lookup failed before so safe to save self.cache.save_overload(sig, cres) return cres dispatcher_registry[target_registry['npyufunc']] = UFuncDispatcher # Utility functions def _compile_element_wise_function(nb_func, targetoptions, sig): # Do compilation # Return CompileResult to test cres = nb_func.compile(sig, **targetoptions) args, return_type = sigutils.normalize_signature(sig) return cres, args, return_type def _finalize_ufunc_signature(cres, args, return_type): '''Given a compilation result, argument types, and a return type, build a valid Numba signature after validating that it doesn't violate the constraints for the compilation mode. ''' if return_type is None: if cres.objectmode: # Object mode is used and return type is not specified raise TypeError("return type must be specified for object mode") else: return_type = cres.signature.return_type assert return_type != types.pyobject return return_type(*args) def _build_element_wise_ufunc_wrapper(cres, signature): '''Build a wrapper for the ufunc loop entry point given by the compilation result object, using the element-wise signature. ''' ctx = cres.target_context library = cres.library fname = cres.fndesc.llvm_func_name with global_compiler_lock: info = build_ufunc_wrapper(library, ctx, fname, signature, cres.objectmode, cres) ptr = info.library.get_pointer_to_function(info.name) # Get dtypes dtypenums = [as_dtype(a).num for a in signature.args] dtypenums.append(as_dtype(signature.return_type).num) return dtypenums, ptr, cres.environment _identities = { 0: _internal.PyUFunc_Zero, 1: _internal.PyUFunc_One, None: _internal.PyUFunc_None, "reorderable": _internal.PyUFunc_ReorderableNone, } def parse_identity(identity): """ Parse an identity value and return the corresponding low-level value for Numpy. """ try: identity = _identities[identity] except KeyError: raise ValueError("Invalid identity value %r" % (identity,)) return identity @contextmanager def _suppress_deprecation_warning_nopython_not_supplied(): """This suppresses the NumbaDeprecationWarning that occurs through the use of `jit` without the `nopython` kwarg. This use of `jit` occurs in a few places in the `{g,}ufunc` mechanism in Numba, predominantly to wrap the "kernel" function.""" with warnings.catch_warnings(): warnings.filterwarnings('ignore', category=NumbaDeprecationWarning, message=(".*The 'nopython' keyword argument " "was not supplied*"),) yield # Class definitions class _BaseUFuncBuilder(object): def add(self, sig=None): if hasattr(self, 'targetoptions'): targetoptions = self.targetoptions else: targetoptions = self.nb_func.targetoptions cres, args, return_type = _compile_element_wise_function( self.nb_func, targetoptions, sig) sig = self._finalize_signature(cres, args, return_type) self._sigs.append(sig) self._cres[sig] = cres return cres def disable_compile(self): """ Disable the compilation of new signatures at call time. """ # Override this for implementations that support lazy compilation class UFuncBuilder(_BaseUFuncBuilder): def __init__(self, py_func, identity=None, cache=False, targetoptions={}): if is_jitted(py_func): py_func = py_func.py_func self.py_func = py_func self.identity = parse_identity(identity) with _suppress_deprecation_warning_nopython_not_supplied(): self.nb_func = jit(_target='npyufunc', cache=cache, **targetoptions)(py_func) self._sigs = [] self._cres = {} def _finalize_signature(self, cres, args, return_type): '''Slated for deprecation, use ufuncbuilder._finalize_ufunc_signature() instead. ''' return _finalize_ufunc_signature(cres, args, return_type) def build_ufunc(self): with global_compiler_lock: dtypelist = [] ptrlist = [] if not self.nb_func: raise TypeError("No definition") # Get signature in the order they are added keepalive = [] cres = None for sig in self._sigs: cres = self._cres[sig] dtypenums, ptr, env = self.build(cres, sig) dtypelist.append(dtypenums) ptrlist.append(int(ptr)) keepalive.append((cres.library, env)) datlist = [None] * len(ptrlist) if cres is None: argspec = inspect.getfullargspec(self.py_func) inct = len(argspec.args) else: inct = len(cres.signature.args) outct = 1 # Becareful that fromfunc does not provide full error checking yet. # If typenum is out-of-bound, we have nasty memory corruptions. # For instance, -1 for typenum will cause segfault. # If elements of type-list (2nd arg) is tuple instead, # there will also memory corruption. (Seems like code rewrite.) ufunc = _internal.fromfunc( self.py_func.__name__, self.py_func.__doc__, ptrlist, dtypelist, inct, outct, datlist, keepalive, self.identity, ) return ufunc def build(self, cres, signature): '''Slated for deprecation, use ufuncbuilder._build_element_wise_ufunc_wrapper(). ''' return _build_element_wise_ufunc_wrapper(cres, signature) class GUFuncBuilder(_BaseUFuncBuilder): # TODO handle scalar def __init__(self, py_func, signature, identity=None, cache=False, targetoptions={}, writable_args=()): self.py_func = py_func self.identity = parse_identity(identity) with _suppress_deprecation_warning_nopython_not_supplied(): self.nb_func = jit(_target='npyufunc', cache=cache)(py_func) self.signature = signature self.sin, self.sout = parse_signature(signature) self.targetoptions = targetoptions self.cache = cache self._sigs = [] self._cres = {} transform_arg = _get_transform_arg(py_func) self.writable_args = tuple([transform_arg(a) for a in writable_args]) def _finalize_signature(self, cres, args, return_type): if not cres.objectmode and cres.signature.return_type != types.void: raise TypeError("gufunc kernel must have void return type") if return_type is None: return_type = types.void return return_type(*args) @global_compiler_lock def build_ufunc(self): type_list = [] func_list = [] if not self.nb_func: raise TypeError("No definition") # Get signature in the order they are added keepalive = [] for sig in self._sigs: cres = self._cres[sig] dtypenums, ptr, env = self.build(cres) type_list.append(dtypenums) func_list.append(int(ptr)) keepalive.append((cres.library, env)) datalist = [None] * len(func_list) nin = len(self.sin) nout = len(self.sout) # Pass envs to fromfuncsig to bind to the lifetime of the ufunc object ufunc = _internal.fromfunc( self.py_func.__name__, self.py_func.__doc__, func_list, type_list, nin, nout, datalist, keepalive, self.identity, self.signature, self.writable_args ) return ufunc def build(self, cres): """ Returns (dtype numbers, function ptr, EnvironmentObject) """ # Builder wrapper for ufunc entry point signature = cres.signature info = build_gufunc_wrapper( self.py_func, cres, self.sin, self.sout, cache=self.cache, is_parfors=False, ) env = info.env ptr = info.library.get_pointer_to_function(info.name) # Get dtypes dtypenums = [] for a in signature.args: if isinstance(a, types.Array): ty = a.dtype else: ty = a dtypenums.append(as_dtype(ty).num) return dtypenums, ptr, env def _get_transform_arg(py_func): """Return function that transform arg into index""" args = inspect.getfullargspec(py_func).args pos_by_arg = {arg: i for i, arg in enumerate(args)} def transform_arg(arg): if isinstance(arg, int): return arg try: return pos_by_arg[arg] except KeyError: msg = (f"Specified writable arg {arg} not found in arg list " f"{args} for function {py_func.__qualname__}") raise RuntimeError(msg) return transform_arg