175 lines
9.7 KiB
C++
175 lines
9.7 KiB
C++
#pragma once
|
|
|
|
#include <cublas_v2.h>
|
|
#include <cusparse.h>
|
|
#include <c10/macros/Export.h>
|
|
|
|
#ifdef CUDART_VERSION
|
|
#include <cusolver_common.h>
|
|
#endif
|
|
|
|
#include <ATen/Context.h>
|
|
#include <c10/util/Exception.h>
|
|
#include <c10/cuda/CUDAException.h>
|
|
|
|
|
|
namespace c10 {
|
|
|
|
class CuDNNError : public c10::Error {
|
|
using Error::Error;
|
|
};
|
|
|
|
} // namespace c10
|
|
|
|
#define AT_CUDNN_FRONTEND_CHECK(EXPR, ...) \
|
|
do { \
|
|
auto error_object = EXPR; \
|
|
if (!error_object.is_good()) { \
|
|
TORCH_CHECK_WITH(CuDNNError, false, \
|
|
"cuDNN Frontend error: ", error_object.get_message()); \
|
|
} \
|
|
} while (0) \
|
|
|
|
#define AT_CUDNN_CHECK_WITH_SHAPES(EXPR, ...) AT_CUDNN_CHECK(EXPR, "\n", ##__VA_ARGS__)
|
|
|
|
// See Note [CHECK macro]
|
|
#define AT_CUDNN_CHECK(EXPR, ...) \
|
|
do { \
|
|
cudnnStatus_t status = EXPR; \
|
|
if (status != CUDNN_STATUS_SUCCESS) { \
|
|
if (status == CUDNN_STATUS_NOT_SUPPORTED) { \
|
|
TORCH_CHECK_WITH(CuDNNError, false, \
|
|
"cuDNN error: ", \
|
|
cudnnGetErrorString(status), \
|
|
". This error may appear if you passed in a non-contiguous input.", ##__VA_ARGS__); \
|
|
} else { \
|
|
TORCH_CHECK_WITH(CuDNNError, false, \
|
|
"cuDNN error: ", cudnnGetErrorString(status), ##__VA_ARGS__); \
|
|
} \
|
|
} \
|
|
} while (0)
|
|
|
|
namespace at::cuda::blas {
|
|
C10_EXPORT const char* _cublasGetErrorEnum(cublasStatus_t error);
|
|
} // namespace at::cuda::blas
|
|
|
|
#define TORCH_CUDABLAS_CHECK(EXPR) \
|
|
do { \
|
|
cublasStatus_t __err = EXPR; \
|
|
TORCH_CHECK(__err == CUBLAS_STATUS_SUCCESS, \
|
|
"CUDA error: ", \
|
|
at::cuda::blas::_cublasGetErrorEnum(__err), \
|
|
" when calling `" #EXPR "`"); \
|
|
} while (0)
|
|
|
|
const char *cusparseGetErrorString(cusparseStatus_t status);
|
|
|
|
#define TORCH_CUDASPARSE_CHECK(EXPR) \
|
|
do { \
|
|
cusparseStatus_t __err = EXPR; \
|
|
TORCH_CHECK(__err == CUSPARSE_STATUS_SUCCESS, \
|
|
"CUDA error: ", \
|
|
cusparseGetErrorString(__err), \
|
|
" when calling `" #EXPR "`"); \
|
|
} while (0)
|
|
|
|
// cusolver related headers are only supported on cuda now
|
|
#ifdef CUDART_VERSION
|
|
|
|
namespace at::cuda::solver {
|
|
C10_EXPORT const char* cusolverGetErrorMessage(cusolverStatus_t status);
|
|
|
|
constexpr const char* _cusolver_backend_suggestion = \
|
|
"If you keep seeing this error, you may use " \
|
|
"`torch.backends.cuda.preferred_linalg_library()` to try " \
|
|
"linear algebra operators with other supported backends. " \
|
|
"See https://pytorch.org/docs/stable/backends.html#torch.backends.cuda.preferred_linalg_library";
|
|
|
|
} // namespace at::cuda::solver
|
|
|
|
// When cuda < 11.5, cusolver raises CUSOLVER_STATUS_EXECUTION_FAILED when input contains nan.
|
|
// When cuda >= 11.5, cusolver normally finishes execution and sets info array indicating convergence issue.
|
|
#define TORCH_CUSOLVER_CHECK(EXPR) \
|
|
do { \
|
|
cusolverStatus_t __err = EXPR; \
|
|
if ((CUDA_VERSION < 11500 && \
|
|
__err == CUSOLVER_STATUS_EXECUTION_FAILED) || \
|
|
(CUDA_VERSION >= 11500 && \
|
|
__err == CUSOLVER_STATUS_INVALID_VALUE)) { \
|
|
TORCH_CHECK_LINALG( \
|
|
false, \
|
|
"cusolver error: ", \
|
|
at::cuda::solver::cusolverGetErrorMessage(__err), \
|
|
", when calling `" #EXPR "`", \
|
|
". This error may appear if the input matrix contains NaN. ", \
|
|
at::cuda::solver::_cusolver_backend_suggestion); \
|
|
} else { \
|
|
TORCH_CHECK( \
|
|
__err == CUSOLVER_STATUS_SUCCESS, \
|
|
"cusolver error: ", \
|
|
at::cuda::solver::cusolverGetErrorMessage(__err), \
|
|
", when calling `" #EXPR "`. ", \
|
|
at::cuda::solver::_cusolver_backend_suggestion); \
|
|
} \
|
|
} while (0)
|
|
|
|
#else
|
|
#define TORCH_CUSOLVER_CHECK(EXPR) EXPR
|
|
#endif
|
|
|
|
#define AT_CUDA_CHECK(EXPR) C10_CUDA_CHECK(EXPR)
|
|
|
|
// For CUDA Driver API
|
|
//
|
|
// This is here instead of in c10 because NVRTC is loaded dynamically via a stub
|
|
// in ATen, and we need to use its nvrtcGetErrorString.
|
|
// See NOTE [ USE OF NVRTC AND DRIVER API ].
|
|
#if !defined(USE_ROCM)
|
|
|
|
#define AT_CUDA_DRIVER_CHECK(EXPR) \
|
|
do { \
|
|
CUresult __err = EXPR; \
|
|
if (__err != CUDA_SUCCESS) { \
|
|
const char* err_str; \
|
|
CUresult get_error_str_err C10_UNUSED = at::globalContext().getNVRTC().cuGetErrorString(__err, &err_str); \
|
|
if (get_error_str_err != CUDA_SUCCESS) { \
|
|
AT_ERROR("CUDA driver error: unknown error"); \
|
|
} else { \
|
|
AT_ERROR("CUDA driver error: ", err_str); \
|
|
} \
|
|
} \
|
|
} while (0)
|
|
|
|
#else
|
|
|
|
#define AT_CUDA_DRIVER_CHECK(EXPR) \
|
|
do { \
|
|
CUresult __err = EXPR; \
|
|
if (__err != CUDA_SUCCESS) { \
|
|
AT_ERROR("CUDA driver error: ", static_cast<int>(__err)); \
|
|
} \
|
|
} while (0)
|
|
|
|
#endif
|
|
|
|
// For CUDA NVRTC
|
|
//
|
|
// Note: As of CUDA 10, nvrtc error code 7, NVRTC_ERROR_BUILTIN_OPERATION_FAILURE,
|
|
// incorrectly produces the error string "NVRTC unknown error."
|
|
// The following maps it correctly.
|
|
//
|
|
// This is here instead of in c10 because NVRTC is loaded dynamically via a stub
|
|
// in ATen, and we need to use its nvrtcGetErrorString.
|
|
// See NOTE [ USE OF NVRTC AND DRIVER API ].
|
|
#define AT_CUDA_NVRTC_CHECK(EXPR) \
|
|
do { \
|
|
nvrtcResult __err = EXPR; \
|
|
if (__err != NVRTC_SUCCESS) { \
|
|
if (static_cast<int>(__err) != 7) { \
|
|
AT_ERROR("CUDA NVRTC error: ", at::globalContext().getNVRTC().nvrtcGetErrorString(__err)); \
|
|
} else { \
|
|
AT_ERROR("CUDA NVRTC error: NVRTC_ERROR_BUILTIN_OPERATION_FAILURE"); \
|
|
} \
|
|
} \
|
|
} while (0)
|