666 lines
26 KiB
Python
666 lines
26 KiB
Python
|
import functools
|
||
|
import warnings
|
||
|
|
||
|
import numpy as np
|
||
|
|
||
|
from numba import jit, typeof
|
||
|
from numba.core import cgutils, types, serialize, sigutils, errors
|
||
|
from numba.core.extending import (is_jitted, overload_attribute,
|
||
|
overload_method, register_jitable,
|
||
|
intrinsic)
|
||
|
from numba.core.typing import npydecl
|
||
|
from numba.core.typing.templates import AbstractTemplate, signature
|
||
|
from numba.cpython.unsafe.tuple import tuple_setitem
|
||
|
from numba.np.ufunc import _internal
|
||
|
from numba.parfors import array_analysis
|
||
|
from numba.np.ufunc import ufuncbuilder
|
||
|
from numba.np import numpy_support
|
||
|
from typing import Callable
|
||
|
from llvmlite import ir
|
||
|
|
||
|
|
||
|
def make_dufunc_kernel(_dufunc):
|
||
|
from numba.np import npyimpl
|
||
|
|
||
|
class DUFuncKernel(npyimpl._Kernel):
|
||
|
"""
|
||
|
npyimpl._Kernel subclass responsible for lowering a DUFunc kernel
|
||
|
(element-wise function) inside a broadcast loop (which is
|
||
|
generated by npyimpl.numpy_ufunc_kernel()).
|
||
|
"""
|
||
|
dufunc = _dufunc
|
||
|
|
||
|
def __init__(self, context, builder, outer_sig):
|
||
|
super(DUFuncKernel, self).__init__(context, builder, outer_sig)
|
||
|
self.inner_sig, self.cres = self.dufunc.find_ewise_function(
|
||
|
outer_sig.args)
|
||
|
|
||
|
def generate(self, *args):
|
||
|
isig = self.inner_sig
|
||
|
osig = self.outer_sig
|
||
|
cast_args = [self.cast(val, inty, outty)
|
||
|
for val, inty, outty in
|
||
|
zip(args, osig.args, isig.args)]
|
||
|
if self.cres.objectmode:
|
||
|
func_type = self.context.call_conv.get_function_type(
|
||
|
types.pyobject, [types.pyobject] * len(isig.args))
|
||
|
else:
|
||
|
func_type = self.context.call_conv.get_function_type(
|
||
|
isig.return_type, isig.args)
|
||
|
module = self.builder.block.function.module
|
||
|
entry_point = cgutils.get_or_insert_function(
|
||
|
module, func_type,
|
||
|
self.cres.fndesc.llvm_func_name)
|
||
|
entry_point.attributes.add("alwaysinline")
|
||
|
|
||
|
_, res = self.context.call_conv.call_function(
|
||
|
self.builder, entry_point, isig.return_type, isig.args,
|
||
|
cast_args)
|
||
|
return self.cast(res, isig.return_type, osig.return_type)
|
||
|
|
||
|
DUFuncKernel.__name__ += _dufunc.ufunc.__name__
|
||
|
return DUFuncKernel
|
||
|
|
||
|
|
||
|
class DUFuncLowerer(object):
|
||
|
'''Callable class responsible for lowering calls to a specific DUFunc.
|
||
|
'''
|
||
|
def __init__(self, dufunc):
|
||
|
self.kernel = make_dufunc_kernel(dufunc)
|
||
|
self.libs = []
|
||
|
|
||
|
def __call__(self, context, builder, sig, args):
|
||
|
from numba.np import npyimpl
|
||
|
return npyimpl.numpy_ufunc_kernel(context, builder, sig, args,
|
||
|
self.kernel.dufunc.ufunc,
|
||
|
self.kernel)
|
||
|
|
||
|
|
||
|
class DUFunc(serialize.ReduceMixin, _internal._DUFunc):
|
||
|
"""
|
||
|
Dynamic universal function (DUFunc) intended to act like a normal
|
||
|
Numpy ufunc, but capable of call-time (just-in-time) compilation
|
||
|
of fast loops specialized to inputs.
|
||
|
"""
|
||
|
# NOTE: __base_kwargs must be kept in synch with the kwlist in
|
||
|
# _internal.c:dufunc_init()
|
||
|
__base_kwargs = set(('identity', '_keepalive', 'nin', 'nout'))
|
||
|
|
||
|
def __init__(self, py_func, identity=None, cache=False, targetoptions={}):
|
||
|
if is_jitted(py_func):
|
||
|
py_func = py_func.py_func
|
||
|
with ufuncbuilder._suppress_deprecation_warning_nopython_not_supplied():
|
||
|
dispatcher = jit(_target='npyufunc',
|
||
|
cache=cache,
|
||
|
**targetoptions)(py_func)
|
||
|
self._initialize(dispatcher, identity)
|
||
|
functools.update_wrapper(self, py_func)
|
||
|
|
||
|
def _initialize(self, dispatcher, identity):
|
||
|
identity = ufuncbuilder.parse_identity(identity)
|
||
|
super(DUFunc, self).__init__(dispatcher, identity=identity)
|
||
|
# Loop over a copy of the keys instead of the keys themselves,
|
||
|
# since we're changing the dictionary while looping.
|
||
|
self._install_type()
|
||
|
self._lower_me = DUFuncLowerer(self)
|
||
|
self._install_cg()
|
||
|
self.__name__ = dispatcher.py_func.__name__
|
||
|
self.__doc__ = dispatcher.py_func.__doc__
|
||
|
|
||
|
def _reduce_states(self):
|
||
|
"""
|
||
|
NOTE: part of ReduceMixin protocol
|
||
|
"""
|
||
|
siglist = list(self._dispatcher.overloads.keys())
|
||
|
return dict(
|
||
|
dispatcher=self._dispatcher,
|
||
|
identity=self.identity,
|
||
|
frozen=self._frozen,
|
||
|
siglist=siglist,
|
||
|
)
|
||
|
|
||
|
@classmethod
|
||
|
def _rebuild(cls, dispatcher, identity, frozen, siglist):
|
||
|
"""
|
||
|
NOTE: part of ReduceMixin protocol
|
||
|
"""
|
||
|
self = _internal._DUFunc.__new__(cls)
|
||
|
self._initialize(dispatcher, identity)
|
||
|
# Re-add signatures
|
||
|
for sig in siglist:
|
||
|
self.add(sig)
|
||
|
if frozen:
|
||
|
self.disable_compile()
|
||
|
return self
|
||
|
|
||
|
def build_ufunc(self):
|
||
|
"""
|
||
|
For compatibility with the various *UFuncBuilder classes.
|
||
|
"""
|
||
|
return self
|
||
|
|
||
|
@property
|
||
|
def targetoptions(self):
|
||
|
return self._dispatcher.targetoptions
|
||
|
|
||
|
@property
|
||
|
def nin(self):
|
||
|
return self.ufunc.nin
|
||
|
|
||
|
@property
|
||
|
def nout(self):
|
||
|
return self.ufunc.nout
|
||
|
|
||
|
@property
|
||
|
def nargs(self):
|
||
|
return self.ufunc.nargs
|
||
|
|
||
|
@property
|
||
|
def ntypes(self):
|
||
|
return self.ufunc.ntypes
|
||
|
|
||
|
@property
|
||
|
def types(self):
|
||
|
return self.ufunc.types
|
||
|
|
||
|
@property
|
||
|
def identity(self):
|
||
|
return self.ufunc.identity
|
||
|
|
||
|
@property
|
||
|
def signature(self):
|
||
|
return self.ufunc.signature
|
||
|
|
||
|
def disable_compile(self):
|
||
|
"""
|
||
|
Disable the compilation of new signatures at call time.
|
||
|
"""
|
||
|
# If disabling compilation then there must be at least one signature
|
||
|
assert len(self._dispatcher.overloads) > 0
|
||
|
self._frozen = True
|
||
|
|
||
|
def add(self, sig):
|
||
|
"""
|
||
|
Compile the DUFunc for the given signature.
|
||
|
"""
|
||
|
args, return_type = sigutils.normalize_signature(sig)
|
||
|
return self._compile_for_argtys(args, return_type)
|
||
|
|
||
|
def __call__(self, *args, **kws):
|
||
|
"""
|
||
|
Allow any argument that has overridden __array_ufunc__ (NEP-18)
|
||
|
to take control of DUFunc.__call__.
|
||
|
"""
|
||
|
default = numpy_support.np.ndarray.__array_ufunc__
|
||
|
|
||
|
for arg in args + tuple(kws.values()):
|
||
|
if getattr(type(arg), "__array_ufunc__", default) is not default:
|
||
|
output = arg.__array_ufunc__(self, "__call__", *args, **kws)
|
||
|
if output is not NotImplemented:
|
||
|
return output
|
||
|
else:
|
||
|
return super().__call__(*args, **kws)
|
||
|
|
||
|
def _compile_for_args(self, *args, **kws):
|
||
|
nin = self.ufunc.nin
|
||
|
if kws:
|
||
|
if 'out' in kws:
|
||
|
out = kws.pop('out')
|
||
|
args += (out,)
|
||
|
if kws:
|
||
|
raise TypeError("unexpected keyword arguments to ufunc: %s"
|
||
|
% ", ".join(repr(k) for k in sorted(kws)))
|
||
|
|
||
|
args_len = len(args)
|
||
|
assert (args_len == nin) or (args_len == nin + self.ufunc.nout)
|
||
|
assert not kws
|
||
|
argtys = []
|
||
|
for arg in args[:nin]:
|
||
|
argty = typeof(arg)
|
||
|
if isinstance(argty, types.Array):
|
||
|
argty = argty.dtype
|
||
|
else:
|
||
|
# To avoid a mismatch in how Numba types scalar values as
|
||
|
# opposed to Numpy, we need special logic for scalars.
|
||
|
# For example, on 64-bit systems, numba.typeof(3) => int32, but
|
||
|
# np.array(3).dtype => int64.
|
||
|
|
||
|
# Note: this will not handle numpy "duckarrays" correctly,
|
||
|
# including but not limited to those implementing `__array__`
|
||
|
# and `__array_ufunc__`.
|
||
|
argty = numpy_support.map_arrayscalar_type(arg)
|
||
|
argtys.append(argty)
|
||
|
return self._compile_for_argtys(tuple(argtys))
|
||
|
|
||
|
def _compile_for_argtys(self, argtys, return_type=None):
|
||
|
"""
|
||
|
Given a tuple of argument types (these should be the array
|
||
|
dtypes, and not the array types themselves), compile the
|
||
|
element-wise function for those inputs, generate a UFunc loop
|
||
|
wrapper, and register the loop with the Numpy ufunc object for
|
||
|
this DUFunc.
|
||
|
"""
|
||
|
if self._frozen:
|
||
|
raise RuntimeError("compilation disabled for %s" % (self,))
|
||
|
assert isinstance(argtys, tuple)
|
||
|
if return_type is None:
|
||
|
sig = argtys
|
||
|
else:
|
||
|
sig = return_type(*argtys)
|
||
|
cres, argtys, return_type = ufuncbuilder._compile_element_wise_function(
|
||
|
self._dispatcher, self.targetoptions, sig)
|
||
|
actual_sig = ufuncbuilder._finalize_ufunc_signature(
|
||
|
cres, argtys, return_type)
|
||
|
dtypenums, ptr, env = ufuncbuilder._build_element_wise_ufunc_wrapper(
|
||
|
cres, actual_sig)
|
||
|
self._add_loop(int(ptr), dtypenums)
|
||
|
self._keepalive.append((ptr, cres.library, env))
|
||
|
self._lower_me.libs.append(cres.library)
|
||
|
return cres
|
||
|
|
||
|
def _install_ufunc_attributes(self, template) -> None:
|
||
|
|
||
|
def get_attr_fn(attr: str) -> Callable:
|
||
|
|
||
|
def impl(ufunc):
|
||
|
val = getattr(ufunc.key[0], attr)
|
||
|
return lambda ufunc: val
|
||
|
return impl
|
||
|
|
||
|
# ntypes/types needs "at" to be a BoundFunction rather than a Function
|
||
|
# But this fails as it cannot a weak reference to an ufunc due to NumPy
|
||
|
# not setting the "tp_weaklistoffset" field. See:
|
||
|
# https://github.com/numpy/numpy/blob/7fc72776b972bfbfdb909e4b15feb0308cf8adba/numpy/core/src/umath/ufunc_object.c#L6968-L6983 # noqa: E501
|
||
|
|
||
|
at = types.Function(template)
|
||
|
attributes = ('nin', 'nout', 'nargs', # 'ntypes', # 'types',
|
||
|
'identity', 'signature')
|
||
|
for attr in attributes:
|
||
|
attr_fn = get_attr_fn(attr)
|
||
|
overload_attribute(at, attr)(attr_fn)
|
||
|
|
||
|
def _install_ufunc_methods(self, template) -> None:
|
||
|
self._install_ufunc_reduce(template)
|
||
|
|
||
|
def _install_ufunc_reduce(self, template) -> None:
|
||
|
at = types.Function(template)
|
||
|
|
||
|
@overload_method(at, 'reduce')
|
||
|
def ol_reduce(ufunc, array, axis=0, dtype=None, initial=None):
|
||
|
|
||
|
warnings.warn("ufunc.reduce feature is experimental",
|
||
|
category=errors.NumbaExperimentalFeatureWarning)
|
||
|
|
||
|
if not isinstance(array, types.Array):
|
||
|
msg = 'The first argument "array" must be array-like'
|
||
|
raise errors.NumbaTypeError(msg)
|
||
|
|
||
|
axis_int = isinstance(axis, types.Integer)
|
||
|
axis_int_tuple = isinstance(axis, types.UniTuple) and \
|
||
|
isinstance(axis.dtype, types.Integer)
|
||
|
axis_empty_tuple = isinstance(axis, types.Tuple) and len(axis) == 0
|
||
|
axis_none = cgutils.is_nonelike(axis)
|
||
|
axis_tuple_size = len(axis) if axis_int_tuple else 0
|
||
|
|
||
|
if self.ufunc.identity is None and not (
|
||
|
(axis_int_tuple and axis_tuple_size == 1) or
|
||
|
axis_empty_tuple or axis_int or axis_none):
|
||
|
msg = (f"reduction operation '{self.ufunc.__name__}' is not "
|
||
|
"reorderable, so at most one axis may be specified")
|
||
|
raise errors.NumbaTypeError(msg)
|
||
|
|
||
|
tup_init = (0,) * (array.ndim)
|
||
|
tup_init_m1 = (0,) * (array.ndim - 1)
|
||
|
nb_dtype = array.dtype if cgutils.is_nonelike(dtype) else dtype
|
||
|
identity = self.identity
|
||
|
|
||
|
id_none = cgutils.is_nonelike(identity)
|
||
|
init_none = cgutils.is_nonelike(initial)
|
||
|
|
||
|
@register_jitable
|
||
|
def tuple_slice(tup, pos):
|
||
|
# Same as
|
||
|
# tup = tup[0 : pos] + tup[pos + 1:]
|
||
|
s = tup_init_m1
|
||
|
i = 0
|
||
|
for j, e in enumerate(tup):
|
||
|
if j == pos:
|
||
|
continue
|
||
|
s = tuple_setitem(s, i, e)
|
||
|
i += 1
|
||
|
return s
|
||
|
|
||
|
@register_jitable
|
||
|
def tuple_slice_append(tup, pos, val):
|
||
|
# Same as
|
||
|
# tup = tup[0 : pos] + val + tup[pos + 1:]
|
||
|
s = tup_init
|
||
|
i, j, sz = 0, 0, len(s)
|
||
|
while j < sz:
|
||
|
if j == pos:
|
||
|
s = tuple_setitem(s, j, val)
|
||
|
else:
|
||
|
e = tup[i]
|
||
|
s = tuple_setitem(s, j, e)
|
||
|
i += 1
|
||
|
j += 1
|
||
|
return s
|
||
|
|
||
|
@intrinsic
|
||
|
def compute_flat_idx(typingctx, strides, itemsize, idx, axis):
|
||
|
sig = types.intp(strides, itemsize, idx, axis)
|
||
|
len_idx = len(idx)
|
||
|
|
||
|
def gen_block(builder, block_pos, block_name, bb_end, args):
|
||
|
strides, _, idx, _ = args
|
||
|
bb = builder.append_basic_block(name=block_name)
|
||
|
|
||
|
with builder.goto_block(bb):
|
||
|
zero = ir.IntType(64)(0)
|
||
|
flat_idx = zero
|
||
|
|
||
|
if block_pos == 0:
|
||
|
for i in range(1, len_idx):
|
||
|
stride = builder.extract_value(strides, i - 1)
|
||
|
idx_i = builder.extract_value(idx, i)
|
||
|
m = builder.mul(stride, idx_i)
|
||
|
flat_idx = builder.add(flat_idx, m)
|
||
|
elif 0 < block_pos < len_idx - 1:
|
||
|
for i in range(0, block_pos):
|
||
|
stride = builder.extract_value(strides, i)
|
||
|
idx_i = builder.extract_value(idx, i)
|
||
|
m = builder.mul(stride, idx_i)
|
||
|
flat_idx = builder.add(flat_idx, m)
|
||
|
|
||
|
for i in range(block_pos + 1, len_idx):
|
||
|
stride = builder.extract_value(strides, i - 1)
|
||
|
idx_i = builder.extract_value(idx, i)
|
||
|
m = builder.mul(stride, idx_i)
|
||
|
flat_idx = builder.add(flat_idx, m)
|
||
|
else:
|
||
|
for i in range(0, len_idx - 1):
|
||
|
stride = builder.extract_value(strides, i)
|
||
|
idx_i = builder.extract_value(idx, i)
|
||
|
m = builder.mul(stride, idx_i)
|
||
|
flat_idx = builder.add(flat_idx, m)
|
||
|
|
||
|
builder.branch(bb_end)
|
||
|
|
||
|
return bb, flat_idx
|
||
|
|
||
|
def codegen(context, builder, sig, args):
|
||
|
strides, itemsize, idx, axis = args
|
||
|
|
||
|
bb = builder.basic_block
|
||
|
switch_end = builder.append_basic_block(name='axis_end')
|
||
|
l = []
|
||
|
for i in range(len_idx):
|
||
|
block, flat_idx = gen_block(builder, i, f"axis_{i}",
|
||
|
switch_end, args)
|
||
|
l.append((block, flat_idx))
|
||
|
|
||
|
with builder.goto_block(bb):
|
||
|
switch = builder.switch(axis, l[-1][0])
|
||
|
for i in range(len_idx):
|
||
|
switch.add_case(i, l[i][0])
|
||
|
|
||
|
builder.position_at_end(switch_end)
|
||
|
phi = builder.phi(l[0][1].type)
|
||
|
for block, value in l:
|
||
|
phi.add_incoming(value, block)
|
||
|
return builder.sdiv(phi, itemsize)
|
||
|
|
||
|
return sig, codegen
|
||
|
|
||
|
@register_jitable
|
||
|
def fixup_axis(axis, ndim):
|
||
|
ax = axis
|
||
|
for i in range(len(axis)):
|
||
|
val = axis[i] + ndim if axis[i] < 0 else axis[i]
|
||
|
ax = tuple_setitem(ax, i, val)
|
||
|
return ax
|
||
|
|
||
|
@register_jitable
|
||
|
def find_min(tup):
|
||
|
idx, e = 0, tup[0]
|
||
|
for i in range(len(tup)):
|
||
|
if tup[i] < e:
|
||
|
idx, e = i, tup[i]
|
||
|
return idx, e
|
||
|
|
||
|
def impl_1d(ufunc, array, axis=0, dtype=None, initial=None):
|
||
|
start = 0
|
||
|
if init_none and id_none:
|
||
|
start = 1
|
||
|
r = array[0]
|
||
|
elif init_none:
|
||
|
r = identity
|
||
|
else:
|
||
|
r = initial
|
||
|
|
||
|
sz = array.shape[0]
|
||
|
for i in range(start, sz):
|
||
|
r = ufunc(r, array[i])
|
||
|
return r
|
||
|
|
||
|
def impl_nd_axis_int(ufunc,
|
||
|
array,
|
||
|
axis=0,
|
||
|
dtype=None,
|
||
|
initial=None):
|
||
|
if axis is None:
|
||
|
raise ValueError("'axis' must be specified")
|
||
|
|
||
|
if axis < 0:
|
||
|
axis += array.ndim
|
||
|
|
||
|
if axis < 0 or axis >= array.ndim:
|
||
|
raise ValueError("Invalid axis")
|
||
|
|
||
|
# create result array
|
||
|
shape = tuple_slice(array.shape, axis)
|
||
|
|
||
|
if initial is None and identity is None:
|
||
|
r = np.empty(shape, dtype=nb_dtype)
|
||
|
for idx, _ in np.ndenumerate(r):
|
||
|
# shape[0:axis] + 0 + shape[axis:]
|
||
|
result_idx = tuple_slice_append(idx, axis, 0)
|
||
|
r[idx] = array[result_idx]
|
||
|
elif initial is None and identity is not None:
|
||
|
# Checking if identity is not none is redundant but required
|
||
|
# compile this block
|
||
|
r = np.full(shape, fill_value=identity, dtype=nb_dtype)
|
||
|
else:
|
||
|
r = np.full(shape, fill_value=initial, dtype=nb_dtype)
|
||
|
|
||
|
# One approach to implement reduce is to remove the axis index
|
||
|
# from the indexing tuple returned by "np.ndenumerate". For
|
||
|
# instance, if idx = (X, Y, Z) and axis=1, the result index
|
||
|
# is (X, Y).
|
||
|
# Another way is to compute the result index using strides,
|
||
|
# which is faster than manipulating tuples.
|
||
|
view = r.ravel()
|
||
|
if initial is None and identity is None:
|
||
|
for idx, val in np.ndenumerate(array):
|
||
|
if idx[axis] == 0:
|
||
|
continue
|
||
|
else:
|
||
|
flat_pos = compute_flat_idx(r.strides, r.itemsize,
|
||
|
idx, axis)
|
||
|
lhs, rhs = view[flat_pos], val
|
||
|
view[flat_pos] = ufunc(lhs, rhs)
|
||
|
else:
|
||
|
for idx, val in np.ndenumerate(array):
|
||
|
if initial is None and identity is None and \
|
||
|
idx[axis] == 0:
|
||
|
continue
|
||
|
flat_pos = compute_flat_idx(r.strides, r.itemsize,
|
||
|
idx, axis)
|
||
|
lhs, rhs = view[flat_pos], val
|
||
|
view[flat_pos] = ufunc(lhs, rhs)
|
||
|
return r
|
||
|
|
||
|
def impl_nd_axis_tuple(ufunc,
|
||
|
array,
|
||
|
axis=0,
|
||
|
dtype=None,
|
||
|
initial=None):
|
||
|
axis_ = fixup_axis(axis, array.ndim)
|
||
|
for i in range(0, len(axis_)):
|
||
|
if axis_[i] < 0 or axis_[i] >= array.ndim:
|
||
|
raise ValueError("Invalid axis")
|
||
|
|
||
|
for j in range(i + 1, len(axis_)):
|
||
|
if axis_[i] == axis_[j]:
|
||
|
raise ValueError("duplicate value in 'axis'")
|
||
|
|
||
|
min_idx, min_elem = find_min(axis_)
|
||
|
r = ufunc.reduce(array,
|
||
|
axis=min_elem,
|
||
|
dtype=dtype,
|
||
|
initial=initial)
|
||
|
if len(axis) == 1:
|
||
|
return r
|
||
|
elif len(axis) == 2:
|
||
|
return ufunc.reduce(r, axis=axis_[(min_idx + 1) % 2] - 1)
|
||
|
else:
|
||
|
ax = axis_tup
|
||
|
for i in range(len(ax)):
|
||
|
if i != min_idx:
|
||
|
ax = tuple_setitem(ax, i, axis_[i])
|
||
|
return ufunc.reduce(r, axis=ax)
|
||
|
|
||
|
def impl_axis_empty_tuple(ufunc,
|
||
|
array,
|
||
|
axis=0,
|
||
|
dtype=None,
|
||
|
initial=None):
|
||
|
return array
|
||
|
|
||
|
def impl_axis_none(ufunc,
|
||
|
array,
|
||
|
axis=0,
|
||
|
dtype=None,
|
||
|
initial=None):
|
||
|
return ufunc.reduce(array, axis_tup, dtype, initial)
|
||
|
|
||
|
if array.ndim == 1 and not axis_empty_tuple:
|
||
|
return impl_1d
|
||
|
elif axis_empty_tuple:
|
||
|
# ufunc(array, axis=())
|
||
|
return impl_axis_empty_tuple
|
||
|
elif axis_none:
|
||
|
# ufunc(array, axis=None)
|
||
|
axis_tup = tuple(range(array.ndim))
|
||
|
return impl_axis_none
|
||
|
elif axis_int_tuple:
|
||
|
# axis is tuple of integers
|
||
|
# ufunc(array, axis=(1, 2, ...))
|
||
|
axis_tup = (0,) * (len(axis) - 1)
|
||
|
return impl_nd_axis_tuple
|
||
|
elif axis == 0 or isinstance(axis, (types.Integer,
|
||
|
types.Omitted,
|
||
|
types.IntegerLiteral)):
|
||
|
# axis is default value (0) or an integer
|
||
|
# ufunc(array, axis=0)
|
||
|
return impl_nd_axis_int
|
||
|
# elif array.ndim == 1:
|
||
|
# return impl_1d
|
||
|
|
||
|
def _install_type(self, typingctx=None):
|
||
|
"""Constructs and installs a typing class for a DUFunc object in the
|
||
|
input typing context. If no typing context is given, then
|
||
|
_install_type() installs into the typing context of the
|
||
|
dispatcher object (should be same default context used by
|
||
|
jit() and njit()).
|
||
|
"""
|
||
|
if typingctx is None:
|
||
|
typingctx = self._dispatcher.targetdescr.typing_context
|
||
|
_ty_cls = type('DUFuncTyping_' + self.ufunc.__name__,
|
||
|
(AbstractTemplate,),
|
||
|
dict(key=self, generic=self._type_me))
|
||
|
typingctx.insert_user_function(self, _ty_cls)
|
||
|
self._install_ufunc_attributes(_ty_cls)
|
||
|
self._install_ufunc_methods(_ty_cls)
|
||
|
|
||
|
def find_ewise_function(self, ewise_types):
|
||
|
"""
|
||
|
Given a tuple of element-wise argument types, find a matching
|
||
|
signature in the dispatcher.
|
||
|
|
||
|
Return a 2-tuple containing the matching signature, and
|
||
|
compilation result. Will return two None's if no matching
|
||
|
signature was found.
|
||
|
"""
|
||
|
if self._frozen:
|
||
|
# If we cannot compile, coerce to the best matching loop
|
||
|
loop = numpy_support.ufunc_find_matching_loop(self, ewise_types)
|
||
|
if loop is None:
|
||
|
return None, None
|
||
|
ewise_types = tuple(loop.inputs + loop.outputs)[:len(ewise_types)]
|
||
|
for sig, cres in self._dispatcher.overloads.items():
|
||
|
if sig.args == ewise_types:
|
||
|
return sig, cres
|
||
|
return None, None
|
||
|
|
||
|
def _type_me(self, argtys, kwtys):
|
||
|
"""
|
||
|
Implement AbstractTemplate.generic() for the typing class
|
||
|
built by DUFunc._install_type().
|
||
|
|
||
|
Return the call-site signature after either validating the
|
||
|
element-wise signature or compiling for it.
|
||
|
"""
|
||
|
assert not kwtys
|
||
|
ufunc = self.ufunc
|
||
|
_handle_inputs_result = npydecl.Numpy_rules_ufunc._handle_inputs(
|
||
|
ufunc, argtys, kwtys)
|
||
|
base_types, explicit_outputs, ndims, layout = _handle_inputs_result
|
||
|
explicit_output_count = len(explicit_outputs)
|
||
|
if explicit_output_count > 0:
|
||
|
ewise_types = tuple(base_types[:-len(explicit_outputs)])
|
||
|
else:
|
||
|
ewise_types = tuple(base_types)
|
||
|
sig, cres = self.find_ewise_function(ewise_types)
|
||
|
if sig is None:
|
||
|
# Matching element-wise signature was not found; must
|
||
|
# compile.
|
||
|
if self._frozen:
|
||
|
raise TypeError("cannot call %s with types %s"
|
||
|
% (self, argtys))
|
||
|
self._compile_for_argtys(ewise_types)
|
||
|
sig, cres = self.find_ewise_function(ewise_types)
|
||
|
assert sig is not None
|
||
|
if explicit_output_count > 0:
|
||
|
outtys = list(explicit_outputs)
|
||
|
elif ufunc.nout == 1:
|
||
|
if ndims > 0:
|
||
|
outtys = [types.Array(sig.return_type, ndims, layout)]
|
||
|
else:
|
||
|
outtys = [sig.return_type]
|
||
|
else:
|
||
|
raise NotImplementedError("typing gufuncs (nout > 1)")
|
||
|
outtys.extend(argtys)
|
||
|
return signature(*outtys)
|
||
|
|
||
|
def _install_cg(self, targetctx=None):
|
||
|
"""
|
||
|
Install an implementation function for a DUFunc object in the
|
||
|
given target context. If no target context is given, then
|
||
|
_install_cg() installs into the target context of the
|
||
|
dispatcher object (should be same default context used by
|
||
|
jit() and njit()).
|
||
|
"""
|
||
|
if targetctx is None:
|
||
|
targetctx = self._dispatcher.targetdescr.target_context
|
||
|
_any = types.Any
|
||
|
_arr = types.Array
|
||
|
# Either all outputs are explicit or none of them are
|
||
|
sig0 = (_any,) * self.ufunc.nin + (_arr,) * self.ufunc.nout
|
||
|
sig1 = (_any,) * self.ufunc.nin
|
||
|
targetctx.insert_func_defn(
|
||
|
[(self._lower_me, self, sig) for sig in (sig0, sig1)])
|
||
|
|
||
|
|
||
|
array_analysis.MAP_TYPES.append(DUFunc)
|