#pragma once #include #include #include namespace at::cuda { inline Device getDeviceFromPtr(void* ptr) { cudaPointerAttributes attr{}; AT_CUDA_CHECK(cudaPointerGetAttributes(&attr, ptr)); #if !defined(USE_ROCM) TORCH_CHECK(attr.type != cudaMemoryTypeUnregistered, "The specified pointer resides on host memory and is not registered with any CUDA device."); #endif return {c10::DeviceType::CUDA, static_cast(attr.device)}; } } // namespace at::cuda