Skip to content

Fix: Add DSP gemm pack with auto memcpy to buffer#7060

Open
Cstandardlib wants to merge 2 commits intodeepmodeling:developfrom
Cstandardlib:dsp/zgemm-pack-mth
Open

Fix: Add DSP gemm pack with auto memcpy to buffer#7060
Cstandardlib wants to merge 2 commits intodeepmodeling:developfrom
Cstandardlib:dsp/zgemm-pack-mth

Conversation

@Cstandardlib
Copy link
Collaborator

What's changed?

  • Temporary solution to make it work for nonlocal gemm on DSP.
  • Add a gemm pack version, with auto memory buffer inside.

Copilot AI review requested due to automatic review settings March 19, 2026 04:37
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR updates the DSP GEMM path for complex types to use new “pack” variants that internally allocate DSP buffers and perform host↔DSP copies, aiming to make nonlocal GEMM work on DSP as a temporary solution.

Changes:

  • Switched DSP complex GEMM calls to cgemm_pack_mth_ / zgemm_pack_mth_ in the BLAS connector.
  • Added DSP connector declarations and implementations for the new packed GEMM entry points.
  • Adjusted DSP connector header guard trailing #endifs.

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 5 comments.

File Description
source/source_base/module_external/blas_connector_matrix.cpp Routes DSP complex GEMM calls through new packed DSP GEMM helpers.
source/source_base/kernels/dsp/dsp_connector.h Declares the new packed DSP GEMM APIs (and adjusts header guard endings).
source/source_base/kernels/dsp/dsp_connector.cpp Implements packed DSP GEMM helpers with DSP-side allocations and memcpy.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +421 to +422
const bool transa_not = (transa[0] == 'N' || transa[0] == 'n');
const bool transb_not = (transb[0] == 'N' || transb[0] == 'n');
Comment on lines +439 to +451
mt_hthread_zgemm(CBLAS_ORDER::CblasColMajor,
convertBLASTranspose(transa),
convertBLASTranspose(transb),
*m,
*n,
*k,
alp,
a,
// A_dsp,
*lda,
b,
// B_dsp,
*ldb,
Comment on lines +433 to +437
// memcpy(A_dsp, a, a_elems * sizeof(std::complex<double>));
// memcpy(B_dsp, b, b_elems * sizeof(std::complex<double>));
memcpy(C_dsp, c, c_elems * sizeof(std::complex<double>));
*alp = *alpha;
*bet = *beta;
Comment on lines +527 to +532
std::complex<float>* A_dsp = static_cast<std::complex<float>*>(malloc_ht(a_elems * sizeof(std::complex<float>), cluster_id));
std::complex<float>* B_dsp = static_cast<std::complex<float>*>(malloc_ht(b_elems * sizeof(std::complex<float>), cluster_id));
std::complex<float>* C_dsp = static_cast<std::complex<float>*>(malloc_ht(c_elems * sizeof(std::complex<float>), cluster_id));
std::complex<float>* alp = static_cast<std::complex<float>*>(malloc_ht(sizeof(std::complex<float>), cluster_id));
std::complex<float>* bet = static_cast<std::complex<float>*>(malloc_ht(sizeof(std::complex<float>), cluster_id));

Comment on lines +79 to +94

void cgemm_mth_(const char* transa,
const char* transb,
const int* m,
const int* n,
const int* k,
const std::complex<float>* alpha,
const std::complex<float>* a,
const int* lda,
const std::complex<float>* b,
const int* ldb,
const std::complex<float>* beta,
std::complex<float>* c,
const int* ldc,
int cluster_id);

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants