# mypy: ignore-errors from __future__ import annotations from typing import Optional import torch from . import _binary_ufuncs_impl, _dtypes_impl, _unary_ufuncs_impl, _util from ._normalizations import ( ArrayLike, ArrayLikeOrScalar, CastingModes, DTypeLike, normalizer, NotImplementedType, OutArray, ) def _ufunc_postprocess(result, out, casting): if out is not None: result = _util.typecast_tensor(result, out.dtype.torch_dtype, casting) result = torch.broadcast_to(result, out.shape) return result # ############# Binary ufuncs ###################### _binary = [ name for name in dir(_binary_ufuncs_impl) if not name.startswith("_") and name not in ["torch", "matmul", "divmod", "ldexp"] ] NEP50_FUNCS = ( "add", "subtract", "multiply", "floor_divide", "true_divide", "divide", "remainder", "bitwise_and", "bitwise_or", "bitwise_xor", "bitwise_left_shift", "bitwise_right_shift", "hypot", "arctan2", "logaddexp", "logaddexp2", "heaviside", "copysign", "fmax", "minimum", "fmin", "maximum", "fmod", "gcd", "lcm", "pow", ) def deco_binary_ufunc(torch_func): """Common infra for binary ufuncs. Normalize arguments, sort out type casting, broadcasting and delegate to the pytorch functions for the actual work. """ @normalizer def wrapped( x1: ArrayLikeOrScalar, x2: ArrayLikeOrScalar, /, out: Optional[OutArray] = None, *, where: NotImplementedType = True, casting: Optional[CastingModes] = "same_kind", order: NotImplementedType = "K", dtype: Optional[DTypeLike] = None, subok: NotImplementedType = False, signature: NotImplementedType = None, extobj: NotImplementedType = None, ): if dtype is not None: def cast(x, dtype): if isinstance(x, torch.Tensor): return _util.typecast_tensor(x, dtype, casting) else: return torch.as_tensor(x, dtype=dtype) x1 = cast(x1, dtype) x2 = cast(x2, dtype) elif isinstance(x1, torch.Tensor) and isinstance(x2, torch.Tensor): dtype = _dtypes_impl.result_type_impl(x1, x2) x1, x2 = _util.typecast_tensors((x1, x2), dtype, casting) else: x1, x2 = _dtypes_impl.nep50_to_tensors( x1, x2, torch_func.__name__ in NEP50_FUNCS, torch_func.__name__ ) result = torch_func(x1, x2) return _ufunc_postprocess(result, out, casting) wrapped.__qualname__ = torch_func.__name__ wrapped.__name__ = torch_func.__name__ return wrapped # matmul's signature is _slightly_ different from other ufuncs: # - no where=... # - additional axis=..., axes=... # - no NEP50 scalars in or out @normalizer def matmul( x1: ArrayLike, x2: ArrayLike, /, out: Optional[OutArray] = None, *, casting: Optional[CastingModes] = "same_kind", order: NotImplementedType = "K", dtype: Optional[DTypeLike] = None, subok: NotImplementedType = False, signature: NotImplementedType = None, extobj: NotImplementedType = None, axes: NotImplementedType = None, axis: NotImplementedType = None, ): if dtype is None: dtype = _dtypes_impl.result_type_impl(x1, x2) x1, x2 = _util.typecast_tensors((x1, x2), dtype, casting) result = _binary_ufuncs_impl.matmul(x1, x2) result = _ufunc_postprocess(result, out, casting) return result # ldexp casting is special : the dtype of the result == dtype of the 1st arg @normalizer def ldexp( x1: ArrayLikeOrScalar, x2: ArrayLikeOrScalar, /, out: Optional[OutArray] = None, *, where: NotImplementedType = True, casting: Optional[CastingModes] = "same_kind", order: NotImplementedType = "K", dtype: Optional[DTypeLike] = None, subok: NotImplementedType = False, signature: NotImplementedType = None, extobj: NotImplementedType = None, ): if dtype is not None: if isinstance(x1, torch.Tensor): x1 = _util.typecast_tensor(x1, dtype, casting) else: x1 = torch.as_tensor(x1, dtype=dtype) else: if not isinstance(x1, torch.Tensor): x1 = torch.as_tensor(x1) x1 = _util.cast_int_to_float(x1) x2 = torch.as_tensor(x2) # the second arg must be integer if _dtypes_impl._category(x2.dtype) != 1: raise ValueError("ldexp 2nd arg must be integer") result = _binary_ufuncs_impl.ldexp(x1, x2) if x1.dtype == torch.float16: # torch.ldexp(f16, int) -> f32, undo it result = result.to(torch.float16) return _ufunc_postprocess(result, out, casting) # nin=2, nout=2 @normalizer def divmod( x1: ArrayLike, x2: ArrayLike, out1: Optional[OutArray] = None, out2: Optional[OutArray] = None, /, out: tuple[Optional[OutArray], Optional[OutArray]] = (None, None), *, where: NotImplementedType = True, casting: Optional[CastingModes] = "same_kind", order: NotImplementedType = "K", dtype: Optional[DTypeLike] = None, subok: NotImplementedType = False, signature: NotImplementedType = None, extobj: NotImplementedType = None, ): # make sure we either have no out arrays at all, or there is either # out1, out2, or out=tuple, but not both num_outs = sum(x is not None for x in [out1, out2]) if num_outs == 1: raise ValueError("both out1 and out2 need to be provided") elif num_outs == 2: o1, o2 = out if o1 is not None or o2 is not None: raise TypeError( "cannot specify 'out' as both a positional and keyword argument" ) else: out1, out2 = out if dtype is None: dtype = _dtypes_impl.result_type_impl(x1, x2) x1, x2 = _util.typecast_tensors((x1, x2), dtype, casting) quot, rem = _binary_ufuncs_impl.divmod(x1, x2) quot = _ufunc_postprocess(quot, out1, casting) rem = _ufunc_postprocess(rem, out2, casting) return quot, rem # # Attach ufuncs to this module, for a further export to the public namespace in __init__.py # for name in _binary: ufunc = getattr(_binary_ufuncs_impl, name) vars()[name] = deco_binary_ufunc(ufunc) def modf(x, /, *args, **kwds): quot, rem = divmod(x, 1, *args, **kwds) return rem, quot _binary = _binary + ["divmod", "modf", "matmul", "ldexp"] # ############# Unary ufuncs ###################### _unary = [ name for name in dir(_unary_ufuncs_impl) if not name.startswith("_") and name != "torch" ] # these are ufunc(int) -> float _fp_unary = [ "arccos", "arccosh", "arcsin", "arcsinh", "arctan", "arctanh", "cbrt", "cos", "cosh", "deg2rad", "degrees", "exp", "exp2", "expm1", "log", "log10", "log1p", "log2", "rad2deg", "radians", "reciprocal", "sin", "sinh", "sqrt", "square", "tan", "tanh", "trunc", ] def deco_unary_ufunc(torch_func): """Common infra for unary ufuncs. Normalize arguments, sort out type casting, broadcasting and delegate to the pytorch functions for the actual work. """ @normalizer def wrapped( x: ArrayLike, /, out: Optional[OutArray] = None, *, where=True, casting: Optional[CastingModes] = "same_kind", order="K", dtype: Optional[DTypeLike] = None, subok: NotImplementedType = False, signature=None, extobj=None, ): if dtype is not None: x = _util.typecast_tensor(x, dtype, casting) if torch_func.__name__ in _fp_unary: x = _util.cast_int_to_float(x) result = torch_func(x) result = _ufunc_postprocess(result, out, casting) return result wrapped.__qualname__ = torch_func.__name__ wrapped.__name__ = torch_func.__name__ return wrapped # # Attach ufuncs to this module, for a further export to the public namespace in __init__.py # for name in _unary: ufunc = getattr(_unary_ufuncs_impl, name) vars()[name] = deco_unary_ufunc(ufunc) __all__ = _binary + _unary # noqa: PLE0605