ai-content-maker/.venv/Lib/site-packages/torch/include/ATen/TensorSubclassLikeUtils.h

87 lines
3.1 KiB
C++

#pragma once
#include <ATen/core/List.h>
#include <ATen/core/Tensor.h>
#include <c10/core/impl/TorchDispatchModeTLS.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#else
#include <ATen/ops/equal.h>
#endif
namespace at {
// Note [Tensor-subclass-like Tensors]
// Tensor-subclass-like is defined as:
// - a Tensor subclass (via __torch_dispatch__ in Python or extending
// TensorImpl in C++)
// - anything else that shares the same perils as Tensor subclasses.
// For example, many Tensor subclasses do not have storage and meta Tensors
// do not have storage either, so meta Tensors belong here.
//
// We should ensure that PyTorch internals supports Tensor-subclass-like
// objects. In particular, Tensor-subclass-like objects struggle with two
// classes of operations that are problematic for Tensor subclasses:
// 1. Because some Tensor subclasses do not have storage, .item() or
// .data_ptr() calls are not good.
// 2. Certain in-place operations can eliminate the typing of the Tensor
// subclass. For example:
// >>> torch.zeros(input.sizes(), grad.options()).diag().copy_(input)
// If input is a Tensor subclass, then the above ends up either erroring out
// or returning a regular non-Tensor-subclass Tensor!
constexpr auto kFunctorchWrappedTensors = DispatchKeySet(
{DispatchKey::FuncTorchGradWrapper,
DispatchKey::FuncTorchBatched,
DispatchKey::Functionalize});
constexpr auto kTensorSubclassLike =
kFunctorchWrappedTensors |
DispatchKeySet(
{// WARNING: DO NOT put combined backend component + functionality keys
// here, you will incorrectly always match on the functionality key
// no matter the backend component
DispatchKey::Batched,
DispatchKey::Sparse,
DispatchKey::SparseCsr,
DispatchKey::Python}) |
DispatchKeySet(BackendComponent::MetaBit);
inline bool isTensorSubclassLike(const Tensor& tensor) {
if (c10::impl::dispatch_mode_enabled())
return true;
auto key_set = tensor.unsafeGetTensorImpl()->key_set();
return !(key_set & kTensorSubclassLike).empty();
}
inline bool areAnyTensorSubclassLike(TensorList tensors) {
if (c10::impl::dispatch_mode_enabled())
return true;
return std::any_of(tensors.begin(), tensors.end(), isTensorSubclassLike);
}
inline bool areAnyOptionalTensorSubclassLike(
const c10::List<c10::optional<Tensor>>& tensors) {
if (c10::impl::dispatch_mode_enabled())
return true;
return std::any_of(
tensors.begin(), tensors.end(), [](const optional<Tensor>& opt_tensor) {
return (
opt_tensor.has_value() && isTensorSubclassLike(opt_tensor.value()));
});
}
// Helper function to deal testing truthfulness of a scalar tensor
// in a Composite Compliant manner.
// NOTE: This function expects a scalar tensor of boolean dtype.
// Eg.
// Non-Composite Compliant Pattern : (t == 0).all().item<bool>()
// Composite Compliant Patter : is_salar_tensor_true((t == 0).all())
inline bool is_scalar_tensor_true(const Tensor& t) {
TORCH_INTERNAL_ASSERT(t.dim() == 0)
TORCH_INTERNAL_ASSERT(t.scalar_type() == kBool)
return at::equal(t, t.new_ones({}, t.options()));
}
} // namespace at