158 lines
7.1 KiB
C++
158 lines
7.1 KiB
C++
#include <ATen/core/Tensor.h>
|
|
#include <ATen/core/dispatch/Dispatcher.h>
|
|
#include <ATen/core/op_registration/op_registration.h>
|
|
#include <ATen/native/UnaryOps.h>
|
|
#include <ATen/native/Resize.h>
|
|
#include <c10/util/irange.h>
|
|
#include <torch/library.h>
|
|
|
|
#ifndef AT_PER_OPERATOR_HEADERS
|
|
#include <ATen/Functions.h>
|
|
#else
|
|
#include <ATen/ops/clone.h>
|
|
|
|
#include <utility>
|
|
#endif
|
|
|
|
namespace at::native {
|
|
// This fallback should only be used for operations that are self inverse and have a corresponding tensor
|
|
// bit (internally implemented using DispatchKey) to maintain the state on tensor using tensor bit.
|
|
// Currently there are two tensor bits that trigger this fallback: conjugate bit and negative bit.
|
|
// Conjugate bit is set on a tensor when `.conj()` is called and neg bit is set on a tensor when `.conj().imag` is called.
|
|
|
|
// NOTE: To use this fallback, `clone` and `copy_` should fully understand and be able to correctly handle the semantic of your math bit.
|
|
struct MathOpFallback {
|
|
MathOpFallback(DispatchKey key_, string op_name_) : key(key_), op_name(std::move(op_name_)) {}
|
|
virtual bool is_bit_set(const Tensor&) = 0;
|
|
void fallback_impl(const c10::OperatorHandle& op, DispatchKeySet dispatch_keys, torch::jit::Stack* stack) {
|
|
/*
|
|
Situations to handle:
|
|
1. Out-of-place operation. Easy: materialize all inputs and
|
|
call it a day.
|
|
2. Inplace operation. Desugar x.add_(2) into x.conj_().add_(2).conj_().
|
|
Materialize other inputs as in (1).
|
|
3. out= operation. Desugar add(x, 2, out=y) into y.copy_(add(x, 2))
|
|
Materialize other inputs as in (1).
|
|
|
|
It is important to be able to tell if we READ from an argument and if we
|
|
WRITE to an argument. Conservative approach is to assume that we always
|
|
READ from an argument, but in out= operations you can skip
|
|
conjugating inputs on entry that never get used. In the current schema we
|
|
can't easily tell if the operation is in in-place or out= operation.
|
|
|
|
Note:
|
|
1. Mutable tensorlists containing tensors whose math bit set to true are disallowed.
|
|
2. Mutable tensors with math bit set to true are unconditionally cloned to ensure
|
|
correct behavior in the case when the mutable tensor shares memory with non mutable arguments.
|
|
|
|
If we were to in-place resolve the math bit for mutable inputs, then the non-mutable inputs sharing partial or full memory
|
|
with these mutable inputs would read into wrong values in the following cases:
|
|
1. Non mutable inputs have their math bit set to false.
|
|
2. Math bit for mutable input(s) is resolved before the non mutable inputs (with bit set to true and sharing memory
|
|
with one or more mutable arg(s)) are cloned.
|
|
At the end, the final value of the mutable arguments from the stack are copied into the original input mutable tensor inputs.
|
|
*/
|
|
const auto& arguments = op.schema().arguments();
|
|
const auto num_arguments = arguments.size();
|
|
const auto stack_start = stack->size() - num_arguments;
|
|
|
|
c10::optional<bool> is_write;
|
|
for (const auto i : c10::irange(num_arguments)) {
|
|
// Three possible states:
|
|
// 1. alias_info has no value --> out-of-place operation
|
|
// 2. alias_info does have a value, alias_info->is_write=True --> in-place or out= operation
|
|
// 3. alias_info does have a value, alias_info->is_write=False --> view operation
|
|
const AliasInfo* alias_info = arguments[i].alias_info();
|
|
if (alias_info != nullptr) {
|
|
if (is_write.has_value()) {
|
|
TORCH_CHECK(*is_write == alias_info->isWrite(),
|
|
"Unsupported operator for ", op_name, " fallback: ", op.schema().name(),
|
|
op_name, " fallback doesn't work for operators with a mix "
|
|
"mutable and non-mutable inputs that alias with outputs, "
|
|
"this must be implemented manually. "
|
|
"If you got this error on a core op, please report a bug to PyTorch.");
|
|
} else {
|
|
is_write = alias_info->isWrite();
|
|
}
|
|
}
|
|
}
|
|
|
|
if (is_write.has_value() && !*is_write) {
|
|
// We assume that view operators automatically handle the math bit
|
|
// correctly by propagating the dispatch key in key_set.
|
|
// This is not necessarily always right, so you should test these cases.
|
|
op.redispatchBoxed(dispatch_keys & c10::DispatchKeySet(DispatchKeySet::FULL_AFTER, key), stack);
|
|
return;
|
|
}
|
|
|
|
// Mutable inputs with math bit set to True and their clones
|
|
std::vector<std::pair<Tensor, Tensor>> mutable_inputs_with_their_clones;
|
|
for (const auto i : c10::irange(num_arguments)) {
|
|
auto& ivalue = (*stack)[stack_start + i];
|
|
if (!(ivalue.isTensor() || ivalue.isTensorList())) {
|
|
continue;
|
|
}
|
|
const auto& argument = arguments[i];
|
|
bool mut_arg = false;
|
|
if (argument.alias_info()) {
|
|
// Was already tested by is_write loop above
|
|
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(argument.alias_info()->isWrite());
|
|
mut_arg = true;
|
|
}
|
|
if (ivalue.isTensor()) {
|
|
if (!is_bit_set(ivalue.toTensor())) {
|
|
continue;
|
|
}
|
|
auto tensor = std::move(ivalue).toTensor();
|
|
auto resolved_tensor = at::clone(tensor);
|
|
if (mut_arg) {
|
|
TORCH_CHECK(mutable_inputs_with_their_clones.empty(), op_name, " fallback does not support operators with more than one mutable tensors with ",
|
|
op_name, "bit set to true.");
|
|
mutable_inputs_with_their_clones.emplace_back(std::move(tensor), resolved_tensor);
|
|
}
|
|
(*stack)[stack_start + i] = std::move(resolved_tensor);
|
|
} else if (ivalue.isTensorList()) {
|
|
auto tensors = std::move(ivalue).toTensorList();
|
|
for(const auto j : c10::irange(tensors.size())) {
|
|
const auto& tensor = tensors[j];
|
|
if (!is_bit_set(tensor)) {
|
|
continue;
|
|
}
|
|
TORCH_CHECK(!mut_arg, " fallback doesn't currently support mutable TensorLists with ",
|
|
op_name, " inputs. Please materialize all the ", op_name, " input tensor(s) in the mutable TensorList inputs before calling ",
|
|
op.schema().name());
|
|
tensors[j] = at::clone(tensor);
|
|
}
|
|
(*stack)[stack_start + i] = std::move(tensors);
|
|
}
|
|
}
|
|
|
|
op.redispatchBoxed(dispatch_keys & c10::DispatchKeySet(DispatchKeySet::FULL_AFTER, key), stack);
|
|
|
|
TORCH_INTERNAL_ASSERT(mutable_inputs_with_their_clones.size() <= 1);
|
|
|
|
for (std::pair<Tensor, Tensor> mut_tensors: mutable_inputs_with_their_clones) {
|
|
auto& mutable_input = mut_tensors.first;
|
|
auto& cloned_mutable_input = mut_tensors.second;
|
|
auto& ivalue = (*stack)[stack_start];
|
|
auto returned_output = std::move(ivalue).toTensor();
|
|
|
|
// sanity check to ensure that the tensor in stack aliases the cloned_mutable_input
|
|
TORCH_INTERNAL_ASSERT(cloned_mutable_input.is_same(returned_output));
|
|
|
|
// necessary for out= arg
|
|
at::native::resize_output(mutable_input, returned_output.sizes());
|
|
|
|
mutable_input.copy_(returned_output);
|
|
(*stack)[stack_start] = std::move(mutable_input);
|
|
}
|
|
}
|
|
|
|
virtual ~MathOpFallback() = default;
|
|
|
|
DispatchKey key;
|
|
string op_name;
|
|
};
|
|
|
|
} // namespace at::native
|