diff --git a/source/source_base/module_device/memory_op.cpp b/source/source_base/module_device/memory_op.cpp index 2ef4be588a..bff9234f64 100644 --- a/source/source_base/module_device/memory_op.cpp +++ b/source/source_base/module_device/memory_op.cpp @@ -471,6 +471,19 @@ struct resize_memory_op_mt } }; +template +struct set_memory_op_mt +{ + void operator()(FPTYPE* arr, const int var, const size_t size) + { + ModuleBase::OMP_PARALLEL([&](int num_thread, int thread_id) { + int beg = 0, len = 0; + ModuleBase::BLOCK_TASK_DIST_1D(num_thread, thread_id, size, (size_t)4096 / sizeof(FPTYPE), beg, len); + memset(arr + beg, var, sizeof(FPTYPE) * len); + }); + } +}; + template struct delete_memory_op_mt { @@ -487,6 +500,12 @@ template struct resize_memory_op_mt; template struct resize_memory_op_mt, base_device::DEVICE_CPU>; template struct resize_memory_op_mt, base_device::DEVICE_CPU>; +template struct set_memory_op_mt; +template struct set_memory_op_mt; +template struct set_memory_op_mt; +template struct set_memory_op_mt, base_device::DEVICE_CPU>; +template struct set_memory_op_mt, base_device::DEVICE_CPU>; + template struct delete_memory_op_mt; template struct delete_memory_op_mt; template struct delete_memory_op_mt; diff --git a/source/source_base/module_device/memory_op.h b/source/source_base/module_device/memory_op.h index c24acbb024..004468f410 100644 --- a/source/source_base/module_device/memory_op.h +++ b/source/source_base/module_device/memory_op.h @@ -234,6 +234,20 @@ struct resize_memory_op_mt void operator()(FPTYPE*& arr, const size_t size, const char* record_in = nullptr); }; +template +struct set_memory_op_mt +{ + /// @brief memset for DSP memory allocated by mt allocator. + /// + /// Input Parameters + /// \param var : the specified constant byte value + /// \param size : array size + /// + /// Output Parameters + /// \param arr : output array initialized by the input value + void operator()(FPTYPE* arr, const int var, const size_t size); +}; + template struct delete_memory_op_mt { diff --git a/source/source_pw/module_pwdft/op_pw_nl.h b/source/source_pw/module_pwdft/op_pw_nl.h index 829bb31e93..dcdbf889a8 100644 --- a/source/source_pw/module_pwdft/op_pw_nl.h +++ b/source/source_pw/module_pwdft/op_pw_nl.h @@ -88,14 +88,15 @@ class Nonlocal> : public OperatorPW using gemv_op = ModuleBase::gemv_op; using gemm_op = ModuleBase::gemm_op; using nonlocal_op = nonlocal_pw_op; - using setmem_complex_op = base_device::memory::set_memory_op; - #ifdef __DSP +#ifdef __DSP + using setmem_complex_op = base_device::memory::set_memory_op_mt; using resmem_complex_op = base_device::memory::resize_memory_op_mt; using delmem_complex_op = base_device::memory::delete_memory_op_mt; - #else +#else + using setmem_complex_op = base_device::memory::set_memory_op; using resmem_complex_op = base_device::memory::resize_memory_op; using delmem_complex_op = base_device::memory::delete_memory_op; - #endif +#endif using syncmem_complex_h2d_op = base_device::memory::synchronize_memory_op; T one{1, 0}; @@ -104,4 +105,4 @@ class Nonlocal> : public OperatorPW } // namespace hamilt -#endif \ No newline at end of file +#endif diff --git a/source/source_pw/module_pwdft/vnl_pw.cpp b/source/source_pw/module_pwdft/vnl_pw.cpp index 3a1fdda873..0ac8ef9b95 100644 --- a/source/source_pw/module_pwdft/vnl_pw.cpp +++ b/source/source_pw/module_pwdft/vnl_pw.cpp @@ -64,6 +64,13 @@ void pseudopot_cell_vnl::release_memory() delmem_ch_op()(this->c_deeq_nc); delmem_ch_op()(this->c_vkb); delmem_ch_op()(this->c_qq_so); +#ifdef __DSP + if (this->z_vkb != nullptr) + { + base_device::memory::delete_memory_op_mt, base_device::DEVICE_CPU>()(this->z_vkb); + this->z_vkb = nullptr; + } +#endif // There's no need to delete double precision pointers while in a CPU environment. } memory_released = true; @@ -273,13 +280,13 @@ void pseudopot_cell_vnl::init(const UnitCell& ucell, resmem_sh_op()(s_tab, this->tab.getSize()); resmem_ch_op()(c_vkb, nkb * npwx); } - #ifdef __DSP +#ifdef __DSP base_device::memory::resize_memory_op_mt, base_device::DEVICE_CPU>() - (this->z_vkb, this->vkb.size, "Nonlocal::ps"); - memcpy(this->z_vkb,this->vkb.c,this->vkb.size*16); - #else + (this->z_vkb, this->vkb.size, "VNL::z_vkb"); + // memcpy(this->z_vkb,this->vkb.c,this->vkb.size*16); +#else this->z_vkb = this->vkb.c; - #endif +#endif this->d_tab = this->tab.ptr; // There's no need to delete double precision pointers while in a CPU environment. } @@ -293,12 +300,12 @@ void pseudopot_cell_vnl::init(const UnitCell& ucell, // with structure factor, for all atoms, in reciprocal space //---------------------------------------------------------- template -void pseudopot_cell_vnl::getvnl(Device* ctx, +void pseudopot_cell_vnl::getvnl(Device* ctx, const UnitCell& ucell, - const int& ik, + const int& ik, std::complex* vkb_in) const { - if (PARAM.inp.test_pp) + if (PARAM.inp.test_pp) { ModuleBase::TITLE("pseudopot_cell_vnl", "getvnl"); } @@ -732,10 +739,10 @@ void pseudopot_cell_vnl::init_vnl(UnitCell& cell, const ModulePW::PW_Basis* rho_ for (int iq = 0; iq < PARAM.globalv.nqx; iq++) { const double q = iq * PARAM.globalv.dq; - ModuleBase::Sphbes::Spherical_Bessel(kkbeta, cell.atoms[it].ncpp.r.data(), q, l, jl); + ModuleBase::Sphbes::Spherical_Bessel(kkbeta, cell.atoms[it].ncpp.r.data(), q, l, jl); for (int ir = 0; ir < kkbeta; ir++) - { - aux[ir] = cell.atoms[it].ncpp.betar(ib, ir) * jl[ir] * cell.atoms[it].ncpp.r[ir]; + { + aux[ir] = cell.atoms[it].ncpp.betar(ib, ir) * jl[ir] * cell.atoms[it].ncpp.r[ir]; } double vqint=0.0; ModuleBase::Integral::Simpson_Integral(kkbeta, aux, cell.atoms[it].ncpp.rab.data(), vqint); @@ -1723,7 +1730,7 @@ template void pseudopot_cell_vnl::getvnl(base_de int const&, std::complex*) const; template void pseudopot_cell_vnl::getvnl(base_device::DEVICE_CPU*, - const UnitCell&, + const UnitCell&, int const&, std::complex*) const; #if defined(__CUDA) || defined(__ROCM)