Skip to content

Commit

Permalink
Refactor: move set_is_occupied func from hsolver to elecstate (#4847)
Browse files Browse the repository at this point in the history
* remove pelec frome hsolverpw solve func

* change set_is_occupied func

* fix test bug

* change the place of set_is_occupied func
  • Loading branch information
haozhihan authored Aug 1, 2024
1 parent 208a61c commit 880d003
Show file tree
Hide file tree
Showing 11 changed files with 103 additions and 110 deletions.
28 changes: 28 additions & 0 deletions source/module_elecstate/elecstate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -356,4 +356,32 @@ void ElecState::cal_nbands()

ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, "NBANDS", GlobalV::NBANDS);
}

void set_is_occupied(std::vector<bool>& is_occupied,
elecstate::ElecState* pes,
const int i_scf,
const int nk,
const int nband,
const bool diago_full_acc)
{
if (i_scf != 0 && diago_full_acc == false)
{
for (int i = 0; i < nk; i++)
{
if (pes->klist->wk[i] > 0.0)
{
for (int j = 0; j < nband; j++)
{
if (pes->wg(i, j) / pes->klist->wk[i] < 0.01)
{
is_occupied[i * nband + j] = false;
}
}
}
}
}
};



} // namespace elecstate
8 changes: 8 additions & 0 deletions source/module_elecstate/elecstate.h
Original file line number Diff line number Diff line change
Expand Up @@ -177,5 +177,13 @@ class ElecState
bool skip_weights = false;
};

// This is an independent function under the elecstate namespace and does not depend on any class.
void set_is_occupied(std::vector<bool>& is_occupied,
elecstate::ElecState* pes,
const int i_scf,
const int nk,
const int nband,
const bool diago_full_acc);

} // namespace elecstate
#endif
14 changes: 12 additions & 2 deletions source/module_esolver/esolver_ks_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -353,12 +353,22 @@ void ESolver_KS_PW<T, Device>::hamilt2density(const int istep, const int iter, c
hsolver::DiagoIterAssist<T, Device>::SCF_ITER = iter;
hsolver::DiagoIterAssist<T, Device>::PW_DIAG_THR = ethr;
hsolver::DiagoIterAssist<T, Device>::PW_DIAG_NMAX = GlobalV::PW_DIAG_NMAX;


std::vector<bool> is_occupied(this->kspw_psi->get_nk() * this->kspw_psi->get_nbands(), true);

elecstate::set_is_occupied(is_occupied,
this->pelec,
hsolver::DiagoIterAssist<T, Device>::SCF_ITER,
this->kspw_psi->get_nk(),
this->kspw_psi->get_nbands(),
PARAM.inp.diago_full_acc);

hsolver::HSolverPW<T, Device> hsolver_pw_obj(this->pw_wfc, &this->wf, this->init_psi);
hsolver_pw_obj.solve(this->p_hamilt, // hamilt::Hamilt<T, Device>* pHamilt,
this->kspw_psi[0], // psi::Psi<T, Device>& psi,
this->pelec, // elecstate::ElecState<T, Device>* pelec,
this->pelec, // elecstate::ElecState<T, Device>* pelec,
this->pelec->ekb.c,
is_occupied,
PARAM.inp.ks_solver,
PARAM.inp.calculation,
PARAM.inp.basis_type,
Expand Down
1 change: 1 addition & 0 deletions source/module_esolver/esolver_ks_pw.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ class ESolver_KS_PW : public ESolver_KS<T, Device>

using castmem_2d_d2h_op
= base_device::memory::cast_memory_op<std::complex<double>, T, base_device::DEVICE_CPU, Device>;

};
} // namespace ModuleESolver
#endif
42 changes: 25 additions & 17 deletions source/module_esolver/pw_fun.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,26 +79,34 @@ void ESolver_KS_PW<T, Device>::hamilt2estates(const double ethr) {
hsolver::DiagoIterAssist<T, Device>::need_subspace = false;
hsolver::DiagoIterAssist<T, Device>::PW_DIAG_THR = ethr;

hsolver::HSolverPW<T, Device> hsolver_pw_obj(this->pw_wfc, &this->wf, this->init_psi);
std::vector<bool> is_occupied(this->kspw_psi->get_nk() * this->kspw_psi->get_nbands(), true);

hsolver_pw_obj.solve(this->p_hamilt,
this->kspw_psi[0],
this->pelec,
this->pelec->ekb.c,
PARAM.inp.ks_solver,
PARAM.inp.calculation,
PARAM.inp.basis_type,
PARAM.inp.use_paw,
GlobalV::use_uspp,
GlobalV::RANK_IN_POOL,
GlobalV::NPROC_IN_POOL,
elecstate::set_is_occupied(is_occupied,
this->pelec,
hsolver::DiagoIterAssist<T, Device>::SCF_ITER,
this->kspw_psi->get_nk(),
this->kspw_psi->get_nbands(),
PARAM.inp.diago_full_acc);

hsolver::DiagoIterAssist<T, Device>::SCF_ITER,
hsolver::DiagoIterAssist<T, Device>::need_subspace,
hsolver::DiagoIterAssist<T, Device>::PW_DIAG_NMAX,
hsolver::DiagoIterAssist<T, Device>::PW_DIAG_THR,
hsolver::HSolverPW<T, Device> hsolver_pw_obj(this->pw_wfc, &this->wf, this->init_psi);

true);
hsolver_pw_obj.solve(this->p_hamilt,
this->kspw_psi[0],
this->pelec,
this->pelec->ekb.c,
is_occupied,
PARAM.inp.ks_solver,
PARAM.inp.calculation,
PARAM.inp.basis_type,
PARAM.inp.use_paw,
GlobalV::use_uspp,
GlobalV::RANK_IN_POOL,
GlobalV::NPROC_IN_POOL,
hsolver::DiagoIterAssist<T, Device>::SCF_ITER,
hsolver::DiagoIterAssist<T, Device>::need_subspace,
hsolver::DiagoIterAssist<T, Device>::PW_DIAG_NMAX,
hsolver::DiagoIterAssist<T, Device>::PW_DIAG_THR,
true);

this->init_psi = true;

Expand Down
14 changes: 5 additions & 9 deletions source/module_hsolver/hsolver.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,23 +49,19 @@ class HSolver
virtual void solve(hamilt::Hamilt<T, Device>* phm,
psi::Psi<T, Device>& ppsi,
elecstate::ElecState* pes,

double* out_eigenvalues,

const std::vector<bool>& is_occupied_in,
const std::string method,

const std::string calculation_type_in,
const std::string basis_type_in,
const bool use_paw_in,
const bool use_uspp_in,
const int rank_in_pool_in,
const int nproc_in_pool_in,

const int scf_iter_in,
const bool need_subspace_in,
const int diag_iter_max_in,
const double pw_diag_thr_in,

const int scf_iter_in,
const bool need_subspace_in,
const int diag_iter_max_in,
const double pw_diag_thr_in,
const bool skip_charge)
{
return;
Expand Down
42 changes: 1 addition & 41 deletions source/module_hsolver/hsolver_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,53 +220,23 @@ HSolverPW<T, Device>::HSolverPW(ModulePW::PW_Basis_K* wfc_basis_in,
this->initialed_psi = initialed_psi_in;
}

template <typename T, typename Device>
void HSolverPW<T, Device>::set_isOccupied(std::vector<bool>& is_occupied,
elecstate::ElecState* pes,
const int i_scf,
const int nk,
const int nband,
const bool diago_full_acc_)
{
if (i_scf != 0 && diago_full_acc_ == false)
{
for (int i = 0; i < nk; i++)
{
if (pes->klist->wk[i] > 0.0)
{
for (int j = 0; j < nband; j++)
{
if (pes->wg(i, j) / pes->klist->wk[i] < 0.01)
{
is_occupied[i * nband + j] = false;
}
}
}
}
}
}

template <typename T, typename Device>
void HSolverPW<T, Device>::solve(hamilt::Hamilt<T, Device>* pHamilt,
psi::Psi<T, Device>& psi,
elecstate::ElecState* pes,

double* out_eigenvalues,

const std::vector<bool>& is_occupied_in,
const std::string method_in,

const std::string calculation_type_in,
const std::string basis_type_in,
const bool use_paw_in,
const bool use_uspp_in,
const int rank_in_pool_in,
const int nproc_in_pool_in,

const int scf_iter_in,
const bool need_subspace_in,
const int diag_iter_max_in,
const double pw_diag_thr_in,

const bool skip_charge)
{
ModuleBase::TITLE("HSolverPW", "solve");
Expand Down Expand Up @@ -298,16 +268,6 @@ void HSolverPW<T, Device>::solve(hamilt::Hamilt<T, Device>* pHamilt,
// prepare for the precondition of diagonalization
std::vector<Real> precondition(psi.get_nbasis(), 0.0);
std::vector<Real> eigenvalues(pes->ekb.nr * pes->ekb.nc, 0.0);
std::vector<bool> is_occupied(psi.get_nk() * psi.get_nbands(), true);
if (this->method == "dav_subspace")
{
this->set_isOccupied(is_occupied,
pes,
this->scf_iter,
psi.get_nk(),
psi.get_nbands(),
this->diago_full_acc);
}

/// Loop over k points for solve Hamiltonian to charge density
for (int ik = 0; ik < this->wfc_basis->nks; ++ik)
Expand Down
26 changes: 1 addition & 25 deletions source/module_hsolver/hsolver_pw.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,6 @@ class HSolverPW : public HSolver<T, Device>
using Real = typename GetTypeReal<T>::type;

public:
/**
* @brief diago_full_acc
* If .TRUE. all the empty states are diagonalized at the same level of
* accuracy of the occupied ones. Otherwise the empty states are
* diagonalized using a larger threshold (this should not affect total
* energy, forces, and other ground-state properties).
*
*/
static bool diago_full_acc;

HSolverPW(ModulePW::PW_Basis_K* wfc_basis_in,
wavefunc* pwf_in,
const bool initialed_psi_in);
Expand All @@ -42,23 +32,19 @@ class HSolverPW : public HSolver<T, Device>
void solve(hamilt::Hamilt<T, Device>* pHamilt,
psi::Psi<T, Device>& psi,
elecstate::ElecState* pes,

double* out_eigenvalues,

const std::vector<bool>& is_occupied_in,
const std::string method_in,

const std::string calculation_type_in,
const std::string basis_type_in,
const bool use_paw_in,
const bool use_uspp_in,
const int rank_in_pool_in,
const int nproc_in_pool_in,

const int scf_iter_in,
const bool need_subspace_in,
const int diag_iter_max_in,
const double pw_diag_thr_in,

const bool skip_charge) override;

virtual Real cal_hsolerror(const Real diag_ethr_in) override;
Expand Down Expand Up @@ -125,14 +111,6 @@ class HSolverPW : public HSolver<T, Device>

int nspin = 1;

void set_isOccupied(std::vector<bool>& is_occupied,
elecstate::ElecState* pes,
const int i_scf,
const int nk,
const int nband,
const bool diago_full_acc);


#ifdef USE_PAW
void paw_func_in_kloop(const int ik);

Expand All @@ -142,8 +120,6 @@ class HSolverPW : public HSolver<T, Device>
#endif
};

template <typename T, typename Device>
bool HSolverPW<T, Device>::diago_full_acc = true;

} // namespace hsolver

Expand Down
22 changes: 14 additions & 8 deletions source/module_hsolver/test/test_hsolver_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,14 @@ TEST_F(TestHSolverPW, solve) {
// check solve()
EXPECT_EQ(this->hs_f.initialed_psi, false);
EXPECT_EQ(this->hs_d.initialed_psi, false);

std::vector<bool> is_occupied(1 * 2, true);

this->hs_f.solve(&hamilt_test_f,
psi_test_cf,
&elecstate_test,
elecstate_test.ekb.c,
is_occupied,
method_test,
"scf",
"pw",
Expand All @@ -89,10 +93,10 @@ TEST_F(TestHSolverPW, solve) {
GlobalV::RANK_IN_POOL,
GlobalV::NPROC_IN_POOL,

hsolver::DiagoIterAssist<std::complex<double>, base_device::DEVICE_CPU>::SCF_ITER,
hsolver::DiagoIterAssist<std::complex<double>, base_device::DEVICE_CPU>::need_subspace,
hsolver::DiagoIterAssist<std::complex<double>, base_device::DEVICE_CPU>::PW_DIAG_NMAX,
hsolver::DiagoIterAssist<std::complex<double>, base_device::DEVICE_CPU>::PW_DIAG_THR,
hsolver::DiagoIterAssist<std::complex<double>, base_device::DEVICE_CPU>::SCF_ITER,
hsolver::DiagoIterAssist<std::complex<double>, base_device::DEVICE_CPU>::need_subspace,
hsolver::DiagoIterAssist<std::complex<double>, base_device::DEVICE_CPU>::PW_DIAG_NMAX,
hsolver::DiagoIterAssist<std::complex<double>, base_device::DEVICE_CPU>::PW_DIAG_THR,

true);
// EXPECT_EQ(this->hs_f.initialed_psi, true);
Expand All @@ -108,6 +112,7 @@ TEST_F(TestHSolverPW, solve) {
psi_test_cd,
&elecstate_test,
elecstate_test.ekb.c,
is_occupied,
method_test,
"scf",
"pw",
Expand All @@ -116,12 +121,13 @@ TEST_F(TestHSolverPW, solve) {
GlobalV::RANK_IN_POOL,
GlobalV::NPROC_IN_POOL,

hsolver::DiagoIterAssist<std::complex<double>, base_device::DEVICE_CPU>::SCF_ITER,
hsolver::DiagoIterAssist<std::complex<double>, base_device::DEVICE_CPU>::need_subspace,
hsolver::DiagoIterAssist<std::complex<double>, base_device::DEVICE_CPU>::PW_DIAG_NMAX,
hsolver::DiagoIterAssist<std::complex<double>, base_device::DEVICE_CPU>::PW_DIAG_THR,
hsolver::DiagoIterAssist<std::complex<double>, base_device::DEVICE_CPU>::SCF_ITER,
hsolver::DiagoIterAssist<std::complex<double>, base_device::DEVICE_CPU>::need_subspace,
hsolver::DiagoIterAssist<std::complex<double>, base_device::DEVICE_CPU>::PW_DIAG_NMAX,
hsolver::DiagoIterAssist<std::complex<double>, base_device::DEVICE_CPU>::PW_DIAG_THR,

true);

// EXPECT_EQ(this->hs_d.initialed_psi, true);
EXPECT_DOUBLE_EQ(hsolver::DiagoIterAssist<std::complex<double>>::avg_iter,
0.0);
Expand Down
8 changes: 0 additions & 8 deletions source/module_io/input_conv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -301,14 +301,6 @@ void Input_Conv::Convert()
GlobalV::PW_DIAG_NMAX = PARAM.inp.pw_diag_nmax;
GlobalV::PW_DIAG_NDIM = PARAM.inp.pw_diag_ndim;

hsolver::HSolverPW<std::complex<float>, base_device::DEVICE_CPU>::diago_full_acc = PARAM.inp.diago_full_acc;
hsolver::HSolverPW<std::complex<double>, base_device::DEVICE_CPU>::diago_full_acc = PARAM.inp.diago_full_acc;

#if ((defined __CUDA) || (defined __ROCM))
hsolver::HSolverPW<std::complex<float>, base_device::DEVICE_GPU>::diago_full_acc = PARAM.inp.diago_full_acc;
hsolver::HSolverPW<std::complex<double>, base_device::DEVICE_GPU>::diago_full_acc = PARAM.inp.diago_full_acc;
#endif

GlobalV::PW_DIAG_THR = PARAM.inp.pw_diag_thr;
GlobalV::NB2D = PARAM.inp.nb2d;
GlobalV::TEST_FORCE = PARAM.inp.test_force;
Expand Down
8 changes: 8 additions & 0 deletions source/module_io/read_input_item_system.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,14 @@ void ReadInput::item_system()
{
Input_Item item("diago_full_acc");
item.annotation = "all the empty states are diagonalized";
/**
* @brief diago_full_acc
* If .TRUE. all the empty states are diagonalized at the same level of
* accuracy of the occupied ones. Otherwise the empty states are
* diagonalized using a larger threshold (this should not affect total
* energy, forces, and other ground-state properties).
*
*/
read_sync_bool(input.diago_full_acc);
this->add_item(item);
}
Expand Down

0 comments on commit 880d003

Please sign in to comment.