ai-content-maker/.venv/Lib/site-packages/torch/masked/maskedtensor/unary.py

189 lines
4.0 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates
import torch
from .core import _map_mt_args_kwargs, _wrap_result
__all__ = [] # type: ignore[var-annotated]
UNARY_NAMES = [
"abs",
"absolute",
"acos",
"arccos",
"acosh",
"arccosh",
"angle",
"asin",
"arcsin",
"asinh",
"arcsinh",
"atan",
"arctan",
"atanh",
"arctanh",
"bitwise_not",
"ceil",
"clamp",
"clip",
"conj_physical",
"cos",
"cosh",
"deg2rad",
"digamma",
"erf",
"erfc",
"erfinv",
"exp",
"exp2",
"expm1",
"fix",
"floor",
"frac",
"lgamma",
"log",
"log10",
"log1p",
"log2",
"logit",
"i0",
"isnan",
"nan_to_num",
"neg",
"negative",
"positive",
"pow",
"rad2deg",
"reciprocal",
"round",
"rsqrt",
"sigmoid",
"sign",
"sgn",
"signbit",
"sin",
"sinc",
"sinh",
"sqrt",
"square",
"tan",
"tanh",
"trunc",
]
INPLACE_UNARY_NAMES = [
n + "_"
for n in (list(set(UNARY_NAMES) - {"angle", "positive", "signbit", "isnan"}))
]
# Explicitly tracking functions we know are currently not supported
# This might be due to missing code gen or because of complex semantics
UNARY_NAMES_UNSUPPORTED = [
"atan2",
"arctan2",
"bitwise_left_shift",
"bitwise_right_shift",
"copysign",
"float_power",
"fmod",
"frexp",
"gradient",
"imag",
"ldexp",
"lerp",
"logical_not",
"hypot",
"igamma",
"igammac",
"mvlgamma",
"nextafter",
"polygamma",
"real",
"remainder",
"true_divide",
"xlogy",
]
def _unary_helper(fn, args, kwargs, inplace):
if len(kwargs) != 0:
raise ValueError("MaskedTensor unary ops require that len(kwargs) == 0. "
"If you need support for this, please open an issue on Github.")
for a in args[1:]:
if torch.is_tensor(a):
raise TypeError("MaskedTensor unary ops do not support additional Tensor arguments")
mask_args, mask_kwargs = _map_mt_args_kwargs(
args, kwargs, lambda x: x._masked_mask
)
data_args, data_kwargs = _map_mt_args_kwargs(
args, kwargs, lambda x: x._masked_data
)
if args[0].layout == torch.sparse_coo:
data_args[0] = data_args[0].coalesce()
s = data_args[0].size()
i = data_args[0].indices()
data_args[0] = data_args[0].coalesce().values()
v = fn(*data_args)
result_data = torch.sparse_coo_tensor(i, v, size=s)
elif args[0].layout == torch.sparse_csr:
crow = data_args[0].crow_indices()
col = data_args[0].col_indices()
data_args[0] = data_args[0].values()
v = fn(*data_args)
result_data = torch.sparse_csr_tensor(crow, col, v)
else:
result_data = fn(*data_args)
if inplace:
args[0]._set_data_mask(result_data, mask_args[0])
return args[0]
else:
return _wrap_result(result_data, mask_args[0])
def _torch_unary(fn_name):
fn = getattr(torch.ops.aten, fn_name)
def unary_fn(*args, **kwargs):
return _unary_helper(fn, args, kwargs, inplace=False)
return unary_fn
def _torch_inplace_unary(fn_name):
fn = getattr(torch.ops.aten, fn_name)
def unary_fn(*args, **kwargs):
return _unary_helper(fn, args, kwargs, inplace=True)
return unary_fn
NATIVE_UNARY_MAP = {
getattr(torch.ops.aten, name): _torch_unary(name) for name in UNARY_NAMES
}
NATIVE_INPLACE_UNARY_MAP = {
getattr(torch.ops.aten, name): _torch_inplace_unary(name)
for name in INPLACE_UNARY_NAMES
}
NATIVE_UNARY_FNS = list(NATIVE_UNARY_MAP.keys())
NATIVE_INPLACE_UNARY_FNS = list(NATIVE_INPLACE_UNARY_MAP.keys())
def _is_native_unary(fn):
return fn in NATIVE_UNARY_FNS or fn in NATIVE_INPLACE_UNARY_FNS
def _apply_native_unary(fn, *args, **kwargs):
if fn in NATIVE_UNARY_FNS:
return NATIVE_UNARY_MAP[fn](*args, **kwargs)
if fn in NATIVE_INPLACE_UNARY_FNS:
return NATIVE_INPLACE_UNARY_MAP[fn](*args, **kwargs)
return NotImplemented