From 87deb1d4a1dcfbf6bcc3520cd7854bc41a1e2135 Mon Sep 17 00:00:00 2001 From: Chen Nuo <49788094+Cstandardlib@users.noreply.github.com> Date: Sat, 14 Mar 2026 13:49:06 +0800 Subject: [PATCH 1/2] Fix dsp setmem op --- .../source_base/module_device/memory_op.cpp | 19 +++++++++++++++++++ source/source_base/module_device/memory_op.h | 14 ++++++++++++++ source/source_pw/module_pwdft/op_pw_nl.h | 4 ++++ source/source_pw/module_pwdft/vnl_pw.cpp | 14 +++++++++++++- source/source_pw/module_pwdft/vnl_pw.h | 1 + 5 files changed, 51 insertions(+), 1 deletion(-) diff --git a/source/source_base/module_device/memory_op.cpp b/source/source_base/module_device/memory_op.cpp index 2ef4be588a..bff9234f64 100644 --- a/source/source_base/module_device/memory_op.cpp +++ b/source/source_base/module_device/memory_op.cpp @@ -471,6 +471,19 @@ struct resize_memory_op_mt } }; +template +struct set_memory_op_mt +{ + void operator()(FPTYPE* arr, const int var, const size_t size) + { + ModuleBase::OMP_PARALLEL([&](int num_thread, int thread_id) { + int beg = 0, len = 0; + ModuleBase::BLOCK_TASK_DIST_1D(num_thread, thread_id, size, (size_t)4096 / sizeof(FPTYPE), beg, len); + memset(arr + beg, var, sizeof(FPTYPE) * len); + }); + } +}; + template struct delete_memory_op_mt { @@ -487,6 +500,12 @@ template struct resize_memory_op_mt; template struct resize_memory_op_mt, base_device::DEVICE_CPU>; template struct resize_memory_op_mt, base_device::DEVICE_CPU>; +template struct set_memory_op_mt; +template struct set_memory_op_mt; +template struct set_memory_op_mt; +template struct set_memory_op_mt, base_device::DEVICE_CPU>; +template struct set_memory_op_mt, base_device::DEVICE_CPU>; + template struct delete_memory_op_mt; template struct delete_memory_op_mt; template struct delete_memory_op_mt; diff --git a/source/source_base/module_device/memory_op.h b/source/source_base/module_device/memory_op.h index c24acbb024..004468f410 100644 --- a/source/source_base/module_device/memory_op.h +++ b/source/source_base/module_device/memory_op.h @@ -234,6 +234,20 @@ struct resize_memory_op_mt void operator()(FPTYPE*& arr, const size_t size, const char* record_in = nullptr); }; +template +struct set_memory_op_mt +{ + /// @brief memset for DSP memory allocated by mt allocator. + /// + /// Input Parameters + /// \param var : the specified constant byte value + /// \param size : array size + /// + /// Output Parameters + /// \param arr : output array initialized by the input value + void operator()(FPTYPE* arr, const int var, const size_t size); +}; + template struct delete_memory_op_mt { diff --git a/source/source_pw/module_pwdft/op_pw_nl.h b/source/source_pw/module_pwdft/op_pw_nl.h index 829bb31e93..9887f25706 100644 --- a/source/source_pw/module_pwdft/op_pw_nl.h +++ b/source/source_pw/module_pwdft/op_pw_nl.h @@ -88,7 +88,11 @@ class Nonlocal> : public OperatorPW using gemv_op = ModuleBase::gemv_op; using gemm_op = ModuleBase::gemm_op; using nonlocal_op = nonlocal_pw_op; + #ifdef __DSP + using setmem_complex_op = base_device::memory::set_memory_op_mt; + #else using setmem_complex_op = base_device::memory::set_memory_op; + #endif #ifdef __DSP using resmem_complex_op = base_device::memory::resize_memory_op_mt; using delmem_complex_op = base_device::memory::delete_memory_op_mt; diff --git a/source/source_pw/module_pwdft/vnl_pw.cpp b/source/source_pw/module_pwdft/vnl_pw.cpp index 3a1fdda873..3077dedce1 100644 --- a/source/source_pw/module_pwdft/vnl_pw.cpp +++ b/source/source_pw/module_pwdft/vnl_pw.cpp @@ -47,6 +47,7 @@ void pseudopot_cell_vnl::release_memory() delmem_zd_op()(this->z_qq_so); delmem_dd_op()(this->d_deeq); delmem_zd_op()(this->z_vkb); + this->z_vkb_mt_allocated_ = false; delmem_dd_op()(this->d_tab); delmem_dd_op()(this->d_indv); delmem_dd_op()(this->d_nhtol); @@ -64,6 +65,14 @@ void pseudopot_cell_vnl::release_memory() delmem_ch_op()(this->c_deeq_nc); delmem_ch_op()(this->c_vkb); delmem_ch_op()(this->c_qq_so); +#ifdef __DSP + if (this->z_vkb_mt_allocated_ && this->z_vkb != nullptr) + { + base_device::memory::delete_memory_op_mt, base_device::DEVICE_CPU>()(this->z_vkb); + this->z_vkb = nullptr; + this->z_vkb_mt_allocated_ = false; + } +#endif // There's no need to delete double precision pointers while in a CPU environment. } memory_released = true; @@ -264,6 +273,7 @@ void pseudopot_cell_vnl::init(const UnitCell& ucell, resmem_cd_op()(c_vkb, nkb * npwx); } resmem_zd_op()(z_vkb, nkb * npwx); + this->z_vkb_mt_allocated_ = false; resmem_dd_op()(d_tab, this->tab.getSize()); } else @@ -275,10 +285,12 @@ void pseudopot_cell_vnl::init(const UnitCell& ucell, } #ifdef __DSP base_device::memory::resize_memory_op_mt, base_device::DEVICE_CPU>() - (this->z_vkb, this->vkb.size, "Nonlocal::ps"); + (this->z_vkb, this->vkb.size, "VNL::z_vkb"); memcpy(this->z_vkb,this->vkb.c,this->vkb.size*16); + this->z_vkb_mt_allocated_ = true; #else this->z_vkb = this->vkb.c; + this->z_vkb_mt_allocated_ = false; #endif this->d_tab = this->tab.ptr; // There's no need to delete double precision pointers while in a CPU environment. diff --git a/source/source_pw/module_pwdft/vnl_pw.h b/source/source_pw/module_pwdft/vnl_pw.h index 93a593e925..6e5501fa53 100644 --- a/source/source_pw/module_pwdft/vnl_pw.h +++ b/source/source_pw/module_pwdft/vnl_pw.h @@ -195,6 +195,7 @@ class pseudopot_cell_vnl double* d_indv = nullptr; double* d_tab = nullptr; std::complex* z_vkb = nullptr; + bool z_vkb_mt_allocated_ = false; const ModulePW::PW_Basis_K* wfcpw = nullptr; From 3117be996a115c3f8ac6090fe3a7090f7ad06a70 Mon Sep 17 00:00:00 2001 From: Chen Nuo <49788094+Cstandardlib@users.noreply.github.com> Date: Thu, 19 Mar 2026 11:28:14 +0800 Subject: [PATCH 2/2] Clean up the code --- source/source_pw/module_pwdft/op_pw_nl.h | 13 ++++------- source/source_pw/module_pwdft/vnl_pw.cpp | 29 ++++++++++-------------- source/source_pw/module_pwdft/vnl_pw.h | 1 - 3 files changed, 17 insertions(+), 26 deletions(-) diff --git a/source/source_pw/module_pwdft/op_pw_nl.h b/source/source_pw/module_pwdft/op_pw_nl.h index 9887f25706..dcdbf889a8 100644 --- a/source/source_pw/module_pwdft/op_pw_nl.h +++ b/source/source_pw/module_pwdft/op_pw_nl.h @@ -88,18 +88,15 @@ class Nonlocal> : public OperatorPW using gemv_op = ModuleBase::gemv_op; using gemm_op = ModuleBase::gemm_op; using nonlocal_op = nonlocal_pw_op; - #ifdef __DSP +#ifdef __DSP using setmem_complex_op = base_device::memory::set_memory_op_mt; - #else - using setmem_complex_op = base_device::memory::set_memory_op; - #endif - #ifdef __DSP using resmem_complex_op = base_device::memory::resize_memory_op_mt; using delmem_complex_op = base_device::memory::delete_memory_op_mt; - #else +#else + using setmem_complex_op = base_device::memory::set_memory_op; using resmem_complex_op = base_device::memory::resize_memory_op; using delmem_complex_op = base_device::memory::delete_memory_op; - #endif +#endif using syncmem_complex_h2d_op = base_device::memory::synchronize_memory_op; T one{1, 0}; @@ -108,4 +105,4 @@ class Nonlocal> : public OperatorPW } // namespace hamilt -#endif \ No newline at end of file +#endif diff --git a/source/source_pw/module_pwdft/vnl_pw.cpp b/source/source_pw/module_pwdft/vnl_pw.cpp index 3077dedce1..0ac8ef9b95 100644 --- a/source/source_pw/module_pwdft/vnl_pw.cpp +++ b/source/source_pw/module_pwdft/vnl_pw.cpp @@ -47,7 +47,6 @@ void pseudopot_cell_vnl::release_memory() delmem_zd_op()(this->z_qq_so); delmem_dd_op()(this->d_deeq); delmem_zd_op()(this->z_vkb); - this->z_vkb_mt_allocated_ = false; delmem_dd_op()(this->d_tab); delmem_dd_op()(this->d_indv); delmem_dd_op()(this->d_nhtol); @@ -66,11 +65,10 @@ void pseudopot_cell_vnl::release_memory() delmem_ch_op()(this->c_vkb); delmem_ch_op()(this->c_qq_so); #ifdef __DSP - if (this->z_vkb_mt_allocated_ && this->z_vkb != nullptr) + if (this->z_vkb != nullptr) { base_device::memory::delete_memory_op_mt, base_device::DEVICE_CPU>()(this->z_vkb); this->z_vkb = nullptr; - this->z_vkb_mt_allocated_ = false; } #endif // There's no need to delete double precision pointers while in a CPU environment. @@ -273,7 +271,6 @@ void pseudopot_cell_vnl::init(const UnitCell& ucell, resmem_cd_op()(c_vkb, nkb * npwx); } resmem_zd_op()(z_vkb, nkb * npwx); - this->z_vkb_mt_allocated_ = false; resmem_dd_op()(d_tab, this->tab.getSize()); } else @@ -283,15 +280,13 @@ void pseudopot_cell_vnl::init(const UnitCell& ucell, resmem_sh_op()(s_tab, this->tab.getSize()); resmem_ch_op()(c_vkb, nkb * npwx); } - #ifdef __DSP +#ifdef __DSP base_device::memory::resize_memory_op_mt, base_device::DEVICE_CPU>() (this->z_vkb, this->vkb.size, "VNL::z_vkb"); - memcpy(this->z_vkb,this->vkb.c,this->vkb.size*16); - this->z_vkb_mt_allocated_ = true; - #else + // memcpy(this->z_vkb,this->vkb.c,this->vkb.size*16); +#else this->z_vkb = this->vkb.c; - this->z_vkb_mt_allocated_ = false; - #endif +#endif this->d_tab = this->tab.ptr; // There's no need to delete double precision pointers while in a CPU environment. } @@ -305,12 +300,12 @@ void pseudopot_cell_vnl::init(const UnitCell& ucell, // with structure factor, for all atoms, in reciprocal space //---------------------------------------------------------- template -void pseudopot_cell_vnl::getvnl(Device* ctx, +void pseudopot_cell_vnl::getvnl(Device* ctx, const UnitCell& ucell, - const int& ik, + const int& ik, std::complex* vkb_in) const { - if (PARAM.inp.test_pp) + if (PARAM.inp.test_pp) { ModuleBase::TITLE("pseudopot_cell_vnl", "getvnl"); } @@ -744,10 +739,10 @@ void pseudopot_cell_vnl::init_vnl(UnitCell& cell, const ModulePW::PW_Basis* rho_ for (int iq = 0; iq < PARAM.globalv.nqx; iq++) { const double q = iq * PARAM.globalv.dq; - ModuleBase::Sphbes::Spherical_Bessel(kkbeta, cell.atoms[it].ncpp.r.data(), q, l, jl); + ModuleBase::Sphbes::Spherical_Bessel(kkbeta, cell.atoms[it].ncpp.r.data(), q, l, jl); for (int ir = 0; ir < kkbeta; ir++) - { - aux[ir] = cell.atoms[it].ncpp.betar(ib, ir) * jl[ir] * cell.atoms[it].ncpp.r[ir]; + { + aux[ir] = cell.atoms[it].ncpp.betar(ib, ir) * jl[ir] * cell.atoms[it].ncpp.r[ir]; } double vqint=0.0; ModuleBase::Integral::Simpson_Integral(kkbeta, aux, cell.atoms[it].ncpp.rab.data(), vqint); @@ -1735,7 +1730,7 @@ template void pseudopot_cell_vnl::getvnl(base_de int const&, std::complex*) const; template void pseudopot_cell_vnl::getvnl(base_device::DEVICE_CPU*, - const UnitCell&, + const UnitCell&, int const&, std::complex*) const; #if defined(__CUDA) || defined(__ROCM) diff --git a/source/source_pw/module_pwdft/vnl_pw.h b/source/source_pw/module_pwdft/vnl_pw.h index 6e5501fa53..93a593e925 100644 --- a/source/source_pw/module_pwdft/vnl_pw.h +++ b/source/source_pw/module_pwdft/vnl_pw.h @@ -195,7 +195,6 @@ class pseudopot_cell_vnl double* d_indv = nullptr; double* d_tab = nullptr; std::complex* z_vkb = nullptr; - bool z_vkb_mt_allocated_ = false; const ModulePW::PW_Basis_K* wfcpw = nullptr;