# cython: infer_types=True # cython: boundscheck=False # Copyright ExplsionAI GmbH, released under BSD. import atexit cdef extern from "blis.h" nogil: enum blis_err_t "err_t": pass cdef struct blis_cntx_t "cntx_t": pass cdef struct blis_rntm_t "rntm_s": pass ctypedef enum blis_trans_t "trans_t": BLIS_NO_TRANSPOSE BLIS_TRANSPOSE BLIS_CONJ_NO_TRANSPOSE BLIS_CONJ_TRANSPOSE ctypedef enum blis_conj_t "conj_t": BLIS_NO_CONJUGATE BLIS_CONJUGATE ctypedef enum blis_side_t "side_t": BLIS_LEFT BLIS_RIGHT ctypedef enum blis_uplo_t "uplo_t": BLIS_LOWER BLIS_UPPER BLIS_DENSE ctypedef enum blis_diag_t "diag_t": BLIS_NONUNIT_DIAG BLIS_UNIT_DIAG char* bli_info_get_int_type_size_str() blis_err_t bli_init() blis_err_t bli_finalize() blis_err_t bli_rntm_init(blis_rntm_t* rntm); # BLAS level 3 routines void bli_dgemm_ex( blis_trans_t transa, blis_trans_t transb, dim_t m, dim_t n, dim_t k, double* alpha, double* a, inc_t rsa, inc_t csa, double* b, inc_t rsb, inc_t csb, double* beta, double* c, inc_t rsc, inc_t csc, blis_cntx_t* cntx, blis_rntm_t* rntm, ) # BLAS level 3 routines void bli_sgemm_ex( blis_trans_t transa, blis_trans_t transb, dim_t m, dim_t n, dim_t k, float* alpha, float* a, inc_t rsa, inc_t csa, float* b, inc_t rsb, inc_t csb, float* beta, float* c, inc_t rsc, inc_t csc, blis_cntx_t* cntx, blis_rntm_t* rntm, ) void bli_dger_ex( blis_conj_t conjx, blis_conj_t conjy, dim_t m, dim_t n, double* alpha, double* x, inc_t incx, double* y, inc_t incy, double* a, inc_t rsa, inc_t csa, blis_cntx_t* cntx, blis_rntm_t* rntm, ) void bli_sger_ex( blis_conj_t conjx, blis_conj_t conjy, dim_t m, dim_t n, float* alpha, float* x, inc_t incx, float* y, inc_t incy, float* a, inc_t rsa, inc_t csa, blis_cntx_t* cntx, blis_rntm_t* rntm, ) void bli_dgemv_ex( blis_trans_t transa, blis_conj_t conjx, dim_t m, dim_t n, double* alpha, double* a, inc_t rsa, inc_t csa, double* x, inc_t incx, double* beta, double* y, inc_t incy, blis_cntx_t* cntx, blis_rntm_t* rntm, ) void bli_sgemv_ex( blis_trans_t transa, blis_conj_t conjx, dim_t m, dim_t n, float* alpha, float* a, inc_t rsa, inc_t csa, float* x, inc_t incx, float* beta, float* y, inc_t incy, blis_cntx_t* cntx, blis_rntm_t* rntm, ) void bli_daxpyv_ex( blis_conj_t conjx, dim_t m, double* alpha, double* x, inc_t incx, double* y, inc_t incy, blis_cntx_t* cntx, blis_rntm_t* rntm, ) void bli_saxpyv_ex( blis_conj_t conjx, dim_t m, float* alpha, float* x, inc_t incx, float* y, inc_t incy, blis_cntx_t* cntx, blis_rntm_t* rntm, ) void bli_dscalv_ex( blis_conj_t conjalpha, dim_t m, double* alpha, double* x, inc_t incx, blis_cntx_t* cntx, blis_rntm_t* rntm, ) void bli_sscalv_ex( blis_conj_t conjalpha, dim_t m, float* alpha, float* x, inc_t incx, blis_cntx_t* cntx, blis_rntm_t* rntm, ) void bli_ddotv_ex( blis_conj_t conjx, blis_conj_t conjy, dim_t m, double* x, inc_t incx, double* y, inc_t incy, double* rho, blis_cntx_t* cntx, blis_rntm_t* rntm, ) void bli_sdotv_ex( blis_conj_t conjx, blis_conj_t conjy, dim_t m, float* x, inc_t incx, float* y, inc_t incy, float* rho, blis_cntx_t* cntx, blis_rntm_t* rntm, ) void bli_snorm1v_ex( dim_t n, float* x, inc_t incx, float* norm, blis_cntx_t* cntx, blis_rntm_t* rntm, ) void bli_dnorm1v_ex( dim_t n, double* x, inc_t incx, double* norm, blis_cntx_t* cntx, blis_rntm_t* rntm, ) void bli_snormfv_ex( dim_t n, float* x, inc_t incx, float* norm, blis_cntx_t* cntx, blis_rntm_t* rntm, ) void bli_dnormfv_ex( dim_t n, double* x, inc_t incx, double* norm, blis_cntx_t* cntx, blis_rntm_t* rntm, ) void bli_snormiv_ex( dim_t n, float* x, inc_t incx, float* norm, blis_cntx_t* cntx, blis_rntm_t* rntm, ) void bli_dnormiv_ex( dim_t n, double* x, inc_t incx, double* norm, blis_cntx_t* cntx, blis_rntm_t* rntm, ) void bli_srandv_ex( dim_t m, float* x, inc_t incx, blis_cntx_t* cntx, blis_rntm_t* rntm, ) void bli_drandv_ex( dim_t m, double* x, inc_t incx, blis_cntx_t* cntx, blis_rntm_t* rntm, ) void bli_ssumsqv_ex( dim_t m, float* x, inc_t incx, float* scale, float* sumsq, blis_cntx_t* cntx, blis_rntm_t* rntm, ) nogil void bli_dsumsqv_ex( dim_t m, double* x, inc_t incx, double* scale, double* sumsq, blis_cntx_t* cntx, blis_rntm_t* rntm, ) nogil bli_init() cdef blis_rntm_t rntm; def init(): bli_init() bli_rntm_init(&rntm); assert BLIS_NO_TRANSPOSE == NO_TRANSPOSE assert BLIS_TRANSPOSE == TRANSPOSE assert BLIS_CONJ_NO_TRANSPOSE == CONJ_NO_TRANSPOSE assert BLIS_CONJ_TRANSPOSE == CONJ_TRANSPOSE assert BLIS_NO_CONJUGATE == NO_CONJUGATE assert BLIS_CONJUGATE == CONJUGATE assert BLIS_LEFT == LEFT assert BLIS_RIGHT == RIGHT assert BLIS_LOWER == LOWER assert BLIS_UPPER == UPPER assert BLIS_DENSE == DENSE assert BLIS_NONUNIT_DIAG == NONUNIT_DIAG assert BLIS_UNIT_DIAG == UNIT_DIAG def get_int_type_size(): cdef char* int_size = bli_info_get_int_type_size_str() return '%d' % int_size[0] # BLAS level 3 routines cdef void gemm( trans_t trans_a, trans_t trans_b, dim_t m, dim_t n, dim_t k, double alpha, reals_ft a, inc_t rsa, inc_t csa, reals_ft b, inc_t rsb, inc_t csb, double beta, reals_ft c, inc_t rsc, inc_t csc ) nogil: cdef float alpha_f = alpha cdef float beta_f = beta cdef double alpha_d = alpha cdef double beta_d = beta if reals_ft is floats_t: bli_sgemm_ex( trans_a, trans_b, m, n, k, &alpha_f, a, rsa, csa, b, rsb, csb, &beta_f, c, rsc, csc, NULL, &rntm) elif reals_ft is doubles_t: bli_dgemm_ex( trans_a, trans_b, m, n, k, &alpha_d, a, rsa, csa, b, rsb, csb, &beta_d, c, rsc, csc, NULL, &rntm) elif reals_ft is float1d_t: bli_sgemm_ex( trans_a, trans_b, m, n, k, &alpha_f, &a[0], rsa, csa, &b[0], rsb, csb, &beta_f, &c[0], rsc, csc, NULL, &rntm) elif reals_ft is double1d_t: bli_dgemm_ex( trans_a, trans_b, m, n, k, &alpha_d, &a[0], rsa, csa, &b[0], rsb, csb, &beta_d, &c[0], rsc, csc, NULL, &rntm) else: # Impossible --- panic? pass cdef void ger( conj_t conjx, conj_t conjy, dim_t m, dim_t n, double alpha, reals_ft x, inc_t incx, reals_ft y, inc_t incy, reals_ft a, inc_t rsa, inc_t csa ) nogil: cdef float alpha_f = alpha cdef double alpha_d = alpha if reals_ft is floats_t: bli_sger_ex( conjx, conjy, m, n, &alpha_f, x, incx, y, incy, a, rsa, csa, NULL, &rntm) elif reals_ft is doubles_t: bli_dger_ex( conjx, conjy, m, n, &alpha_d, x, incx, y, incy, a, rsa, csa, NULL, &rntm) elif reals_ft is float1d_t: bli_sger_ex( conjx, conjy, m, n, &alpha_f, &x[0], incx, &y[0], incy, &a[0], rsa, csa, NULL, &rntm) elif reals_ft is double1d_t: bli_dger_ex( conjx, conjy, m, n, &alpha_d, &x[0], incx, &y[0], incy, &a[0], rsa, csa, NULL, &rntm) else: # Impossible --- panic? pass cdef void gemv( trans_t transa, conj_t conjx, dim_t m, dim_t n, real_ft alpha, reals_ft a, inc_t rsa, inc_t csa, reals_ft x, inc_t incx, real_ft beta, reals_ft y, inc_t incy ) nogil: cdef float alpha_f = alpha cdef double alpha_d = alpha cdef float beta_f = alpha cdef double beta_d = alpha if reals_ft is floats_t: bli_sgemv_ex( transa, conjx, m, n, &alpha_f, a, rsa, csa, x, incx, &beta_f, y, incy, NULL, &rntm) elif reals_ft is doubles_t: bli_dgemv_ex( transa, conjx, m, n, &alpha_d, a, rsa, csa, x, incx, &beta_d, y, incy, NULL, &rntm) elif reals_ft is float1d_t: bli_sgemv_ex( transa, conjx, m, n, &alpha_f, &a[0], rsa, csa, &x[0], incx, &beta_f, &y[0], incy, NULL, &rntm) elif reals_ft is double1d_t: bli_dgemv_ex( transa, conjx, m, n, &alpha_d, &a[0], rsa, csa, &x[0], incx, &beta_d, &y[0], incy, NULL, &rntm) else: # Impossible --- panic? pass cdef void axpyv( conj_t conjx, dim_t m, real_ft alpha, reals_ft x, inc_t incx, reals_ft y, inc_t incy ) nogil: cdef float alpha_f = alpha cdef double alpha_d = alpha if reals_ft is floats_t: bli_saxpyv_ex(conjx, m, &alpha_f, x, incx, y, incy, NULL, &rntm) elif reals_ft is doubles_t: bli_daxpyv_ex(conjx, m, &alpha_d, x, incx, y, incy, NULL, &rntm) elif reals_ft is float1d_t: bli_saxpyv_ex(conjx, m, &alpha_f, &x[0], incx, &y[0], incy, NULL, &rntm) elif reals_ft is double1d_t: bli_daxpyv_ex(conjx, m, &alpha_d, &x[0], incx, &y[0], incy, NULL, &rntm) else: # Impossible --- panic? pass cdef void scalv( conj_t conjalpha, dim_t m, real_ft alpha, reals_ft x, inc_t incx ) nogil: cdef float alpha_f = alpha cdef double alpha_d = alpha if reals_ft is floats_t: bli_sscalv_ex(conjalpha, m, &alpha_f, x, incx, NULL, &rntm) elif reals_ft is doubles_t: bli_dscalv_ex(conjalpha, m, &alpha_d, x, incx, NULL, &rntm) elif reals_ft is float1d_t: bli_sscalv_ex(conjalpha, m, &alpha_f, &x[0], incx, NULL, &rntm) elif reals_ft is double1d_t: bli_dscalv_ex(conjalpha, m, &alpha_d, &x[0], incx, NULL, &rntm) else: # Impossible --- panic? pass cdef double norm_L1( dim_t n, reals_ft x, inc_t incx ) nogil: cdef double dnorm = 0 cdef float snorm = 0 if reals_ft is floats_t: bli_snorm1v_ex(n, x, incx, &snorm, NULL, &rntm) dnorm = snorm elif reals_ft is doubles_t: bli_dnorm1v_ex(n, x, incx, &dnorm, NULL, &rntm) elif reals_ft is float1d_t: bli_snorm1v_ex(n, &x[0], incx, &snorm, NULL, &rntm) dnorm = snorm elif reals_ft is double1d_t: bli_dnorm1v_ex(n, &x[0], incx, &dnorm, NULL, &rntm) else: # Impossible --- panic? pass return dnorm cdef double norm_L2( dim_t n, reals_ft x, inc_t incx ) nogil: cdef double dnorm = 0 cdef float snorm = 0 if reals_ft is floats_t: bli_snormfv_ex(n, x, incx, &snorm, NULL, &rntm) dnorm = snorm elif reals_ft is doubles_t: bli_dnormfv_ex(n, x, incx, &dnorm, NULL, &rntm) elif reals_ft is float1d_t: bli_snormfv_ex(n, &x[0], incx, &snorm, NULL, &rntm) dnorm = snorm elif reals_ft is double1d_t: bli_dnormfv_ex(n, &x[0], incx, &dnorm, NULL, &rntm) else: # Impossible --- panic? pass return dnorm cdef double norm_inf( dim_t n, reals_ft x, inc_t incx ) nogil: cdef double dnorm = 0 cdef float snorm = 0 if reals_ft is floats_t: bli_snormiv_ex(n, x, incx, &snorm, NULL, &rntm) dnorm = snorm elif reals_ft is doubles_t: bli_dnormiv_ex(n, x, incx, &dnorm, NULL, &rntm) elif reals_ft is float1d_t: bli_snormiv_ex(n, &x[0], incx, &snorm, NULL, &rntm) dnorm = snorm elif reals_ft is double1d_t: bli_dnormiv_ex(n, &x[0], incx, &dnorm, NULL, &rntm) else: # Impossible --- panic? pass return dnorm cdef double dotv( conj_t conjx, conj_t conjy, dim_t m, reals_ft x, reals_ft y, inc_t incx, inc_t incy, ) nogil: cdef double rho_d = 0.0 cdef float rho_f = 0.0 if reals_ft is floats_t: bli_sdotv_ex(conjx, conjy, m, x, incx, y, incy, &rho_f, NULL, &rntm) return rho_f elif reals_ft is doubles_t: bli_ddotv_ex(conjx, conjy, m, x, incx, y, incy, &rho_d, NULL, &rntm) return rho_d elif reals_ft is float1d_t: bli_sdotv_ex(conjx, conjy, m, &x[0], incx, &y[0], incy, &rho_f, NULL, &rntm) return rho_f elif reals_ft is double1d_t: bli_ddotv_ex(conjx, conjy, m, &x[0], incx, &y[0], incy, &rho_d, NULL, &rntm) return rho_d else: raise ValueError("Unhandled fused type") cdef void randv(dim_t m, reals_ft x, inc_t incx) nogil: if reals_ft is floats_t: bli_srandv_ex(m, x, incx, NULL, &rntm) elif reals_ft is float1d_t: bli_srandv_ex(m, &x[0], incx, NULL, &rntm) if reals_ft is doubles_t: bli_drandv_ex(m, x, incx, NULL, &rntm) elif reals_ft is double1d_t: bli_drandv_ex(m, &x[0], incx, NULL, &rntm) else: with gil: raise ValueError("Unhandled fused type") cdef void sumsqv(dim_t m, reals_ft x, inc_t incx, reals_ft scale, reals_ft sumsq) nogil: if reals_ft is floats_t: bli_ssumsqv_ex(m, &x[0], incx, scale, sumsq, NULL, &rntm) elif reals_ft is float1d_t: bli_ssumsqv_ex(m, &x[0], incx, &scale[0], &sumsq[0], NULL, &rntm) if reals_ft is doubles_t: bli_dsumsqv_ex(m, x, incx, scale, sumsq, NULL, &rntm) elif reals_ft is double1d_t: bli_dsumsqv_ex(m, &x[0], incx, &scale[0], &sumsq[0], NULL, &rntm) else: with gil: raise ValueError("Unhandled fused type") cdef void dgemm(bint transA, bint transB, int M, int N, int K, double alpha, const double* A, int lda, const double* B, int ldb, double beta, double* C, int ldc) nogil: gemm( TRANSPOSE if transA else NO_TRANSPOSE, TRANSPOSE if transB else NO_TRANSPOSE, M, N, K, alpha, A, lda, 1, B, ldb, 1, beta, C, ldc, 1 ) cdef void sgemm(bint transA, bint transB, int M, int N, int K, float alpha, const float* A, int lda, const float* B, int ldb, float beta, float* C, int ldc) nogil: gemm( TRANSPOSE if transA else NO_TRANSPOSE, TRANSPOSE if transB else NO_TRANSPOSE, M, N, K, alpha, A, lda, 1, B, ldb, 1, beta, C, ldc, 1 ) cdef void saxpy(int N, float alpha, const float* X, int incX, float* Y, int incY) nogil: axpyv(NO_CONJUGATE, N, alpha, X, incX, Y, incY) cdef void daxpy(int N, double alpha, const double* X, int incX, double* Y, int incY) nogil: axpyv(NO_CONJUGATE, N, alpha, X, incX, Y, incY) @atexit.register def finalize(): bli_finalize()