436 lines
14 KiB
Python
436 lines
14 KiB
Python
# -*- coding: utf-8 -*-
|
|
|
|
import inspect
|
|
import warnings
|
|
from contextlib import contextmanager
|
|
|
|
from numba.core import config, targetconfig
|
|
from numba.core.decorators import jit
|
|
from numba.core.descriptors import TargetDescriptor
|
|
from numba.core.extending import is_jitted
|
|
from numba.core.errors import NumbaDeprecationWarning
|
|
from numba.core.options import TargetOptions, include_default_options
|
|
from numba.core.registry import cpu_target
|
|
from numba.core.target_extension import dispatcher_registry, target_registry
|
|
from numba.core import utils, types, serialize, compiler, sigutils
|
|
from numba.np.numpy_support import as_dtype
|
|
from numba.np.ufunc import _internal
|
|
from numba.np.ufunc.sigparse import parse_signature
|
|
from numba.np.ufunc.wrappers import build_ufunc_wrapper, build_gufunc_wrapper
|
|
from numba.core.caching import FunctionCache, NullCache
|
|
from numba.core.compiler_lock import global_compiler_lock
|
|
|
|
|
|
_options_mixin = include_default_options(
|
|
"nopython",
|
|
"forceobj",
|
|
"boundscheck",
|
|
"fastmath",
|
|
"target_backend",
|
|
"writable_args"
|
|
)
|
|
|
|
|
|
class UFuncTargetOptions(_options_mixin, TargetOptions):
|
|
|
|
def finalize(self, flags, options):
|
|
if not flags.is_set("enable_pyobject"):
|
|
flags.enable_pyobject = True
|
|
|
|
if not flags.is_set("enable_looplift"):
|
|
flags.enable_looplift = True
|
|
|
|
flags.inherit_if_not_set("nrt", default=True)
|
|
|
|
if not flags.is_set("debuginfo"):
|
|
flags.debuginfo = config.DEBUGINFO_DEFAULT
|
|
|
|
if not flags.is_set("boundscheck"):
|
|
flags.boundscheck = flags.debuginfo
|
|
|
|
flags.enable_pyobject_looplift = True
|
|
|
|
flags.inherit_if_not_set("fastmath")
|
|
|
|
|
|
class UFuncTarget(TargetDescriptor):
|
|
options = UFuncTargetOptions
|
|
|
|
def __init__(self):
|
|
super().__init__('ufunc')
|
|
|
|
@property
|
|
def typing_context(self):
|
|
return cpu_target.typing_context
|
|
|
|
@property
|
|
def target_context(self):
|
|
return cpu_target.target_context
|
|
|
|
|
|
ufunc_target = UFuncTarget()
|
|
|
|
|
|
class UFuncDispatcher(serialize.ReduceMixin):
|
|
"""
|
|
An object handling compilation of various signatures for a ufunc.
|
|
"""
|
|
targetdescr = ufunc_target
|
|
|
|
def __init__(self, py_func, locals={}, targetoptions={}):
|
|
self.py_func = py_func
|
|
self.overloads = utils.UniqueDict()
|
|
self.targetoptions = targetoptions
|
|
self.locals = locals
|
|
self.cache = NullCache()
|
|
|
|
def _reduce_states(self):
|
|
"""
|
|
NOTE: part of ReduceMixin protocol
|
|
"""
|
|
return dict(
|
|
pyfunc=self.py_func,
|
|
locals=self.locals,
|
|
targetoptions=self.targetoptions,
|
|
)
|
|
|
|
@classmethod
|
|
def _rebuild(cls, pyfunc, locals, targetoptions):
|
|
"""
|
|
NOTE: part of ReduceMixin protocol
|
|
"""
|
|
return cls(py_func=pyfunc, locals=locals, targetoptions=targetoptions)
|
|
|
|
def enable_caching(self):
|
|
self.cache = FunctionCache(self.py_func)
|
|
|
|
def compile(self, sig, locals={}, **targetoptions):
|
|
locs = self.locals.copy()
|
|
locs.update(locals)
|
|
|
|
topt = self.targetoptions.copy()
|
|
topt.update(targetoptions)
|
|
|
|
flags = compiler.Flags()
|
|
self.targetdescr.options.parse_as_flags(flags, topt)
|
|
|
|
flags.no_cpython_wrapper = True
|
|
flags.error_model = "numpy"
|
|
# Disable loop lifting
|
|
# The feature requires a real
|
|
# python function
|
|
flags.enable_looplift = False
|
|
|
|
return self._compile_core(sig, flags, locals)
|
|
|
|
def _compile_core(self, sig, flags, locals):
|
|
"""
|
|
Trigger the compiler on the core function or load a previously
|
|
compiled version from the cache. Returns the CompileResult.
|
|
"""
|
|
typingctx = self.targetdescr.typing_context
|
|
targetctx = self.targetdescr.target_context
|
|
|
|
@contextmanager
|
|
def store_overloads_on_success():
|
|
# use to ensure overloads are stored on success
|
|
try:
|
|
yield
|
|
except Exception:
|
|
raise
|
|
else:
|
|
exists = self.overloads.get(cres.signature)
|
|
if exists is None:
|
|
self.overloads[cres.signature] = cres
|
|
|
|
# Use cache and compiler in a critical section
|
|
with global_compiler_lock:
|
|
with targetconfig.ConfigStack().enter(flags.copy()):
|
|
with store_overloads_on_success():
|
|
# attempt look up of existing
|
|
cres = self.cache.load_overload(sig, targetctx)
|
|
if cres is not None:
|
|
return cres
|
|
|
|
# Compile
|
|
args, return_type = sigutils.normalize_signature(sig)
|
|
cres = compiler.compile_extra(typingctx, targetctx,
|
|
self.py_func, args=args,
|
|
return_type=return_type,
|
|
flags=flags, locals=locals)
|
|
|
|
# cache lookup failed before so safe to save
|
|
self.cache.save_overload(sig, cres)
|
|
|
|
return cres
|
|
|
|
|
|
dispatcher_registry[target_registry['npyufunc']] = UFuncDispatcher
|
|
|
|
|
|
# Utility functions
|
|
|
|
def _compile_element_wise_function(nb_func, targetoptions, sig):
|
|
# Do compilation
|
|
# Return CompileResult to test
|
|
cres = nb_func.compile(sig, **targetoptions)
|
|
args, return_type = sigutils.normalize_signature(sig)
|
|
return cres, args, return_type
|
|
|
|
|
|
def _finalize_ufunc_signature(cres, args, return_type):
|
|
'''Given a compilation result, argument types, and a return type,
|
|
build a valid Numba signature after validating that it doesn't
|
|
violate the constraints for the compilation mode.
|
|
'''
|
|
if return_type is None:
|
|
if cres.objectmode:
|
|
# Object mode is used and return type is not specified
|
|
raise TypeError("return type must be specified for object mode")
|
|
else:
|
|
return_type = cres.signature.return_type
|
|
|
|
assert return_type != types.pyobject
|
|
return return_type(*args)
|
|
|
|
|
|
def _build_element_wise_ufunc_wrapper(cres, signature):
|
|
'''Build a wrapper for the ufunc loop entry point given by the
|
|
compilation result object, using the element-wise signature.
|
|
'''
|
|
ctx = cres.target_context
|
|
library = cres.library
|
|
fname = cres.fndesc.llvm_func_name
|
|
|
|
with global_compiler_lock:
|
|
info = build_ufunc_wrapper(library, ctx, fname, signature,
|
|
cres.objectmode, cres)
|
|
ptr = info.library.get_pointer_to_function(info.name)
|
|
# Get dtypes
|
|
dtypenums = [as_dtype(a).num for a in signature.args]
|
|
dtypenums.append(as_dtype(signature.return_type).num)
|
|
return dtypenums, ptr, cres.environment
|
|
|
|
|
|
_identities = {
|
|
0: _internal.PyUFunc_Zero,
|
|
1: _internal.PyUFunc_One,
|
|
None: _internal.PyUFunc_None,
|
|
"reorderable": _internal.PyUFunc_ReorderableNone,
|
|
}
|
|
|
|
|
|
def parse_identity(identity):
|
|
"""
|
|
Parse an identity value and return the corresponding low-level value
|
|
for Numpy.
|
|
"""
|
|
try:
|
|
identity = _identities[identity]
|
|
except KeyError:
|
|
raise ValueError("Invalid identity value %r" % (identity,))
|
|
return identity
|
|
|
|
|
|
@contextmanager
|
|
def _suppress_deprecation_warning_nopython_not_supplied():
|
|
"""This suppresses the NumbaDeprecationWarning that occurs through the use
|
|
of `jit` without the `nopython` kwarg. This use of `jit` occurs in a few
|
|
places in the `{g,}ufunc` mechanism in Numba, predominantly to wrap the
|
|
"kernel" function."""
|
|
with warnings.catch_warnings():
|
|
warnings.filterwarnings('ignore',
|
|
category=NumbaDeprecationWarning,
|
|
message=(".*The 'nopython' keyword argument "
|
|
"was not supplied*"),)
|
|
yield
|
|
|
|
|
|
# Class definitions
|
|
|
|
class _BaseUFuncBuilder(object):
|
|
|
|
def add(self, sig=None):
|
|
if hasattr(self, 'targetoptions'):
|
|
targetoptions = self.targetoptions
|
|
else:
|
|
targetoptions = self.nb_func.targetoptions
|
|
cres, args, return_type = _compile_element_wise_function(
|
|
self.nb_func, targetoptions, sig)
|
|
sig = self._finalize_signature(cres, args, return_type)
|
|
self._sigs.append(sig)
|
|
self._cres[sig] = cres
|
|
return cres
|
|
|
|
def disable_compile(self):
|
|
"""
|
|
Disable the compilation of new signatures at call time.
|
|
"""
|
|
# Override this for implementations that support lazy compilation
|
|
|
|
|
|
class UFuncBuilder(_BaseUFuncBuilder):
|
|
|
|
def __init__(self, py_func, identity=None, cache=False, targetoptions={}):
|
|
if is_jitted(py_func):
|
|
py_func = py_func.py_func
|
|
self.py_func = py_func
|
|
self.identity = parse_identity(identity)
|
|
with _suppress_deprecation_warning_nopython_not_supplied():
|
|
self.nb_func = jit(_target='npyufunc',
|
|
cache=cache,
|
|
**targetoptions)(py_func)
|
|
self._sigs = []
|
|
self._cres = {}
|
|
|
|
def _finalize_signature(self, cres, args, return_type):
|
|
'''Slated for deprecation, use ufuncbuilder._finalize_ufunc_signature()
|
|
instead.
|
|
'''
|
|
return _finalize_ufunc_signature(cres, args, return_type)
|
|
|
|
def build_ufunc(self):
|
|
with global_compiler_lock:
|
|
dtypelist = []
|
|
ptrlist = []
|
|
if not self.nb_func:
|
|
raise TypeError("No definition")
|
|
|
|
# Get signature in the order they are added
|
|
keepalive = []
|
|
cres = None
|
|
for sig in self._sigs:
|
|
cres = self._cres[sig]
|
|
dtypenums, ptr, env = self.build(cres, sig)
|
|
dtypelist.append(dtypenums)
|
|
ptrlist.append(int(ptr))
|
|
keepalive.append((cres.library, env))
|
|
|
|
datlist = [None] * len(ptrlist)
|
|
|
|
if cres is None:
|
|
argspec = inspect.getfullargspec(self.py_func)
|
|
inct = len(argspec.args)
|
|
else:
|
|
inct = len(cres.signature.args)
|
|
outct = 1
|
|
|
|
# Becareful that fromfunc does not provide full error checking yet.
|
|
# If typenum is out-of-bound, we have nasty memory corruptions.
|
|
# For instance, -1 for typenum will cause segfault.
|
|
# If elements of type-list (2nd arg) is tuple instead,
|
|
# there will also memory corruption. (Seems like code rewrite.)
|
|
ufunc = _internal.fromfunc(
|
|
self.py_func.__name__, self.py_func.__doc__,
|
|
ptrlist, dtypelist, inct, outct, datlist,
|
|
keepalive, self.identity,
|
|
)
|
|
|
|
return ufunc
|
|
|
|
def build(self, cres, signature):
|
|
'''Slated for deprecation, use
|
|
ufuncbuilder._build_element_wise_ufunc_wrapper().
|
|
'''
|
|
return _build_element_wise_ufunc_wrapper(cres, signature)
|
|
|
|
|
|
class GUFuncBuilder(_BaseUFuncBuilder):
|
|
|
|
# TODO handle scalar
|
|
def __init__(self, py_func, signature, identity=None, cache=False,
|
|
targetoptions={}, writable_args=()):
|
|
self.py_func = py_func
|
|
self.identity = parse_identity(identity)
|
|
with _suppress_deprecation_warning_nopython_not_supplied():
|
|
self.nb_func = jit(_target='npyufunc', cache=cache)(py_func)
|
|
self.signature = signature
|
|
self.sin, self.sout = parse_signature(signature)
|
|
self.targetoptions = targetoptions
|
|
self.cache = cache
|
|
self._sigs = []
|
|
self._cres = {}
|
|
|
|
transform_arg = _get_transform_arg(py_func)
|
|
self.writable_args = tuple([transform_arg(a) for a in writable_args])
|
|
|
|
def _finalize_signature(self, cres, args, return_type):
|
|
if not cres.objectmode and cres.signature.return_type != types.void:
|
|
raise TypeError("gufunc kernel must have void return type")
|
|
|
|
if return_type is None:
|
|
return_type = types.void
|
|
|
|
return return_type(*args)
|
|
|
|
@global_compiler_lock
|
|
def build_ufunc(self):
|
|
type_list = []
|
|
func_list = []
|
|
if not self.nb_func:
|
|
raise TypeError("No definition")
|
|
|
|
# Get signature in the order they are added
|
|
keepalive = []
|
|
for sig in self._sigs:
|
|
cres = self._cres[sig]
|
|
dtypenums, ptr, env = self.build(cres)
|
|
type_list.append(dtypenums)
|
|
func_list.append(int(ptr))
|
|
keepalive.append((cres.library, env))
|
|
|
|
datalist = [None] * len(func_list)
|
|
|
|
nin = len(self.sin)
|
|
nout = len(self.sout)
|
|
|
|
# Pass envs to fromfuncsig to bind to the lifetime of the ufunc object
|
|
ufunc = _internal.fromfunc(
|
|
self.py_func.__name__, self.py_func.__doc__,
|
|
func_list, type_list, nin, nout, datalist,
|
|
keepalive, self.identity, self.signature, self.writable_args
|
|
)
|
|
return ufunc
|
|
|
|
def build(self, cres):
|
|
"""
|
|
Returns (dtype numbers, function ptr, EnvironmentObject)
|
|
"""
|
|
# Builder wrapper for ufunc entry point
|
|
signature = cres.signature
|
|
info = build_gufunc_wrapper(
|
|
self.py_func, cres, self.sin, self.sout,
|
|
cache=self.cache, is_parfors=False,
|
|
)
|
|
|
|
env = info.env
|
|
ptr = info.library.get_pointer_to_function(info.name)
|
|
# Get dtypes
|
|
dtypenums = []
|
|
for a in signature.args:
|
|
if isinstance(a, types.Array):
|
|
ty = a.dtype
|
|
else:
|
|
ty = a
|
|
dtypenums.append(as_dtype(ty).num)
|
|
return dtypenums, ptr, env
|
|
|
|
|
|
def _get_transform_arg(py_func):
|
|
"""Return function that transform arg into index"""
|
|
args = inspect.getfullargspec(py_func).args
|
|
pos_by_arg = {arg: i for i, arg in enumerate(args)}
|
|
|
|
def transform_arg(arg):
|
|
if isinstance(arg, int):
|
|
return arg
|
|
|
|
try:
|
|
return pos_by_arg[arg]
|
|
except KeyError:
|
|
msg = (f"Specified writable arg {arg} not found in arg list "
|
|
f"{args} for function {py_func.__qualname__}")
|
|
raise RuntimeError(msg)
|
|
|
|
return transform_arg
|