195 lines
6.0 KiB
Python
195 lines
6.0 KiB
Python
from numba import typeof
|
|
from numba.core import types
|
|
from numba.np.ufunc.ufuncbuilder import GUFuncBuilder
|
|
from numba.np.ufunc.sigparse import parse_signature
|
|
from numba.np.numpy_support import ufunc_find_matching_loop
|
|
from numba.core import serialize
|
|
import functools
|
|
|
|
|
|
class GUFunc(serialize.ReduceMixin):
|
|
"""
|
|
Dynamic generalized universal function (GUFunc)
|
|
intended to act like a normal Numpy gufunc, but capable
|
|
of call-time (just-in-time) compilation of fast loops
|
|
specialized to inputs.
|
|
"""
|
|
|
|
def __init__(self, py_func, signature, identity=None, cache=None,
|
|
is_dynamic=False, targetoptions={}, writable_args=()):
|
|
self.ufunc = None
|
|
self._frozen = False
|
|
self._is_dynamic = is_dynamic
|
|
self._identity = identity
|
|
|
|
# GUFunc cannot inherit from GUFuncBuilder because "identity"
|
|
# is a property of GUFunc. Thus, we hold a reference to a GUFuncBuilder
|
|
# object here
|
|
self.gufunc_builder = GUFuncBuilder(
|
|
py_func, signature, identity, cache, targetoptions, writable_args)
|
|
self.__name__ = self.gufunc_builder.py_func.__name__
|
|
functools.update_wrapper(self, py_func)
|
|
|
|
def _reduce_states(self):
|
|
gb = self.gufunc_builder
|
|
dct = dict(
|
|
py_func=gb.py_func,
|
|
signature=gb.signature,
|
|
identity=self._identity,
|
|
cache=gb.cache,
|
|
is_dynamic=self._is_dynamic,
|
|
targetoptions=gb.targetoptions,
|
|
writable_args=gb.writable_args,
|
|
typesigs=gb._sigs,
|
|
frozen=self._frozen,
|
|
)
|
|
return dct
|
|
|
|
@classmethod
|
|
def _rebuild(cls, py_func, signature, identity, cache, is_dynamic,
|
|
targetoptions, writable_args, typesigs, frozen):
|
|
self = cls(py_func=py_func, signature=signature, identity=identity,
|
|
cache=cache, is_dynamic=is_dynamic,
|
|
targetoptions=targetoptions, writable_args=writable_args)
|
|
for sig in typesigs:
|
|
self.add(sig)
|
|
self.build_ufunc()
|
|
self._frozen = frozen
|
|
return self
|
|
|
|
def __repr__(self):
|
|
return f"<numba._GUFunc '{self.__name__}'>"
|
|
|
|
def add(self, fty):
|
|
self.gufunc_builder.add(fty)
|
|
|
|
def build_ufunc(self):
|
|
self.ufunc = self.gufunc_builder.build_ufunc()
|
|
return self
|
|
|
|
def disable_compile(self):
|
|
"""
|
|
Disable the compilation of new signatures at call time.
|
|
"""
|
|
# If disabling compilation then there must be at least one signature
|
|
assert len(self.gufunc_builder._sigs) > 0
|
|
self._frozen = True
|
|
|
|
@property
|
|
def is_dynamic(self):
|
|
return self._is_dynamic
|
|
|
|
@property
|
|
def nin(self):
|
|
return self.ufunc.nin
|
|
|
|
@property
|
|
def nout(self):
|
|
return self.ufunc.nout
|
|
|
|
@property
|
|
def nargs(self):
|
|
return self.ufunc.nargs
|
|
|
|
@property
|
|
def ntypes(self):
|
|
return self.ufunc.ntypes
|
|
|
|
@property
|
|
def types(self):
|
|
return self.ufunc.types
|
|
|
|
@property
|
|
def identity(self):
|
|
return self.ufunc.identity
|
|
|
|
@property
|
|
def signature(self):
|
|
return self.ufunc.signature
|
|
|
|
@property
|
|
def accumulate(self):
|
|
return self.ufunc.accumulate
|
|
|
|
@property
|
|
def at(self):
|
|
return self.ufunc.at
|
|
|
|
@property
|
|
def outer(self):
|
|
return self.ufunc.outer
|
|
|
|
@property
|
|
def reduce(self):
|
|
return self.ufunc.reduce
|
|
|
|
@property
|
|
def reduceat(self):
|
|
return self.ufunc.reduceat
|
|
|
|
def _get_ewise_dtypes(self, args):
|
|
argtys = map(lambda x: typeof(x), args)
|
|
tys = []
|
|
for argty in argtys:
|
|
if isinstance(argty, types.Array):
|
|
tys.append(argty.dtype)
|
|
else:
|
|
tys.append(argty)
|
|
return tys
|
|
|
|
def _num_args_match(self, *args):
|
|
parsed_sig = parse_signature(self.gufunc_builder.signature)
|
|
return len(args) == len(parsed_sig[0]) + len(parsed_sig[1])
|
|
|
|
def _get_signature(self, *args):
|
|
parsed_sig = parse_signature(self.gufunc_builder.signature)
|
|
# ewise_types is a list of [int32, int32, int32, ...]
|
|
ewise_types = self._get_ewise_dtypes(args)
|
|
|
|
# first time calling the gufunc
|
|
# generate a signature based on input arguments
|
|
l = []
|
|
for idx, sig_dim in enumerate(parsed_sig[0]):
|
|
ndim = len(sig_dim)
|
|
if ndim == 0: # append scalar
|
|
l.append(ewise_types[idx])
|
|
else:
|
|
l.append(types.Array(ewise_types[idx], ndim, 'A'))
|
|
|
|
offset = len(parsed_sig[0])
|
|
# add return type to signature
|
|
for idx, sig_dim in enumerate(parsed_sig[1]):
|
|
retty = ewise_types[idx + offset]
|
|
ret_ndim = len(sig_dim) or 1 # small hack to return scalars
|
|
l.append(types.Array(retty, ret_ndim, 'A'))
|
|
|
|
return types.none(*l)
|
|
|
|
def __call__(self, *args, **kwargs):
|
|
# If compilation is disabled OR it is NOT a dynamic gufunc
|
|
# call the underlying gufunc
|
|
if self._frozen or not self.is_dynamic:
|
|
return self.ufunc(*args, **kwargs)
|
|
elif "out" in kwargs:
|
|
# If "out" argument is supplied
|
|
args += (kwargs.pop("out"),)
|
|
|
|
if self._num_args_match(*args) is False:
|
|
# It is not allowed to call a dynamic gufunc without
|
|
# providing all the arguments
|
|
# see: https://github.com/numba/numba/pull/5938#discussion_r506429392 # noqa: E501
|
|
msg = (
|
|
f"Too few arguments for function '{self.__name__}'. "
|
|
"Note that the pattern `out = gufunc(Arg1, Arg2, ..., ArgN)` "
|
|
"is not allowed. Use `gufunc(Arg1, Arg2, ..., ArgN, out) "
|
|
"instead.")
|
|
raise TypeError(msg)
|
|
|
|
# at this point we know the gufunc is a dynamic one
|
|
ewise = self._get_ewise_dtypes(args)
|
|
if not (self.ufunc and ufunc_find_matching_loop(self.ufunc, ewise)):
|
|
sig = self._get_signature(*args)
|
|
self.add(sig)
|
|
self.build_ufunc()
|
|
return self.ufunc(*args, **kwargs)
|