Skip to content

Commit

Permalink
Refactor: replace sto_hchi by HamiltSdftPW::hPsi (#5298)
Browse files Browse the repository at this point in the history
* change sto_hchi to hamilt_sdft_pw

* add is_first_node parameter for act function

* optimize hPsi

* [pre-commit.ci lite] apply automatic fixes

* fix compile error

* fix compile error and add UTs for hamilt_sdft

* fix CUDA compile

* fix wrong setmem

* [pre-commit.ci lite] apply automatic fixes

* fix undefined hspi

* fix compile in sdft

* optimize for

---------

Co-authored-by: pre-commit-ci-lite[bot] <117423508+pre-commit-ci-lite[bot]@users.noreply.github.com>
  • Loading branch information
Qianruipku and pre-commit-ci-lite[bot] authored Oct 23, 2024
1 parent 0bac03f commit 7194eb7
Show file tree
Hide file tree
Showing 41 changed files with 591 additions and 430 deletions.
2 changes: 1 addition & 1 deletion source/Makefile.Objects
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ OBJS_GINT=gint.o\
init_orb.o\

OBJS_HAMILT=hamilt_pw.o\
hamilt_sdft_pw.o\
operator.o\
operator_pw.o\
ekinetic_pw.o\
Expand Down Expand Up @@ -648,7 +649,6 @@ OBJS_SRCPW=H_Ewald_pw.o\
structure_factor_k.o\
soc.o\
sto_iter.o\
sto_hchi.o\
sto_che.o\
sto_wf.o\
sto_func.o\
Expand Down
7 changes: 4 additions & 3 deletions source/module_base/math_chebyshev.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,7 @@ bool Chebyshev<REAL>::checkconverge(
std::function<void(std::complex<REAL>* in, std::complex<REAL>* out, const int)> funA,
std::complex<REAL>* wavein,
const int N,
const int LDA,
REAL& tmax,
REAL& tmin,
REAL stept)
Expand All @@ -584,9 +585,9 @@ bool Chebyshev<REAL>::checkconverge(
std::complex<REAL>* arrayn;
std::complex<REAL>* arrayn_1;

arraynp1 = new std::complex<REAL>[N];
arrayn = new std::complex<REAL>[N];
arrayn_1 = new std::complex<REAL>[N];
arraynp1 = new std::complex<REAL>[LDA];
arrayn = new std::complex<REAL>[LDA];
arrayn_1 = new std::complex<REAL>[LDA];

ModuleBase::GlobalFunc::DCOPY(wavein, arrayn_1, N);

Expand Down
1 change: 1 addition & 0 deletions source/module_base/math_chebyshev.h
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ class Chebyshev
bool checkconverge(std::function<void(std::complex<REAL>* in, std::complex<REAL>* out, const int)> funA,
std::complex<REAL>* wavein,
const int N,
const int LDA,
REAL& tmax, // trial number for upper bound
REAL& tmin, // trial number for lower bound
REAL stept); // tmax = max() + stept, tmin = min() - stept
Expand Down
14 changes: 7 additions & 7 deletions source/module_base/test/math_chebyshev_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -346,15 +346,15 @@ TEST_F(MathChebyshevTest, checkconverge)
double tmin = -1.1;
double tmax = 1.1;
bool converge;
converge = p_chetest->checkconverge(fun_sigma_y, v, 2, tmax, tmin, 0.2);
converge = p_chetest->checkconverge(fun_sigma_y, v, 2, 2, tmax, tmin, 0.2);
EXPECT_TRUE(converge);
converge = p_chetest->checkconverge(fun_sigma_y, v + 2, 2, tmax, tmin, 0.2);
converge = p_chetest->checkconverge(fun_sigma_y, v + 2, 2, 2, tmax, tmin, 0.2);
EXPECT_TRUE(converge);
EXPECT_NEAR(tmin, -1.1, 1e-8);
EXPECT_NEAR(tmax, 1.1, 1e-8);

tmax = -1.1;
converge = p_chetest->checkconverge(fun_sigma_y, v, 2, tmax, tmin, 2.2);
converge = p_chetest->checkconverge(fun_sigma_y, v, 2, 2, tmax, tmin, 2.2);
EXPECT_TRUE(converge);
EXPECT_NEAR(tmin, -1.1, 1e-8);
EXPECT_NEAR(tmax, 1.1, 1e-8);
Expand All @@ -363,12 +363,12 @@ TEST_F(MathChebyshevTest, checkconverge)
v[0] = std::complex<double>(0, 1), v[1] = 1;
fun.factor = 1.5;
tmin = -1.1, tmax = 1.1;
converge = p_chetest->checkconverge(fun_sigma_y, v, 2, tmax, tmin, 0.2);
converge = p_chetest->checkconverge(fun_sigma_y, v, 2, 2, tmax, tmin, 0.2);
EXPECT_FALSE(converge);

fun.factor = -1.5;
tmin = -1.1, tmax = 1.1;
converge = p_chetest->checkconverge(fun_sigma_y, v, 2, tmax, tmin, 0.2);
converge = p_chetest->checkconverge(fun_sigma_y, v, 2, 2, tmax, tmin, 0.2);
EXPECT_FALSE(converge);
fun.factor = 1;

Expand Down Expand Up @@ -632,9 +632,9 @@ TEST_F(MathChebyshevTest, checkconverge_float)

auto fun_sigma_yf
= [&](std::complex<float>* in, std::complex<float>* out, const int m = 1) { fun.sigma_y(in, out, m); };
converge = p_fchetest->checkconverge(fun_sigma_yf, v, 2, tmax, tmin, 0.2);
converge = p_fchetest->checkconverge(fun_sigma_yf, v, 2, 2, tmax, tmin, 0.2);
EXPECT_TRUE(converge);
converge = p_fchetest->checkconverge(fun_sigma_yf, v + 2, 2, tmax, tmin, 0.2);
converge = p_fchetest->checkconverge(fun_sigma_yf, v + 2, 2, 2, tmax, tmin, 0.2);
EXPECT_TRUE(converge);
EXPECT_NEAR(tmin, -1.1, 1e-6);
EXPECT_NEAR(tmax, 1.1, 1e-6);
Expand Down
12 changes: 11 additions & 1 deletion source/module_esolver/esolver_sdft_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,15 @@ void ESolver_SDFT_PW::before_all_runners(const Input_para& inp, UnitCell& ucell)
void ESolver_SDFT_PW::before_scf(const int istep)
{
ESolver_KS_PW::before_scf(istep);
delete reinterpret_cast<hamilt::HamiltPW<double>*>(this->p_hamilt);
this->p_hamilt = new hamilt::HamiltSdftPW<std::complex<double>>(this->pelec->pot,
this->pw_wfc,
&this->kv,
PARAM.globalv.npol,
&this->stoche.emin_sto,
&this->stoche.emax_sto);
this->p_hamilt_sto = static_cast<hamilt::HamiltSdftPW<std::complex<double>>*>(this->p_hamilt);

if (istep > 0 && PARAM.inp.nbands_sto != 0 && PARAM.inp.initsto_freq > 0 && istep % PARAM.inp.initsto_freq == 0)
{
Update_Sto_Orbitals(this->stowf, PARAM.inp.seed_sto);
Expand Down Expand Up @@ -177,7 +186,8 @@ void ESolver_SDFT_PW::hamilt2density(int istep, int iter, double ethr)
this->pw_wfc,
&this->wf,
this->stowf,
this->stoche,
this->stoche,
this->p_hamilt_sto,
PARAM.inp.calculation,
PARAM.inp.basis_type,
PARAM.inp.ks_solver,
Expand Down
3 changes: 2 additions & 1 deletion source/module_esolver/esolver_sdft_pw.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
#define ESOLVER_SDFT_PW_H

#include "esolver_ks_pw.h"
#include "module_hamilt_pw/hamilt_stodft/sto_hchi.h"
#include "module_hamilt_pw/hamilt_stodft/sto_iter.h"
#include "module_hamilt_pw/hamilt_stodft/sto_wf.h"
#include "module_hamilt_pw/hamilt_stodft/sto_che.h"
#include "module_hamilt_pw/hamilt_stodft/hamilt_sdft_pw.h"

namespace ModuleESolver
{
Expand All @@ -27,6 +27,7 @@ class ESolver_SDFT_PW : public ESolver_KS_PW<std::complex<double>>
public:
Stochastic_WF stowf;
StoChe<double> stoche;
hamilt::HamiltSdftPW<std::complex<double>>* p_hamilt_sto = nullptr;

protected:
virtual void before_scf(const int istep) override;
Expand Down
10 changes: 5 additions & 5 deletions source/module_hamilt_general/operator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ typename Operator<T, Device>::hpsi_info Operator<T, Device>::hPsi(hpsi_info& inp
this->hpsi = new psi::Psi<T, Device>(hpsi_pointer, *psi_input, 1, nbands / psi_input->npol);
}

auto call_act = [&, this](const Operator* op) -> void {
auto call_act = [&, this](const Operator* op, const bool& is_first_node) -> void {
// a "psi" with the bands of needed range
psi::Psi<T, Device> psi_wrapper(const_cast<T*>(tmpsi_in), 1, nbands, psi_input->get_nbasis());
switch (op->get_act_type())
Expand All @@ -69,17 +69,17 @@ typename Operator<T, Device>::hpsi_info Operator<T, Device>::hPsi(hpsi_info& inp
op->act(psi_wrapper, *this->hpsi, nbands);
break;
default:
op->act(nbands, psi_input->get_nbasis(), psi_input->npol, tmpsi_in, this->hpsi->get_pointer(), psi_input->get_ngk(op->ik));
op->act(nbands, psi_input->get_nbasis(), psi_input->npol, tmpsi_in, this->hpsi->get_pointer(), psi_input->get_ngk(op->ik), is_first_node);
break;
}
};

ModuleBase::timer::tick("Operator", "hPsi");
call_act(this);
call_act(this, true); // first node
Operator* node((Operator*)this->next_op);
while (node != nullptr)
{
call_act(node);
call_act(node, false); // other nodes
node = (Operator*)(node->next_op);
}
ModuleBase::timer::tick("Operator", "hPsi");
Expand Down Expand Up @@ -162,7 +162,7 @@ T* Operator<T, Device>::get_hpsi(const hpsi_info& info) const
size_t total_hpsi_size = nbands_range * this->hpsi->get_nbasis();
// ModuleBase::GlobalFunc::ZEROS(hpsi_pointer, total_hpsi_size);
// denghui replaced at 20221104
set_memory_op()(this->ctx, hpsi_pointer, 0, total_hpsi_size);
// set_memory_op()(this->ctx, hpsi_pointer, 0, total_hpsi_size);
return hpsi_pointer;
}

Expand Down
4 changes: 3 additions & 1 deletion source/module_hamilt_general/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,14 @@ class Operator
///do operation : |hpsi_choosed> = V|psi_choosed>
///V is the target operator act on choosed psi, the consequence should be added to choosed hpsi
/// interface type 1: pointer-only (default)
/// @note PW: nbasis = max_npw * npol, nbands = nband * npol, npol = npol. Strange but PAY ATTENTION!!!
virtual void act(const int nbands,
const int nbasis,
const int npol,
const T* tmpsi_in,
T* tmhpsi,
const int ngk_ik = 0)const {};
const int ngk_ik = 0,
const bool is_first_node = false)const {};

/// developer-friendly interfaces for act() function
/// interface type 2: input and change the Psi-type HPsi
Expand Down
25 changes: 21 additions & 4 deletions source/module_hamilt_pw/hamilt_pwdft/kernels/cuda/ekinetic_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,32 @@ template <typename FPTYPE>
__global__ void ekinetic_pw(
const int npw,
const int max_npw,
const bool is_first_node,
const FPTYPE tpiba2,
const FPTYPE* gk2,
thrust::complex<FPTYPE>* hpsi,
const thrust::complex<FPTYPE>* psi)
{
const int block_idx = blockIdx.x;
const int thread_idx = threadIdx.x;
for (int ii = thread_idx; ii < npw; ii+= blockDim.x) {
hpsi[block_idx * max_npw + ii]
+= gk2[ii] * tpiba2 * psi[block_idx * max_npw + ii];
const int start_idx = block_idx * max_npw;
if(is_first_node)
{
for (int ii = thread_idx; ii < npw; ii += blockDim.x)
{
hpsi[start_idx + ii] = gk2[ii] * tpiba2 * psi[start_idx + ii];
}
for (int ii = npw + thread_idx; ii < max_npw; ii += blockDim.x)
{
hpsi[start_idx + ii] = 0.0;
}
}
else
{
for (int ii = thread_idx; ii < npw; ii += blockDim.x)
{
hpsi[start_idx + ii] += gk2[ii] * tpiba2 * psi[start_idx + ii];
}
}
}

Expand All @@ -31,6 +47,7 @@ void hamilt::ekinetic_pw_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const b
const int& nband,
const int& npw,
const int& max_npw,
const bool& is_first_node,
const FPTYPE& tpiba2,
const FPTYPE* gk2_ik,
std::complex<FPTYPE>* tmhpsi,
Expand All @@ -39,7 +56,7 @@ void hamilt::ekinetic_pw_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const b
// denghui implement 20221019
// <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
ekinetic_pw<FPTYPE><<<nband, THREADS_PER_BLOCK>>>(
npw, max_npw, tpiba2, // control params
npw, max_npw, is_first_node, tpiba2, // control params
gk2_ik, // array of data
reinterpret_cast<thrust::complex<FPTYPE>*>(tmhpsi), // array of data
reinterpret_cast<const thrust::complex<FPTYPE>*>(tmpsi_in)); // array of data
Expand Down
42 changes: 36 additions & 6 deletions source/module_hamilt_pw/hamilt_pwdft/kernels/ekinetic_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,50 @@ struct ekinetic_pw_op<FPTYPE, base_device::DEVICE_CPU>
const int& nband,
const int& npw,
const int& max_npw,
const bool& is_first_node,
const FPTYPE& tpiba2,
const FPTYPE* gk2_ik,
std::complex<FPTYPE>* tmhpsi,
const std::complex<FPTYPE>* tmpsi_in)
{
if (is_first_node)
{
for (int ib = 0; ib < nband; ++ib)
{
#ifdef _OPENMP
#pragma omp parallel for collapse(2) schedule(static, 4096/sizeof(FPTYPE))
#pragma omp parallel for
#endif
for (int ib = 0; ib < nband; ++ib) {
for (int ig = 0; ig < npw; ++ig) {
tmhpsi[ib * max_npw + ig] += gk2_ik[ig] * tpiba2 * tmpsi_in[ib * max_npw + ig];
}
for (int ig = 0; ig < npw; ++ig)
{
tmhpsi[ig] = gk2_ik[ig] * tpiba2 * tmpsi_in[ig];
}
#ifdef _OPENMP
#pragma omp parallel for
#endif
for (int ig = npw; ig < max_npw; ++ig)
{
tmhpsi[ig] = 0.0;
}
tmpsi_in += max_npw;
tmhpsi += max_npw;
}
}
else
{
for (int ib = 0; ib < nband; ++ib)
{
#ifdef _OPENMP
#pragma omp parallel for
#endif
for (int ig = 0; ig < npw; ++ig)
{
tmhpsi[ig] += gk2_ik[ig] * tpiba2 * tmpsi_in[ig];
}
tmpsi_in += max_npw;
tmhpsi += max_npw;
}
}
}
}
};

template struct ekinetic_pw_op<float, base_device::DEVICE_CPU>;
Expand Down
2 changes: 2 additions & 0 deletions source/module_hamilt_pw/hamilt_pwdft/kernels/ekinetic_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ struct ekinetic_pw_op {
const int& nband,
const int& npw,
const int& max_npw,
const bool& is_first_node,
const FPTYPE& tpiba2,
const FPTYPE* gk2_ik,
std::complex<FPTYPE>* tmhpsi,
Expand All @@ -41,6 +42,7 @@ struct ekinetic_pw_op<FPTYPE, base_device::DEVICE_GPU>
const int& nband,
const int& npw,
const int& max_npw,
const bool& is_first_node,
const FPTYPE& tpiba2,
const FPTYPE* gk2_ik,
std::complex<FPTYPE>* tmhpsi,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,24 +13,42 @@ template <typename FPTYPE>
__global__ void ekinetic_pw(
const int npw,
const int max_npw,
const bool is_first_node,
const FPTYPE tpiba2,
const FPTYPE* gk2,
thrust::complex<FPTYPE>* hpsi,
const thrust::complex<FPTYPE>* psi)
{
const int block_idx = blockIdx.x;
const int thread_idx = threadIdx.x;
for (int ii = thread_idx; ii < npw; ii+= blockDim.x) {
hpsi[block_idx * max_npw + ii]
+= gk2[ii] * tpiba2 * psi[block_idx * max_npw + ii];
const int start_idx = block_idx * max_npw;
if(is_first_node)
{
for (int ii = thread_idx; ii < npw; ii += blockDim.x)
{
hpsi[start_idx + ii] = gk2[ii] * tpiba2 * psi[start_idx + ii];
}
for (int ii = npw + thread_idx; ii < max_npw; ii += blockDim.x)
{
hpsi[start_idx + ii] = 0.0;
}
}
else
{
for (int ii = thread_idx; ii < npw; ii += blockDim.x)
{
hpsi[start_idx + ii] += gk2[ii] * tpiba2 * psi[start_idx + ii];
}
}

}

template <typename FPTYPE>
void hamilt::ekinetic_pw_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* dev,
const int& nband,
const int& npw,
const int& max_npw,
const bool& is_first_node,
const FPTYPE& tpiba2,
const FPTYPE* gk2_ik,
std::complex<FPTYPE>* tmhpsi,
Expand All @@ -39,7 +57,7 @@ void hamilt::ekinetic_pw_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const b
// denghui implement 20221019
// <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
hipLaunchKernelGGL(HIP_KERNEL_NAME(ekinetic_pw<FPTYPE>), dim3(nband), dim3(THREADS_PER_BLOCK), 0, 0,
npw, max_npw, tpiba2, // control params
npw, max_npw, is_first_node, tpiba2, // control params
gk2_ik, // array of data
reinterpret_cast<thrust::complex<FPTYPE>*>(tmhpsi), // array of data
reinterpret_cast<const thrust::complex<FPTYPE>*>(tmpsi_in)); // array of data
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class TestModuleHamiltEkinetic : public ::testing::Test
TEST_F(TestModuleHamiltEkinetic, ekinetic_pw_op_cpu)
{
std::vector<std::complex<double> > hpsi(expected_hpsi.size(), std::complex<double>(0.0, 0.0));
ekinetic_cpu_op()(cpu_ctx, band, dim, dim, tpiba2, gk2.data(), hpsi.data(), psi.data());
ekinetic_cpu_op()(cpu_ctx, band, dim, dim, false, tpiba2, gk2.data(), hpsi.data(), psi.data());
for (int ii = 0; ii < hpsi.size(); ii++) {
EXPECT_LT(std::abs(hpsi[ii] - expected_hpsi[ii]), 1e-6);
}
Expand All @@ -89,7 +89,7 @@ TEST_F(TestModuleHamiltEkinetic, ekinetic_pw_op_gpu)
syncmem_d_h2d_op()(gpu_ctx, cpu_ctx, gk2_dev, gk2.data(), gk2.size());
syncmem_cd_h2d_op()(gpu_ctx, cpu_ctx, psi_dev, psi.data(), psi.size());
// ekinetic_cpu_op()(cpu_ctx, band, dim, dim, tpiba2, gk2.data(), hpsi.data(), psi.data());
ekinetic_gpu_op()(gpu_ctx, band, dim, dim, tpiba2, gk2_dev, hpsi_dev, psi_dev);
ekinetic_gpu_op()(gpu_ctx, band, dim, dim, false, tpiba2, gk2_dev, hpsi_dev, psi_dev);
syncmem_cd_d2h_op()(cpu_ctx, gpu_ctx, hpsi.data(), hpsi_dev, hpsi.size());

for (int ii = 0; ii < hpsi.size(); ii++) {
Expand Down
Loading

0 comments on commit 7194eb7

Please sign in to comment.