From b50284b1f07624ff42a4a76528bb7fcbd36b8e22 Mon Sep 17 00:00:00 2001 From: Chen Nuo <49788094+Cstandardlib@users.noreply.github.com> Date: Thu, 19 Mar 2026 12:33:31 +0800 Subject: [PATCH] Add DSP gemm pack with auto memcpy to buffer --- .../source_base/kernels/dsp/dsp_connector.cpp | 120 +++++++++++++++++- .../source_base/kernels/dsp/dsp_connector.h | 61 ++++++--- .../module_external/blas_connector_matrix.cpp | 16 ++- 3 files changed, 177 insertions(+), 20 deletions(-) diff --git a/source/source_base/kernels/dsp/dsp_connector.cpp b/source/source_base/kernels/dsp/dsp_connector.cpp index 2baf73a4ec..7fa0f20ee7 100644 --- a/source/source_base/kernels/dsp/dsp_connector.cpp +++ b/source/source_base/kernels/dsp/dsp_connector.cpp @@ -403,6 +403,66 @@ void zgemm_mth_(const char* transa, free_ht(bet); } // zgemm that needn't malloc_ht or free_ht +void zgemm_pack_mth_(const char* transa, + const char* transb, + const int* m, + const int* n, + const int* k, + const std::complex* alpha, + const std::complex* a, + const int* lda, + const std::complex* b, + const int* ldb, + const std::complex* beta, + std::complex* c, + const int* ldc, + int cluster_id) +{ + const bool transa_not = (transa[0] == 'N' || transa[0] == 'n'); + const bool transb_not = (transb[0] == 'N' || transb[0] == 'n'); + // const size_t a_elems = static_cast(*lda) * (transa_not ? static_cast(*k) : static_cast(*m)); + // const size_t b_elems = static_cast(*ldb) * (transb_not ? static_cast(*n) : static_cast(*k)); + const size_t c_elems = static_cast(*ldc) * static_cast(*n); + + // std::complex* A_dsp = static_cast*>(malloc_ht(a_elems * sizeof(std::complex), cluster_id)); + // std::complex* B_dsp = static_cast*>(malloc_ht(b_elems * sizeof(std::complex), cluster_id)); + std::complex* C_dsp = static_cast*>(malloc_ht(c_elems * sizeof(std::complex), cluster_id)); + std::complex* alp = static_cast*>(malloc_ht(sizeof(std::complex), cluster_id)); + std::complex* bet = static_cast*>(malloc_ht(sizeof(std::complex), cluster_id)); + + // memcpy(A_dsp, a, a_elems * sizeof(std::complex)); + // memcpy(B_dsp, b, b_elems * sizeof(std::complex)); + memcpy(C_dsp, c, c_elems * sizeof(std::complex)); + *alp = *alpha; + *bet = *beta; + + mt_hthread_zgemm(CBLAS_ORDER::CblasColMajor, + convertBLASTranspose(transa), + convertBLASTranspose(transb), + *m, + *n, + *k, + alp, + a, + // A_dsp, + *lda, + b, + // B_dsp, + *ldb, + bet, + // c, + C_dsp, + *ldc, + cluster_id); + memcpy(c, C_dsp, c_elems * sizeof(std::complex)); + + // free_ht(A_dsp); + // free_ht(B_dsp); + free_ht(C_dsp); + free_ht(alp); + free_ht(bet); +} + void cgemm_mth_(const char* transa, const char* transb, const int* m, @@ -443,6 +503,64 @@ void cgemm_mth_(const char* transa, free_ht(bet); } // cgemm that needn't malloc_ht or free_ht +void cgemm_pack_mth_(const char* transa, + const char* transb, + const int* m, + const int* n, + const int* k, + const std::complex* alpha, + const std::complex* a, + const int* lda, + const std::complex* b, + const int* ldb, + const std::complex* beta, + std::complex* c, + const int* ldc, + int cluster_id) +{ + const bool transa_not = (transa[0] == 'N' || transa[0] == 'n'); + const bool transb_not = (transb[0] == 'N' || transb[0] == 'n'); + const size_t a_elems = static_cast(*lda) * (transa_not ? static_cast(*k) : static_cast(*m)); + const size_t b_elems = static_cast(*ldb) * (transb_not ? static_cast(*n) : static_cast(*k)); + const size_t c_elems = static_cast(*ldc) * static_cast(*n); + + std::complex* A_dsp = static_cast*>(malloc_ht(a_elems * sizeof(std::complex), cluster_id)); + std::complex* B_dsp = static_cast*>(malloc_ht(b_elems * sizeof(std::complex), cluster_id)); + std::complex* C_dsp = static_cast*>(malloc_ht(c_elems * sizeof(std::complex), cluster_id)); + std::complex* alp = static_cast*>(malloc_ht(sizeof(std::complex), cluster_id)); + std::complex* bet = static_cast*>(malloc_ht(sizeof(std::complex), cluster_id)); + + memcpy(A_dsp, a, a_elems * sizeof(std::complex)); + memcpy(B_dsp, b, b_elems * sizeof(std::complex)); + memcpy(C_dsp, c, c_elems * sizeof(std::complex)); + *alp = *alpha; + *bet = *beta; + + mt_hthread_cgemm(CBLAS_ORDER::CblasColMajor, + convertBLASTranspose(transa), + convertBLASTranspose(transb), + *m, + *n, + *k, + (const void*)alp, + (const void*)A_dsp, + *lda, + (const void*)B_dsp, + *ldb, + (const void*)bet, + (void*)C_dsp, + *ldc, + cluster_id); + + memcpy(c, C_dsp, c_elems * sizeof(std::complex)); + + free_ht(A_dsp); + free_ht(B_dsp); + free_ht(C_dsp); + free_ht(alp); + free_ht(bet); +} + void sgemv_mth_(const char* transa, const int* m, const int* n, @@ -570,4 +688,4 @@ void cgemv_mth_(const char* transa, free_ht(alp); free_ht(bet); } -} // namespace mtfunc \ No newline at end of file +} // namespace mtfunc diff --git a/source/source_base/kernels/dsp/dsp_connector.h b/source/source_base/kernels/dsp/dsp_connector.h index 997a21de59..beb6614c25 100644 --- a/source/source_base/kernels/dsp/dsp_connector.h +++ b/source/source_base/kernels/dsp/dsp_connector.h @@ -61,20 +61,51 @@ void zgemm_mt_(const char* transa, const int* ldc, int cluster_id); -void cgemm_mt_(const char* transa, - const char* transb, - const int* m, - const int* n, - const int* k, - const std::complex* alpha, - const std::complex* a, - const int* lda, - const std::complex* b, - const int* ldb, - const std::complex* beta, - std::complex* c, - const int* ldc, - int cluster_id); +void zgemm_pack_mth_(const char* transa, + const char* transb, + const int* m, + const int* n, + const int* k, + const std::complex* alpha, + const std::complex* a, + const int* lda, + const std::complex* b, + const int* ldb, + const std::complex* beta, + std::complex* c, + const int* ldc, + int cluster_id); + + +void cgemm_mth_(const char* transa, + const char* transb, + const int* m, + const int* n, + const int* k, + const std::complex* alpha, + const std::complex* a, + const int* lda, + const std::complex* b, + const int* ldb, + const std::complex* beta, + std::complex* c, + const int* ldc, + int cluster_id); + +void cgemm_pack_mth_(const char* transa, + const char* transb, + const int* m, + const int* n, + const int* k, + const std::complex* alpha, + const std::complex* a, + const int* lda, + const std::complex* b, + const int* ldb, + const std::complex* beta, + std::complex* c, + const int* ldc, + int cluster_id); void sgemv_mt_(const char* transa, const int* m, @@ -282,4 +313,4 @@ void dsp_dav_subspace_reduce(T* hcc, T* scc, int nbase, int nbase_x, int notconv } // namespace mtfunc #endif -#endif \ No newline at end of file +#endif diff --git a/source/source_base/module_external/blas_connector_matrix.cpp b/source/source_base/module_external/blas_connector_matrix.cpp index 3b18d3ee3a..2becee24bf 100644 --- a/source/source_base/module_external/blas_connector_matrix.cpp +++ b/source/source_base/module_external/blas_connector_matrix.cpp @@ -107,7 +107,9 @@ void BlasConnector::gemm(const char transa, #ifdef __DSP else if (device_type == base_device::AbacusDevice_t::DspDevice) { - mtfunc::cgemm_mth_(&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda, &beta, c, &ldc, GlobalV::MY_RANK % PARAM.inp.dsp_count); + mtfunc::cgemm_pack_mth_(&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda, &beta, c, &ldc, GlobalV::MY_RANK % PARAM.inp.dsp_count); + // cgemm_mth_ for raw dsp mth; + // cgemm_pack_mth_ for dsp mth with memcpy to DSP buffer } #endif else if (device_type == base_device::AbacusDevice_t::GpuDevice) @@ -158,7 +160,9 @@ void BlasConnector::gemm(const char transa, #ifdef __DSP else if (device_type == base_device::AbacusDevice_t::DspDevice) { - mtfunc::zgemm_mth_(&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda, &beta, c, &ldc, GlobalV::MY_RANK % PARAM.inp.dsp_count); + mtfunc::zgemm_pack_mth_(&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda, &beta, c, &ldc, GlobalV::MY_RANK % PARAM.inp.dsp_count); + // zgemm_mth_ for raw dsp mth; + // zgemm_pack_mth_ for dsp mth with memcpy to DSP buffer } #endif else if (device_type == base_device::AbacusDevice_t::GpuDevice) @@ -277,7 +281,9 @@ void BlasConnector::gemm_cm(const char transa, #ifdef __DSP else if (device_type == base_device::AbacusDevice_t::DspDevice) { - mtfunc::cgemm_mth_(&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c, &ldc, GlobalV::MY_RANK % PARAM.inp.dsp_count); + mtfunc::cgemm_pack_mth_(&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c, &ldc, GlobalV::MY_RANK % PARAM.inp.dsp_count); + // cgemm_mth_ for raw dsp mth; + // cgemm_pack_mth_ for dsp mth with memcpy to DSP buffer } #endif #ifdef __CUDA @@ -328,7 +334,9 @@ void BlasConnector::gemm_cm(const char transa, #ifdef __DSP else if (device_type == base_device::AbacusDevice_t::DspDevice) { - mtfunc::zgemm_mth_(&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c, &ldc, GlobalV::MY_RANK % PARAM.inp.dsp_count); + mtfunc::zgemm_pack_mth_(&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c, &ldc, GlobalV::MY_RANK % PARAM.inp.dsp_count); + // zgemm_mth_ for raw dsp mth; + // zgemm_pack_mth_ for dsp mth with memcpy to DSP buffer } #endif #ifdef __CUDA