335 lines
8.2 KiB
Python
335 lines
8.2 KiB
Python
# mypy: ignore-errors
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import Optional
|
|
|
|
import torch
|
|
|
|
from . import _binary_ufuncs_impl, _dtypes_impl, _unary_ufuncs_impl, _util
|
|
from ._normalizations import (
|
|
ArrayLike,
|
|
ArrayLikeOrScalar,
|
|
CastingModes,
|
|
DTypeLike,
|
|
normalizer,
|
|
NotImplementedType,
|
|
OutArray,
|
|
)
|
|
|
|
|
|
def _ufunc_postprocess(result, out, casting):
|
|
if out is not None:
|
|
result = _util.typecast_tensor(result, out.dtype.torch_dtype, casting)
|
|
result = torch.broadcast_to(result, out.shape)
|
|
return result
|
|
|
|
|
|
# ############# Binary ufuncs ######################
|
|
|
|
_binary = [
|
|
name
|
|
for name in dir(_binary_ufuncs_impl)
|
|
if not name.startswith("_") and name not in ["torch", "matmul", "divmod", "ldexp"]
|
|
]
|
|
|
|
|
|
NEP50_FUNCS = (
|
|
"add",
|
|
"subtract",
|
|
"multiply",
|
|
"floor_divide",
|
|
"true_divide",
|
|
"divide",
|
|
"remainder",
|
|
"bitwise_and",
|
|
"bitwise_or",
|
|
"bitwise_xor",
|
|
"bitwise_left_shift",
|
|
"bitwise_right_shift",
|
|
"hypot",
|
|
"arctan2",
|
|
"logaddexp",
|
|
"logaddexp2",
|
|
"heaviside",
|
|
"copysign",
|
|
"fmax",
|
|
"minimum",
|
|
"fmin",
|
|
"maximum",
|
|
"fmod",
|
|
"gcd",
|
|
"lcm",
|
|
"pow",
|
|
)
|
|
|
|
|
|
def deco_binary_ufunc(torch_func):
|
|
"""Common infra for binary ufuncs.
|
|
|
|
Normalize arguments, sort out type casting, broadcasting and delegate to
|
|
the pytorch functions for the actual work.
|
|
"""
|
|
|
|
@normalizer
|
|
def wrapped(
|
|
x1: ArrayLikeOrScalar,
|
|
x2: ArrayLikeOrScalar,
|
|
/,
|
|
out: Optional[OutArray] = None,
|
|
*,
|
|
where: NotImplementedType = True,
|
|
casting: Optional[CastingModes] = "same_kind",
|
|
order: NotImplementedType = "K",
|
|
dtype: Optional[DTypeLike] = None,
|
|
subok: NotImplementedType = False,
|
|
signature: NotImplementedType = None,
|
|
extobj: NotImplementedType = None,
|
|
):
|
|
if dtype is not None:
|
|
|
|
def cast(x, dtype):
|
|
if isinstance(x, torch.Tensor):
|
|
return _util.typecast_tensor(x, dtype, casting)
|
|
else:
|
|
return torch.as_tensor(x, dtype=dtype)
|
|
|
|
x1 = cast(x1, dtype)
|
|
x2 = cast(x2, dtype)
|
|
elif isinstance(x1, torch.Tensor) and isinstance(x2, torch.Tensor):
|
|
dtype = _dtypes_impl.result_type_impl(x1, x2)
|
|
x1, x2 = _util.typecast_tensors((x1, x2), dtype, casting)
|
|
else:
|
|
x1, x2 = _dtypes_impl.nep50_to_tensors(
|
|
x1, x2, torch_func.__name__ in NEP50_FUNCS, torch_func.__name__
|
|
)
|
|
|
|
result = torch_func(x1, x2)
|
|
|
|
return _ufunc_postprocess(result, out, casting)
|
|
|
|
wrapped.__qualname__ = torch_func.__name__
|
|
wrapped.__name__ = torch_func.__name__
|
|
|
|
return wrapped
|
|
|
|
|
|
# matmul's signature is _slightly_ different from other ufuncs:
|
|
# - no where=...
|
|
# - additional axis=..., axes=...
|
|
# - no NEP50 scalars in or out
|
|
@normalizer
|
|
def matmul(
|
|
x1: ArrayLike,
|
|
x2: ArrayLike,
|
|
/,
|
|
out: Optional[OutArray] = None,
|
|
*,
|
|
casting: Optional[CastingModes] = "same_kind",
|
|
order: NotImplementedType = "K",
|
|
dtype: Optional[DTypeLike] = None,
|
|
subok: NotImplementedType = False,
|
|
signature: NotImplementedType = None,
|
|
extobj: NotImplementedType = None,
|
|
axes: NotImplementedType = None,
|
|
axis: NotImplementedType = None,
|
|
):
|
|
if dtype is None:
|
|
dtype = _dtypes_impl.result_type_impl(x1, x2)
|
|
x1, x2 = _util.typecast_tensors((x1, x2), dtype, casting)
|
|
|
|
result = _binary_ufuncs_impl.matmul(x1, x2)
|
|
|
|
result = _ufunc_postprocess(result, out, casting)
|
|
return result
|
|
|
|
|
|
# ldexp casting is special : the dtype of the result == dtype of the 1st arg
|
|
@normalizer
|
|
def ldexp(
|
|
x1: ArrayLikeOrScalar,
|
|
x2: ArrayLikeOrScalar,
|
|
/,
|
|
out: Optional[OutArray] = None,
|
|
*,
|
|
where: NotImplementedType = True,
|
|
casting: Optional[CastingModes] = "same_kind",
|
|
order: NotImplementedType = "K",
|
|
dtype: Optional[DTypeLike] = None,
|
|
subok: NotImplementedType = False,
|
|
signature: NotImplementedType = None,
|
|
extobj: NotImplementedType = None,
|
|
):
|
|
if dtype is not None:
|
|
if isinstance(x1, torch.Tensor):
|
|
x1 = _util.typecast_tensor(x1, dtype, casting)
|
|
else:
|
|
x1 = torch.as_tensor(x1, dtype=dtype)
|
|
else:
|
|
if not isinstance(x1, torch.Tensor):
|
|
x1 = torch.as_tensor(x1)
|
|
x1 = _util.cast_int_to_float(x1)
|
|
|
|
x2 = torch.as_tensor(x2)
|
|
# the second arg must be integer
|
|
if _dtypes_impl._category(x2.dtype) != 1:
|
|
raise ValueError("ldexp 2nd arg must be integer")
|
|
|
|
result = _binary_ufuncs_impl.ldexp(x1, x2)
|
|
|
|
if x1.dtype == torch.float16:
|
|
# torch.ldexp(f16, int) -> f32, undo it
|
|
result = result.to(torch.float16)
|
|
|
|
return _ufunc_postprocess(result, out, casting)
|
|
|
|
|
|
# nin=2, nout=2
|
|
@normalizer
|
|
def divmod(
|
|
x1: ArrayLike,
|
|
x2: ArrayLike,
|
|
out1: Optional[OutArray] = None,
|
|
out2: Optional[OutArray] = None,
|
|
/,
|
|
out: tuple[Optional[OutArray], Optional[OutArray]] = (None, None),
|
|
*,
|
|
where: NotImplementedType = True,
|
|
casting: Optional[CastingModes] = "same_kind",
|
|
order: NotImplementedType = "K",
|
|
dtype: Optional[DTypeLike] = None,
|
|
subok: NotImplementedType = False,
|
|
signature: NotImplementedType = None,
|
|
extobj: NotImplementedType = None,
|
|
):
|
|
# make sure we either have no out arrays at all, or there is either
|
|
# out1, out2, or out=tuple, but not both
|
|
num_outs = sum(x is not None for x in [out1, out2])
|
|
if num_outs == 1:
|
|
raise ValueError("both out1 and out2 need to be provided")
|
|
elif num_outs == 2:
|
|
o1, o2 = out
|
|
if o1 is not None or o2 is not None:
|
|
raise TypeError(
|
|
"cannot specify 'out' as both a positional and keyword argument"
|
|
)
|
|
else:
|
|
out1, out2 = out
|
|
|
|
if dtype is None:
|
|
dtype = _dtypes_impl.result_type_impl(x1, x2)
|
|
x1, x2 = _util.typecast_tensors((x1, x2), dtype, casting)
|
|
|
|
quot, rem = _binary_ufuncs_impl.divmod(x1, x2)
|
|
|
|
quot = _ufunc_postprocess(quot, out1, casting)
|
|
rem = _ufunc_postprocess(rem, out2, casting)
|
|
return quot, rem
|
|
|
|
|
|
#
|
|
# Attach ufuncs to this module, for a further export to the public namespace in __init__.py
|
|
#
|
|
for name in _binary:
|
|
ufunc = getattr(_binary_ufuncs_impl, name)
|
|
vars()[name] = deco_binary_ufunc(ufunc)
|
|
|
|
|
|
def modf(x, /, *args, **kwds):
|
|
quot, rem = divmod(x, 1, *args, **kwds)
|
|
return rem, quot
|
|
|
|
|
|
_binary = _binary + ["divmod", "modf", "matmul", "ldexp"]
|
|
|
|
|
|
# ############# Unary ufuncs ######################
|
|
|
|
|
|
_unary = [
|
|
name
|
|
for name in dir(_unary_ufuncs_impl)
|
|
if not name.startswith("_") and name != "torch"
|
|
]
|
|
|
|
|
|
# these are ufunc(int) -> float
|
|
_fp_unary = [
|
|
"arccos",
|
|
"arccosh",
|
|
"arcsin",
|
|
"arcsinh",
|
|
"arctan",
|
|
"arctanh",
|
|
"cbrt",
|
|
"cos",
|
|
"cosh",
|
|
"deg2rad",
|
|
"degrees",
|
|
"exp",
|
|
"exp2",
|
|
"expm1",
|
|
"log",
|
|
"log10",
|
|
"log1p",
|
|
"log2",
|
|
"rad2deg",
|
|
"radians",
|
|
"reciprocal",
|
|
"sin",
|
|
"sinh",
|
|
"sqrt",
|
|
"square",
|
|
"tan",
|
|
"tanh",
|
|
"trunc",
|
|
]
|
|
|
|
|
|
def deco_unary_ufunc(torch_func):
|
|
"""Common infra for unary ufuncs.
|
|
|
|
Normalize arguments, sort out type casting, broadcasting and delegate to
|
|
the pytorch functions for the actual work.
|
|
"""
|
|
|
|
@normalizer
|
|
def wrapped(
|
|
x: ArrayLike,
|
|
/,
|
|
out: Optional[OutArray] = None,
|
|
*,
|
|
where=True,
|
|
casting: Optional[CastingModes] = "same_kind",
|
|
order="K",
|
|
dtype: Optional[DTypeLike] = None,
|
|
subok: NotImplementedType = False,
|
|
signature=None,
|
|
extobj=None,
|
|
):
|
|
if dtype is not None:
|
|
x = _util.typecast_tensor(x, dtype, casting)
|
|
|
|
if torch_func.__name__ in _fp_unary:
|
|
x = _util.cast_int_to_float(x)
|
|
|
|
result = torch_func(x)
|
|
result = _ufunc_postprocess(result, out, casting)
|
|
return result
|
|
|
|
wrapped.__qualname__ = torch_func.__name__
|
|
wrapped.__name__ = torch_func.__name__
|
|
|
|
return wrapped
|
|
|
|
|
|
#
|
|
# Attach ufuncs to this module, for a further export to the public namespace in __init__.py
|
|
#
|
|
for name in _unary:
|
|
ufunc = getattr(_unary_ufuncs_impl, name)
|
|
vars()[name] = deco_unary_ufunc(ufunc)
|
|
|
|
|
|
__all__ = _binary + _unary # noqa: PLE0605
|