190 lines
5.8 KiB
C++
190 lines
5.8 KiB
C++
#pragma once
|
|
|
|
#include <ATen/OpMathType.h>
|
|
#include <ATen/native/DispatchStub.h>
|
|
#include <ATen/native/TransposeType.h>
|
|
#include <c10/util/complex.h>
|
|
#include <c10/core/ScalarType.h>
|
|
#include <c10/core/Scalar.h>
|
|
|
|
namespace at::native::cpublas {
|
|
|
|
namespace internal {
|
|
void normalize_last_dims(
|
|
TransposeType transa, TransposeType transb,
|
|
int64_t m, int64_t n, int64_t k,
|
|
int64_t *lda, int64_t *ldb, int64_t *ldc);
|
|
} // namespace internal
|
|
|
|
using gemm_fn = void(*)(
|
|
at::ScalarType type,
|
|
TransposeType transa, TransposeType transb,
|
|
int64_t m, int64_t n, int64_t k,
|
|
const Scalar& alpha,
|
|
const void *a, int64_t lda,
|
|
const void *b, int64_t ldb,
|
|
const Scalar& beta,
|
|
void *c, int64_t ldc);
|
|
|
|
DECLARE_DISPATCH(gemm_fn, gemm_stub);
|
|
|
|
template <typename scalar_t>
|
|
void gemm(
|
|
TransposeType transa, TransposeType transb,
|
|
int64_t m, int64_t n, int64_t k,
|
|
at::opmath_type<scalar_t> alpha,
|
|
const scalar_t *a, int64_t lda,
|
|
const scalar_t *b, int64_t ldb,
|
|
at::opmath_type<scalar_t> beta,
|
|
scalar_t *c, int64_t ldc) {
|
|
internal::normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc);
|
|
gemm_stub(
|
|
kCPU, c10::CppTypeToScalarType<scalar_t>::value,
|
|
transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
|
|
}
|
|
|
|
void gemm(
|
|
TransposeType transa, TransposeType transb,
|
|
int64_t m, int64_t n, int64_t k,
|
|
double alpha,
|
|
const double *a, int64_t lda,
|
|
const double *b, int64_t ldb,
|
|
double beta,
|
|
double *c, int64_t ldc);
|
|
|
|
void gemm(
|
|
TransposeType transa, TransposeType transb,
|
|
int64_t m, int64_t n, int64_t k,
|
|
float alpha,
|
|
const float *a, int64_t lda,
|
|
const float *b, int64_t ldb,
|
|
float beta,
|
|
float *c, int64_t ldc);
|
|
|
|
void gemm(
|
|
TransposeType transa, TransposeType transb,
|
|
int64_t m, int64_t n, int64_t k,
|
|
float alpha,
|
|
const at::BFloat16 *a, int64_t lda,
|
|
const at::BFloat16 *b, int64_t ldb,
|
|
float beta,
|
|
at::BFloat16 *c, int64_t ldc);
|
|
|
|
void gemm(
|
|
TransposeType transa, TransposeType transb,
|
|
int64_t m, int64_t n, int64_t k,
|
|
const float alpha,
|
|
const at::BFloat16 *a, int64_t lda,
|
|
const at::BFloat16 *b, int64_t ldb,
|
|
const float beta,
|
|
float *c, int64_t ldc);
|
|
|
|
void gemm(
|
|
TransposeType transa, TransposeType transb,
|
|
int64_t m, int64_t n, int64_t k,
|
|
float alpha,
|
|
const at::Half *a, int64_t lda,
|
|
const at::Half *b, int64_t ldb,
|
|
float beta,
|
|
at::Half *c, int64_t ldc);
|
|
|
|
void gemm(
|
|
TransposeType transa, TransposeType transb,
|
|
int64_t m, int64_t n, int64_t k,
|
|
const float alpha,
|
|
const at::Half *a, int64_t lda,
|
|
const at::Half *b, int64_t ldb,
|
|
const float beta,
|
|
float *c, int64_t ldc);
|
|
|
|
void gemm(
|
|
TransposeType transa, TransposeType transb,
|
|
int64_t m, int64_t n, int64_t k,
|
|
c10::complex<double> alpha,
|
|
const c10::complex<double> *a, int64_t lda,
|
|
const c10::complex<double> *b, int64_t ldb,
|
|
c10::complex<double> beta,
|
|
c10::complex<double> *c, int64_t ldc);
|
|
|
|
void gemm(
|
|
TransposeType transa, TransposeType transb,
|
|
int64_t m, int64_t n, int64_t k,
|
|
c10::complex<float> alpha,
|
|
const c10::complex<float> *a, int64_t lda,
|
|
const c10::complex<float> *b, int64_t ldb,
|
|
c10::complex<float> beta,
|
|
c10::complex<float> *c, int64_t ldc);
|
|
|
|
void gemm(
|
|
TransposeType transa, TransposeType transb,
|
|
int64_t m, int64_t n, int64_t k,
|
|
int64_t alpha,
|
|
const int64_t *a, int64_t lda,
|
|
const int64_t *b, int64_t ldb,
|
|
int64_t beta,
|
|
int64_t *c, int64_t ldc);
|
|
|
|
template <typename scalar_t>
|
|
void gemm_batched(
|
|
TransposeType transa, TransposeType transb,
|
|
int64_t batch_size, int64_t m, int64_t n, int64_t k,
|
|
scalar_t alpha,
|
|
const scalar_t * const *a, int64_t lda,
|
|
const scalar_t * const *b, int64_t ldb,
|
|
const scalar_t beta,
|
|
scalar_t * const *c, int64_t ldc);
|
|
|
|
template <typename scalar_t>
|
|
void gemm_batched_with_stride(
|
|
TransposeType transa, TransposeType transb,
|
|
int64_t batch_size, int64_t m, int64_t n, int64_t k,
|
|
scalar_t alpha,
|
|
const scalar_t *a, int64_t lda, int64_t batch_stride_a,
|
|
const scalar_t *b, int64_t ldb, int64_t batch_stride_b,
|
|
scalar_t beta,
|
|
scalar_t *c, int64_t ldc, int64_t batch_stride_c);
|
|
|
|
using axpy_fn = void(*)(at::ScalarType type, int64_t n, const Scalar& a, const void *x, int64_t incx, void *y, int64_t incy);
|
|
|
|
DECLARE_DISPATCH(axpy_fn, axpy_stub);
|
|
|
|
template<typename scalar_t>
|
|
void axpy(int64_t n, scalar_t a, const scalar_t *x, int64_t incx, scalar_t *y, int64_t incy){
|
|
if(n == 1)
|
|
{
|
|
incx = 1;
|
|
incy = 1;
|
|
}
|
|
axpy_stub(
|
|
kCPU, c10::CppTypeToScalarType<scalar_t>::value,
|
|
n, a, x, incx, y, incy);
|
|
}
|
|
|
|
void axpy(int64_t n, double a, const double *x, int64_t incx, double *y, int64_t incy);
|
|
void axpy(int64_t n, float a, const float *x, int64_t incx, float *y, int64_t incy);
|
|
void axpy(int64_t n, c10::complex<double> a, const c10::complex<double> *x, int64_t incx, c10::complex<double> *y, int64_t incy);
|
|
void axpy(int64_t n, c10::complex<float> a, const c10::complex<float> *x, int64_t incx, c10::complex<float> *y, int64_t incy);
|
|
|
|
using copy_fn = void(*)(at::ScalarType type, int64_t n, const void *x, int64_t incx, void *y, int64_t incy);
|
|
|
|
DECLARE_DISPATCH(copy_fn, copy_stub);
|
|
|
|
template<typename scalar_t>
|
|
void copy(int64_t n, const scalar_t *x, int64_t incx, scalar_t *y, int64_t incy) {
|
|
if(n == 1)
|
|
{
|
|
incx = 1;
|
|
incy = 1;
|
|
}
|
|
copy_stub(
|
|
kCPU, c10::CppTypeToScalarType<scalar_t>::value,
|
|
n, x, incx, y, incy);
|
|
}
|
|
|
|
void copy(int64_t n, const double *x, int64_t incx, double *y, int64_t incy);
|
|
void copy(int64_t n, const float *x, int64_t incx, float *y, int64_t incy);
|
|
void copy(int64_t n, const c10::complex<double> *x, int64_t incx, c10::complex<double> *y, int64_t incy);
|
|
void copy(int64_t n, const c10::complex<float> *x, int64_t incx, c10::complex<float> *y, int64_t incy);
|
|
|
|
} // namespace at::native::cpublas
|