ai-content-maker/.venv/Lib/site-packages/numba/np/ufunc/ufuncbuilder.py

436 lines
14 KiB
Python
Raw Permalink Normal View History

2024-05-03 04:18:51 +03:00
# -*- coding: utf-8 -*-
import inspect
import warnings
from contextlib import contextmanager
from numba.core import config, targetconfig
from numba.core.decorators import jit
from numba.core.descriptors import TargetDescriptor
from numba.core.extending import is_jitted
from numba.core.errors import NumbaDeprecationWarning
from numba.core.options import TargetOptions, include_default_options
from numba.core.registry import cpu_target
from numba.core.target_extension import dispatcher_registry, target_registry
from numba.core import utils, types, serialize, compiler, sigutils
from numba.np.numpy_support import as_dtype
from numba.np.ufunc import _internal
from numba.np.ufunc.sigparse import parse_signature
from numba.np.ufunc.wrappers import build_ufunc_wrapper, build_gufunc_wrapper
from numba.core.caching import FunctionCache, NullCache
from numba.core.compiler_lock import global_compiler_lock
_options_mixin = include_default_options(
"nopython",
"forceobj",
"boundscheck",
"fastmath",
"target_backend",
"writable_args"
)
class UFuncTargetOptions(_options_mixin, TargetOptions):
def finalize(self, flags, options):
if not flags.is_set("enable_pyobject"):
flags.enable_pyobject = True
if not flags.is_set("enable_looplift"):
flags.enable_looplift = True
flags.inherit_if_not_set("nrt", default=True)
if not flags.is_set("debuginfo"):
flags.debuginfo = config.DEBUGINFO_DEFAULT
if not flags.is_set("boundscheck"):
flags.boundscheck = flags.debuginfo
flags.enable_pyobject_looplift = True
flags.inherit_if_not_set("fastmath")
class UFuncTarget(TargetDescriptor):
options = UFuncTargetOptions
def __init__(self):
super().__init__('ufunc')
@property
def typing_context(self):
return cpu_target.typing_context
@property
def target_context(self):
return cpu_target.target_context
ufunc_target = UFuncTarget()
class UFuncDispatcher(serialize.ReduceMixin):
"""
An object handling compilation of various signatures for a ufunc.
"""
targetdescr = ufunc_target
def __init__(self, py_func, locals={}, targetoptions={}):
self.py_func = py_func
self.overloads = utils.UniqueDict()
self.targetoptions = targetoptions
self.locals = locals
self.cache = NullCache()
def _reduce_states(self):
"""
NOTE: part of ReduceMixin protocol
"""
return dict(
pyfunc=self.py_func,
locals=self.locals,
targetoptions=self.targetoptions,
)
@classmethod
def _rebuild(cls, pyfunc, locals, targetoptions):
"""
NOTE: part of ReduceMixin protocol
"""
return cls(py_func=pyfunc, locals=locals, targetoptions=targetoptions)
def enable_caching(self):
self.cache = FunctionCache(self.py_func)
def compile(self, sig, locals={}, **targetoptions):
locs = self.locals.copy()
locs.update(locals)
topt = self.targetoptions.copy()
topt.update(targetoptions)
flags = compiler.Flags()
self.targetdescr.options.parse_as_flags(flags, topt)
flags.no_cpython_wrapper = True
flags.error_model = "numpy"
# Disable loop lifting
# The feature requires a real
# python function
flags.enable_looplift = False
return self._compile_core(sig, flags, locals)
def _compile_core(self, sig, flags, locals):
"""
Trigger the compiler on the core function or load a previously
compiled version from the cache. Returns the CompileResult.
"""
typingctx = self.targetdescr.typing_context
targetctx = self.targetdescr.target_context
@contextmanager
def store_overloads_on_success():
# use to ensure overloads are stored on success
try:
yield
except Exception:
raise
else:
exists = self.overloads.get(cres.signature)
if exists is None:
self.overloads[cres.signature] = cres
# Use cache and compiler in a critical section
with global_compiler_lock:
with targetconfig.ConfigStack().enter(flags.copy()):
with store_overloads_on_success():
# attempt look up of existing
cres = self.cache.load_overload(sig, targetctx)
if cres is not None:
return cres
# Compile
args, return_type = sigutils.normalize_signature(sig)
cres = compiler.compile_extra(typingctx, targetctx,
self.py_func, args=args,
return_type=return_type,
flags=flags, locals=locals)
# cache lookup failed before so safe to save
self.cache.save_overload(sig, cres)
return cres
dispatcher_registry[target_registry['npyufunc']] = UFuncDispatcher
# Utility functions
def _compile_element_wise_function(nb_func, targetoptions, sig):
# Do compilation
# Return CompileResult to test
cres = nb_func.compile(sig, **targetoptions)
args, return_type = sigutils.normalize_signature(sig)
return cres, args, return_type
def _finalize_ufunc_signature(cres, args, return_type):
'''Given a compilation result, argument types, and a return type,
build a valid Numba signature after validating that it doesn't
violate the constraints for the compilation mode.
'''
if return_type is None:
if cres.objectmode:
# Object mode is used and return type is not specified
raise TypeError("return type must be specified for object mode")
else:
return_type = cres.signature.return_type
assert return_type != types.pyobject
return return_type(*args)
def _build_element_wise_ufunc_wrapper(cres, signature):
'''Build a wrapper for the ufunc loop entry point given by the
compilation result object, using the element-wise signature.
'''
ctx = cres.target_context
library = cres.library
fname = cres.fndesc.llvm_func_name
with global_compiler_lock:
info = build_ufunc_wrapper(library, ctx, fname, signature,
cres.objectmode, cres)
ptr = info.library.get_pointer_to_function(info.name)
# Get dtypes
dtypenums = [as_dtype(a).num for a in signature.args]
dtypenums.append(as_dtype(signature.return_type).num)
return dtypenums, ptr, cres.environment
_identities = {
0: _internal.PyUFunc_Zero,
1: _internal.PyUFunc_One,
None: _internal.PyUFunc_None,
"reorderable": _internal.PyUFunc_ReorderableNone,
}
def parse_identity(identity):
"""
Parse an identity value and return the corresponding low-level value
for Numpy.
"""
try:
identity = _identities[identity]
except KeyError:
raise ValueError("Invalid identity value %r" % (identity,))
return identity
@contextmanager
def _suppress_deprecation_warning_nopython_not_supplied():
"""This suppresses the NumbaDeprecationWarning that occurs through the use
of `jit` without the `nopython` kwarg. This use of `jit` occurs in a few
places in the `{g,}ufunc` mechanism in Numba, predominantly to wrap the
"kernel" function."""
with warnings.catch_warnings():
warnings.filterwarnings('ignore',
category=NumbaDeprecationWarning,
message=(".*The 'nopython' keyword argument "
"was not supplied*"),)
yield
# Class definitions
class _BaseUFuncBuilder(object):
def add(self, sig=None):
if hasattr(self, 'targetoptions'):
targetoptions = self.targetoptions
else:
targetoptions = self.nb_func.targetoptions
cres, args, return_type = _compile_element_wise_function(
self.nb_func, targetoptions, sig)
sig = self._finalize_signature(cres, args, return_type)
self._sigs.append(sig)
self._cres[sig] = cres
return cres
def disable_compile(self):
"""
Disable the compilation of new signatures at call time.
"""
# Override this for implementations that support lazy compilation
class UFuncBuilder(_BaseUFuncBuilder):
def __init__(self, py_func, identity=None, cache=False, targetoptions={}):
if is_jitted(py_func):
py_func = py_func.py_func
self.py_func = py_func
self.identity = parse_identity(identity)
with _suppress_deprecation_warning_nopython_not_supplied():
self.nb_func = jit(_target='npyufunc',
cache=cache,
**targetoptions)(py_func)
self._sigs = []
self._cres = {}
def _finalize_signature(self, cres, args, return_type):
'''Slated for deprecation, use ufuncbuilder._finalize_ufunc_signature()
instead.
'''
return _finalize_ufunc_signature(cres, args, return_type)
def build_ufunc(self):
with global_compiler_lock:
dtypelist = []
ptrlist = []
if not self.nb_func:
raise TypeError("No definition")
# Get signature in the order they are added
keepalive = []
cres = None
for sig in self._sigs:
cres = self._cres[sig]
dtypenums, ptr, env = self.build(cres, sig)
dtypelist.append(dtypenums)
ptrlist.append(int(ptr))
keepalive.append((cres.library, env))
datlist = [None] * len(ptrlist)
if cres is None:
argspec = inspect.getfullargspec(self.py_func)
inct = len(argspec.args)
else:
inct = len(cres.signature.args)
outct = 1
# Becareful that fromfunc does not provide full error checking yet.
# If typenum is out-of-bound, we have nasty memory corruptions.
# For instance, -1 for typenum will cause segfault.
# If elements of type-list (2nd arg) is tuple instead,
# there will also memory corruption. (Seems like code rewrite.)
ufunc = _internal.fromfunc(
self.py_func.__name__, self.py_func.__doc__,
ptrlist, dtypelist, inct, outct, datlist,
keepalive, self.identity,
)
return ufunc
def build(self, cres, signature):
'''Slated for deprecation, use
ufuncbuilder._build_element_wise_ufunc_wrapper().
'''
return _build_element_wise_ufunc_wrapper(cres, signature)
class GUFuncBuilder(_BaseUFuncBuilder):
# TODO handle scalar
def __init__(self, py_func, signature, identity=None, cache=False,
targetoptions={}, writable_args=()):
self.py_func = py_func
self.identity = parse_identity(identity)
with _suppress_deprecation_warning_nopython_not_supplied():
self.nb_func = jit(_target='npyufunc', cache=cache)(py_func)
self.signature = signature
self.sin, self.sout = parse_signature(signature)
self.targetoptions = targetoptions
self.cache = cache
self._sigs = []
self._cres = {}
transform_arg = _get_transform_arg(py_func)
self.writable_args = tuple([transform_arg(a) for a in writable_args])
def _finalize_signature(self, cres, args, return_type):
if not cres.objectmode and cres.signature.return_type != types.void:
raise TypeError("gufunc kernel must have void return type")
if return_type is None:
return_type = types.void
return return_type(*args)
@global_compiler_lock
def build_ufunc(self):
type_list = []
func_list = []
if not self.nb_func:
raise TypeError("No definition")
# Get signature in the order they are added
keepalive = []
for sig in self._sigs:
cres = self._cres[sig]
dtypenums, ptr, env = self.build(cres)
type_list.append(dtypenums)
func_list.append(int(ptr))
keepalive.append((cres.library, env))
datalist = [None] * len(func_list)
nin = len(self.sin)
nout = len(self.sout)
# Pass envs to fromfuncsig to bind to the lifetime of the ufunc object
ufunc = _internal.fromfunc(
self.py_func.__name__, self.py_func.__doc__,
func_list, type_list, nin, nout, datalist,
keepalive, self.identity, self.signature, self.writable_args
)
return ufunc
def build(self, cres):
"""
Returns (dtype numbers, function ptr, EnvironmentObject)
"""
# Builder wrapper for ufunc entry point
signature = cres.signature
info = build_gufunc_wrapper(
self.py_func, cres, self.sin, self.sout,
cache=self.cache, is_parfors=False,
)
env = info.env
ptr = info.library.get_pointer_to_function(info.name)
# Get dtypes
dtypenums = []
for a in signature.args:
if isinstance(a, types.Array):
ty = a.dtype
else:
ty = a
dtypenums.append(as_dtype(ty).num)
return dtypenums, ptr, env
def _get_transform_arg(py_func):
"""Return function that transform arg into index"""
args = inspect.getfullargspec(py_func).args
pos_by_arg = {arg: i for i, arg in enumerate(args)}
def transform_arg(arg):
if isinstance(arg, int):
return arg
try:
return pos_by_arg[arg]
except KeyError:
msg = (f"Specified writable arg {arg} not found in arg list "
f"{args} for function {py_func.__qualname__}")
raise RuntimeError(msg)
return transform_arg