188 lines
6.0 KiB
C++
188 lines
6.0 KiB
C++
#pragma once
|
|
|
|
#include <c10/core/MemoryFormat.h>
|
|
#include <c10/core/SymIntArrayRef.h>
|
|
#include <c10/core/TensorImpl.h>
|
|
#include <c10/util/Exception.h>
|
|
|
|
namespace at {
|
|
|
|
// An "Opaque" TensorImpl -- there are no strides and (for now)
|
|
// even data() is not supported (thus no pointer arithmetic).
|
|
|
|
// NOTE: We could allow data() in the future, but would have to ensure pointer
|
|
// arithmetic code is properly guarded.
|
|
//
|
|
// NOTE: This does not support resize_ (and other metadata-changing ops) because
|
|
// of `shallow_copy_and_detach`. We would need to define an interface to
|
|
// "shallow copy" in order to add support.
|
|
|
|
template <typename OpaqueHandle>
|
|
struct TORCH_API OpaqueTensorImpl : public TensorImpl {
|
|
// public constructor for now...
|
|
OpaqueTensorImpl(
|
|
at::DispatchKeySet key_set,
|
|
const caffe2::TypeMeta data_type,
|
|
c10::Device device,
|
|
OpaqueHandle opaque_handle,
|
|
c10::IntArrayRef sizes,
|
|
bool is_non_overlapping_and_dense = true)
|
|
: TensorImpl(key_set, data_type, device),
|
|
opaque_handle_(std::move(opaque_handle)) {
|
|
set_storage_access_should_throw();
|
|
set_custom_sizes_strides(SizesStridesPolicy::CustomStrides);
|
|
sizes_and_strides_.set_sizes(sizes);
|
|
refresh_numel();
|
|
// NOLINTNEXTLINE(cppcoreguidelines-prefer-member-initializer)
|
|
is_non_overlapping_and_dense_ = is_non_overlapping_and_dense;
|
|
}
|
|
|
|
// Destructor doesn't call release_resources because it's
|
|
// unnecessary; don't forget to change that if needed!
|
|
void release_resources() override {
|
|
TensorImpl::release_resources();
|
|
opaque_handle_ = {};
|
|
}
|
|
|
|
void set_size(int64_t dim, int64_t new_size) override {
|
|
AT_ERROR("opaque tensors do not have set_size");
|
|
}
|
|
|
|
void set_stride(int64_t dim, int64_t new_stride) override {
|
|
AT_ERROR("opaque tensors do not have set_stride");
|
|
}
|
|
|
|
void set_storage_offset(int64_t storage_offset) override {
|
|
AT_ERROR("opaque tensors do not have set_storage_offset");
|
|
}
|
|
|
|
#ifdef DEBUG
|
|
bool has_storage() const override {
|
|
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
|
|
!storage_, "OpaqueTensorImpl assumes that storage_ is never set");
|
|
return false;
|
|
}
|
|
#endif
|
|
|
|
/**
|
|
* Return a TensorImpl that is a shallow-copy of this TensorImpl.
|
|
*
|
|
* For usage of `version_counter` and `allow_tensor_metadata_change`,
|
|
* see NOTE [ TensorImpl Shallow-Copying ].
|
|
*/
|
|
c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
|
|
const c10::VariableVersion& version_counter,
|
|
bool allow_tensor_metadata_change) const override {
|
|
auto impl = c10::make_intrusive<OpaqueTensorImpl<OpaqueHandle>>(
|
|
key_set(),
|
|
dtype(),
|
|
device(),
|
|
opaque_handle_,
|
|
sizes_and_strides_.sizes_arrayref());
|
|
copy_tensor_metadata(
|
|
/*src_opaque_impl=*/this,
|
|
/*dest_opaque_impl=*/impl.get(),
|
|
/*version_counter=*/version_counter,
|
|
/*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
|
|
impl->refresh_numel();
|
|
return impl;
|
|
}
|
|
|
|
/**
|
|
* Return a TensorImpl that is a shallow-copy of this TensorImpl.
|
|
*
|
|
* For usage of `version_counter` and `allow_tensor_metadata_change`,
|
|
* see NOTE [ TensorImpl Shallow-Copying ].
|
|
*/
|
|
c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
|
|
c10::VariableVersion&& version_counter,
|
|
bool allow_tensor_metadata_change) const override {
|
|
auto impl = c10::make_intrusive<OpaqueTensorImpl<OpaqueHandle>>(
|
|
key_set(),
|
|
dtype(),
|
|
device(),
|
|
opaque_handle_,
|
|
sizes_and_strides_.sizes_arrayref());
|
|
copy_tensor_metadata(
|
|
/*src_opaque_impl=*/this,
|
|
/*dest_opaque_impl=*/impl.get(),
|
|
/*version_counter=*/std::move(version_counter),
|
|
/*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
|
|
impl->refresh_numel();
|
|
return impl;
|
|
}
|
|
|
|
/**
|
|
* Shallow-copies data from another TensorImpl into this TensorImpl.
|
|
*
|
|
* For why this function doesn't check this TensorImpl's
|
|
* `allow_tensor_metadata_change_`, see NOTE [ TensorImpl Shallow-Copying ].
|
|
*/
|
|
void shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) override {
|
|
AT_ASSERT(has_compatible_shallow_copy_type(impl->key_set()));
|
|
auto opaque_impl =
|
|
static_cast<const OpaqueTensorImpl<OpaqueHandle>*>(impl.get());
|
|
copy_tensor_metadata(
|
|
/*src_impl=*/opaque_impl,
|
|
/*dest_impl=*/this,
|
|
/*version_counter=*/version_counter(),
|
|
/*allow_tensor_metadata_change=*/allow_tensor_metadata_change());
|
|
refresh_numel();
|
|
}
|
|
|
|
const OpaqueHandle& opaque_handle() const {
|
|
return opaque_handle_;
|
|
}
|
|
|
|
OpaqueHandle& unsafe_opaque_handle() {
|
|
return opaque_handle_;
|
|
}
|
|
|
|
protected:
|
|
/**
|
|
* Copy the tensor metadata fields (e.g. sizes / strides / storage pointer /
|
|
* storage_offset) from one TensorImpl to another TensorImpl.
|
|
*
|
|
* For usage of `version_counter` and `allow_tensor_metadata_change`, see NOTE
|
|
* [ TensorImpl Shallow-Copying ].
|
|
*/
|
|
static void copy_tensor_metadata(
|
|
const OpaqueTensorImpl<OpaqueHandle>* src_opaque_impl,
|
|
OpaqueTensorImpl<OpaqueHandle>* dest_opaque_impl,
|
|
const c10::VariableVersion& version_counter,
|
|
bool allow_tensor_metadata_change) {
|
|
TensorImpl::copy_tensor_metadata(
|
|
src_opaque_impl,
|
|
dest_opaque_impl,
|
|
version_counter,
|
|
allow_tensor_metadata_change);
|
|
|
|
// OpaqueTensorImpl-specific fields.
|
|
dest_opaque_impl->opaque_handle_ = src_opaque_impl->opaque_handle_;
|
|
}
|
|
|
|
static void copy_tensor_metadata(
|
|
const OpaqueTensorImpl<OpaqueHandle>* src_opaque_impl,
|
|
OpaqueTensorImpl<OpaqueHandle>* dest_opaque_impl,
|
|
c10::VariableVersion&& version_counter,
|
|
bool allow_tensor_metadata_change) {
|
|
TensorImpl::copy_tensor_metadata(
|
|
src_opaque_impl,
|
|
dest_opaque_impl,
|
|
std::move(version_counter),
|
|
allow_tensor_metadata_change);
|
|
|
|
// OpaqueTensorImpl-specific fields.
|
|
dest_opaque_impl->opaque_handle_ = src_opaque_impl->opaque_handle_;
|
|
}
|
|
|
|
private:
|
|
const char* tensorimpl_type_name() const override {
|
|
return "OpaqueTensorImpl";
|
|
}
|
|
|
|
OpaqueHandle opaque_handle_;
|
|
};
|
|
|
|
} // namespace at
|