Skip to content

Commit

Permalink
Refactor: reorganized HSolverPW<T, Device>::solve function in `HSol…
Browse files Browse the repository at this point in the history
…verPW` (#4675)

* refactor hsolver_pw

* refactor hamiltSolvePsiK

* fix build bug

* [pre-commit.ci lite] apply automatic fixes

* fix build bug

* fix build bug

* solve conflicts

---------

Co-authored-by: pre-commit-ci-lite[bot] <117423508+pre-commit-ci-lite[bot]@users.noreply.github.com>
Co-authored-by: Mohan Chen <[email protected]>
  • Loading branch information
3 people authored Jul 16, 2024
1 parent 7a1faa6 commit 234675f
Show file tree
Hide file tree
Showing 5 changed files with 187 additions and 191 deletions.
3 changes: 2 additions & 1 deletion source/module_hsolver/diago_dav_subspace.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ int Diago_DavSubspace<T, Device>::diag_once(const HPsiFunc& hpsi_func,
this->notconv = 0;
for (int m = 0; m < this->n_band; m++)
{
if (is_occupied[m])
if (is_occupied[m]) // always true
{
convflag[m] = (std::abs(eigenvalue_iter[m] - eigenvalue_in_hsolver[m]) < this->diag_thr);
}
Expand Down Expand Up @@ -740,6 +740,7 @@ int Diago_DavSubspace<T, Device>::diag(const HPsiFunc& hpsi_func,

int sum_iter = 0;
int ntry = 0;

do
{
if (this->is_subspace || ntry > 0)
Expand Down
226 changes: 103 additions & 123 deletions source/module_hsolver/hsolver_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "module_hsolver/diago_iter_assist.h"

#include <algorithm>
#include <vector>

#ifdef USE_PAW
#include "module_cell/module_paw/paw_cell.h"
Expand All @@ -30,7 +31,6 @@ HSolverPW<T, Device>::HSolverPW(ModulePW::PW_Basis_K* wfc_basis_in, wavefunc* pw
this->wfc_basis = wfc_basis_in;
this->pwf = pwf_in;
this->diag_ethr = GlobalV::PW_DIAG_THR;
/*this->init(pbas_in);*/
}

#ifdef USE_PAW
Expand Down Expand Up @@ -213,6 +213,32 @@ void HSolverPW<T, Device>::paw_func_after_kloop(psi::Psi<T, Device>& psi, elecst

#endif

template <typename T, typename Device>
void HSolverPW<T, Device>::set_isOccupied(std::vector<bool>& is_occupied,
elecstate::ElecState* pes,
const int i_scf,
const int nk,
const int nband,
const bool diago_full_acc_)
{
if (i_scf != 0 && diago_full_acc_ == false)
{
for (int i = 0; i < nk; i++)
{
if (pes->klist->wk[i] > 0.0)
{
for (int j = 0; j < nband; j++)
{
if (pes->wg(i, j) / pes->klist->wk[i] < 0.01)
{
is_occupied[i * nband + j] = false;
}
}
}
}
}
}

template <typename T, typename Device>
void HSolverPW<T, Device>::solve(hamilt::Hamilt<T, Device>* pHamilt,
psi::Psi<T, Device>& psi,
Expand All @@ -222,46 +248,29 @@ void HSolverPW<T, Device>::solve(hamilt::Hamilt<T, Device>* pHamilt,
{
ModuleBase::TITLE("HSolverPW", "solve");
ModuleBase::timer::tick("HSolverPW", "solve");
// prepare for the precondition of diagonalization
this->precondition.resize(psi.get_nbasis());
this->hamilt_ = pHamilt;

// select the method of diagonalization
this->method = method_in;

// report if the specified diagonalization method is not supported
const std::initializer_list<std::string> _methods = {"cg", "dav", "dav_subspace", "bpcg"};
if (std::find(std::begin(_methods), std::end(_methods), this->method) == std::end(_methods))
{
ModuleBase::WARNING_QUIT("HSolverPW::solve", "This method of DiagH is not supported!");
}

std::vector<Real> eigenvalues(pes->ekb.nr * pes->ekb.nc, 0);

if (this->is_first_scf)
{
is_occupied.resize(psi.get_nk() * psi.get_nbands(), true);
}
else
// prepare for the precondition of diagonalization
std::vector<Real> precondition(psi.get_nbasis(), 0.0);
std::vector<Real> eigenvalues(pes->ekb.nr * pes->ekb.nc, 0.0);
std::vector<bool> is_occupied(psi.get_nk() * psi.get_nbands(), true);
if (this->method == "dav_subspace")
{
if (this->diago_full_acc)
{
is_occupied.assign(is_occupied.size(), true);
}
else
{
for (int i = 0; i < psi.get_nk(); i++)
{
if (pes->klist->wk[i] > 0.0)
{
for (int j = 0; j < psi.get_nbands(); j++)
{
if (pes->wg(i, j) / pes->klist->wk[i] < 0.01)
{
is_occupied[i * psi.get_nbands() + j] = false;
}
}
}
}
}
this->set_isOccupied(is_occupied,
pes,
DiagoIterAssist<T, Device>::SCF_ITER,
psi.get_nk(),
psi.get_nbands(),
this->diago_full_acc);
}

/// Loop over k points for solve Hamiltonian to charge density
Expand All @@ -284,7 +293,7 @@ void HSolverPW<T, Device>::solve(hamilt::Hamilt<T, Device>* pHamilt,
#endif

/// solve eigenvector and eigenvalue for H(k)
this->hamiltSolvePsiK(pHamilt, psi, eigenvalues.data() + ik * pes->ekb.nc);
this->hamiltSolvePsiK(pHamilt, psi, precondition, eigenvalues.data() + ik * pes->ekb.nc);

if (skip_charge)
{
Expand All @@ -298,54 +307,35 @@ void HSolverPW<T, Device>::solve(hamilt::Hamilt<T, Device>* pHamilt,
}
// END Loop over k points

// copy eigenvalues to pes->ekb in ElecState
base_device::memory::cast_memory_op<double, Real, base_device::DEVICE_CPU, base_device::DEVICE_CPU>()(
cpu_ctx,
cpu_ctx,
pes->ekb.c,
eigenvalues.data(),
pes->ekb.nr * pes->ekb.nc);

this->is_first_scf = false;

this->endDiagh();
// psi only should be initialed once for PW
if (!this->initialed_psi)
{
this->initialed_psi = true;
}

if (skip_charge)
{
ModuleBase::timer::tick("HSolverPW", "solve");
return;
}
reinterpret_cast<elecstate::ElecStatePW<T, Device>*>(pes)->psiToRho(psi);
else
{
reinterpret_cast<elecstate::ElecStatePW<T, Device>*>(pes)->psiToRho(psi);

#ifdef USE_PAW
this->paw_func_after_kloop(psi, pes);
this->paw_func_after_kloop(psi, pes);
#endif

ModuleBase::timer::tick("HSolverPW", "solve");
return;
}

template <typename T, typename Device>
void HSolverPW<T, Device>::endDiagh()
{
// in PW base, average iteration steps for each band and k-point should be
// printing
if (DiagoIterAssist<T, Device>::avg_iter > 0.0)
{
GlobalV::ofs_running << "Average iterative diagonalization steps: "
<< DiagoIterAssist<T, Device>::avg_iter / this->wfc_basis->nks
<< " ; where current threshold is: " << DiagoIterAssist<T, Device>::PW_DIAG_THR << " . "
<< std::endl;

// std::cout << "avg_iter == " << DiagoIterAssist<T, Device>::avg_iter
// << std::endl;

// reset avg_iter
DiagoIterAssist<T, Device>::avg_iter = 0.0;
}
// psi only should be initialed once for PW
if (!this->initialed_psi)
{
this->initialed_psi = true;
ModuleBase::timer::tick("HSolverPW", "solve");
return;
}
}

Expand All @@ -361,13 +351,22 @@ void HSolverPW<T, Device>::updatePsiK(hamilt::Hamilt<T, Device>* pHamilt, psi::P
}

template <typename T, typename Device>
void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm, psi::Psi<T, Device>& psi, Real* eigenvalue)
void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
psi::Psi<T, Device>& psi,
std::vector<Real>& pre_condition,
Real* eigenvalue)
{
#ifdef __MPI
const diag_comm_info comm_info = {POOL_WORLD, GlobalV::RANK_IN_POOL, GlobalV::NPROC_IN_POOL};
#else
const diag_comm_info comm_info = {GlobalV::RANK_IN_POOL, GlobalV::NPROC_IN_POOL};
#endif

if (this->method == "cg")
{
// warp the subspace_func into a lambda function
auto ngk_pointer = psi.get_ngk_pointer();
auto subspace_func = [this, ngk_pointer](const ct::Tensor& psi_in, ct::Tensor& psi_out) {
auto subspace_func = [hm, ngk_pointer](const ct::Tensor& psi_in, ct::Tensor& psi_out) {
// psi_in should be a 2D tensor:
// psi_in.shape() = [nbands, nbasis]
const auto ndim = psi_in.shape().ndim();
Expand All @@ -387,7 +386,7 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm, psi::P
ct::DeviceType::CpuDevice,
ct::TensorShape({psi_in.shape().dim_size(0)}));

DiagoIterAssist<T, Device>::diagH_subspace(hamilt_, psi_in_wrapper, psi_out_wrapper, eigen.data<Real>());
DiagoIterAssist<T, Device>::diagH_subspace(hm, psi_in_wrapper, psi_out_wrapper, eigen.data<Real>());
};
DiagoCG<T, Device> cg(GlobalV::BASIS_TYPE,
GlobalV::CALCULATION,
Expand Down Expand Up @@ -456,45 +455,26 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm, psi::P
ct::DataTypeToEnum<Real>::value,
ct::DeviceTypeToEnum<ct::DEVICE_CPU>::value,
ct::TensorShape({psi.get_nbands()}));
auto prec_tensor = ct::TensorMap(precondition.data(),
auto prec_tensor = ct::TensorMap(pre_condition.data(),
ct::DataTypeToEnum<Real>::value,
ct::DeviceTypeToEnum<ct::DEVICE_CPU>::value,
ct::TensorShape({static_cast<int>(precondition.size())}))
ct::TensorShape({static_cast<int>(pre_condition.size())}))
.to_device<ct_Device>()
.slice({0}, {psi.get_current_nbas()});

cg.diag(hpsi_func, spsi_func, psi_tensor, eigen_tensor, prec_tensor);
// TODO: Double check tensormap's potential problem
ct::TensorMap(psi.get_pointer(), psi_tensor, {psi.get_nbands(), psi.get_nbasis()}).sync(psi_tensor);
}
else if (this->method == "bpcg")
{
DiagoBPCG<T, Device> bpcg(pre_condition.data());
bpcg.init_iter(psi);
bpcg.diag(hm, psi, eigenvalue);
}
else if (this->method == "dav_subspace")
{
#ifdef __MPI
const diag_comm_info comm_info = {POOL_WORLD, GlobalV::RANK_IN_POOL, GlobalV::NPROC_IN_POOL};
#else
const diag_comm_info comm_info = {GlobalV::RANK_IN_POOL, GlobalV::NPROC_IN_POOL};
#endif
Diago_DavSubspace<T, Device> dav_subspace(this->precondition,
psi.get_nbands(),
psi.get_k_first() ? psi.get_current_nbas()
: psi.get_nk() * psi.get_nbasis(),
GlobalV::PW_DIAG_NDIM,
DiagoIterAssist<T, Device>::PW_DIAG_THR,
DiagoIterAssist<T, Device>::PW_DIAG_NMAX,
DiagoIterAssist<T, Device>::need_subspace,
comm_info);
bool scf;
if (GlobalV::CALCULATION == "nscf")
{
scf = false;
}
else
{
scf = true;
}

auto ngk_pointer = psi.get_ngk_pointer();

auto hpsi_func = [hm, ngk_pointer](T* hpsi_out,
T* psi_in,
const int nband_in,
Expand All @@ -514,40 +494,26 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm, psi::P

ModuleBase::timer::tick("DavSubspace", "hpsi_func");
};
bool scf = GlobalV::CALCULATION == "nscf" ? false : true;
const std::vector<bool> is_occupied(psi.get_nbands(), true);

auto subspace_func = [hm, ngk_pointer](T* psi_out,
T* psi_in,
Real* eigenvalue_in_hsolver,
const int nband_in,
const int nbasis_max_in) {
// Convert "pointer data stucture" to a psi::Psi object
auto psi_in_wrapper = psi::Psi<T, Device>(psi_in, 1, nband_in, nbasis_max_in, ngk_pointer);
auto psi_out_wrapper = psi::Psi<T, Device>(psi_out, 1, nband_in, nbasis_max_in, ngk_pointer);

DiagoIterAssist<T, Device>::diagH_subspace(hm,
psi_in_wrapper,
psi_out_wrapper,
eigenvalue_in_hsolver,
nband_in);
};
Diago_DavSubspace<T, Device> dav_subspace(pre_condition,
psi.get_nbands(),
psi.get_k_first() ? psi.get_current_nbas()
: psi.get_nk() * psi.get_nbasis(),
GlobalV::PW_DIAG_NDIM,
DiagoIterAssist<T, Device>::PW_DIAG_THR,
DiagoIterAssist<T, Device>::PW_DIAG_NMAX,
DiagoIterAssist<T, Device>::need_subspace,
comm_info);

DiagoIterAssist<T, Device>::avg_iter += static_cast<double>(
dav_subspace.diag(hpsi_func, psi.get_pointer(), psi.get_nbasis(), eigenvalue, is_occupied, scf));
}
else if (this->method == "bpcg")
{
DiagoBPCG<T, Device> bpcg(precondition.data());
bpcg.init_iter(psi);
bpcg.diag(hm, psi, eigenvalue);
DiagoIterAssist<T, Device>::avg_iter
+= static_cast<double>(dav_subspace.diag(hpsi_func, psi.get_pointer(), psi.get_nbasis(), eigenvalue, is_occupied, scf));
}
else if (this->method == "dav")
{
#ifdef __MPI
const diag_comm_info comm_info = {POOL_WORLD, GlobalV::RANK_IN_POOL, GlobalV::NPROC_IN_POOL};
#else
const diag_comm_info comm_info = {GlobalV::RANK_IN_POOL, GlobalV::NPROC_IN_POOL};
#endif
// Davidson iter parameters

// Allow 5 tries at most. If ntry > ntry_max = 5, exit diag loop.
const int ntry_max = 5;
// In non-self consistent calculation, do until totally converged. Else
Expand All @@ -561,7 +527,6 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm, psi::P
const int dim = psi.get_current_nbas();
const int nband = psi.get_nbands();
const int ldPsi = psi.get_nbasis();


auto ngk_pointer = psi.get_ngk_pointer();
/// wrap for hpsi function, Matrix \times blockvector
Expand Down Expand Up @@ -604,7 +569,7 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm, psi::P
1);
*/

DiagoDavid<T, Device> david(precondition.data(), GlobalV::PW_DIAG_NDIM, GlobalV::use_paw, comm_info);
DiagoDavid<T, Device> david(pre_condition.data(), GlobalV::PW_DIAG_NDIM, GlobalV::use_paw, comm_info);
DiagoIterAssist<T, Device>::avg_iter += static_cast<double>(
david.diag(hpsi_func, spsi_func, dim, nband, ldPsi, psi, eigenvalue, david_diag_thr, david_maxiter, ntry_max, notconv_max));
}
Expand Down Expand Up @@ -657,6 +622,21 @@ void HSolverPW<T, Device>::update_precondition(std::vector<Real>& h_diag, const
}
}

template <typename T, typename Device>
void HSolverPW<T, Device>::output_iterInfo()
{
// in PW base, average iteration steps for each band and k-point should be printing
if (DiagoIterAssist<T, Device>::avg_iter > 0.0)
{
GlobalV::ofs_running << "Average iterative diagonalization steps: "
<< DiagoIterAssist<T, Device>::avg_iter / this->wfc_basis->nks
<< " ; where current threshold is: " << DiagoIterAssist<T, Device>::PW_DIAG_THR << " . "
<< std::endl;
// reset avg_iter
DiagoIterAssist<T, Device>::avg_iter = 0.0;
}
}

template <typename T, typename Device>
typename HSolverPW<T, Device>::Real HSolverPW<T, Device>::cal_hsolerror()
{
Expand Down
Loading

0 comments on commit 234675f

Please sign in to comment.