#pragma once #include namespace c10 { class Scalar; } namespace at { struct TensorIterator; struct TensorIteratorBase; namespace native { #if defined(__CUDACC__) || defined(__HIPCC__) #define HOST_DEVICE __host__ __device__ #else #define HOST_DEVICE #endif // integral power in pytorch allows for negative exponents, giving truncated integral results. // e.g. since 2**-1==0.5, the truncated integral result is zero. 1**negative_exponent is the // only non-zero result. template ::value, T>::type* = nullptr> static inline HOST_DEVICE __ubsan_ignore_signed_int_overflow__ T powi_impl(T a, T b) { T result = 1; while (b) { if (b & 1) { result *= a; } b /= 2; a *= a; } return result; } template ::value && !std::is_signed::value, T>::type* = nullptr> static inline HOST_DEVICE T powi(T a, T b) { return powi_impl(a, b); } template ::value && std::is_signed::value, T>::type* = nullptr> static inline HOST_DEVICE T powi(T a, T b) { if ( b < 0 ) { if ( a == 1 ) { return 1; } else if ( a == -1 ) { auto negative = (-b) % static_cast(2); return negative ? -1 : 1; } else { return 0; } } return powi_impl(a, b); } using pow_tensor_tensor_fn = void (*)(TensorIteratorBase&); using pow_tensor_scalar_fn = void (*)(TensorIteratorBase&, const c10::Scalar&); DECLARE_DISPATCH(pow_tensor_tensor_fn, pow_tensor_tensor_stub); DECLARE_DISPATCH(pow_tensor_scalar_fn, pow_tensor_scalar_stub); } // namespace native } // namespace at