193 lines
5.2 KiB
Python
193 lines
5.2 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates
|
|
|
|
import torch
|
|
|
|
from .core import _map_mt_args_kwargs, _masks_match, _tensors_match, _wrap_result, is_masked_tensor
|
|
|
|
__all__ = [] # type: ignore[var-annotated]
|
|
|
|
BINARY_NAMES = [
|
|
"add",
|
|
"atan2",
|
|
"arctan2",
|
|
"bitwise_and",
|
|
"bitwise_or",
|
|
"bitwise_xor",
|
|
"bitwise_left_shift",
|
|
"bitwise_right_shift",
|
|
"div",
|
|
"divide",
|
|
"floor_divide",
|
|
"fmod",
|
|
"logaddexp",
|
|
"logaddexp2",
|
|
"mul",
|
|
"multiply",
|
|
"nextafter",
|
|
"remainder",
|
|
"sub",
|
|
"subtract",
|
|
"true_divide",
|
|
"eq",
|
|
"ne",
|
|
"le",
|
|
"ge",
|
|
"greater",
|
|
"greater_equal",
|
|
"gt",
|
|
"less_equal",
|
|
"lt",
|
|
"less",
|
|
"maximum",
|
|
"minimum",
|
|
"fmax",
|
|
"fmin",
|
|
"not_equal",
|
|
]
|
|
|
|
INPLACE_BINARY_NAMES = [
|
|
n + "_"
|
|
for n in (
|
|
list(
|
|
set(BINARY_NAMES)
|
|
- {
|
|
"logaddexp",
|
|
"logaddexp2",
|
|
"equal",
|
|
"fmin",
|
|
"minimum",
|
|
"maximum",
|
|
"fmax",
|
|
}
|
|
)
|
|
)
|
|
]
|
|
|
|
|
|
def _get_at_least_one_mask(a, b):
|
|
if not is_masked_tensor(a) and not is_masked_tensor(b):
|
|
raise TypeError("At least one of `a` and `b` must be a MaskedTensor")
|
|
if not _masks_match(a, b):
|
|
raise ValueError("a and b must have matching masks")
|
|
if is_masked_tensor(a):
|
|
return a.get_mask()
|
|
return b.get_mask()
|
|
|
|
|
|
def _binary_helper(fn, args, kwargs, inplace):
|
|
if len(kwargs) != 0:
|
|
raise ValueError("len(kwargs) must equal 0")
|
|
for a in args[2:]:
|
|
if torch.is_tensor(a):
|
|
raise TypeError("MaskedTensor binary ops do not support Tensor arguments aside from the lhs and rhs")
|
|
|
|
if not _masks_match(*args[:2]):
|
|
raise ValueError(
|
|
"Input masks must match. If you need support for this, please open an issue on Github."
|
|
)
|
|
|
|
data_args, data_kwargs = _map_mt_args_kwargs(
|
|
args, kwargs, lambda x: x.get_data()
|
|
)
|
|
mask_args, mask_kwargs = _map_mt_args_kwargs(
|
|
args, kwargs, lambda x: x.get_mask()
|
|
)
|
|
|
|
args0_layout = data_args[0].layout
|
|
same_layout = (
|
|
(torch.is_tensor(data_args[1]) or is_masked_tensor(data_args[1])) and
|
|
(args0_layout == data_args[1].layout)
|
|
)
|
|
|
|
if args0_layout == torch.sparse_coo:
|
|
if same_layout:
|
|
if not _tensors_match(data_args[0].indices(), data_args[1].indices()):
|
|
raise ValueError(
|
|
"sparse_coo indices must match. If you need support for this, please open an issue on Github."
|
|
)
|
|
if data_args[0].size() != data_args[1].size():
|
|
raise ValueError("input1 and input2 must have the same size for binary functions.")
|
|
|
|
data_args[1] = data_args[1].values()
|
|
|
|
i = data_args[0].indices()
|
|
size = data_args[0].size()
|
|
data_args[0] = data_args[0].values()
|
|
v = fn(*data_args)
|
|
result_data = torch.sparse_coo_tensor(i, v, size)
|
|
|
|
elif args0_layout == torch.sparse_csr:
|
|
if same_layout:
|
|
if not (
|
|
_tensors_match(data_args[0].crow_indices(), data_args[1].crow_indices())
|
|
and _tensors_match(
|
|
data_args[0].col_indices(), data_args[1].col_indices()
|
|
)
|
|
):
|
|
raise ValueError(
|
|
"sparse_csr indices must match. If you need support for this, please open an issue on Github."
|
|
)
|
|
|
|
data_args[1] = data_args[1].values()
|
|
|
|
crow = data_args[0].crow_indices()
|
|
col = data_args[0].col_indices()
|
|
data_args[0] = data_args[0].values()
|
|
v = fn(*data_args)
|
|
result_data = torch.sparse_csr_tensor(crow, col, v)
|
|
|
|
else:
|
|
result_data = fn(*data_args)
|
|
|
|
if inplace:
|
|
args[0]._set_data_mask(result_data, mask_args[0])
|
|
return args[0]
|
|
else:
|
|
result_mask = _get_at_least_one_mask(*args[:2])
|
|
# sparse tensors don't have strides so we can only expand if the layout is strided
|
|
if args0_layout == torch.strided:
|
|
result_mask = result_mask.expand_as(result_data)
|
|
return _wrap_result(result_data, result_mask)
|
|
|
|
|
|
def _torch_binary(fn_name):
|
|
fn = getattr(torch.ops.aten, fn_name)
|
|
|
|
def binary_fn(*args, **kwargs):
|
|
return _binary_helper(fn, args, kwargs, inplace=False)
|
|
|
|
return binary_fn
|
|
|
|
|
|
def _torch_inplace_binary(fn_name):
|
|
fn = getattr(torch.ops.aten, fn_name)
|
|
|
|
def binary_fn(*args, **kwargs):
|
|
return _binary_helper(fn, args, kwargs, inplace=True)
|
|
|
|
return binary_fn
|
|
|
|
|
|
NATIVE_BINARY_MAP = {
|
|
getattr(torch.ops.aten, name): _torch_binary(name) for name in BINARY_NAMES
|
|
}
|
|
NATIVE_INPLACE_BINARY_MAP = {
|
|
getattr(torch.ops.aten, name): _torch_inplace_binary(name)
|
|
for name in INPLACE_BINARY_NAMES
|
|
}
|
|
|
|
NATIVE_BINARY_FNS = list(NATIVE_BINARY_MAP.keys())
|
|
NATIVE_INPLACE_BINARY_FNS = list(NATIVE_INPLACE_BINARY_MAP.keys())
|
|
|
|
|
|
def _is_native_binary(fn):
|
|
return fn in NATIVE_BINARY_FNS or fn in NATIVE_INPLACE_BINARY_FNS
|
|
|
|
|
|
def _apply_native_binary(fn, *args, **kwargs):
|
|
if fn in NATIVE_BINARY_FNS:
|
|
return NATIVE_BINARY_MAP[fn](*args, **kwargs)
|
|
if fn in NATIVE_INPLACE_BINARY_FNS:
|
|
return NATIVE_INPLACE_BINARY_MAP[fn](*args, **kwargs)
|
|
return NotImplemented
|