ai-content-maker/.venv/Lib/site-packages/numba/np/ufunc/gufunc.py

195 lines
6.0 KiB
Python
Raw Normal View History

2024-05-03 04:18:51 +03:00
from numba import typeof
from numba.core import types
from numba.np.ufunc.ufuncbuilder import GUFuncBuilder
from numba.np.ufunc.sigparse import parse_signature
from numba.np.numpy_support import ufunc_find_matching_loop
from numba.core import serialize
import functools
class GUFunc(serialize.ReduceMixin):
"""
Dynamic generalized universal function (GUFunc)
intended to act like a normal Numpy gufunc, but capable
of call-time (just-in-time) compilation of fast loops
specialized to inputs.
"""
def __init__(self, py_func, signature, identity=None, cache=None,
is_dynamic=False, targetoptions={}, writable_args=()):
self.ufunc = None
self._frozen = False
self._is_dynamic = is_dynamic
self._identity = identity
# GUFunc cannot inherit from GUFuncBuilder because "identity"
# is a property of GUFunc. Thus, we hold a reference to a GUFuncBuilder
# object here
self.gufunc_builder = GUFuncBuilder(
py_func, signature, identity, cache, targetoptions, writable_args)
self.__name__ = self.gufunc_builder.py_func.__name__
functools.update_wrapper(self, py_func)
def _reduce_states(self):
gb = self.gufunc_builder
dct = dict(
py_func=gb.py_func,
signature=gb.signature,
identity=self._identity,
cache=gb.cache,
is_dynamic=self._is_dynamic,
targetoptions=gb.targetoptions,
writable_args=gb.writable_args,
typesigs=gb._sigs,
frozen=self._frozen,
)
return dct
@classmethod
def _rebuild(cls, py_func, signature, identity, cache, is_dynamic,
targetoptions, writable_args, typesigs, frozen):
self = cls(py_func=py_func, signature=signature, identity=identity,
cache=cache, is_dynamic=is_dynamic,
targetoptions=targetoptions, writable_args=writable_args)
for sig in typesigs:
self.add(sig)
self.build_ufunc()
self._frozen = frozen
return self
def __repr__(self):
return f"<numba._GUFunc '{self.__name__}'>"
def add(self, fty):
self.gufunc_builder.add(fty)
def build_ufunc(self):
self.ufunc = self.gufunc_builder.build_ufunc()
return self
def disable_compile(self):
"""
Disable the compilation of new signatures at call time.
"""
# If disabling compilation then there must be at least one signature
assert len(self.gufunc_builder._sigs) > 0
self._frozen = True
@property
def is_dynamic(self):
return self._is_dynamic
@property
def nin(self):
return self.ufunc.nin
@property
def nout(self):
return self.ufunc.nout
@property
def nargs(self):
return self.ufunc.nargs
@property
def ntypes(self):
return self.ufunc.ntypes
@property
def types(self):
return self.ufunc.types
@property
def identity(self):
return self.ufunc.identity
@property
def signature(self):
return self.ufunc.signature
@property
def accumulate(self):
return self.ufunc.accumulate
@property
def at(self):
return self.ufunc.at
@property
def outer(self):
return self.ufunc.outer
@property
def reduce(self):
return self.ufunc.reduce
@property
def reduceat(self):
return self.ufunc.reduceat
def _get_ewise_dtypes(self, args):
argtys = map(lambda x: typeof(x), args)
tys = []
for argty in argtys:
if isinstance(argty, types.Array):
tys.append(argty.dtype)
else:
tys.append(argty)
return tys
def _num_args_match(self, *args):
parsed_sig = parse_signature(self.gufunc_builder.signature)
return len(args) == len(parsed_sig[0]) + len(parsed_sig[1])
def _get_signature(self, *args):
parsed_sig = parse_signature(self.gufunc_builder.signature)
# ewise_types is a list of [int32, int32, int32, ...]
ewise_types = self._get_ewise_dtypes(args)
# first time calling the gufunc
# generate a signature based on input arguments
l = []
for idx, sig_dim in enumerate(parsed_sig[0]):
ndim = len(sig_dim)
if ndim == 0: # append scalar
l.append(ewise_types[idx])
else:
l.append(types.Array(ewise_types[idx], ndim, 'A'))
offset = len(parsed_sig[0])
# add return type to signature
for idx, sig_dim in enumerate(parsed_sig[1]):
retty = ewise_types[idx + offset]
ret_ndim = len(sig_dim) or 1 # small hack to return scalars
l.append(types.Array(retty, ret_ndim, 'A'))
return types.none(*l)
def __call__(self, *args, **kwargs):
# If compilation is disabled OR it is NOT a dynamic gufunc
# call the underlying gufunc
if self._frozen or not self.is_dynamic:
return self.ufunc(*args, **kwargs)
elif "out" in kwargs:
# If "out" argument is supplied
args += (kwargs.pop("out"),)
if self._num_args_match(*args) is False:
# It is not allowed to call a dynamic gufunc without
# providing all the arguments
# see: https://github.com/numba/numba/pull/5938#discussion_r506429392 # noqa: E501
msg = (
f"Too few arguments for function '{self.__name__}'. "
"Note that the pattern `out = gufunc(Arg1, Arg2, ..., ArgN)` "
"is not allowed. Use `gufunc(Arg1, Arg2, ..., ArgN, out) "
"instead.")
raise TypeError(msg)
# at this point we know the gufunc is a dynamic one
ewise = self._get_ewise_dtypes(args)
if not (self.ufunc and ufunc_find_matching_loop(self.ufunc, ewise)):
sig = self._get_signature(*args)
self.add(sig)
self.build_ufunc()
return self.ufunc(*args, **kwargs)