From 000f7ba10736382a1572872d1faee7f68934c01d Mon Sep 17 00:00:00 2001 From: Haozhi Han Date: Thu, 11 Jul 2024 17:46:57 +0800 Subject: [PATCH] refactor paw func in hsolver --- source/module_hsolver/hsolver_pw.cpp | 272 +++++++++++---------------- source/module_hsolver/hsolver_pw.h | 3 +- 2 files changed, 113 insertions(+), 162 deletions(-) diff --git a/source/module_hsolver/hsolver_pw.cpp b/source/module_hsolver/hsolver_pw.cpp index c615756336..0c3051a383 100644 --- a/source/module_hsolver/hsolver_pw.cpp +++ b/source/module_hsolver/hsolver_pw.cpp @@ -102,169 +102,9 @@ void HSolverPW::call_paw_cell_set_currentk(const int ik) } } -#endif - template -void HSolverPW::solve(hamilt::Hamilt* pHamilt, - psi::Psi& psi, - elecstate::ElecState* pes, - const std::string method_in, - const bool skip_charge) +void HSolverPW::paw_func_after_kloop(psi::Psi& psi, elecstate::ElecState* pes) { - ModuleBase::TITLE("HSolverPW", "solve"); - ModuleBase::timer::tick("HSolverPW", "solve"); - // prepare for the precondition of diagonalization - this->precondition.resize(psi.get_nbasis()); - this->hamilt_ = pHamilt; - // select the method of diagonalization - this->method = method_in; - // report if the specified diagonalization method is not supported - const std::initializer_list _methods = {"cg", "dav", "dav_subspace", "bpcg"}; - if (std::find(std::begin(_methods), std::end(_methods), this->method) == std::end(_methods)) - { - ModuleBase::WARNING_QUIT("HSolverPW::solve", "This method of DiagH is not supported!"); - } - - std::vector eigenvalues(pes->ekb.nr * pes->ekb.nc, 0); - - if (this->is_first_scf) - { - is_occupied.resize(psi.get_nk() * psi.get_nbands(), true); - } - else - { - if (this->diago_full_acc) - { - is_occupied.assign(is_occupied.size(), true); - } - else - { - for (int i = 0; i < psi.get_nk(); i++) - { - if (pes->klist->wk[i] > 0.0) - { - for (int j = 0; j < psi.get_nbands(); j++) - { - if (pes->wg(i, j) / pes->klist->wk[i] < 0.01) - { - is_occupied[i * psi.get_nbands() + j] = false; - } - } - } - } - } - } - - /// Loop over k points for solve Hamiltonian to charge density - for (int ik = 0; ik < this->wfc_basis->nks; ++ik) - { - /// update H(k) for each k point - pHamilt->updateHk(ik); - -#ifdef USE_PAW - // if (GlobalV::use_paw) - // { - // const int npw = this->wfc_basis->npwk[ik]; - // ModuleBase::Vector3* _gk = new ModuleBase::Vector3[npw]; - // for (int ig = 0; ig < npw; ig++) - // { - // _gk[ig] = this->wfc_basis->getgpluskcar(ik, ig); - // } - - // std::vector kpt(3, 0); - // kpt[0] = this->wfc_basis->kvec_c[ik].x; - // kpt[1] = this->wfc_basis->kvec_c[ik].y; - // kpt[2] = this->wfc_basis->kvec_c[ik].z; - - // double** kpg; - // double** gcar; - // kpg = new double*[npw]; - // gcar = new double*[npw]; - // for (int ipw = 0; ipw < npw; ipw++) - // { - // kpg[ipw] = new double[3]; - // kpg[ipw][0] = _gk[ipw].x; - // kpg[ipw][1] = _gk[ipw].y; - // kpg[ipw][2] = _gk[ipw].z; - - // gcar[ipw] = new double[3]; - // gcar[ipw][0] = this->wfc_basis->getgcar(ik, ipw).x; - // gcar[ipw][1] = this->wfc_basis->getgcar(ik, ipw).y; - // gcar[ipw][2] = this->wfc_basis->getgcar(ik, ipw).z; - // } - - // GlobalC::paw_cell.set_paw_k(npw, - // wfc_basis->npwk_max, - // kpt.data(), - // this->wfc_basis->get_ig2ix(ik).data(), - // this->wfc_basis->get_ig2iy(ik).data(), - // this->wfc_basis->get_ig2iz(ik).data(), - // (const double**)kpg, - // GlobalC::ucell.tpiba, - // (const double**)gcar); - - // std::vector().swap(kpt); - // for (int ipw = 0; ipw < npw; ipw++) - // { - // delete[] kpg[ipw]; - // delete[] gcar[ipw]; - // } - // delete[] kpg; - // delete[] gcar; - - // GlobalC::paw_cell.get_vkb(); - - // GlobalC::paw_cell.set_currentk(ik); - // } - - this->paw_func_in_kloop(ik); - -#endif - - this->updatePsiK(pHamilt, psi, ik); - - // template add precondition calculating here - update_precondition(precondition, ik, this->wfc_basis->npwk[ik]); - -#ifdef USE_PAW - // GlobalC::paw_cell.set_currentk(ik); - this->call_paw_cell_set_currentk(ik); -#endif - - /// solve eigenvector and eigenvalue for H(k) - this->hamiltSolvePsiK(pHamilt, psi, eigenvalues.data() + ik * pes->ekb.nc); - - if (skip_charge) - { - GlobalV::ofs_running << "Average iterative diagonalization steps for k-points " << ik - << " is: " << DiagoIterAssist::avg_iter - << " ; where current threshold is: " << DiagoIterAssist::PW_DIAG_THR - << " . " << std::endl; - DiagoIterAssist::avg_iter = 0.0; - } - /// calculate the contribution of Psi for charge density rho - } - // END Loop over k points - - base_device::memory::cast_memory_op()( - cpu_ctx, - cpu_ctx, - pes->ekb.c, - eigenvalues.data(), - pes->ekb.nr * pes->ekb.nc); - - this->is_first_scf = false; - - this->endDiagh(); - - if (skip_charge) - { - ModuleBase::timer::tick("HSolverPW", "solve"); - return; - } - reinterpret_cast*>(pes)->psiToRho(psi); - -#ifdef USE_PAW if (GlobalV::use_paw) { if (typeid(Real) != typeid(double)) @@ -369,7 +209,117 @@ void HSolverPW::solve(hamilt::Hamilt* pHamilt, double* nhatgr; GlobalC::paw_cell.get_nhat(pes->charge->nhat, nhatgr); } +} + #endif + +template +void HSolverPW::solve(hamilt::Hamilt* pHamilt, + psi::Psi& psi, + elecstate::ElecState* pes, + const std::string method_in, + const bool skip_charge) +{ + ModuleBase::TITLE("HSolverPW", "solve"); + ModuleBase::timer::tick("HSolverPW", "solve"); + // prepare for the precondition of diagonalization + this->precondition.resize(psi.get_nbasis()); + this->hamilt_ = pHamilt; + // select the method of diagonalization + this->method = method_in; + // report if the specified diagonalization method is not supported + const std::initializer_list _methods = {"cg", "dav", "dav_subspace", "bpcg"}; + if (std::find(std::begin(_methods), std::end(_methods), this->method) == std::end(_methods)) + { + ModuleBase::WARNING_QUIT("HSolverPW::solve", "This method of DiagH is not supported!"); + } + + std::vector eigenvalues(pes->ekb.nr * pes->ekb.nc, 0); + + if (this->is_first_scf) + { + is_occupied.resize(psi.get_nk() * psi.get_nbands(), true); + } + else + { + if (this->diago_full_acc) + { + is_occupied.assign(is_occupied.size(), true); + } + else + { + for (int i = 0; i < psi.get_nk(); i++) + { + if (pes->klist->wk[i] > 0.0) + { + for (int j = 0; j < psi.get_nbands(); j++) + { + if (pes->wg(i, j) / pes->klist->wk[i] < 0.01) + { + is_occupied[i * psi.get_nbands() + j] = false; + } + } + } + } + } + } + + /// Loop over k points for solve Hamiltonian to charge density + for (int ik = 0; ik < this->wfc_basis->nks; ++ik) + { + /// update H(k) for each k point + pHamilt->updateHk(ik); + +#ifdef USE_PAW + this->paw_func_in_kloop(ik); +#endif + + this->updatePsiK(pHamilt, psi, ik); + + // template add precondition calculating here + update_precondition(precondition, ik, this->wfc_basis->npwk[ik]); + +#ifdef USE_PAW + this->call_paw_cell_set_currentk(ik); +#endif + + /// solve eigenvector and eigenvalue for H(k) + this->hamiltSolvePsiK(pHamilt, psi, eigenvalues.data() + ik * pes->ekb.nc); + + if (skip_charge) + { + GlobalV::ofs_running << "Average iterative diagonalization steps for k-points " << ik + << " is: " << DiagoIterAssist::avg_iter + << " ; where current threshold is: " << DiagoIterAssist::PW_DIAG_THR + << " . " << std::endl; + DiagoIterAssist::avg_iter = 0.0; + } + /// calculate the contribution of Psi for charge density rho + } + // END Loop over k points + + base_device::memory::cast_memory_op()( + cpu_ctx, + cpu_ctx, + pes->ekb.c, + eigenvalues.data(), + pes->ekb.nr * pes->ekb.nc); + + this->is_first_scf = false; + + this->endDiagh(); + + if (skip_charge) + { + ModuleBase::timer::tick("HSolverPW", "solve"); + return; + } + reinterpret_cast*>(pes)->psiToRho(psi); + +#ifdef USE_PAW + this->paw_func_after_kloop(psi, pes); +#endif + ModuleBase::timer::tick("HSolverPW", "solve"); return; } diff --git a/source/module_hsolver/hsolver_pw.h b/source/module_hsolver/hsolver_pw.h index 5041492bef..78f5ee7543 100644 --- a/source/module_hsolver/hsolver_pw.h +++ b/source/module_hsolver/hsolver_pw.h @@ -78,8 +78,9 @@ class HSolverPW : public HSolver void paw_func_in_kloop(const int ik); void call_paw_cell_set_currentk(const int ik); -#endif + void paw_func_after_kloop(psi::Psi& psi, elecstate::ElecState* pes); +#endif }; template