From b083759b86afd16dd05df0c8a45a70f1c736234e Mon Sep 17 00:00:00 2001 From: Haozhi Han Date: Tue, 24 Sep 2024 12:47:02 +0800 Subject: [PATCH] Refactor: refactor hsolver-lcao func (#5148) * refactor hsolver-lcao func * remove useless value in hsolver_lcaopw --- source/module_esolver/esolver_ks_lcao.cpp | 210 ++++++++++-------- .../module_esolver/esolver_ks_lcao_tddft.cpp | 82 +++---- source/module_esolver/lcao_nscf.cpp | 2 +- .../module_deltaspin/cal_mw_from_lambda.cpp | 2 +- source/module_hsolver/hsolver_lcao.cpp | 173 +++++++-------- source/module_hsolver/hsolver_lcao.h | 12 +- source/module_hsolver/hsolver_lcaopw.cpp | 5 - source/module_hsolver/hsolver_lcaopw.h | 6 +- source/module_hsolver/hsolver_pw.h | 3 - source/module_hsolver/hsolver_pw_sdft.cpp | 2 +- source/module_hsolver/hsolver_pw_sdft.h | 20 +- source/module_lr/hsolver_lrtd.cpp | 5 - 12 files changed, 255 insertions(+), 267 deletions(-) diff --git a/source/module_esolver/esolver_ks_lcao.cpp b/source/module_esolver/esolver_ks_lcao.cpp index 735ef0dda1..eb8bb63c2c 100644 --- a/source/module_esolver/esolver_ks_lcao.cpp +++ b/source/module_esolver/esolver_ks_lcao.cpp @@ -19,8 +19,6 @@ #include "module_parameter/parameter.h" //--------------temporary---------------------------- -#include - #include "module_base/global_function.h" #include "module_cell/module_neighbor/sltk_grid_driver.h" #include "module_elecstate/module_charge/symmetry_rho.h" @@ -30,9 +28,11 @@ #include "module_hamilt_lcao/module_dftu/dftu.h" #include "module_hamilt_pw/hamilt_pwdft/global.h" #include "module_io/print_info.h" + +#include #ifdef __EXX -#include "module_ri/RPA_LRI.h" #include "module_io/restart_exx_csr.h" +#include "module_ri/RPA_LRI.h" #endif #ifdef __DEEPKS @@ -69,7 +69,7 @@ ESolver_KS_LCAO::ESolver_KS_LCAO() this->basisname = "LCAO"; #ifdef __EXX // 1. currently this initialization must be put in constructor rather than `before_all_runners()` - // because the latter is not reused by ESolver_LCAO_TDDFT, + // because the latter is not reused by ESolver_LCAO_TDDFT, // which cause the failure of the subsequent procedure reused by ESolver_LCAO_TDDFT // 2. always construct but only initialize when if(cal_exx) is true // because some members like two_level_step are used outside if(cal_exx) @@ -133,8 +133,7 @@ void ESolver_KS_LCAO::before_all_runners(const Input_para& inp, UnitCell } // 1.3) Setup k-points according to symmetry. - this->kv - .set(ucell.symm, PARAM.inp.kpoint_file, PARAM.inp.nspin, ucell.G, ucell.latvec, GlobalV::ofs_running); + this->kv.set(ucell.symm, PARAM.inp.kpoint_file, PARAM.inp.nspin, ucell.G, ucell.latvec, GlobalV::ofs_running); ModuleBase::GlobalFunc::DONE(GlobalV::ofs_running, "INIT K-POINTS"); Print_Info::setup_parameters(ucell, this->kv); @@ -151,26 +150,25 @@ void ESolver_KS_LCAO::before_all_runners(const Input_para& inp, UnitCell if (this->pelec == nullptr) { // TK stands for double and complex? - this->pelec = new elecstate::ElecStateLCAO( - &(this->chr), // use which parameter? - &(this->kv), - this->kv.get_nks(), - &(this->GG), // mohan add 2024-04-01 - &(this->GK), // mohan add 2024-04-01 - this->pw_rho, - this->pw_big); + this->pelec = new elecstate::ElecStateLCAO(&(this->chr), // use which parameter? + &(this->kv), + this->kv.get_nks(), + &(this->GG), // mohan add 2024-04-01 + &(this->GK), // mohan add 2024-04-01 + this->pw_rho, + this->pw_big); } // 3) init LCAO basis // reading the localized orbitals/projectors // construct the interpolation tables. - LCAO_domain::init_basis_lcao(this->pv, - inp.onsite_radius, - inp.lcao_ecut, - inp.lcao_dk, - inp.lcao_dr, - inp.lcao_rmax, - ucell, + LCAO_domain::init_basis_lcao(this->pv, + inp.onsite_radius, + inp.lcao_ecut, + inp.lcao_dk, + inp.lcao_dr, + inp.lcao_rmax, + ucell, two_center_bundle_, orb_); //------------------init Basis_lcao---------------------- @@ -178,8 +176,7 @@ void ESolver_KS_LCAO::before_all_runners(const Input_para& inp, UnitCell // 5) initialize density matrix // DensityMatrix is allocated here, DMK is also initialized here // DMR is not initialized here, it will be constructed in each before_scf - dynamic_cast*>(this->pelec) - ->init_DM(&this->kv, &(this->pv), PARAM.inp.nspin); + dynamic_cast*>(this->pelec)->init_DM(&this->kv, &(this->pv), PARAM.inp.nspin); // this function should be removed outside of the function if (PARAM.inp.calculation == "get_S") @@ -196,22 +193,28 @@ void ESolver_KS_LCAO::before_all_runners(const Input_para& inp, UnitCell #ifdef __EXX // 7) initialize exx // PLEASE simplify the Exx_Global interface - if (PARAM.inp.calculation == "scf" || PARAM.inp.calculation == "relax" - || PARAM.inp.calculation == "cell-relax" + if (PARAM.inp.calculation == "scf" || PARAM.inp.calculation == "relax" || PARAM.inp.calculation == "cell-relax" || PARAM.inp.calculation == "md") { if (GlobalC::exx_info.info_global.cal_exx) { XC_Functional::set_xc_first_loop(ucell); // initialize 2-center radial tables for EXX-LRI - if (GlobalC::exx_info.info_ri.real_number) { this->exx_lri_double->init(MPI_COMM_WORLD, this->kv, orb_); } - else { this->exx_lri_complex->init(MPI_COMM_WORLD, this->kv, orb_); } + if (GlobalC::exx_info.info_ri.real_number) + { + this->exx_lri_double->init(MPI_COMM_WORLD, this->kv, orb_); + } + else + { + this->exx_lri_complex->init(MPI_COMM_WORLD, this->kv, orb_); + } } } #endif // 8) initialize DFT+U - if (PARAM.inp.dft_plus_u) { + if (PARAM.inp.dft_plus_u) + { GlobalC::dftu.init(ucell, &this->pv, this->kv.get_nks(), orb_); } @@ -256,19 +259,19 @@ void ESolver_KS_LCAO::before_all_runners(const Input_para& inp, UnitCell { if (this->kv.get_nks() % GlobalV::KPAR_LCAO != 0) { - ModuleBase::WARNING("ESolver_KS_LCAO::before_all_runners", - "nks is not divisible by kpar."); + ModuleBase::WARNING("ESolver_KS_LCAO::before_all_runners", "nks is not divisible by kpar."); std::cout << "\n%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%" - "%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%" - "%%%%%%%%%%%%%%%%%%%%%%%%%%" << std::endl; - std::cout << " Warning: nks (" << this->kv.get_nks() << ") is not divisible by kpar (" - << GlobalV::KPAR_LCAO << ")." << std::endl; + "%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%" + "%%%%%%%%%%%%%%%%%%%%%%%%%%" + << std::endl; + std::cout << " Warning: nks (" << this->kv.get_nks() << ") is not divisible by kpar (" << GlobalV::KPAR_LCAO + << ")." << std::endl; std::cout << " This may lead to poor load balance. It is strongly suggested to" << std::endl; std::cout << " set nks to be divisible by kpar, but if this is really what" << std::endl; std::cout << " you want, please ignore this warning." << std::endl; std::cout << "%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%" - "%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%" - "%%%%%%%%%%%%\n"; + "%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%" + "%%%%%%%%%%%%\n"; } } @@ -450,25 +453,26 @@ void ESolver_KS_LCAO::after_all_runners() if (PARAM.inp.out_mat_xc) { ModuleIO::write_Vxc(PARAM.inp.nspin, - GlobalV::NLOCAL, - GlobalV::DRANK, - &this->pv, - *this->psi, - GlobalC::ucell, - this->sf, - *this->pw_rho, - *this->pw_rhod, - GlobalC::ppcell.vloc, - *this->pelec->charge, - this->GG, - this->GK, - this->kv, - orb_.cutoffs(), - this->pelec->wg, - GlobalC::GridD + GlobalV::NLOCAL, + GlobalV::DRANK, + &this->pv, + *this->psi, + GlobalC::ucell, + this->sf, + *this->pw_rho, + *this->pw_rhod, + GlobalC::ppcell.vloc, + *this->pelec->charge, + this->GG, + this->GK, + this->kv, + orb_.cutoffs(), + this->pelec->wg, + GlobalC::GridD #ifdef __EXX - , this->exx_lri_double ? &this->exx_lri_double->Hexxs : nullptr - , this->exx_lri_complex ? &this->exx_lri_complex->Hexxs : nullptr + , + this->exx_lri_double ? &this->exx_lri_double->Hexxs : nullptr, + this->exx_lri_complex ? &this->exx_lri_complex->Hexxs : nullptr #endif ); } @@ -476,26 +480,27 @@ void ESolver_KS_LCAO::after_all_runners() if (PARAM.inp.out_eband_terms) { ModuleIO::write_eband_terms(PARAM.inp.nspin, - GlobalV::NLOCAL, - GlobalV::DRANK, - &this->pv, - *this->psi, - GlobalC::ucell, - this->sf, - *this->pw_rho, - *this->pw_rhod, - GlobalC::ppcell.vloc, - *this->pelec->charge, - this->GG, - this->GK, - this->kv, - this->pelec->wg, - GlobalC::GridD, - orb_.cutoffs(), - this->two_center_bundle_ + GlobalV::NLOCAL, + GlobalV::DRANK, + &this->pv, + *this->psi, + GlobalC::ucell, + this->sf, + *this->pw_rho, + *this->pw_rhod, + GlobalC::ppcell.vloc, + *this->pelec->charge, + this->GG, + this->GK, + this->kv, + this->pelec->wg, + GlobalC::GridD, + orb_.cutoffs(), + this->two_center_bundle_ #ifdef __EXX - , this->exx_lri_double ? &this->exx_lri_double->Hexxs : nullptr - , this->exx_lri_complex ? &this->exx_lri_complex->Hexxs : nullptr + , + this->exx_lri_double ? &this->exx_lri_double->Hexxs : nullptr, + this->exx_lri_complex ? &this->exx_lri_complex->Hexxs : nullptr #endif ); } @@ -503,7 +508,6 @@ void ESolver_KS_LCAO::after_all_runners() ModuleBase::timer::tick("ESolver_KS_LCAO", "after_all_runners"); } - //------------------------------------------------------------------------------ //! the 10th function of ESolver_KS_LCAO: iter_init //! mohan add 2024-05-11 @@ -619,11 +623,15 @@ void ESolver_KS_LCAO::iter_init(const int istep, const int iter) // calculate exact-exchange if (GlobalC::exx_info.info_ri.real_number) { - this->exd->exx_eachiterinit(*dynamic_cast*>(this->pelec)->get_DM(), this->kv, iter); + this->exd->exx_eachiterinit(*dynamic_cast*>(this->pelec)->get_DM(), + this->kv, + iter); } else { - this->exc->exx_eachiterinit(*dynamic_cast*>(this->pelec)->get_DM(), this->kv, iter); + this->exc->exx_eachiterinit(*dynamic_cast*>(this->pelec)->get_DM(), + this->kv, + iter); } #endif @@ -706,8 +714,7 @@ void ESolver_KS_LCAO::hamilt2density(int istep, int iter, double ethr) this->pelec->f_en.demet = 0.0; hsolver::HSolverLCAO hsolver_lcao_obj(&(this->pv), PARAM.inp.ks_solver); - hsolver_lcao_obj.solve(this->p_hamilt, this->psi[0], this->pelec, PARAM.inp.ks_solver, false); - + hsolver_lcao_obj.solve(this->p_hamilt, this->psi[0], this->pelec, false); if (PARAM.inp.out_bandgap) { @@ -818,7 +825,7 @@ void ESolver_KS_LCAO::update_pot(const int istep, const int iter) } for (int ik = 0; ik < this->kv.get_nks(); ++ik) { - if (PARAM.inp.out_mat_hs[0]|| PARAM.inp.deepks_v_delta) + if (PARAM.inp.out_mat_hs[0] || PARAM.inp.deepks_v_delta) { this->p_hamilt->updateHk(ik); } @@ -858,7 +865,7 @@ void ESolver_KS_LCAO::update_pot(const int istep, const int iter) GlobalV::DRANK); } #ifdef __DEEPKS - if(PARAM.inp.deepks_out_labels && PARAM.inp.deepks_v_delta) + if (PARAM.inp.deepks_out_labels && PARAM.inp.deepks_v_delta) { DeePKS_domain::save_h_mat(h_mat.p, this->pv.nloc); } @@ -880,7 +887,6 @@ void ESolver_KS_LCAO::update_pot(const int istep, const int iter) istep); } - if (!this->conv_elec) { if (PARAM.inp.nspin == 4) @@ -945,7 +951,7 @@ void ESolver_KS_LCAO::iter_finish(int& iter) Hexxk_save.set_zero_hk(); hamilt::OperatorEXX> opexx_save(&Hexxk_save, - nullptr, + nullptr, this->kv); opexx_save.contributeHk(ik); @@ -975,8 +981,12 @@ void ESolver_KS_LCAO::iter_finish(int& iter) { // Kerker mixing does not work for the density matrix. // In the separate loop case, it can still work in the subsequent inner loops where Hexx(DM) is fixed. - // In the non-separate loop case where Hexx(DM) is updated in every iteration of the 2nd loop, it should be closed. - if (!GlobalC::exx_info.info_global.separate_loop) { this->p_chgmix->close_kerker_gg0(); } + // In the non-separate loop case where Hexx(DM) is updated in every iteration of the 2nd loop, it should be + // closed. + if (!GlobalC::exx_info.info_global.separate_loop) + { + this->p_chgmix->close_kerker_gg0(); + } if (GlobalC::exx_info.info_ri.real_number) { this->conv_elec = this->exd->exx_after_converge( @@ -1126,12 +1136,16 @@ void ESolver_KS_LCAO::after_scf(const int istep) #ifdef __EXX // 4) write Hexx matrix for NSCF (see `out_chg` in docs/advanced/input_files/input-main.md) - if (GlobalC::exx_info.info_global.cal_exx && PARAM.inp.out_chg[0] && istep % PARAM.inp.out_interval == 0) // Peize Lin add if 2022.11.14 + if (GlobalC::exx_info.info_global.cal_exx && PARAM.inp.out_chg[0] + && istep % PARAM.inp.out_interval == 0) // Peize Lin add if 2022.11.14 { const std::string file_name_exx = PARAM.globalv.global_out_dir + "HexxR" + std::to_string(GlobalV::MY_RANK); - if (GlobalC::exx_info.info_ri.real_number) { + if (GlobalC::exx_info.info_ri.real_number) + { ModuleIO::write_Hexxs_csr(file_name_exx, GlobalC::ucell, this->exd->get_Hexxs()); - } else { + } + else + { ModuleIO::write_Hexxs_csr(file_name_exx, GlobalC::ucell, this->exc->get_Hexxs()); } } @@ -1222,9 +1236,9 @@ void ESolver_KS_LCAO::after_scf(const int istep) // 15) write spin constrian MW? // spin constrain calculations, added by Tianqi Zhao. - if (PARAM.inp.sc_mag_switch) { - SpinConstrain& sc - = SpinConstrain::getScInstance(); + if (PARAM.inp.sc_mag_switch) + { + SpinConstrain& sc = SpinConstrain::getScInstance(); sc.cal_MW(istep, true); sc.print_Mag_Force(); } @@ -1254,14 +1268,14 @@ void ESolver_KS_LCAO::after_scf(const int istep) { hamilt::HS_Matrix_K hsk(&pv, true); hamilt::HContainer hR(&pv); - hamilt::Operator* ekinetic = - new hamilt::EkineticNew>(&hsk, - this->kv.kvec_d, - &hR, - &GlobalC::ucell, - orb_.cutoffs(), - &GlobalC::GridD, - two_center_bundle_.kinetic_orb.get()); + hamilt::Operator* ekinetic + = new hamilt::EkineticNew>(&hsk, + this->kv.kvec_d, + &hR, + &GlobalC::ucell, + orb_.cutoffs(), + &GlobalC::GridD, + two_center_bundle_.kinetic_orb.get()); const int nspin_k = (PARAM.inp.nspin == 2 ? 2 : 1); for (int ik = 0; ik < this->kv.get_nks() / nspin_k; ++ik) diff --git a/source/module_esolver/esolver_ks_lcao_tddft.cpp b/source/module_esolver/esolver_ks_lcao_tddft.cpp index 17d36bc46d..c1bf644d52 100644 --- a/source/module_esolver/esolver_ks_lcao_tddft.cpp +++ b/source/module_esolver/esolver_ks_lcao_tddft.cpp @@ -11,8 +11,8 @@ //--------------temporary---------------------------- #include "module_base/blas_connector.h" #include "module_base/global_function.h" -#include "module_base/scalapack_connector.h" #include "module_base/lapack_connector.h" +#include "module_base/scalapack_connector.h" #include "module_elecstate/module_charge/symmetry_rho.h" #include "module_elecstate/occupy.h" #include "module_hamilt_lcao/hamilt_lcaodft/LCAO_domain.h" // need divide_HS_in_frag @@ -73,34 +73,34 @@ void ESolver_KS_LCAO_TDDFT::before_all_runners(const Input_para& inp, UnitCell& GlobalC::ppcell.init_vloc(GlobalC::ppcell.vloc, pw_rho); // 3) initialize the electronic states for TDDFT - if (this->pelec == nullptr) { - this->pelec = new elecstate::ElecStateLCAO_TDDFT( - &this->chr, - &kv, - kv.get_nks(), - &this->GK, // mohan add 2024-04-01 - this->pw_rho, - pw_big); + if (this->pelec == nullptr) + { + this->pelec = new elecstate::ElecStateLCAO_TDDFT(&this->chr, + &kv, + kv.get_nks(), + &this->GK, // mohan add 2024-04-01 + this->pw_rho, + pw_big); } // 4) read the local orbitals and construct the interpolation tables. // initialize the pv - LCAO_domain::init_basis_lcao(this->pv, - inp.onsite_radius, - inp.lcao_ecut, - inp.lcao_dk, - inp.lcao_dr, - inp.lcao_rmax, - ucell, + LCAO_domain::init_basis_lcao(this->pv, + inp.onsite_radius, + inp.lcao_ecut, + inp.lcao_dk, + inp.lcao_dr, + inp.lcao_rmax, + ucell, two_center_bundle_, orb_); // 5) allocate H and S matrices according to computational resources LCAO_domain::divide_HS_in_frag(PARAM.globalv.gamma_only_local, this->pv, kv.get_nks(), orb_); - // 6) initialize Density Matrix - dynamic_cast>*>(this->pelec)->init_DM(&kv, &this->pv, PARAM.inp.nspin); + dynamic_cast>*>(this->pelec) + ->init_DM(&kv, &this->pv, PARAM.inp.nspin); // 8) initialize the charge density this->pelec->charge->allocate(PARAM.inp.nspin); @@ -169,7 +169,7 @@ void ESolver_KS_LCAO_TDDFT::hamilt2density(const int istep, const int iter, cons if (this->psi != nullptr) { hsolver::HSolverLCAO> hsolver_lcao_obj(&this->pv, PARAM.inp.ks_solver); - hsolver_lcao_obj.solve(this->p_hamilt, this->psi[0], this->pelec_td, PARAM.inp.ks_solver, false); + hsolver_lcao_obj.solve(this->p_hamilt, this->psi[0], this->pelec_td, false); } } // else @@ -255,28 +255,28 @@ void ESolver_KS_LCAO_TDDFT::update_pot(const int istep, const int iter) if (PARAM.inp.out_mat_hs[0]) { ModuleIO::save_mat(istep, - h_mat.p, - GlobalV::NLOCAL, - bit, - PARAM.inp.out_mat_hs[1], - 1, - PARAM.inp.out_app_flag, - "H", - "data-" + std::to_string(ik), - this->pv, - GlobalV::DRANK); + h_mat.p, + GlobalV::NLOCAL, + bit, + PARAM.inp.out_mat_hs[1], + 1, + PARAM.inp.out_app_flag, + "H", + "data-" + std::to_string(ik), + this->pv, + GlobalV::DRANK); ModuleIO::save_mat(istep, - s_mat.p, - GlobalV::NLOCAL, - bit, - PARAM.inp.out_mat_hs[1], - 1, - PARAM.inp.out_app_flag, - "S", - "data-" + std::to_string(ik), - this->pv, - GlobalV::DRANK); + s_mat.p, + GlobalV::NLOCAL, + bit, + PARAM.inp.out_mat_hs[1], + 1, + PARAM.inp.out_app_flag, + "S", + "data-" + std::to_string(ik), + this->pv, + GlobalV::DRANK); } } } @@ -416,8 +416,8 @@ void ESolver_KS_LCAO_TDDFT::after_scf(const int istep) } if (TD_Velocity::out_current == true) { - elecstate::DensityMatrix, double>* tmp_DM = - dynamic_cast>*>(this->pelec)->get_DM(); + elecstate::DensityMatrix, double>* tmp_DM + = dynamic_cast>*>(this->pelec)->get_DM(); ModuleIO::write_current(istep, this->psi, diff --git a/source/module_esolver/lcao_nscf.cpp b/source/module_esolver/lcao_nscf.cpp index 7904c0d6ea..56a3414696 100644 --- a/source/module_esolver/lcao_nscf.cpp +++ b/source/module_esolver/lcao_nscf.cpp @@ -53,7 +53,7 @@ void ESolver_KS_LCAO::nscf() { // istep becomes istep-1, this should be fixed in future int istep = 0; hsolver::HSolverLCAO hsolver_lcao_obj(&(this->pv), PARAM.inp.ks_solver); - hsolver_lcao_obj.solve(this->p_hamilt, this->psi[0], this->pelec, PARAM.inp.ks_solver, true); + hsolver_lcao_obj.solve(this->p_hamilt, this->psi[0], this->pelec, true); time_t time_finish = std::time(nullptr); ModuleBase::GlobalFunc::OUT_TIME("cal_bands", time_start, time_finish); diff --git a/source/module_hamilt_lcao/module_deltaspin/cal_mw_from_lambda.cpp b/source/module_hamilt_lcao/module_deltaspin/cal_mw_from_lambda.cpp index 4c5d5ab535..bda9ce7ebd 100644 --- a/source/module_hamilt_lcao/module_deltaspin/cal_mw_from_lambda.cpp +++ b/source/module_hamilt_lcao/module_deltaspin/cal_mw_from_lambda.cpp @@ -13,7 +13,7 @@ void SpinConstrain, base_device::DEVICE_CPU>::cal_mw_from_l // diagonalization without update charge hsolver::HSolverLCAO> hsolver_lcao_obj(this->ParaV, this->KS_SOLVER); - hsolver_lcao_obj.solve(this->p_hamilt, this->psi[0], this->pelec, this->KS_SOLVER, true); + hsolver_lcao_obj.solve(this->p_hamilt, this->psi[0], this->pelec, true); elecstate::ElecStateLCAO>* pelec_lcao = dynamic_cast>*>(this->pelec); diff --git a/source/module_hsolver/hsolver_lcao.cpp b/source/module_hsolver/hsolver_lcao.cpp index 833a26fe18..e312e3be96 100644 --- a/source/module_hsolver/hsolver_lcao.cpp +++ b/source/module_hsolver/hsolver_lcao.cpp @@ -1,64 +1,64 @@ #include "hsolver_lcao.h" -#include "module_parameter/parameter.h" -#include "diago_cg.h" - #ifdef __MPI #include "diago_scalapack.h" #else #include "diago_lapack.h" #endif -#include "module_base/timer.h" -#include "module_hsolver/diago_iter_assist.h" -#include "module_hsolver/kernels/math_kernel_op.h" -#include "module_io/write_HS.h" - -#include "module_base/global_variable.h" - -#include -#include -#include #ifdef __CUSOLVERMP #include "diago_cusolvermp.h" -#endif // __CUSOLVERMP +#endif + #ifdef __ELPA #include "diago_elpa.h" #include "diago_elpa_native.h" #endif + #ifdef __CUDA #include "diago_cusolver.h" #endif + #ifdef __PEXSI #include "diago_pexsi.h" #include "module_elecstate/elecstate_lcao.h" #endif +#include "diago_cg.h" +#include "module_base/global_variable.h" +#include "module_base/memory.h" #include "module_base/scalapack_connector.h" +#include "module_base/timer.h" +#include "module_hsolver/diago_iter_assist.h" +#include "module_hsolver/kernels/math_kernel_op.h" #include "module_hsolver/parallel_k2d.h" -#include "module_base/memory.h" +#include "module_io/write_HS.h" +#include "module_parameter/parameter.h" +#include +#include +#include #include -namespace hsolver { +namespace hsolver +{ template void HSolverLCAO::solve(hamilt::Hamilt* pHamilt, psi::Psi& psi, elecstate::ElecState* pes, - const std::string method_in, const bool skip_charge) { ModuleBase::TITLE("HSolverLCAO", "solve"); ModuleBase::timer::tick("HSolverLCAO", "solve"); - // select the method of diagonalization - this->method = method_in; #ifdef __PEXSI // other purification methods should follow this routine + // Zhang Xiaoyang : Please modify Pesxi usage later if (this->method == "pexsi") { DiagoPexsi pe(ParaV); - for (int ik = 0; ik < psi.get_nk(); ++ik) { + for (int ik = 0; ik < psi.get_nk(); ++ik) + { /// update H(k) for each k point pHamilt->updateHk(ik); psi.fix_k(ik); @@ -74,21 +74,21 @@ void HSolverLCAO::solve(hamilt::Hamilt* pHamilt, } #endif - // Zhang Xiaoyang : Please modify Pesxi usage later if (this->method == "cg_in_lcao") { this->precondition_lcao.resize(psi.get_nbasis()); using Real = typename GetTypeReal::type; // set precondition - for (size_t i = 0; i < precondition_lcao.size(); i++) { + for (size_t i = 0; i < precondition_lcao.size(); i++) + { precondition_lcao[i] = 1.0; } } #ifdef __MPI - if (GlobalV::KPAR_LCAO > 1 && - (this->method == "genelpa" || this->method == "elpa" || this->method == "scalapack_gvx")) + if (GlobalV::KPAR_LCAO > 1 + && (this->method == "genelpa" || this->method == "elpa" || this->method == "scalapack_gvx")) { this->parakSolve(pHamilt, psi, pes, GlobalV::KPAR_LCAO); } @@ -96,7 +96,8 @@ void HSolverLCAO::solve(hamilt::Hamilt* pHamilt, #endif { /// Loop over k points for solve Hamiltonian to charge density - for (int ik = 0; ik < psi.get_nk(); ++ik) { + for (int ik = 0; ik < psi.get_nk(); ++ik) + { /// update H(k) for each k point pHamilt->updateHk(ik); @@ -107,28 +108,18 @@ void HSolverLCAO::solve(hamilt::Hamilt* pHamilt, } } - if (this->method == "cg_in_lcao") { - this->is_first_scf = false; - } - - if (this->method != "genelpa" && this->method != "elpa" && this->method != "scalapack_gvx" && this->method != "lapack" - && this->method != "cusolver" && this->method != "cusolvermp" && this->method != "cg_in_lcao" - && this->method != "pexsi") + if (skip_charge) // used in nscf calculation { - //delete this->pdiagh; - //this->pdiagh = nullptr; + ModuleBase::timer::tick("HSolverLCAO", "solve"); } - - // used in nscf calculation - if (skip_charge) { + else // used in scf calculation + { + // calculate charge by psi + pes->psiToRho(psi); ModuleBase::timer::tick("HSolverLCAO", "solve"); - return; } - // calculate charge by psi - // called in scf calculation - pes->psiToRho(psi); - ModuleBase::timer::tick("HSolverLCAO", "solve"); + return; } template @@ -187,11 +178,11 @@ void HSolverLCAO::hamiltSolvePsiK(hamilt::Hamilt* hm, psi::Psi& using ct_Device = typename ct::PsiToContainer::type; auto subspace_func = [](const ct::Tensor& psi_in, ct::Tensor& psi_out) { - // psi_in should be a 2D tensor: - // psi_in.shape() = [nbands, nbasis] - const auto ndim = psi_in.shape().ndim(); - REQUIRES_OK(ndim == 2, "dims of psi_in should be less than or equal to 2"); - }; + // psi_in should be a 2D tensor: + // psi_in.shape() = [nbands, nbasis] + const auto ndim = psi_in.shape().ndim(); + REQUIRES_OK(ndim == 2, "dims of psi_in should be less than or equal to 2"); + }; DiagoCG cg(PARAM.inp.basis_type, PARAM.inp.calculation, @@ -268,17 +259,17 @@ void HSolverLCAO::hamiltSolvePsiK(hamilt::Hamilt* hm, psi::Psi& ModuleBase::timer::tick("DiagoCG_New", "spsi_func"); }; - if (this->is_first_scf) + // if (this->is_first_scf) + // { + for (size_t i = 0; i < psi.get_nbands(); i++) { - for (size_t i = 0; i < psi.get_nbands(); i++) + for (size_t j = 0; j < psi.get_nbasis(); j++) { - for (size_t j = 0; j < psi.get_nbasis(); j++) - { - psi(i, j) = *zero_; - } - psi(i, i) = *one_; + psi(i, j) = *zero_; } + psi(i, i) = *one_; } + // } auto psi_tensor = ct::TensorMap(psi.get_pointer(), ct::DataTypeToEnum::value, @@ -308,9 +299,9 @@ void HSolverLCAO::hamiltSolvePsiK(hamilt::Hamilt* hm, psi::Psi& template void HSolverLCAO::parakSolve(hamilt::Hamilt* pHamilt, - psi::Psi& psi, - elecstate::ElecState* pes, - int kpar) + psi::Psi& psi, + elecstate::ElecState* pes, + int kpar) { #ifdef __MPI ModuleBase::timer::tick("HSolverLCAO", "parakSolve"); @@ -320,32 +311,33 @@ void HSolverLCAO::parakSolve(hamilt::Hamilt* pHamilt, int nks = psi.get_nk(); int nrow = this->ParaV->get_global_row_size(); int nb2d = this->ParaV->get_block_size(); - k2d.set_para_env(psi.get_nk(), - nrow, - nb2d, - GlobalV::NPROC, - GlobalV::MY_RANK, - PARAM.inp.nspin); + k2d.set_para_env(psi.get_nk(), nrow, nb2d, GlobalV::NPROC, GlobalV::MY_RANK, PARAM.inp.nspin); /// set psi_pool const int zero = 0; - int ncol_bands_pool = numroc_(&(nbands), &(nb2d), &(k2d.get_p2D_pool()->coord[1]), &zero, &(k2d.get_p2D_pool()->dim1)); + int ncol_bands_pool + = numroc_(&(nbands), &(nb2d), &(k2d.get_p2D_pool()->coord[1]), &zero, &(k2d.get_p2D_pool()->dim1)); /// Loop over k points for solve Hamiltonian to charge density for (int ik = 0; ik < k2d.get_pKpoints()->get_max_nks_pool(); ++ik) { // if nks is not equal to the number of k points in the pool std::vector ik_kpar; int ik_avail = 0; - for (int i = 0; i < k2d.get_kpar(); i++) { - if (ik + k2d.get_pKpoints()->startk_pool[i] < nks && ik < k2d.get_pKpoints()->nks_pool[i]) { + for (int i = 0; i < k2d.get_kpar(); i++) + { + if (ik + k2d.get_pKpoints()->startk_pool[i] < nks && ik < k2d.get_pKpoints()->nks_pool[i]) + { ik_avail++; } } - if (ik_avail == 0) { - ModuleBase::WARNING_QUIT("HSolverLCAO::solve", - "ik_avail is 0!"); - } else { + if (ik_avail == 0) + { + ModuleBase::WARNING_QUIT("HSolverLCAO::solve", "ik_avail is 0!"); + } + else + { ik_kpar.resize(ik_avail); - for (int i = 0; i < ik_avail; i++) { + for (int i = 0; i < ik_avail; i++) + { ik_kpar[i] = ik + k2d.get_pKpoints()->startk_pool[i]; } } @@ -359,31 +351,35 @@ void HSolverLCAO::parakSolve(hamilt::Hamilt* pHamilt, /// local psi in pool psi_pool.fix_k(0); hamilt::MatrixBlock hk_pool = hamilt::MatrixBlock{k2d.hk_pool.data(), - (size_t)k2d.get_p2D_pool()->get_row_size(), (size_t)k2d.get_p2D_pool()->get_col_size(), k2d.get_p2D_pool()->desc}; + (size_t)k2d.get_p2D_pool()->get_row_size(), + (size_t)k2d.get_p2D_pool()->get_col_size(), + k2d.get_p2D_pool()->desc}; hamilt::MatrixBlock sk_pool = hamilt::MatrixBlock{k2d.sk_pool.data(), - (size_t)k2d.get_p2D_pool()->get_row_size(), (size_t)k2d.get_p2D_pool()->get_col_size(), k2d.get_p2D_pool()->desc}; + (size_t)k2d.get_p2D_pool()->get_row_size(), + (size_t)k2d.get_p2D_pool()->get_col_size(), + k2d.get_p2D_pool()->desc}; /// solve eigenvector and eigenvalue for H(k) if (this->method == "scalapack_gvx") { DiagoScalapack sa; - sa.diag_pool(hk_pool, sk_pool, psi_pool,&(pes->ekb(ik_global, 0)), k2d.POOL_WORLD_K2D); + sa.diag_pool(hk_pool, sk_pool, psi_pool, &(pes->ekb(ik_global, 0)), k2d.POOL_WORLD_K2D); } #ifdef __ELPA else if (this->method == "genelpa") { DiagoElpa el; - el.diag_pool(hk_pool, sk_pool, psi_pool,&(pes->ekb(ik_global, 0)), k2d.POOL_WORLD_K2D); + el.diag_pool(hk_pool, sk_pool, psi_pool, &(pes->ekb(ik_global, 0)), k2d.POOL_WORLD_K2D); } else if (this->method == "elpa") { DiagoElpaNative el; - el.diag_pool(hk_pool, sk_pool, psi_pool,&(pes->ekb(ik_global, 0)), k2d.POOL_WORLD_K2D); + el.diag_pool(hk_pool, sk_pool, psi_pool, &(pes->ekb(ik_global, 0)), k2d.POOL_WORLD_K2D); } #endif else { ModuleBase::WARNING_QUIT("HSolverLCAO::solve", - "This method of DiagH for k-parallelism diagnolization is not supported!"); + "This method of DiagH for k-parallelism diagnolization is not supported!"); } } MPI_Barrier(MPI_COMM_WORLD); @@ -394,21 +390,22 @@ void HSolverLCAO::parakSolve(hamilt::Hamilt* pHamilt, MPI_Bcast(&(pes->ekb(ik_kpar[ipool], 0)), nbands, MPI_DOUBLE, source, MPI_COMM_WORLD); int desc_pool[9]; std::copy(k2d.get_p2D_pool()->desc, k2d.get_p2D_pool()->desc + 9, desc_pool); - if (k2d.get_my_pool() != ipool) { + if (k2d.get_my_pool() != ipool) + { desc_pool[1] = -1; } psi.fix_k(ik_kpar[ipool]); Cpxgemr2d(nrow, - nbands, - psi_pool.get_pointer(), - 1, - 1, - desc_pool, - psi.get_pointer(), - 1, - 1, - k2d.get_p2D_global()->desc, - k2d.get_p2D_global()->blacs_ctxt); + nbands, + psi_pool.get_pointer(), + 1, + 1, + desc_pool, + psi.get_pointer(), + 1, + 1, + k2d.get_p2D_global()->desc, + k2d.get_p2D_global()->blacs_ctxt); } MPI_Barrier(MPI_COMM_WORLD); ModuleBase::timer::tick("HSolverLCAO", "collect_psi"); diff --git a/source/module_hsolver/hsolver_lcao.h b/source/module_hsolver/hsolver_lcao.h index 8dbb3e52f8..c2b0c92024 100644 --- a/source/module_hsolver/hsolver_lcao.h +++ b/source/module_hsolver/hsolver_lcao.h @@ -16,24 +16,20 @@ class HSolverLCAO void solve(hamilt::Hamilt* pHamilt, psi::Psi& psi, elecstate::ElecState* pes, - const std::string method_in, const bool skip_charge); private: void hamiltSolvePsiK(hamilt::Hamilt* hm, psi::Psi& psi, double* eigenvalue); - const Parallel_Orbitals* ParaV; - void parakSolve(hamilt::Hamilt* pHamilt, psi::Psi& psi, elecstate::ElecState* pes, int kpar); - bool is_first_scf = true; + const Parallel_Orbitals* ParaV; + + const std::string method; + // for cg_in_lcao using Real = typename GetTypeReal::type; std::vector precondition_lcao; - - DiagH* pdiagh = nullptr; // for single Hamiltonian matrix diagonal solver - - std::string method = "none"; }; } // namespace hsolver diff --git a/source/module_hsolver/hsolver_lcaopw.cpp b/source/module_hsolver/hsolver_lcaopw.cpp index c1261ff695..6782df33b7 100644 --- a/source/module_hsolver/hsolver_lcaopw.cpp +++ b/source/module_hsolver/hsolver_lcaopw.cpp @@ -193,11 +193,6 @@ void HSolverLIP::paw_func_after_kloop(psi::Psi& psi, elecstate::ElecState* } #endif -template -HSolverLIP::HSolverLIP(ModulePW::PW_Basis_K* wfc_basis_in) -{ - this->wfc_basis = wfc_basis_in; -} /* lcao_in_pw diff --git a/source/module_hsolver/hsolver_lcaopw.h b/source/module_hsolver/hsolver_lcaopw.h index 1694e6c62b..fd82e7b8eb 100644 --- a/source/module_hsolver/hsolver_lcaopw.h +++ b/source/module_hsolver/hsolver_lcaopw.h @@ -18,7 +18,7 @@ class HSolverLIP using Real = typename GetTypeReal::type; public: - HSolverLIP(ModulePW::PW_Basis_K* wfc_basis_in); + HSolverLIP(ModulePW::PW_Basis_K* wfc_basis_in) : wfc_basis(wfc_basis_in) {}; /// @brief solve function for lcao_in_pw /// @param pHamilt interface to hamilt @@ -33,9 +33,7 @@ class HSolverLIP const bool skip_charge); private: - ModulePW::PW_Basis_K* wfc_basis = nullptr; - - std::vector eigenvalues; + ModulePW::PW_Basis_K* wfc_basis; #ifdef USE_PAW void paw_func_in_kloop(const int ik); diff --git a/source/module_hsolver/hsolver_pw.h b/source/module_hsolver/hsolver_pw.h index e4b220fbc1..62bbc5120b 100644 --- a/source/module_hsolver/hsolver_pw.h +++ b/source/module_hsolver/hsolver_pw.h @@ -21,20 +21,17 @@ class HSolverPW public: HSolverPW(ModulePW::PW_Basis_K* wfc_basis_in, wavefunc* pwf_in, - const std::string calculation_type_in, const std::string basis_type_in, const std::string method_in, const bool use_paw_in, const bool use_uspp_in, const int nspin_in, - const int scf_iter_in, const int diag_iter_max_in, const double diag_thr_in, const bool need_subspace_in, const bool initialed_psi_in) - : wfc_basis(wfc_basis_in), pwf(pwf_in), calculation_type(calculation_type_in), basis_type(basis_type_in), method(method_in), use_paw(use_paw_in), use_uspp(use_uspp_in), nspin(nspin_in), diff --git a/source/module_hsolver/hsolver_pw_sdft.cpp b/source/module_hsolver/hsolver_pw_sdft.cpp index 0f15879bb1..4f4c858cf4 100644 --- a/source/module_hsolver/hsolver_pw_sdft.cpp +++ b/source/module_hsolver/hsolver_pw_sdft.cpp @@ -64,7 +64,7 @@ void HSolverPW_SDFT::solve(hamilt::Hamilt>* pHamilt, } this->output_iterInfo(); - + for (int ik = 0; ik < nks; ik++) { // init k diff --git a/source/module_hsolver/hsolver_pw_sdft.h b/source/module_hsolver/hsolver_pw_sdft.h index aa342b03e1..144a508985 100644 --- a/source/module_hsolver/hsolver_pw_sdft.h +++ b/source/module_hsolver/hsolver_pw_sdft.h @@ -12,21 +12,17 @@ class HSolverPW_SDFT : public HSolverPW> wavefunc* pwf_in, Stochastic_WF& stowf, StoChe& stoche, - const std::string calculation_type_in, const std::string basis_type_in, const std::string method_in, const bool use_paw_in, const bool use_uspp_in, const int nspin_in, - const int scf_iter_in, const int diag_iter_max_in, const double diag_thr_in, - const bool need_subspace_in, const bool initialed_psi_in) - : HSolverPW(wfc_basis_in, pwf_in, calculation_type_in, @@ -44,14 +40,14 @@ class HSolverPW_SDFT : public HSolverPW> stoiter.init(pkv, wfc_basis_in, stowf, stoche); } - virtual void solve(hamilt::Hamilt>* pHamilt, - psi::Psi>& psi, - elecstate::ElecState* pes, - ModulePW::PW_Basis_K* wfc_basis, - Stochastic_WF& stowf, - const int istep, - const int iter, - const bool skip_charge); + void solve(hamilt::Hamilt>* pHamilt, + psi::Psi>& psi, + elecstate::ElecState* pes, + ModulePW::PW_Basis_K* wfc_basis, + Stochastic_WF& stowf, + const int istep, + const int iter, + const bool skip_charge); Stochastic_Iter stoiter; }; diff --git a/source/module_lr/hsolver_lrtd.cpp b/source/module_lr/hsolver_lrtd.cpp index 638ae6e807..0cc95fb8f6 100644 --- a/source/module_lr/hsolver_lrtd.cpp +++ b/source/module_lr/hsolver_lrtd.cpp @@ -156,11 +156,6 @@ namespace LR std::vector(psi_k1_dav.get_nbands(), true), false /*scf*/)); } - // else if (this->method == "cg") - // { - // this->pdiagh = new DiagoCG(precondition.data()); - // this->pdiagh->method = this->method; - // } else {throw std::runtime_error("HSolverLR::solve: method not implemented");} }