#pragma once #include #include #include #ifdef CUDART_VERSION #include #endif #include #include #include namespace c10 { class CuDNNError : public c10::Error { using Error::Error; }; } // namespace c10 #define AT_CUDNN_FRONTEND_CHECK(EXPR, ...) \ do { \ auto error_object = EXPR; \ if (!error_object.is_good()) { \ TORCH_CHECK_WITH(CuDNNError, false, \ "cuDNN Frontend error: ", error_object.get_message()); \ } \ } while (0) \ #define AT_CUDNN_CHECK_WITH_SHAPES(EXPR, ...) AT_CUDNN_CHECK(EXPR, "\n", ##__VA_ARGS__) // See Note [CHECK macro] #define AT_CUDNN_CHECK(EXPR, ...) \ do { \ cudnnStatus_t status = EXPR; \ if (status != CUDNN_STATUS_SUCCESS) { \ if (status == CUDNN_STATUS_NOT_SUPPORTED) { \ TORCH_CHECK_WITH(CuDNNError, false, \ "cuDNN error: ", \ cudnnGetErrorString(status), \ ". This error may appear if you passed in a non-contiguous input.", ##__VA_ARGS__); \ } else { \ TORCH_CHECK_WITH(CuDNNError, false, \ "cuDNN error: ", cudnnGetErrorString(status), ##__VA_ARGS__); \ } \ } \ } while (0) namespace at::cuda::blas { C10_EXPORT const char* _cublasGetErrorEnum(cublasStatus_t error); } // namespace at::cuda::blas #define TORCH_CUDABLAS_CHECK(EXPR) \ do { \ cublasStatus_t __err = EXPR; \ TORCH_CHECK(__err == CUBLAS_STATUS_SUCCESS, \ "CUDA error: ", \ at::cuda::blas::_cublasGetErrorEnum(__err), \ " when calling `" #EXPR "`"); \ } while (0) const char *cusparseGetErrorString(cusparseStatus_t status); #define TORCH_CUDASPARSE_CHECK(EXPR) \ do { \ cusparseStatus_t __err = EXPR; \ TORCH_CHECK(__err == CUSPARSE_STATUS_SUCCESS, \ "CUDA error: ", \ cusparseGetErrorString(__err), \ " when calling `" #EXPR "`"); \ } while (0) // cusolver related headers are only supported on cuda now #ifdef CUDART_VERSION namespace at::cuda::solver { C10_EXPORT const char* cusolverGetErrorMessage(cusolverStatus_t status); constexpr const char* _cusolver_backend_suggestion = \ "If you keep seeing this error, you may use " \ "`torch.backends.cuda.preferred_linalg_library()` to try " \ "linear algebra operators with other supported backends. " \ "See https://pytorch.org/docs/stable/backends.html#torch.backends.cuda.preferred_linalg_library"; } // namespace at::cuda::solver // When cuda < 11.5, cusolver raises CUSOLVER_STATUS_EXECUTION_FAILED when input contains nan. // When cuda >= 11.5, cusolver normally finishes execution and sets info array indicating convergence issue. #define TORCH_CUSOLVER_CHECK(EXPR) \ do { \ cusolverStatus_t __err = EXPR; \ if ((CUDA_VERSION < 11500 && \ __err == CUSOLVER_STATUS_EXECUTION_FAILED) || \ (CUDA_VERSION >= 11500 && \ __err == CUSOLVER_STATUS_INVALID_VALUE)) { \ TORCH_CHECK_LINALG( \ false, \ "cusolver error: ", \ at::cuda::solver::cusolverGetErrorMessage(__err), \ ", when calling `" #EXPR "`", \ ". This error may appear if the input matrix contains NaN. ", \ at::cuda::solver::_cusolver_backend_suggestion); \ } else { \ TORCH_CHECK( \ __err == CUSOLVER_STATUS_SUCCESS, \ "cusolver error: ", \ at::cuda::solver::cusolverGetErrorMessage(__err), \ ", when calling `" #EXPR "`. ", \ at::cuda::solver::_cusolver_backend_suggestion); \ } \ } while (0) #else #define TORCH_CUSOLVER_CHECK(EXPR) EXPR #endif #define AT_CUDA_CHECK(EXPR) C10_CUDA_CHECK(EXPR) // For CUDA Driver API // // This is here instead of in c10 because NVRTC is loaded dynamically via a stub // in ATen, and we need to use its nvrtcGetErrorString. // See NOTE [ USE OF NVRTC AND DRIVER API ]. #if !defined(USE_ROCM) #define AT_CUDA_DRIVER_CHECK(EXPR) \ do { \ CUresult __err = EXPR; \ if (__err != CUDA_SUCCESS) { \ const char* err_str; \ CUresult get_error_str_err C10_UNUSED = at::globalContext().getNVRTC().cuGetErrorString(__err, &err_str); \ if (get_error_str_err != CUDA_SUCCESS) { \ AT_ERROR("CUDA driver error: unknown error"); \ } else { \ AT_ERROR("CUDA driver error: ", err_str); \ } \ } \ } while (0) #else #define AT_CUDA_DRIVER_CHECK(EXPR) \ do { \ CUresult __err = EXPR; \ if (__err != CUDA_SUCCESS) { \ AT_ERROR("CUDA driver error: ", static_cast(__err)); \ } \ } while (0) #endif // For CUDA NVRTC // // Note: As of CUDA 10, nvrtc error code 7, NVRTC_ERROR_BUILTIN_OPERATION_FAILURE, // incorrectly produces the error string "NVRTC unknown error." // The following maps it correctly. // // This is here instead of in c10 because NVRTC is loaded dynamically via a stub // in ATen, and we need to use its nvrtcGetErrorString. // See NOTE [ USE OF NVRTC AND DRIVER API ]. #define AT_CUDA_NVRTC_CHECK(EXPR) \ do { \ nvrtcResult __err = EXPR; \ if (__err != NVRTC_SUCCESS) { \ if (static_cast(__err) != 7) { \ AT_ERROR("CUDA NVRTC error: ", at::globalContext().getNVRTC().nvrtcGetErrorString(__err)); \ } else { \ AT_ERROR("CUDA NVRTC error: NVRTC_ERROR_BUILTIN_OPERATION_FAILURE"); \ } \ } \ } while (0)