Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions source/source_base/module_device/memory_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,19 @@ struct resize_memory_op_mt<FPTYPE, base_device::DEVICE_CPU>
}
};

template <typename FPTYPE>
struct set_memory_op_mt<FPTYPE, base_device::DEVICE_CPU>
{
void operator()(FPTYPE* arr, const int var, const size_t size)
{
ModuleBase::OMP_PARALLEL([&](int num_thread, int thread_id) {
int beg = 0, len = 0;
ModuleBase::BLOCK_TASK_DIST_1D(num_thread, thread_id, size, (size_t)4096 / sizeof(FPTYPE), beg, len);
memset(arr + beg, var, sizeof(FPTYPE) * len);
});
}
};

template <typename FPTYPE>
struct delete_memory_op_mt<FPTYPE, base_device::DEVICE_CPU>
{
Expand All @@ -487,6 +500,12 @@ template struct resize_memory_op_mt<double, base_device::DEVICE_CPU>;
template struct resize_memory_op_mt<std::complex<float>, base_device::DEVICE_CPU>;
template struct resize_memory_op_mt<std::complex<double>, base_device::DEVICE_CPU>;

template struct set_memory_op_mt<int, base_device::DEVICE_CPU>;
template struct set_memory_op_mt<float, base_device::DEVICE_CPU>;
template struct set_memory_op_mt<double, base_device::DEVICE_CPU>;
template struct set_memory_op_mt<std::complex<float>, base_device::DEVICE_CPU>;
template struct set_memory_op_mt<std::complex<double>, base_device::DEVICE_CPU>;

template struct delete_memory_op_mt<int, base_device::DEVICE_CPU>;
template struct delete_memory_op_mt<float, base_device::DEVICE_CPU>;
template struct delete_memory_op_mt<double, base_device::DEVICE_CPU>;
Expand Down
14 changes: 14 additions & 0 deletions source/source_base/module_device/memory_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,20 @@ struct resize_memory_op_mt
void operator()(FPTYPE*& arr, const size_t size, const char* record_in = nullptr);
};

template <typename FPTYPE, typename Device>
struct set_memory_op_mt
{
/// @brief memset for DSP memory allocated by mt allocator.
///
/// Input Parameters
/// \param var : the specified constant byte value
/// \param size : array size
///
/// Output Parameters
/// \param arr : output array initialized by the input value
void operator()(FPTYPE* arr, const int var, const size_t size);
};

template <typename FPTYPE, typename Device>
struct delete_memory_op_mt
{
Expand Down
11 changes: 6 additions & 5 deletions source/source_pw/module_pwdft/op_pw_nl.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,14 +88,15 @@ class Nonlocal<OperatorPW<T, Device>> : public OperatorPW<T, Device>
using gemv_op = ModuleBase::gemv_op<T, Device>;
using gemm_op = ModuleBase::gemm_op<T, Device>;
using nonlocal_op = nonlocal_pw_op<Real, Device>;
using setmem_complex_op = base_device::memory::set_memory_op<T, Device>;
#ifdef __DSP
#ifdef __DSP
using setmem_complex_op = base_device::memory::set_memory_op_mt<T, Device>;
using resmem_complex_op = base_device::memory::resize_memory_op_mt<T, Device>;
using delmem_complex_op = base_device::memory::delete_memory_op_mt<T, Device>;
#else
#else
using setmem_complex_op = base_device::memory::set_memory_op<T, Device>;
using resmem_complex_op = base_device::memory::resize_memory_op<T, Device>;
using delmem_complex_op = base_device::memory::delete_memory_op<T, Device>;
#endif
#endif
using syncmem_complex_h2d_op = base_device::memory::synchronize_memory_op<T, Device, base_device::DEVICE_CPU>;

T one{1, 0};
Expand All @@ -104,4 +105,4 @@ class Nonlocal<OperatorPW<T, Device>> : public OperatorPW<T, Device>

} // namespace hamilt

#endif
#endif
31 changes: 19 additions & 12 deletions source/source_pw/module_pwdft/vnl_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,13 @@ void pseudopot_cell_vnl::release_memory()
delmem_ch_op()(this->c_deeq_nc);
delmem_ch_op()(this->c_vkb);
delmem_ch_op()(this->c_qq_so);
#ifdef __DSP
if (this->z_vkb != nullptr)
{
base_device::memory::delete_memory_op_mt<std::complex<double>, base_device::DEVICE_CPU>()(this->z_vkb);
this->z_vkb = nullptr;
}
#endif
// There's no need to delete double precision pointers while in a CPU environment.
}
memory_released = true;
Expand Down Expand Up @@ -273,13 +280,13 @@ void pseudopot_cell_vnl::init(const UnitCell& ucell,
resmem_sh_op()(s_tab, this->tab.getSize());
resmem_ch_op()(c_vkb, nkb * npwx);
}
#ifdef __DSP
#ifdef __DSP
base_device::memory::resize_memory_op_mt<std::complex<double>, base_device::DEVICE_CPU>()
(this->z_vkb, this->vkb.size, "Nonlocal<PW>::ps");
memcpy(this->z_vkb,this->vkb.c,this->vkb.size*16);
#else
(this->z_vkb, this->vkb.size, "VNL::z_vkb");
// memcpy(this->z_vkb,this->vkb.c,this->vkb.size*16);
#else
this->z_vkb = this->vkb.c;
#endif
#endif
Comment on lines +283 to +289
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This possible should be considered, what's your opinion?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this->d_tab = this->tab.ptr;
// There's no need to delete double precision pointers while in a CPU environment.
}
Expand All @@ -293,12 +300,12 @@ void pseudopot_cell_vnl::init(const UnitCell& ucell,
// with structure factor, for all atoms, in reciprocal space
//----------------------------------------------------------
template <typename FPTYPE, typename Device>
void pseudopot_cell_vnl::getvnl(Device* ctx,
void pseudopot_cell_vnl::getvnl(Device* ctx,
const UnitCell& ucell,
const int& ik,
const int& ik,
std::complex<FPTYPE>* vkb_in) const
{
if (PARAM.inp.test_pp)
if (PARAM.inp.test_pp)
{
ModuleBase::TITLE("pseudopot_cell_vnl", "getvnl");
}
Expand Down Expand Up @@ -732,10 +739,10 @@ void pseudopot_cell_vnl::init_vnl(UnitCell& cell, const ModulePW::PW_Basis* rho_
for (int iq = 0; iq < PARAM.globalv.nqx; iq++)
{
const double q = iq * PARAM.globalv.dq;
ModuleBase::Sphbes::Spherical_Bessel(kkbeta, cell.atoms[it].ncpp.r.data(), q, l, jl);
ModuleBase::Sphbes::Spherical_Bessel(kkbeta, cell.atoms[it].ncpp.r.data(), q, l, jl);
for (int ir = 0; ir < kkbeta; ir++)
{
aux[ir] = cell.atoms[it].ncpp.betar(ib, ir) * jl[ir] * cell.atoms[it].ncpp.r[ir];
{
aux[ir] = cell.atoms[it].ncpp.betar(ib, ir) * jl[ir] * cell.atoms[it].ncpp.r[ir];
}
double vqint=0.0;
ModuleBase::Integral::Simpson_Integral(kkbeta, aux, cell.atoms[it].ncpp.rab.data(), vqint);
Expand Down Expand Up @@ -1723,7 +1730,7 @@ template void pseudopot_cell_vnl::getvnl<float, base_device::DEVICE_CPU>(base_de
int const&,
std::complex<float>*) const;
template void pseudopot_cell_vnl::getvnl<double, base_device::DEVICE_CPU>(base_device::DEVICE_CPU*,
const UnitCell&,
const UnitCell&,
int const&,
std::complex<double>*) const;
#if defined(__CUDA) || defined(__ROCM)
Expand Down
Loading