#include #include #include #include namespace at { Tensor TensorMaker::make_tensor() { AutoDispatchBelowADInplaceOrView guard{}; // TODO: Remove. tracer::impl::NoTracerDispatchMode tracer_guard{}; check_size_nonnegative(sizes_); TORCH_CHECK_VALUE( !deleter_ || !ctx_, "The deleter and context arguments are mutually exclusive."); if (device_ == nullopt) { device_ = globalContext().getDeviceFromPtr(data_, opts_.device().type()); } if (opts_.device().has_index()) { // clang-format off TORCH_CHECK_VALUE( opts_.device() == *device_, "Specified device ", opts_.device(), " does not match device of data ", *device_); // clang-format on } std::size_t size_bytes = computeStorageSize(); DataPtr data_ptr{}; if (deleter_) { data_ptr = makeDataPtrFromDeleter(); } else { data_ptr = makeDataPtrFromContext(); } TORCH_CHECK(!resizeable_ || allocator_ != nullptr, "Must specify an allocator with allocator() if you want to use resizeable_storage()"); Storage storage{Storage::use_byte_size_t{}, size_bytes, std::move(data_ptr), /*allocator=*/allocator_, /*resizeable=*/resizeable_}; Tensor tensor = detail::make_tensor( std::move(storage), opts_.computeDispatchKey(), opts_.dtype()); TensorImpl* tensor_impl = tensor.unsafeGetTensorImpl(); if (strides_) { tensor_impl->set_sizes_and_strides(sizes_, *strides_); } else { tensor_impl->set_sizes_contiguous(sizes_); } if (storage_offset_) { tensor_impl->set_storage_offset(*storage_offset_); } return tensor; } std::size_t TensorMaker::computeStorageSize() const noexcept { std::size_t itemsize = opts_.dtype().itemsize(); if (strides_) { auto storage_size = detail::computeStorageNbytes(sizes_, *strides_, itemsize); if (storage_offset_) { storage_size += storage_offset_.value(); } return storage_size; } std::size_t size = 1; for (std::int64_t s : sizes_) { size *= static_cast(s); } auto storage_size = size * itemsize; if (storage_offset_) { storage_size += storage_offset_.value(); } return storage_size; } inline DataPtr TensorMaker::makeDataPtrFromDeleter() noexcept { return InefficientStdFunctionContext::makeDataPtr(data_, std::move(deleter_), *device_); } inline DataPtr TensorMaker::makeDataPtrFromContext() noexcept { return DataPtr{data_, ctx_.release(), ctx_.get_deleter(), *device_}; } IntArrayRef TensorMaker::makeTempSizes() const noexcept { static std::int64_t zeros[5] = {0, 0, 0, 0, 0}; if (opts_.has_memory_format()) { MemoryFormat format = *opts_.memory_format_opt(); if (format == MemoryFormat::ChannelsLast) { return IntArrayRef(zeros, 4); } if (format == MemoryFormat::ChannelsLast3d) { return IntArrayRef(zeros, 5); } } return IntArrayRef(zeros, 1); } } // namespace at