# Copyright (c) Meta Platforms, Inc. and affiliates import warnings import torch from .core import is_masked_tensor from .creation import as_masked_tensor, masked_tensor __all__ = [] # type: ignore[var-annotated] def _masked_all_all(data, mask=None): if mask is None: return data.all() return data.masked_fill(~mask, True).all() def _masked_all_dim(data, dim, keepdim=False, mask=None): if mask is None: return torch.all(data, dim=dim, keepdim=keepdim) return torch.all(data.masked_fill(~mask, True), dim=dim, keepdim=keepdim) def _masked_all(*args, **kwargs): if len(args) == 1 and len(kwargs) == 1: return _masked_all_all(args[0], mask=kwargs["mask"]) return _masked_all_dim(*args, **kwargs) def _multidim_any(mask, dim, keepdim): if isinstance(dim, int): return _multidim_any(mask, [dim], keepdim) for d in sorted(dim, reverse=True): mask = torch.any(mask, dim=d, keepdim=keepdim) return mask def _get_masked_fn(fn): if fn == "all": return _masked_all return getattr(torch.masked, fn) def _torch_reduce_all(fn): def reduce_all(self): masked_fn = _get_masked_fn(fn) data = self.get_data() mask = self.get_mask().values() if self.is_sparse else self.get_mask() # When reduction is "all", then torch.argmin/torch.argmax needs to return the index of the # element corresponding to the min/max, but this operation isn't supported correctly for sparse layouts. # Therefore, this implementation calculates it using the strides. if fn == "all": result_data = masked_fn(data, mask=mask) elif fn in {"argmin", "argmax"} and self.is_sparse_coo(): sparse_idx = masked_fn(data.values(), mask=mask).to(dtype=torch.int) indices = ( data.to_sparse_coo().indices() if not self.is_sparse_coo() else data.indices() ) idx = indices.unbind(1)[sparse_idx] stride = data.size().numel() / torch.tensor( data.size(), device=data.device ).cumprod(0) result_data = torch.sum(idx * stride) # we simply pass in the values for sparse COO/CSR tensors elif self.is_sparse: result_data = masked_fn(masked_tensor(data.values(), mask)) else: result_data = masked_fn(self, mask=mask) return as_masked_tensor(result_data, torch.any(mask)) return reduce_all def _torch_reduce_dim(fn): def reduce_dim(self, dim, keepdim=False, dtype=None): if self.is_sparse: msg = ( f"The sparse version of {fn} is not implemented in reductions.\n" "If you would like this operator to be supported, please file an issue for a feature request at " "https://github.com/pytorch/maskedtensor/issues with a minimal reproducible code snippet.\n" "In the case that the semantics for the operator are not trivial, it would be appreciated " "to also include a proposal for the semantics." ) warnings.warn(msg) return NotImplemented if not is_masked_tensor(self): raise TypeError("Input to reduce_dim must be a MaskedTensor") masked_fn = _get_masked_fn(fn) data = self.get_data() mask = self.get_mask() if fn == "all": result_data = masked_fn(data, dim=dim, keepdim=keepdim, mask=mask) else: result_data = masked_fn( self, dim=dim, keepdim=keepdim, dtype=dtype, mask=self.get_mask() ) return as_masked_tensor(result_data, _multidim_any(mask, dim, keepdim)) return reduce_dim def _torch_reduce(fn): def reduce_fn(*args, **kwargs): if len(args) == 1 and len(kwargs) == 0: return _torch_reduce_all(fn)(args[0]) return _torch_reduce_dim(fn)(*args, **kwargs) return reduce_fn def _reduce_dim_args(input, dim, keepdim=False, dtype=None): return input, dim, keepdim, dtype def _torch_grad_reduce(fn): def grad_reduce(*args, **kwargs): if len(args) == 1 and len(kwargs) == 0: return _torch_reduce_all(fn)(args[0]) # TODO: autograd.Function doesn't support kwarg input, dim, keepdim, dtype = _reduce_dim_args(*args, **kwargs) return _torch_reduce_dim(fn)(input, dim, keepdim, dtype) return grad_reduce REDUCE_NAMES = [ "sum", "mean", "amin", "amax", "argmin", "argmax", "prod", "all", "norm", "var", "std", ] NATIVE_REDUCE_MAP = { getattr(torch.ops.aten, name): _torch_reduce(name) for name in REDUCE_NAMES } TORCH_REDUCE_MAP = { getattr(torch, name): _torch_grad_reduce(name) for name in REDUCE_NAMES } TENSOR_REDUCE_MAP = { getattr(torch.Tensor, name): _torch_grad_reduce(name) for name in REDUCE_NAMES } NATIVE_REDUCE_FNS = list(NATIVE_REDUCE_MAP.keys()) TORCH_REDUCE_FNS = list(TORCH_REDUCE_MAP.keys()) TENSOR_REDUCE_FNS = list(TENSOR_REDUCE_MAP.keys()) def _is_reduction(fn): return fn in NATIVE_REDUCE_MAP or fn in TORCH_REDUCE_MAP or fn in TENSOR_REDUCE_MAP def _apply_reduction(fn, *args, **kwargs): if fn in NATIVE_REDUCE_MAP: return NATIVE_REDUCE_MAP[fn](*args, **kwargs) if fn in TORCH_REDUCE_MAP: return TORCH_REDUCE_MAP[fn](*args, **kwargs) if fn in TENSOR_REDUCE_MAP: return TENSOR_REDUCE_MAP[fn](*args, **kwargs) return NotImplemented