189 lines
4.0 KiB
Python
189 lines
4.0 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates
|
|
|
|
import torch
|
|
|
|
from .core import _map_mt_args_kwargs, _wrap_result
|
|
|
|
__all__ = [] # type: ignore[var-annotated]
|
|
|
|
|
|
UNARY_NAMES = [
|
|
"abs",
|
|
"absolute",
|
|
"acos",
|
|
"arccos",
|
|
"acosh",
|
|
"arccosh",
|
|
"angle",
|
|
"asin",
|
|
"arcsin",
|
|
"asinh",
|
|
"arcsinh",
|
|
"atan",
|
|
"arctan",
|
|
"atanh",
|
|
"arctanh",
|
|
"bitwise_not",
|
|
"ceil",
|
|
"clamp",
|
|
"clip",
|
|
"conj_physical",
|
|
"cos",
|
|
"cosh",
|
|
"deg2rad",
|
|
"digamma",
|
|
"erf",
|
|
"erfc",
|
|
"erfinv",
|
|
"exp",
|
|
"exp2",
|
|
"expm1",
|
|
"fix",
|
|
"floor",
|
|
"frac",
|
|
"lgamma",
|
|
"log",
|
|
"log10",
|
|
"log1p",
|
|
"log2",
|
|
"logit",
|
|
"i0",
|
|
"isnan",
|
|
"nan_to_num",
|
|
"neg",
|
|
"negative",
|
|
"positive",
|
|
"pow",
|
|
"rad2deg",
|
|
"reciprocal",
|
|
"round",
|
|
"rsqrt",
|
|
"sigmoid",
|
|
"sign",
|
|
"sgn",
|
|
"signbit",
|
|
"sin",
|
|
"sinc",
|
|
"sinh",
|
|
"sqrt",
|
|
"square",
|
|
"tan",
|
|
"tanh",
|
|
"trunc",
|
|
]
|
|
|
|
INPLACE_UNARY_NAMES = [
|
|
n + "_"
|
|
for n in (list(set(UNARY_NAMES) - {"angle", "positive", "signbit", "isnan"}))
|
|
]
|
|
|
|
# Explicitly tracking functions we know are currently not supported
|
|
# This might be due to missing code gen or because of complex semantics
|
|
UNARY_NAMES_UNSUPPORTED = [
|
|
"atan2",
|
|
"arctan2",
|
|
"bitwise_left_shift",
|
|
"bitwise_right_shift",
|
|
"copysign",
|
|
"float_power",
|
|
"fmod",
|
|
"frexp",
|
|
"gradient",
|
|
"imag",
|
|
"ldexp",
|
|
"lerp",
|
|
"logical_not",
|
|
"hypot",
|
|
"igamma",
|
|
"igammac",
|
|
"mvlgamma",
|
|
"nextafter",
|
|
"polygamma",
|
|
"real",
|
|
"remainder",
|
|
"true_divide",
|
|
"xlogy",
|
|
]
|
|
|
|
|
|
def _unary_helper(fn, args, kwargs, inplace):
|
|
if len(kwargs) != 0:
|
|
raise ValueError("MaskedTensor unary ops require that len(kwargs) == 0. "
|
|
"If you need support for this, please open an issue on Github.")
|
|
for a in args[1:]:
|
|
if torch.is_tensor(a):
|
|
raise TypeError("MaskedTensor unary ops do not support additional Tensor arguments")
|
|
|
|
mask_args, mask_kwargs = _map_mt_args_kwargs(
|
|
args, kwargs, lambda x: x._masked_mask
|
|
)
|
|
data_args, data_kwargs = _map_mt_args_kwargs(
|
|
args, kwargs, lambda x: x._masked_data
|
|
)
|
|
|
|
if args[0].layout == torch.sparse_coo:
|
|
data_args[0] = data_args[0].coalesce()
|
|
s = data_args[0].size()
|
|
i = data_args[0].indices()
|
|
data_args[0] = data_args[0].coalesce().values()
|
|
v = fn(*data_args)
|
|
result_data = torch.sparse_coo_tensor(i, v, size=s)
|
|
|
|
elif args[0].layout == torch.sparse_csr:
|
|
crow = data_args[0].crow_indices()
|
|
col = data_args[0].col_indices()
|
|
data_args[0] = data_args[0].values()
|
|
v = fn(*data_args)
|
|
result_data = torch.sparse_csr_tensor(crow, col, v)
|
|
|
|
else:
|
|
result_data = fn(*data_args)
|
|
|
|
if inplace:
|
|
args[0]._set_data_mask(result_data, mask_args[0])
|
|
return args[0]
|
|
else:
|
|
return _wrap_result(result_data, mask_args[0])
|
|
|
|
|
|
def _torch_unary(fn_name):
|
|
fn = getattr(torch.ops.aten, fn_name)
|
|
|
|
def unary_fn(*args, **kwargs):
|
|
return _unary_helper(fn, args, kwargs, inplace=False)
|
|
|
|
return unary_fn
|
|
|
|
|
|
def _torch_inplace_unary(fn_name):
|
|
fn = getattr(torch.ops.aten, fn_name)
|
|
|
|
def unary_fn(*args, **kwargs):
|
|
return _unary_helper(fn, args, kwargs, inplace=True)
|
|
|
|
return unary_fn
|
|
|
|
|
|
NATIVE_UNARY_MAP = {
|
|
getattr(torch.ops.aten, name): _torch_unary(name) for name in UNARY_NAMES
|
|
}
|
|
NATIVE_INPLACE_UNARY_MAP = {
|
|
getattr(torch.ops.aten, name): _torch_inplace_unary(name)
|
|
for name in INPLACE_UNARY_NAMES
|
|
}
|
|
|
|
NATIVE_UNARY_FNS = list(NATIVE_UNARY_MAP.keys())
|
|
NATIVE_INPLACE_UNARY_FNS = list(NATIVE_INPLACE_UNARY_MAP.keys())
|
|
|
|
|
|
def _is_native_unary(fn):
|
|
return fn in NATIVE_UNARY_FNS or fn in NATIVE_INPLACE_UNARY_FNS
|
|
|
|
|
|
def _apply_native_unary(fn, *args, **kwargs):
|
|
if fn in NATIVE_UNARY_FNS:
|
|
return NATIVE_UNARY_MAP[fn](*args, **kwargs)
|
|
if fn in NATIVE_INPLACE_UNARY_FNS:
|
|
return NATIVE_INPLACE_UNARY_MAP[fn](*args, **kwargs)
|
|
return NotImplemented
|