104 lines
3.0 KiB
C++
104 lines
3.0 KiB
C++
#include <array>
|
|
|
|
#include <ATen/Functions.h>
|
|
#include <ATen/Utils.h>
|
|
#include <c10/core/Allocator.h>
|
|
|
|
namespace at {
|
|
|
|
Tensor TensorMaker::make_tensor() {
|
|
AutoDispatchBelowADInplaceOrView guard{}; // TODO: Remove.
|
|
tracer::impl::NoTracerDispatchMode tracer_guard{};
|
|
|
|
check_size_nonnegative(sizes_);
|
|
|
|
TORCH_CHECK_VALUE(
|
|
!deleter_ || !ctx_,
|
|
"The deleter and context arguments are mutually exclusive.");
|
|
|
|
if (device_ == nullopt) {
|
|
device_ = globalContext().getDeviceFromPtr(data_, opts_.device().type());
|
|
}
|
|
|
|
if (opts_.device().has_index()) {
|
|
// clang-format off
|
|
TORCH_CHECK_VALUE(
|
|
opts_.device() == *device_,
|
|
"Specified device ", opts_.device(), " does not match device of data ", *device_);
|
|
// clang-format on
|
|
}
|
|
|
|
std::size_t size_bytes = computeStorageSize();
|
|
|
|
DataPtr data_ptr{};
|
|
if (deleter_) {
|
|
data_ptr = makeDataPtrFromDeleter();
|
|
} else {
|
|
data_ptr = makeDataPtrFromContext();
|
|
}
|
|
|
|
TORCH_CHECK(!resizeable_ || allocator_ != nullptr, "Must specify an allocator with allocator() if you want to use resizeable_storage()");
|
|
Storage storage{Storage::use_byte_size_t{}, size_bytes, std::move(data_ptr), /*allocator=*/allocator_, /*resizeable=*/resizeable_};
|
|
|
|
Tensor tensor = detail::make_tensor<TensorImpl>(
|
|
std::move(storage), opts_.computeDispatchKey(), opts_.dtype());
|
|
|
|
TensorImpl* tensor_impl = tensor.unsafeGetTensorImpl();
|
|
if (strides_) {
|
|
tensor_impl->set_sizes_and_strides(sizes_, *strides_);
|
|
} else {
|
|
tensor_impl->set_sizes_contiguous(sizes_);
|
|
}
|
|
if (storage_offset_) {
|
|
tensor_impl->set_storage_offset(*storage_offset_);
|
|
}
|
|
|
|
return tensor;
|
|
}
|
|
|
|
std::size_t TensorMaker::computeStorageSize() const noexcept {
|
|
std::size_t itemsize = opts_.dtype().itemsize();
|
|
|
|
if (strides_) {
|
|
auto storage_size = detail::computeStorageNbytes(sizes_, *strides_, itemsize);
|
|
if (storage_offset_) {
|
|
storage_size += storage_offset_.value();
|
|
}
|
|
return storage_size;
|
|
}
|
|
|
|
std::size_t size = 1;
|
|
for (std::int64_t s : sizes_) {
|
|
size *= static_cast<std::size_t>(s);
|
|
}
|
|
auto storage_size = size * itemsize;
|
|
if (storage_offset_) {
|
|
storage_size += storage_offset_.value();
|
|
}
|
|
return storage_size;
|
|
}
|
|
|
|
inline DataPtr TensorMaker::makeDataPtrFromDeleter() noexcept {
|
|
return InefficientStdFunctionContext::makeDataPtr(data_, std::move(deleter_), *device_);
|
|
}
|
|
|
|
inline DataPtr TensorMaker::makeDataPtrFromContext() noexcept {
|
|
return DataPtr{data_, ctx_.release(), ctx_.get_deleter(), *device_};
|
|
}
|
|
|
|
IntArrayRef TensorMaker::makeTempSizes() const noexcept {
|
|
static std::int64_t zeros[5] = {0, 0, 0, 0, 0};
|
|
if (opts_.has_memory_format()) {
|
|
MemoryFormat format = *opts_.memory_format_opt();
|
|
if (format == MemoryFormat::ChannelsLast) {
|
|
return IntArrayRef(zeros, 4);
|
|
}
|
|
if (format == MemoryFormat::ChannelsLast3d) {
|
|
return IntArrayRef(zeros, 5);
|
|
}
|
|
}
|
|
return IntArrayRef(zeros, 1);
|
|
}
|
|
|
|
} // namespace at
|