From 43357d198440bbd0605f76baa43e73c7d286a2a1 Mon Sep 17 00:00:00 2001 From: Haozhi Han Date: Wed, 10 Jul 2024 21:39:48 +0800 Subject: [PATCH] Refactor: format hsolver_pw (#4633) * format hsolver_pw * fix build bug --- source/module_hsolver/hsolver_pw.cpp | 503 +++++++++++++-------------- source/module_hsolver/hsolver_pw.h | 45 +-- 2 files changed, 255 insertions(+), 293 deletions(-) diff --git a/source/module_hsolver/hsolver_pw.cpp b/source/module_hsolver/hsolver_pw.cpp index a9ea6728ec..62bf341168 100644 --- a/source/module_hsolver/hsolver_pw.cpp +++ b/source/module_hsolver/hsolver_pw.cpp @@ -4,6 +4,7 @@ #include "diago_cg.h" #include "diago_dav_subspace.h" #include "diago_david.h" + #include "module_base/global_variable.h" #include "module_base/parallel_global.h" // for MPI #include "module_base/timer.h" @@ -12,19 +13,20 @@ #include "module_hamilt_pw/hamilt_pwdft/global.h" #include "module_hamilt_pw/hamilt_pwdft/hamilt_pw.h" #include "module_hamilt_pw/hamilt_pwdft/wavefunc.h" -#include "module_hsolver/diagh.h" #include "module_hsolver/diag_comm_info.h" #include "module_hsolver/diago_iter_assist.h" #include + #ifdef USE_PAW #include "module_cell/module_paw/paw_cell.h" #endif -namespace hsolver { +namespace hsolver +{ template -HSolverPW::HSolverPW(ModulePW::PW_Basis_K* wfc_basis_in, - wavefunc* pwf_in) { +HSolverPW::HSolverPW(ModulePW::PW_Basis_K* wfc_basis_in, wavefunc* pwf_in) +{ this->classname = "HSolverPW"; this->wfc_basis = wfc_basis_in; this->pwf = pwf_in; @@ -37,7 +39,8 @@ void HSolverPW::solve(hamilt::Hamilt* pHamilt, psi::Psi& psi, elecstate::ElecState* pes, const std::string method_in, - const bool skip_charge) { + const bool skip_charge) +{ ModuleBase::TITLE("HSolverPW", "solve"); ModuleBase::timer::tick("HSolverPW", "solve"); // prepare for the precondition of diagonalization @@ -46,26 +49,34 @@ void HSolverPW::solve(hamilt::Hamilt* pHamilt, // select the method of diagonalization this->method = method_in; // report if the specified diagonalization method is not supported - const std::initializer_list _methods - = {"cg", "dav", "dav_subspace", "bpcg"}; - if (std::find(std::begin(_methods), std::end(_methods), this->method) - == std::end(_methods)) { - ModuleBase::WARNING_QUIT("HSolverPW::solve", - "This method of DiagH is not supported!"); + const std::initializer_list _methods = {"cg", "dav", "dav_subspace", "bpcg"}; + if (std::find(std::begin(_methods), std::end(_methods), this->method) == std::end(_methods)) + { + ModuleBase::WARNING_QUIT("HSolverPW::solve", "This method of DiagH is not supported!"); } std::vector eigenvalues(pes->ekb.nr * pes->ekb.nc, 0); - if (this->is_first_scf) { + if (this->is_first_scf) + { is_occupied.resize(psi.get_nk() * psi.get_nbands(), true); - } else { - if (this->diago_full_acc) { + } + else + { + if (this->diago_full_acc) + { is_occupied.assign(is_occupied.size(), true); - } else { - for (int i = 0; i < psi.get_nk(); i++) { - if (pes->klist->wk[i] > 0.0) { - for (int j = 0; j < psi.get_nbands(); j++) { - if (pes->wg(i, j) / pes->klist->wk[i] < 0.01) { + } + else + { + for (int i = 0; i < psi.get_nk(); i++) + { + if (pes->klist->wk[i] > 0.0) + { + for (int j = 0; j < psi.get_nbands(); j++) + { + if (pes->wg(i, j) / pes->klist->wk[i] < 0.01) + { is_occupied[i * psi.get_nbands() + j] = false; } } @@ -75,16 +86,18 @@ void HSolverPW::solve(hamilt::Hamilt* pHamilt, } /// Loop over k points for solve Hamiltonian to charge density - for (int ik = 0; ik < this->wfc_basis->nks; ++ik) { + for (int ik = 0; ik < this->wfc_basis->nks; ++ik) + { /// update H(k) for each k point pHamilt->updateHk(ik); #ifdef USE_PAW - if (GlobalV::use_paw) { + if (GlobalV::use_paw) + { const int npw = this->wfc_basis->npwk[ik]; - ModuleBase::Vector3* _gk - = new ModuleBase::Vector3[npw]; - for (int ig = 0; ig < npw; ig++) { + ModuleBase::Vector3* _gk = new ModuleBase::Vector3[npw]; + for (int ig = 0; ig < npw; ig++) + { _gk[ig] = this->wfc_basis->getgpluskcar(ik, ig); } @@ -97,7 +110,8 @@ void HSolverPW::solve(hamilt::Hamilt* pHamilt, double** gcar; kpg = new double*[npw]; gcar = new double*[npw]; - for (int ipw = 0; ipw < npw; ipw++) { + for (int ipw = 0; ipw < npw; ipw++) + { kpg[ipw] = new double[3]; kpg[ipw][0] = _gk[ipw].x; kpg[ipw][1] = _gk[ipw].y; @@ -120,7 +134,8 @@ void HSolverPW::solve(hamilt::Hamilt* pHamilt, (const double**)gcar); std::vector().swap(kpt); - for (int ipw = 0; ipw < npw; ipw++) { + for (int ipw = 0; ipw < npw; ipw++) + { delete[] kpg[ipw]; delete[] gcar[ipw]; } @@ -143,53 +158,53 @@ void HSolverPW::solve(hamilt::Hamilt* pHamilt, #endif /// solve eigenvector and eigenvalue for H(k) - this->hamiltSolvePsiK(pHamilt, - psi, - eigenvalues.data() + ik * pes->ekb.nc); - - if (skip_charge) { - GlobalV::ofs_running - << "Average iterative diagonalization steps for k-points " << ik - << " is: " << DiagoIterAssist::avg_iter - << " ; where current threshold is: " - << DiagoIterAssist::PW_DIAG_THR << " . " - << std::endl; + this->hamiltSolvePsiK(pHamilt, psi, eigenvalues.data() + ik * pes->ekb.nc); + + if (skip_charge) + { + GlobalV::ofs_running << "Average iterative diagonalization steps for k-points " << ik + << " is: " << DiagoIterAssist::avg_iter + << " ; where current threshold is: " << DiagoIterAssist::PW_DIAG_THR + << " . " << std::endl; DiagoIterAssist::avg_iter = 0.0; } /// calculate the contribution of Psi for charge density rho } // END Loop over k points - castmem_2d_2h_op()(cpu_ctx, - cpu_ctx, - pes->ekb.c, - eigenvalues.data(), - pes->ekb.nr * pes->ekb.nc); + base_device::memory::cast_memory_op()( + cpu_ctx, + cpu_ctx, + pes->ekb.c, + eigenvalues.data(), + pes->ekb.nr * pes->ekb.nc); this->is_first_scf = false; this->endDiagh(); - if (skip_charge) { + if (skip_charge) + { ModuleBase::timer::tick("HSolverPW", "solve"); return; } reinterpret_cast*>(pes)->psiToRho(psi); #ifdef USE_PAW - if (GlobalV::use_paw) { - if (typeid(Real) != typeid(double)) { - ModuleBase::WARNING_QUIT( - "HSolverPW::solve", - "PAW is only supported for double precision!"); + if (GlobalV::use_paw) + { + if (typeid(Real) != typeid(double)) + { + ModuleBase::WARNING_QUIT("HSolverPW::solve", "PAW is only supported for double precision!"); } GlobalC::paw_cell.reset_rhoij(); - for (int ik = 0; ik < this->wfc_basis->nks; ++ik) { + for (int ik = 0; ik < this->wfc_basis->nks; ++ik) + { const int npw = this->wfc_basis->npwk[ik]; - ModuleBase::Vector3* _gk - = new ModuleBase::Vector3[npw]; - for (int ig = 0; ig < npw; ig++) { + ModuleBase::Vector3* _gk = new ModuleBase::Vector3[npw]; + for (int ig = 0; ig < npw; ig++) + { _gk[ig] = this->wfc_basis->getgpluskcar(ik, ig); } @@ -202,7 +217,8 @@ void HSolverPW::solve(hamilt::Hamilt* pHamilt, double** gcar; kpg = new double*[npw]; gcar = new double*[npw]; - for (int ipw = 0; ipw < npw; ipw++) { + for (int ipw = 0; ipw < npw; ipw++) + { kpg[ipw] = new double[3]; kpg[ipw][0] = _gk[ipw].x; kpg[ipw][1] = _gk[ipw].y; @@ -225,7 +241,8 @@ void HSolverPW::solve(hamilt::Hamilt* pHamilt, (const double**)gcar); std::vector().swap(kpt); - for (int ipw = 0; ipw < npw; ipw++) { + for (int ipw = 0; ipw < npw; ipw++) + { delete[] kpg[ipw]; delete[] gcar[ipw]; } @@ -237,11 +254,10 @@ void HSolverPW::solve(hamilt::Hamilt* pHamilt, psi.fix_k(ik); GlobalC::paw_cell.set_currentk(ik); int nbands = psi.get_nbands(); - for (int ib = 0; ib < nbands; ib++) { - GlobalC::paw_cell.accumulate_rhoij( - reinterpret_cast*>( - psi.get_pointer(ib)), - pes->wg(ik, ib)); + for (int ib = 0; ib < nbands; ib++) + { + GlobalC::paw_cell.accumulate_rhoij(reinterpret_cast*>(psi.get_pointer(ib)), + pes->wg(ik, ib)); } } @@ -250,10 +266,12 @@ void HSolverPW::solve(hamilt::Hamilt* pHamilt, std::vector nrhoijsel; #ifdef __MPI - if (GlobalV::RANK_IN_POOL == 0) { + if (GlobalV::RANK_IN_POOL == 0) + { GlobalC::paw_cell.get_rhoijp(rhoijp, rhoijselect, nrhoijsel); - for (int iat = 0; iat < GlobalC::ucell.nat; iat++) { + for (int iat = 0; iat < GlobalC::ucell.nat; iat++) + { GlobalC::paw_cell.set_rhoij(iat, nrhoijsel[iat], rhoijselect[iat].size(), @@ -264,7 +282,8 @@ void HSolverPW::solve(hamilt::Hamilt* pHamilt, #else GlobalC::paw_cell.get_rhoijp(rhoijp, rhoijselect, nrhoijsel); - for (int iat = 0; iat < GlobalC::ucell.nat; iat++) { + for (int iat = 0; iat < GlobalC::ucell.nat; iat++) + { GlobalC::paw_cell.set_rhoij(iat, nrhoijsel[iat], rhoijselect[iat].size(), @@ -282,15 +301,16 @@ void HSolverPW::solve(hamilt::Hamilt* pHamilt, } template -void HSolverPW::endDiagh() { +void HSolverPW::endDiagh() +{ // in PW base, average iteration steps for each band and k-point should be // printing - if (DiagoIterAssist::avg_iter > 0.0) { - GlobalV::ofs_running - << "Average iterative diagonalization steps: " - << DiagoIterAssist::avg_iter / this->wfc_basis->nks - << " ; where current threshold is: " - << DiagoIterAssist::PW_DIAG_THR << " . " << std::endl; + if (DiagoIterAssist::avg_iter > 0.0) + { + GlobalV::ofs_running << "Average iterative diagonalization steps: " + << DiagoIterAssist::avg_iter / this->wfc_basis->nks + << " ; where current threshold is: " << DiagoIterAssist::PW_DIAG_THR << " . " + << std::endl; // std::cout << "avg_iter == " << DiagoIterAssist::avg_iter // << std::endl; @@ -299,59 +319,51 @@ void HSolverPW::endDiagh() { DiagoIterAssist::avg_iter = 0.0; } // psi only should be initialed once for PW - if (!this->initialed_psi) { + if (!this->initialed_psi) + { this->initialed_psi = true; } } template -void HSolverPW::updatePsiK(hamilt::Hamilt* pHamilt, - psi::Psi& psi, - const int ik) { +void HSolverPW::updatePsiK(hamilt::Hamilt* pHamilt, psi::Psi& psi, const int ik) +{ psi.fix_k(ik); - if (!GlobalV::psi_initializer && !this->initialed_psi && GlobalV::BASIS_TYPE == "pw") { - hamilt::diago_PAO_in_pw_k2(this->ctx, ik, psi, this->wfc_basis, this->pwf, pHamilt); -} + if (!GlobalV::psi_initializer && !this->initialed_psi && GlobalV::BASIS_TYPE == "pw") + { + hamilt::diago_PAO_in_pw_k2(this->ctx, ik, psi, this->wfc_basis, this->pwf, pHamilt); + } /* lcao_in_pw now is based on newly implemented psi initializer, so it does not appear here*/ - } template -void HSolverPW::hamiltSolvePsiK(hamilt::Hamilt* hm, - psi::Psi& psi, - Real* eigenvalue) { - if (this->method == "cg") { +void HSolverPW::hamiltSolvePsiK(hamilt::Hamilt* hm, psi::Psi& psi, Real* eigenvalue) +{ + if (this->method == "cg") + { // warp the subspace_func into a lambda function auto ngk_pointer = psi.get_ngk_pointer(); - auto subspace_func = [this, ngk_pointer](const ct::Tensor& psi_in, - ct::Tensor& psi_out) { + auto subspace_func = [this, ngk_pointer](const ct::Tensor& psi_in, ct::Tensor& psi_out) { // psi_in should be a 2D tensor: // psi_in.shape() = [nbands, nbasis] const auto ndim = psi_in.shape().ndim(); - REQUIRES_OK(ndim == 2, - "dims of psi_in should be less than or equal to 2"); + REQUIRES_OK(ndim == 2, "dims of psi_in should be less than or equal to 2"); // Convert a Tensor object to a psi::Psi object - auto psi_in_wrapper - = psi::Psi(psi_in.data(), - 1, - psi_in.shape().dim_size(0), - psi_in.shape().dim_size(1), - ngk_pointer); - auto psi_out_wrapper - = psi::Psi(psi_out.data(), - 1, - psi_out.shape().dim_size(0), - psi_out.shape().dim_size(1), - ngk_pointer); - auto eigen - = ct::Tensor(ct::DataTypeToEnum::value, - ct::DeviceType::CpuDevice, - ct::TensorShape({psi_in.shape().dim_size(0)})); - - DiagoIterAssist::diagH_subspace(hamilt_, - psi_in_wrapper, - psi_out_wrapper, - eigen.data()); + auto psi_in_wrapper = psi::Psi(psi_in.data(), + 1, + psi_in.shape().dim_size(0), + psi_in.shape().dim_size(1), + ngk_pointer); + auto psi_out_wrapper = psi::Psi(psi_out.data(), + 1, + psi_out.shape().dim_size(0), + psi_out.shape().dim_size(1), + ngk_pointer); + auto eigen = ct::Tensor(ct::DataTypeToEnum::value, + ct::DeviceType::CpuDevice, + ct::TensorShape({psi_in.shape().dim_size(0)})); + + DiagoIterAssist::diagH_subspace(hamilt_, psi_in_wrapper, psi_out_wrapper, eigen.data()); }; DiagoCG cg(GlobalV::BASIS_TYPE, GlobalV::CALCULATION, @@ -365,109 +377,95 @@ void HSolverPW::hamiltSolvePsiK(hamilt::Hamilt* hm, using ct_Device = typename ct::PsiToContainer::type; // warp the hpsi_func and spsi_func into a lambda function - auto hpsi_func = [hm, ngk_pointer](const ct::Tensor& psi_in, - ct::Tensor& hpsi_out) { + auto hpsi_func = [hm, ngk_pointer](const ct::Tensor& psi_in, ct::Tensor& hpsi_out) { ModuleBase::timer::tick("DiagoCG_New", "hpsi_func"); // psi_in should be a 2D tensor: // psi_in.shape() = [nbands, nbasis] const auto ndim = psi_in.shape().ndim(); - REQUIRES_OK(ndim <= 2, - "dims of psi_in should be less than or equal to 2"); + REQUIRES_OK(ndim <= 2, "dims of psi_in should be less than or equal to 2"); // Convert a Tensor object to a psi::Psi object - auto psi_wrapper = psi::Psi( - psi_in.data(), - 1, - ndim == 1 ? 1 : psi_in.shape().dim_size(0), - ndim == 1 ? psi_in.NumElements() : psi_in.shape().dim_size(1), - ngk_pointer); - psi::Range all_bands_range(true, - psi_wrapper.get_current_k(), - 0, - psi_wrapper.get_nbands() - 1); + auto psi_wrapper = psi::Psi(psi_in.data(), + 1, + ndim == 1 ? 1 : psi_in.shape().dim_size(0), + ndim == 1 ? psi_in.NumElements() : psi_in.shape().dim_size(1), + ngk_pointer); + psi::Range all_bands_range(true, psi_wrapper.get_current_k(), 0, psi_wrapper.get_nbands() - 1); using hpsi_info = typename hamilt::Operator::hpsi_info; hpsi_info info(&psi_wrapper, all_bands_range, hpsi_out.data()); hm->ops->hPsi(info); ModuleBase::timer::tick("DiagoCG_New", "hpsi_func"); }; - auto spsi_func = [this, hm](const ct::Tensor& psi_in, - ct::Tensor& spsi_out) { + auto spsi_func = [this, hm](const ct::Tensor& psi_in, ct::Tensor& spsi_out) { ModuleBase::timer::tick("DiagoCG_New", "spsi_func"); // psi_in should be a 2D tensor: // psi_in.shape() = [nbands, nbasis] const auto ndim = psi_in.shape().ndim(); - REQUIRES_OK(ndim <= 2, - "dims of psi_in should be less than or equal to 2"); + REQUIRES_OK(ndim <= 2, "dims of psi_in should be less than or equal to 2"); - if (GlobalV::use_uspp) { + if (GlobalV::use_uspp) + { // Convert a Tensor object to a psi::Psi object hm->sPsi(psi_in.data(), spsi_out.data(), - ndim == 1 ? psi_in.NumElements() - : psi_in.shape().dim_size(1), - ndim == 1 ? psi_in.NumElements() - : psi_in.shape().dim_size(1), + ndim == 1 ? psi_in.NumElements() : psi_in.shape().dim_size(1), + ndim == 1 ? psi_in.NumElements() : psi_in.shape().dim_size(1), ndim == 1 ? 1 : psi_in.shape().dim_size(0)); - } else { + } + else + { base_device::memory::synchronize_memory_op()( this->ctx, this->ctx, spsi_out.data(), psi_in.data(), - static_cast( - (ndim == 1 ? 1 : psi_in.shape().dim_size(0)) - * (ndim == 1 ? psi_in.NumElements() - : psi_in.shape().dim_size(1)))); + static_cast((ndim == 1 ? 1 : psi_in.shape().dim_size(0)) + * (ndim == 1 ? psi_in.NumElements() : psi_in.shape().dim_size(1)))); } ModuleBase::timer::tick("DiagoCG_New", "spsi_func"); }; - auto psi_tensor = ct::TensorMap( - psi.get_pointer(), - ct::DataTypeToEnum::value, - ct::DeviceTypeToEnum::value, - ct::TensorShape({psi.get_nbands(), psi.get_nbasis()})); - auto eigen_tensor - = ct::TensorMap(eigenvalue, - ct::DataTypeToEnum::value, - ct::DeviceTypeToEnum::value, - ct::TensorShape({psi.get_nbands()})); - auto prec_tensor - = ct::TensorMap( - precondition.data(), - ct::DataTypeToEnum::value, - ct::DeviceTypeToEnum::value, - ct::TensorShape({static_cast(precondition.size())})) - .to_device() - .slice({0}, {psi.get_current_nbas()}); + auto psi_tensor = ct::TensorMap(psi.get_pointer(), + ct::DataTypeToEnum::value, + ct::DeviceTypeToEnum::value, + ct::TensorShape({psi.get_nbands(), psi.get_nbasis()})); + auto eigen_tensor = ct::TensorMap(eigenvalue, + ct::DataTypeToEnum::value, + ct::DeviceTypeToEnum::value, + ct::TensorShape({psi.get_nbands()})); + auto prec_tensor = ct::TensorMap(precondition.data(), + ct::DataTypeToEnum::value, + ct::DeviceTypeToEnum::value, + ct::TensorShape({static_cast(precondition.size())})) + .to_device() + .slice({0}, {psi.get_current_nbas()}); cg.diag(hpsi_func, spsi_func, psi_tensor, eigen_tensor, prec_tensor); // TODO: Double check tensormap's potential problem - ct::TensorMap(psi.get_pointer(), - psi_tensor, - {psi.get_nbands(), psi.get_nbasis()}) - .sync(psi_tensor); - } else if (this->method == "dav_subspace") { + ct::TensorMap(psi.get_pointer(), psi_tensor, {psi.get_nbands(), psi.get_nbasis()}).sync(psi_tensor); + } + else if (this->method == "dav_subspace") + { #ifdef __MPI - const diag_comm_info comm_info - = {POOL_WORLD, GlobalV::RANK_IN_POOL, GlobalV::NPROC_IN_POOL}; + const diag_comm_info comm_info = {POOL_WORLD, GlobalV::RANK_IN_POOL, GlobalV::NPROC_IN_POOL}; #else - const diag_comm_info comm_info - = {GlobalV::RANK_IN_POOL, GlobalV::NPROC_IN_POOL}; + const diag_comm_info comm_info = {GlobalV::RANK_IN_POOL, GlobalV::NPROC_IN_POOL}; #endif - Diago_DavSubspace dav_subspace( - this->precondition, - psi.get_nbands(), - psi.get_k_first() ? psi.get_current_nbas() - : psi.get_nk() * psi.get_nbasis(), - GlobalV::PW_DIAG_NDIM, - DiagoIterAssist::PW_DIAG_THR, - DiagoIterAssist::PW_DIAG_NMAX, - DiagoIterAssist::need_subspace, - comm_info); + Diago_DavSubspace dav_subspace(this->precondition, + psi.get_nbands(), + psi.get_k_first() ? psi.get_current_nbas() + : psi.get_nk() * psi.get_nbasis(), + GlobalV::PW_DIAG_NDIM, + DiagoIterAssist::PW_DIAG_THR, + DiagoIterAssist::PW_DIAG_NMAX, + DiagoIterAssist::need_subspace, + comm_info); bool scf; - if (GlobalV::CALCULATION == "nscf") { + if (GlobalV::CALCULATION == "nscf") + { scf = false; - } else { + } + else + { scf = true; } @@ -482,11 +480,7 @@ void HSolverPW::hamiltSolvePsiK(hamilt::Hamilt* hm, ModuleBase::timer::tick("DavSubspace", "hpsi_func"); // Convert "pointer data stucture" to a psi::Psi object - auto psi_iter_wrapper = psi::Psi(psi_in, - 1, - nband_in, - nbasis_in, - ngk_pointer); + auto psi_iter_wrapper = psi::Psi(psi_in, 1, nband_in, nbasis_in, ngk_pointer); psi::Range bands_range(true, 0, band_index1, band_index2); @@ -503,16 +497,8 @@ void HSolverPW::hamiltSolvePsiK(hamilt::Hamilt* hm, const int nband_in, const int nbasis_max_in) { // Convert "pointer data stucture" to a psi::Psi object - auto psi_in_wrapper = psi::Psi(psi_in, - 1, - nband_in, - nbasis_max_in, - ngk_pointer); - auto psi_out_wrapper = psi::Psi(psi_out, - 1, - nband_in, - nbasis_max_in, - ngk_pointer); + auto psi_in_wrapper = psi::Psi(psi_in, 1, nband_in, nbasis_max_in, ngk_pointer); + auto psi_out_wrapper = psi::Psi(psi_out, 1, nband_in, nbasis_max_in, ngk_pointer); DiagoIterAssist::diagH_subspace(hm, psi_in_wrapper, @@ -521,24 +507,21 @@ void HSolverPW::hamiltSolvePsiK(hamilt::Hamilt* hm, nband_in); }; - DiagoIterAssist::avg_iter - += static_cast(dav_subspace.diag(hpsi_func, - psi.get_pointer(), - psi.get_nbasis(), - eigenvalue, - is_occupied, - scf)); - } else if (this->method == "bpcg") { + DiagoIterAssist::avg_iter += static_cast( + dav_subspace.diag(hpsi_func, psi.get_pointer(), psi.get_nbasis(), eigenvalue, is_occupied, scf)); + } + else if (this->method == "bpcg") + { DiagoBPCG bpcg(precondition.data()); bpcg.init_iter(psi); bpcg.diag(hm, psi, eigenvalue); - } else if (this->method == "dav") { + } + else if (this->method == "dav") + { #ifdef __MPI - const diag_comm_info comm_info - = {POOL_WORLD, GlobalV::RANK_IN_POOL, GlobalV::NPROC_IN_POOL}; + const diag_comm_info comm_info = {POOL_WORLD, GlobalV::RANK_IN_POOL, GlobalV::NPROC_IN_POOL}; #else - const diag_comm_info comm_info - = {GlobalV::RANK_IN_POOL, GlobalV::NPROC_IN_POOL}; + const diag_comm_info comm_info = {GlobalV::RANK_IN_POOL, GlobalV::NPROC_IN_POOL}; #endif // Allow 5 tries at most. If ntry > ntry_max = 5, exit diag loop. const int ntry_max = 5; @@ -553,29 +536,16 @@ void HSolverPW::hamiltSolvePsiK(hamilt::Hamilt* hm, const int nband = psi.get_nbands(); const int ldPsi = psi.get_nbasis(); - DiagoDavid david(precondition.data(), - GlobalV::PW_DIAG_NDIM, - GlobalV::use_paw, - comm_info); - DiagoIterAssist::avg_iter - += static_cast(david.diag(hm, - dim, - nband, - ldPsi, - psi, - eigenvalue, - david_diag_thr, - david_maxiter, - ntry_max, - notconv_max)); + DiagoDavid david(precondition.data(), GlobalV::PW_DIAG_NDIM, GlobalV::use_paw, comm_info); + DiagoIterAssist::avg_iter += static_cast( + david.diag(hm, dim, nband, ldPsi, psi, eigenvalue, david_diag_thr, david_maxiter, ntry_max, notconv_max)); } return; } template -void HSolverPW::update_precondition(std::vector& h_diag, - const int ik, - const int npw) { +void HSolverPW::update_precondition(std::vector& h_diag, const int ik, const int npw) +{ h_diag.assign(h_diag.size(), 1.0); int precondition_type = 2; const auto tpiba2 = static_cast(this->wfc_basis->tpiba2); @@ -585,46 +555,56 @@ void HSolverPW::update_precondition(std::vector& h_diag, // h_diag is the precondition matrix // h_diag(1:npw) = MAX( 1.0, g2kin(1:npw) ); //=========================================== - if (precondition_type == 1) { - for (int ig = 0; ig < npw; ig++) { - Real g2kin - = static_cast(this->wfc_basis->getgk2(ik, ig)) * tpiba2; + if (precondition_type == 1) + { + for (int ig = 0; ig < npw; ig++) + { + Real g2kin = static_cast(this->wfc_basis->getgk2(ik, ig)) * tpiba2; h_diag[ig] = std::max(static_cast(1.0), g2kin); } - } else if (precondition_type == 2) { - for (int ig = 0; ig < npw; ig++) { - Real g2kin - = static_cast(this->wfc_basis->getgk2(ik, ig)) * tpiba2; - - if (this->method == "dav_subspace") { + } + else if (precondition_type == 2) + { + for (int ig = 0; ig < npw; ig++) + { + Real g2kin = static_cast(this->wfc_basis->getgk2(ik, ig)) * tpiba2; + + if (this->method == "dav_subspace") + { h_diag[ig] = g2kin; - } else { + } + else + { h_diag[ig] = 1 + g2kin + sqrt(1 + (g2kin - 1) * (g2kin - 1)); } } } - if (GlobalV::NSPIN == 4) { + if (GlobalV::NSPIN == 4) + { const int size = h_diag.size(); - for (int ig = 0; ig < npw; ig++) { + for (int ig = 0; ig < npw; ig++) + { h_diag[ig + size / 2] = h_diag[ig]; } } } template -typename HSolverPW::Real HSolverPW::cal_hsolerror() { +typename HSolverPW::Real HSolverPW::cal_hsolerror() +{ return this->diag_ethr * static_cast(std::max(1.0, GlobalV::nelec)); } template -typename HSolverPW::Real - HSolverPW::set_diagethr(const int istep, - const int iter, - const Real drho) { +typename HSolverPW::Real HSolverPW::set_diagethr(const int istep, const int iter, const Real drho) +{ // It is too complex now and should be modified. - if (iter == 1) { - if (std::abs(this->diag_ethr - 1.0e-2) < 1.0e-6) { - if (GlobalV::init_chg == "file") { + if (iter == 1) + { + if (std::abs(this->diag_ethr - 1.0e-2) < 1.0e-6) + { + if (GlobalV::init_chg == "file") + { //====================================================== // if you think that the starting potential is good // do not spoil it with a louly first diagonalization: @@ -632,7 +612,9 @@ typename HSolverPW::Real // ()diago_the_init //====================================================== this->diag_ethr = 1.0e-5; - } else { + } + else + { //======================================================= // starting atomic potential is probably far from scf // don't waste iterations in the first diagonalization @@ -641,45 +623,46 @@ typename HSolverPW::Real } } // if (GlobalV::FINAL_SCF) this->diag_ethr = 1.0e-2; - if (GlobalV::CALCULATION == "md" || GlobalV::CALCULATION == "relax" - || GlobalV::CALCULATION == "cell-relax") { - this->diag_ethr = std::max(this->diag_ethr, - static_cast(GlobalV::PW_DIAG_THR)); + if (GlobalV::CALCULATION == "md" || GlobalV::CALCULATION == "relax" || GlobalV::CALCULATION == "cell-relax") + { + this->diag_ethr = std::max(this->diag_ethr, static_cast(GlobalV::PW_DIAG_THR)); } - } else { - if (iter == 2) { + } + else + { + if (iter == 2) + { this->diag_ethr = 1.e-2; } - this->diag_ethr - = std::min(this->diag_ethr, - static_cast(0.1) * drho - / std::max(static_cast(1.0), - static_cast(GlobalV::nelec))); + this->diag_ethr = std::min(this->diag_ethr, + static_cast(0.1) * drho + / std::max(static_cast(1.0), static_cast(GlobalV::nelec))); } // It is essential for single precision implementation to keep the diag_ethr // value less or equal to the single-precision limit of convergence(0.5e-4). // modified by denghuilu at 2023-05-15 - if (GlobalV::precision_flag == "single") { + if (GlobalV::precision_flag == "single") + { this->diag_ethr = std::max(this->diag_ethr, static_cast(0.5e-4)); } return this->diag_ethr; } template -typename HSolverPW::Real - HSolverPW::reset_diagethr(std::ofstream& ofs_running, - const Real hsover_error, - const Real drho) { +typename HSolverPW::Real HSolverPW::reset_diagethr(std::ofstream& ofs_running, + const Real hsover_error, + const Real drho) +{ ofs_running << " Notice: Threshold on eigenvalues was too large.\n"; ModuleBase::WARNING("scf", "Threshold on eigenvalues was too large."); - ofs_running << " hsover_error=" << hsover_error << " > DRHO=" << drho - << std::endl; + ofs_running << " hsover_error=" << hsover_error << " > DRHO=" << drho << std::endl; ofs_running << " Origin diag_ethr = " << this->diag_ethr << std::endl; this->diag_ethr = 0.1 * drho / GlobalV::nelec; // It is essential for single precision implementation to keep the diag_ethr // value less or equal to the single-precision limit of convergence(0.5e-4). // modified by denghuilu at 2023-05-15 - if (GlobalV::precision_flag == "single") { + if (GlobalV::precision_flag == "single") + { this->diag_ethr = std::max(this->diag_ethr, static_cast(0.5e-4)); } ofs_running << " New diag_ethr = " << this->diag_ethr << std::endl; diff --git a/source/module_hsolver/hsolver_pw.h b/source/module_hsolver/hsolver_pw.h index d4e2cc4eb2..8604ce6dd9 100644 --- a/source/module_hsolver/hsolver_pw.h +++ b/source/module_hsolver/hsolver_pw.h @@ -6,10 +6,12 @@ #include "module_basis/module_pw/pw_basis_k.h" #include "module_hamilt_pw/hamilt_pwdft/wavefunc.h" -namespace hsolver { +namespace hsolver +{ template -class HSolverPW : public HSolver { +class HSolverPW : public HSolver +{ private: bool is_first_scf = true; @@ -31,13 +33,6 @@ class HSolverPW : public HSolver { HSolverPW(ModulePW::PW_Basis_K* wfc_basis_in, wavefunc* pwf_in); - /*void init( - const Basis* pbas - //const Input &in, - ) override; - void update(//Input &in - ) override;*/ - /// @brief solve function for pw /// @param pHamilt interface to hamilt /// @param psi reference to psi @@ -48,33 +43,26 @@ class HSolverPW : public HSolver { psi::Psi& psi, elecstate::ElecState* pes, const std::string method_in, - const bool skip_charge) override; + const bool skip_charge) override; virtual Real cal_hsolerror() override; - virtual Real - set_diagethr(const int istep, const int iter, const Real drho) override; - virtual Real reset_diagethr(std::ofstream& ofs_running, - const Real hsover_error, - const Real drho) override; + + virtual Real set_diagethr(const int istep, const int iter, const Real drho) override; + + virtual Real reset_diagethr(std::ofstream& ofs_running, const Real hsover_error, const Real drho) override; protected: // void initDiagh(const psi::Psi& psi_in); void endDiagh(); - void hamiltSolvePsiK(hamilt::Hamilt* hm, - psi::Psi& psi, - Real* eigenvalue); + void hamiltSolvePsiK(hamilt::Hamilt* hm, psi::Psi& psi, Real* eigenvalue); - void updatePsiK(hamilt::Hamilt* pHamilt, - psi::Psi& psi, - const int ik); + void updatePsiK(hamilt::Hamilt* pHamilt, psi::Psi& psi, const int ik); ModulePW::PW_Basis_K* wfc_basis = nullptr; wavefunc* pwf = nullptr; // calculate the precondition array for diagonalization in PW base - void update_precondition(std::vector& h_diag, - const int ik, - const int npw); + void update_precondition(std::vector& h_diag, const int ik, const int npw); std::vector precondition; std::vector eigenvalues; @@ -85,15 +73,6 @@ class HSolverPW : public HSolver { hamilt::Hamilt* hamilt_ = nullptr; Device* ctx = {}; - using resmem_var_op - = base_device::memory::resize_memory_op; - using delmem_var_op - = base_device::memory::delete_memory_op; - using castmem_2d_2h_op - = base_device::memory::cast_memory_op; }; template