174 lines
5.4 KiB
Python
174 lines
5.4 KiB
Python
|
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||
|
|
||
|
import warnings
|
||
|
|
||
|
import torch
|
||
|
|
||
|
from .core import is_masked_tensor
|
||
|
from .creation import as_masked_tensor, masked_tensor
|
||
|
|
||
|
__all__ = [] # type: ignore[var-annotated]
|
||
|
|
||
|
|
||
|
def _masked_all_all(data, mask=None):
|
||
|
if mask is None:
|
||
|
return data.all()
|
||
|
return data.masked_fill(~mask, True).all()
|
||
|
|
||
|
|
||
|
def _masked_all_dim(data, dim, keepdim=False, mask=None):
|
||
|
if mask is None:
|
||
|
return torch.all(data, dim=dim, keepdim=keepdim)
|
||
|
return torch.all(data.masked_fill(~mask, True), dim=dim, keepdim=keepdim)
|
||
|
|
||
|
|
||
|
def _masked_all(*args, **kwargs):
|
||
|
if len(args) == 1 and len(kwargs) == 1:
|
||
|
return _masked_all_all(args[0], mask=kwargs["mask"])
|
||
|
return _masked_all_dim(*args, **kwargs)
|
||
|
|
||
|
|
||
|
def _multidim_any(mask, dim, keepdim):
|
||
|
if isinstance(dim, int):
|
||
|
return _multidim_any(mask, [dim], keepdim)
|
||
|
for d in sorted(dim, reverse=True):
|
||
|
mask = torch.any(mask, dim=d, keepdim=keepdim)
|
||
|
return mask
|
||
|
|
||
|
|
||
|
def _get_masked_fn(fn):
|
||
|
if fn == "all":
|
||
|
return _masked_all
|
||
|
return getattr(torch.masked, fn)
|
||
|
|
||
|
|
||
|
def _torch_reduce_all(fn):
|
||
|
def reduce_all(self):
|
||
|
masked_fn = _get_masked_fn(fn)
|
||
|
data = self.get_data()
|
||
|
mask = self.get_mask().values() if self.is_sparse else self.get_mask()
|
||
|
# When reduction is "all", then torch.argmin/torch.argmax needs to return the index of the
|
||
|
# element corresponding to the min/max, but this operation isn't supported correctly for sparse layouts.
|
||
|
# Therefore, this implementation calculates it using the strides.
|
||
|
if fn == "all":
|
||
|
result_data = masked_fn(data, mask=mask)
|
||
|
|
||
|
elif fn in {"argmin", "argmax"} and self.is_sparse_coo():
|
||
|
sparse_idx = masked_fn(data.values(), mask=mask).to(dtype=torch.int)
|
||
|
indices = (
|
||
|
data.to_sparse_coo().indices()
|
||
|
if not self.is_sparse_coo()
|
||
|
else data.indices()
|
||
|
)
|
||
|
idx = indices.unbind(1)[sparse_idx]
|
||
|
stride = data.size().numel() / torch.tensor(
|
||
|
data.size(), device=data.device
|
||
|
).cumprod(0)
|
||
|
result_data = torch.sum(idx * stride)
|
||
|
|
||
|
# we simply pass in the values for sparse COO/CSR tensors
|
||
|
elif self.is_sparse:
|
||
|
result_data = masked_fn(masked_tensor(data.values(), mask))
|
||
|
|
||
|
else:
|
||
|
result_data = masked_fn(self, mask=mask)
|
||
|
|
||
|
return as_masked_tensor(result_data, torch.any(mask))
|
||
|
|
||
|
return reduce_all
|
||
|
|
||
|
|
||
|
def _torch_reduce_dim(fn):
|
||
|
def reduce_dim(self, dim, keepdim=False, dtype=None):
|
||
|
if self.is_sparse:
|
||
|
msg = (
|
||
|
f"The sparse version of {fn} is not implemented in reductions.\n"
|
||
|
"If you would like this operator to be supported, please file an issue for a feature request at "
|
||
|
"https://github.com/pytorch/maskedtensor/issues with a minimal reproducible code snippet.\n"
|
||
|
"In the case that the semantics for the operator are not trivial, it would be appreciated "
|
||
|
"to also include a proposal for the semantics."
|
||
|
)
|
||
|
warnings.warn(msg)
|
||
|
return NotImplemented
|
||
|
if not is_masked_tensor(self):
|
||
|
raise TypeError("Input to reduce_dim must be a MaskedTensor")
|
||
|
|
||
|
masked_fn = _get_masked_fn(fn)
|
||
|
data = self.get_data()
|
||
|
mask = self.get_mask()
|
||
|
if fn == "all":
|
||
|
result_data = masked_fn(data, dim=dim, keepdim=keepdim, mask=mask)
|
||
|
else:
|
||
|
result_data = masked_fn(
|
||
|
self, dim=dim, keepdim=keepdim, dtype=dtype, mask=self.get_mask()
|
||
|
)
|
||
|
return as_masked_tensor(result_data, _multidim_any(mask, dim, keepdim))
|
||
|
|
||
|
return reduce_dim
|
||
|
|
||
|
|
||
|
def _torch_reduce(fn):
|
||
|
def reduce_fn(*args, **kwargs):
|
||
|
if len(args) == 1 and len(kwargs) == 0:
|
||
|
return _torch_reduce_all(fn)(args[0])
|
||
|
return _torch_reduce_dim(fn)(*args, **kwargs)
|
||
|
|
||
|
return reduce_fn
|
||
|
|
||
|
|
||
|
def _reduce_dim_args(input, dim, keepdim=False, dtype=None):
|
||
|
return input, dim, keepdim, dtype
|
||
|
|
||
|
|
||
|
def _torch_grad_reduce(fn):
|
||
|
def grad_reduce(*args, **kwargs):
|
||
|
if len(args) == 1 and len(kwargs) == 0:
|
||
|
return _torch_reduce_all(fn)(args[0])
|
||
|
# TODO: autograd.Function doesn't support kwarg
|
||
|
input, dim, keepdim, dtype = _reduce_dim_args(*args, **kwargs)
|
||
|
return _torch_reduce_dim(fn)(input, dim, keepdim, dtype)
|
||
|
|
||
|
return grad_reduce
|
||
|
|
||
|
|
||
|
REDUCE_NAMES = [
|
||
|
"sum",
|
||
|
"mean",
|
||
|
"amin",
|
||
|
"amax",
|
||
|
"argmin",
|
||
|
"argmax",
|
||
|
"prod",
|
||
|
"all",
|
||
|
"norm",
|
||
|
"var",
|
||
|
"std",
|
||
|
]
|
||
|
|
||
|
NATIVE_REDUCE_MAP = {
|
||
|
getattr(torch.ops.aten, name): _torch_reduce(name) for name in REDUCE_NAMES
|
||
|
}
|
||
|
TORCH_REDUCE_MAP = {
|
||
|
getattr(torch, name): _torch_grad_reduce(name) for name in REDUCE_NAMES
|
||
|
}
|
||
|
TENSOR_REDUCE_MAP = {
|
||
|
getattr(torch.Tensor, name): _torch_grad_reduce(name) for name in REDUCE_NAMES
|
||
|
}
|
||
|
|
||
|
NATIVE_REDUCE_FNS = list(NATIVE_REDUCE_MAP.keys())
|
||
|
TORCH_REDUCE_FNS = list(TORCH_REDUCE_MAP.keys())
|
||
|
TENSOR_REDUCE_FNS = list(TENSOR_REDUCE_MAP.keys())
|
||
|
|
||
|
def _is_reduction(fn):
|
||
|
return fn in NATIVE_REDUCE_MAP or fn in TORCH_REDUCE_MAP or fn in TENSOR_REDUCE_MAP
|
||
|
|
||
|
|
||
|
def _apply_reduction(fn, *args, **kwargs):
|
||
|
if fn in NATIVE_REDUCE_MAP:
|
||
|
return NATIVE_REDUCE_MAP[fn](*args, **kwargs)
|
||
|
if fn in TORCH_REDUCE_MAP:
|
||
|
return TORCH_REDUCE_MAP[fn](*args, **kwargs)
|
||
|
if fn in TENSOR_REDUCE_MAP:
|
||
|
return TENSOR_REDUCE_MAP[fn](*args, **kwargs)
|
||
|
return NotImplemented
|