diff --git a/source/module_elecstate/elecstate.cpp b/source/module_elecstate/elecstate.cpp index b11953caa8..3e0149ae29 100644 --- a/source/module_elecstate/elecstate.cpp +++ b/source/module_elecstate/elecstate.cpp @@ -356,4 +356,32 @@ void ElecState::cal_nbands() ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, "NBANDS", GlobalV::NBANDS); } + +void set_is_occupied(std::vector& is_occupied, + elecstate::ElecState* pes, + const int i_scf, + const int nk, + const int nband, + const bool diago_full_acc) +{ + if (i_scf != 0 && diago_full_acc == false) + { + for (int i = 0; i < nk; i++) + { + if (pes->klist->wk[i] > 0.0) + { + for (int j = 0; j < nband; j++) + { + if (pes->wg(i, j) / pes->klist->wk[i] < 0.01) + { + is_occupied[i * nband + j] = false; + } + } + } + } + } +}; + + + } // namespace elecstate diff --git a/source/module_elecstate/elecstate.h b/source/module_elecstate/elecstate.h index 07507f1385..38fcf58e28 100644 --- a/source/module_elecstate/elecstate.h +++ b/source/module_elecstate/elecstate.h @@ -177,5 +177,13 @@ class ElecState bool skip_weights = false; }; +// This is an independent function under the elecstate namespace and does not depend on any class. +void set_is_occupied(std::vector& is_occupied, + elecstate::ElecState* pes, + const int i_scf, + const int nk, + const int nband, + const bool diago_full_acc); + } // namespace elecstate #endif diff --git a/source/module_esolver/esolver_ks_pw.cpp b/source/module_esolver/esolver_ks_pw.cpp index 86e745a08f..dfc65e8c90 100644 --- a/source/module_esolver/esolver_ks_pw.cpp +++ b/source/module_esolver/esolver_ks_pw.cpp @@ -353,12 +353,22 @@ void ESolver_KS_PW::hamilt2density(const int istep, const int iter, c hsolver::DiagoIterAssist::SCF_ITER = iter; hsolver::DiagoIterAssist::PW_DIAG_THR = ethr; hsolver::DiagoIterAssist::PW_DIAG_NMAX = GlobalV::PW_DIAG_NMAX; - + + std::vector is_occupied(this->kspw_psi->get_nk() * this->kspw_psi->get_nbands(), true); + + elecstate::set_is_occupied(is_occupied, + this->pelec, + hsolver::DiagoIterAssist::SCF_ITER, + this->kspw_psi->get_nk(), + this->kspw_psi->get_nbands(), + PARAM.inp.diago_full_acc); + hsolver::HSolverPW hsolver_pw_obj(this->pw_wfc, &this->wf, this->init_psi); hsolver_pw_obj.solve(this->p_hamilt, // hamilt::Hamilt* pHamilt, this->kspw_psi[0], // psi::Psi& psi, - this->pelec, // elecstate::ElecState* pelec, + this->pelec, // elecstate::ElecState* pelec, this->pelec->ekb.c, + is_occupied, PARAM.inp.ks_solver, PARAM.inp.calculation, PARAM.inp.basis_type, diff --git a/source/module_esolver/esolver_ks_pw.h b/source/module_esolver/esolver_ks_pw.h index 03c0ba3365..ab60961485 100644 --- a/source/module_esolver/esolver_ks_pw.h +++ b/source/module_esolver/esolver_ks_pw.h @@ -79,6 +79,7 @@ class ESolver_KS_PW : public ESolver_KS using castmem_2d_d2h_op = base_device::memory::cast_memory_op, T, base_device::DEVICE_CPU, Device>; + }; } // namespace ModuleESolver #endif diff --git a/source/module_esolver/pw_fun.cpp b/source/module_esolver/pw_fun.cpp index a7a4a5bea0..df517dd0c9 100644 --- a/source/module_esolver/pw_fun.cpp +++ b/source/module_esolver/pw_fun.cpp @@ -79,26 +79,34 @@ void ESolver_KS_PW::hamilt2estates(const double ethr) { hsolver::DiagoIterAssist::need_subspace = false; hsolver::DiagoIterAssist::PW_DIAG_THR = ethr; - hsolver::HSolverPW hsolver_pw_obj(this->pw_wfc, &this->wf, this->init_psi); + std::vector is_occupied(this->kspw_psi->get_nk() * this->kspw_psi->get_nbands(), true); - hsolver_pw_obj.solve(this->p_hamilt, - this->kspw_psi[0], - this->pelec, - this->pelec->ekb.c, - PARAM.inp.ks_solver, - PARAM.inp.calculation, - PARAM.inp.basis_type, - PARAM.inp.use_paw, - GlobalV::use_uspp, - GlobalV::RANK_IN_POOL, - GlobalV::NPROC_IN_POOL, + elecstate::set_is_occupied(is_occupied, + this->pelec, + hsolver::DiagoIterAssist::SCF_ITER, + this->kspw_psi->get_nk(), + this->kspw_psi->get_nbands(), + PARAM.inp.diago_full_acc); - hsolver::DiagoIterAssist::SCF_ITER, - hsolver::DiagoIterAssist::need_subspace, - hsolver::DiagoIterAssist::PW_DIAG_NMAX, - hsolver::DiagoIterAssist::PW_DIAG_THR, + hsolver::HSolverPW hsolver_pw_obj(this->pw_wfc, &this->wf, this->init_psi); - true); + hsolver_pw_obj.solve(this->p_hamilt, + this->kspw_psi[0], + this->pelec, + this->pelec->ekb.c, + is_occupied, + PARAM.inp.ks_solver, + PARAM.inp.calculation, + PARAM.inp.basis_type, + PARAM.inp.use_paw, + GlobalV::use_uspp, + GlobalV::RANK_IN_POOL, + GlobalV::NPROC_IN_POOL, + hsolver::DiagoIterAssist::SCF_ITER, + hsolver::DiagoIterAssist::need_subspace, + hsolver::DiagoIterAssist::PW_DIAG_NMAX, + hsolver::DiagoIterAssist::PW_DIAG_THR, + true); this->init_psi = true; diff --git a/source/module_hsolver/hsolver.h b/source/module_hsolver/hsolver.h index 3f2ad82c35..f6af747083 100644 --- a/source/module_hsolver/hsolver.h +++ b/source/module_hsolver/hsolver.h @@ -49,23 +49,19 @@ class HSolver virtual void solve(hamilt::Hamilt* phm, psi::Psi& ppsi, elecstate::ElecState* pes, - double* out_eigenvalues, - + const std::vector& is_occupied_in, const std::string method, - const std::string calculation_type_in, const std::string basis_type_in, const bool use_paw_in, const bool use_uspp_in, const int rank_in_pool_in, const int nproc_in_pool_in, - - const int scf_iter_in, - const bool need_subspace_in, - const int diag_iter_max_in, - const double pw_diag_thr_in, - + const int scf_iter_in, + const bool need_subspace_in, + const int diag_iter_max_in, + const double pw_diag_thr_in, const bool skip_charge) { return; diff --git a/source/module_hsolver/hsolver_pw.cpp b/source/module_hsolver/hsolver_pw.cpp index 177d8d9f38..0d2b3c0a5b 100644 --- a/source/module_hsolver/hsolver_pw.cpp +++ b/source/module_hsolver/hsolver_pw.cpp @@ -220,53 +220,23 @@ HSolverPW::HSolverPW(ModulePW::PW_Basis_K* wfc_basis_in, this->initialed_psi = initialed_psi_in; } -template -void HSolverPW::set_isOccupied(std::vector& is_occupied, - elecstate::ElecState* pes, - const int i_scf, - const int nk, - const int nband, - const bool diago_full_acc_) -{ - if (i_scf != 0 && diago_full_acc_ == false) - { - for (int i = 0; i < nk; i++) - { - if (pes->klist->wk[i] > 0.0) - { - for (int j = 0; j < nband; j++) - { - if (pes->wg(i, j) / pes->klist->wk[i] < 0.01) - { - is_occupied[i * nband + j] = false; - } - } - } - } - } -} - template void HSolverPW::solve(hamilt::Hamilt* pHamilt, psi::Psi& psi, elecstate::ElecState* pes, - double* out_eigenvalues, - + const std::vector& is_occupied_in, const std::string method_in, - const std::string calculation_type_in, const std::string basis_type_in, const bool use_paw_in, const bool use_uspp_in, const int rank_in_pool_in, const int nproc_in_pool_in, - const int scf_iter_in, const bool need_subspace_in, const int diag_iter_max_in, const double pw_diag_thr_in, - const bool skip_charge) { ModuleBase::TITLE("HSolverPW", "solve"); @@ -298,16 +268,6 @@ void HSolverPW::solve(hamilt::Hamilt* pHamilt, // prepare for the precondition of diagonalization std::vector precondition(psi.get_nbasis(), 0.0); std::vector eigenvalues(pes->ekb.nr * pes->ekb.nc, 0.0); - std::vector is_occupied(psi.get_nk() * psi.get_nbands(), true); - if (this->method == "dav_subspace") - { - this->set_isOccupied(is_occupied, - pes, - this->scf_iter, - psi.get_nk(), - psi.get_nbands(), - this->diago_full_acc); - } /// Loop over k points for solve Hamiltonian to charge density for (int ik = 0; ik < this->wfc_basis->nks; ++ik) diff --git a/source/module_hsolver/hsolver_pw.h b/source/module_hsolver/hsolver_pw.h index ebdc5003b5..0282f7685d 100644 --- a/source/module_hsolver/hsolver_pw.h +++ b/source/module_hsolver/hsolver_pw.h @@ -19,16 +19,6 @@ class HSolverPW : public HSolver using Real = typename GetTypeReal::type; public: - /** - * @brief diago_full_acc - * If .TRUE. all the empty states are diagonalized at the same level of - * accuracy of the occupied ones. Otherwise the empty states are - * diagonalized using a larger threshold (this should not affect total - * energy, forces, and other ground-state properties). - * - */ - static bool diago_full_acc; - HSolverPW(ModulePW::PW_Basis_K* wfc_basis_in, wavefunc* pwf_in, const bool initialed_psi_in); @@ -42,23 +32,19 @@ class HSolverPW : public HSolver void solve(hamilt::Hamilt* pHamilt, psi::Psi& psi, elecstate::ElecState* pes, - double* out_eigenvalues, - + const std::vector& is_occupied_in, const std::string method_in, - const std::string calculation_type_in, const std::string basis_type_in, const bool use_paw_in, const bool use_uspp_in, const int rank_in_pool_in, const int nproc_in_pool_in, - const int scf_iter_in, const bool need_subspace_in, const int diag_iter_max_in, const double pw_diag_thr_in, - const bool skip_charge) override; virtual Real cal_hsolerror(const Real diag_ethr_in) override; @@ -125,14 +111,6 @@ class HSolverPW : public HSolver int nspin = 1; - void set_isOccupied(std::vector& is_occupied, - elecstate::ElecState* pes, - const int i_scf, - const int nk, - const int nband, - const bool diago_full_acc); - - #ifdef USE_PAW void paw_func_in_kloop(const int ik); @@ -142,8 +120,6 @@ class HSolverPW : public HSolver #endif }; -template -bool HSolverPW::diago_full_acc = true; } // namespace hsolver diff --git a/source/module_hsolver/test/test_hsolver_pw.cpp b/source/module_hsolver/test/test_hsolver_pw.cpp index e01abd65b8..7b334d7b7b 100644 --- a/source/module_hsolver/test/test_hsolver_pw.cpp +++ b/source/module_hsolver/test/test_hsolver_pw.cpp @@ -77,10 +77,14 @@ TEST_F(TestHSolverPW, solve) { // check solve() EXPECT_EQ(this->hs_f.initialed_psi, false); EXPECT_EQ(this->hs_d.initialed_psi, false); + + std::vector is_occupied(1 * 2, true); + this->hs_f.solve(&hamilt_test_f, psi_test_cf, &elecstate_test, elecstate_test.ekb.c, + is_occupied, method_test, "scf", "pw", @@ -89,10 +93,10 @@ TEST_F(TestHSolverPW, solve) { GlobalV::RANK_IN_POOL, GlobalV::NPROC_IN_POOL, - hsolver::DiagoIterAssist, base_device::DEVICE_CPU>::SCF_ITER, - hsolver::DiagoIterAssist, base_device::DEVICE_CPU>::need_subspace, - hsolver::DiagoIterAssist, base_device::DEVICE_CPU>::PW_DIAG_NMAX, - hsolver::DiagoIterAssist, base_device::DEVICE_CPU>::PW_DIAG_THR, + hsolver::DiagoIterAssist, base_device::DEVICE_CPU>::SCF_ITER, + hsolver::DiagoIterAssist, base_device::DEVICE_CPU>::need_subspace, + hsolver::DiagoIterAssist, base_device::DEVICE_CPU>::PW_DIAG_NMAX, + hsolver::DiagoIterAssist, base_device::DEVICE_CPU>::PW_DIAG_THR, true); // EXPECT_EQ(this->hs_f.initialed_psi, true); @@ -108,6 +112,7 @@ TEST_F(TestHSolverPW, solve) { psi_test_cd, &elecstate_test, elecstate_test.ekb.c, + is_occupied, method_test, "scf", "pw", @@ -116,12 +121,13 @@ TEST_F(TestHSolverPW, solve) { GlobalV::RANK_IN_POOL, GlobalV::NPROC_IN_POOL, - hsolver::DiagoIterAssist, base_device::DEVICE_CPU>::SCF_ITER, - hsolver::DiagoIterAssist, base_device::DEVICE_CPU>::need_subspace, - hsolver::DiagoIterAssist, base_device::DEVICE_CPU>::PW_DIAG_NMAX, - hsolver::DiagoIterAssist, base_device::DEVICE_CPU>::PW_DIAG_THR, + hsolver::DiagoIterAssist, base_device::DEVICE_CPU>::SCF_ITER, + hsolver::DiagoIterAssist, base_device::DEVICE_CPU>::need_subspace, + hsolver::DiagoIterAssist, base_device::DEVICE_CPU>::PW_DIAG_NMAX, + hsolver::DiagoIterAssist, base_device::DEVICE_CPU>::PW_DIAG_THR, true); + // EXPECT_EQ(this->hs_d.initialed_psi, true); EXPECT_DOUBLE_EQ(hsolver::DiagoIterAssist>::avg_iter, 0.0); diff --git a/source/module_io/input_conv.cpp b/source/module_io/input_conv.cpp index 55430e192d..75c920e3e4 100644 --- a/source/module_io/input_conv.cpp +++ b/source/module_io/input_conv.cpp @@ -301,14 +301,6 @@ void Input_Conv::Convert() GlobalV::PW_DIAG_NMAX = PARAM.inp.pw_diag_nmax; GlobalV::PW_DIAG_NDIM = PARAM.inp.pw_diag_ndim; - hsolver::HSolverPW, base_device::DEVICE_CPU>::diago_full_acc = PARAM.inp.diago_full_acc; - hsolver::HSolverPW, base_device::DEVICE_CPU>::diago_full_acc = PARAM.inp.diago_full_acc; - -#if ((defined __CUDA) || (defined __ROCM)) - hsolver::HSolverPW, base_device::DEVICE_GPU>::diago_full_acc = PARAM.inp.diago_full_acc; - hsolver::HSolverPW, base_device::DEVICE_GPU>::diago_full_acc = PARAM.inp.diago_full_acc; -#endif - GlobalV::PW_DIAG_THR = PARAM.inp.pw_diag_thr; GlobalV::NB2D = PARAM.inp.nb2d; GlobalV::TEST_FORCE = PARAM.inp.test_force; diff --git a/source/module_io/read_input_item_system.cpp b/source/module_io/read_input_item_system.cpp index a151c05b36..dc64db337e 100644 --- a/source/module_io/read_input_item_system.cpp +++ b/source/module_io/read_input_item_system.cpp @@ -451,6 +451,14 @@ void ReadInput::item_system() { Input_Item item("diago_full_acc"); item.annotation = "all the empty states are diagonalized"; + /** + * @brief diago_full_acc + * If .TRUE. all the empty states are diagonalized at the same level of + * accuracy of the occupied ones. Otherwise the empty states are + * diagonalized using a larger threshold (this should not affect total + * energy, forces, and other ground-state properties). + * + */ read_sync_bool(input.diago_full_acc); this->add_item(item); }