Skip to content

Commit

Permalink
Refactor: refactor paw func in hsolver_pw (#4646)
Browse files Browse the repository at this point in the history
* refactor paw func in hsolver

* refactor paw func in hsolver
  • Loading branch information
haozhihan authored Jul 11, 2024
1 parent 1fa1c58 commit d655c2a
Show file tree
Hide file tree
Showing 2 changed files with 172 additions and 140 deletions.
304 changes: 164 additions & 140 deletions source/module_hsolver/hsolver_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
#include "diago_cg.h"
#include "diago_dav_subspace.h"
#include "diago_david.h"

#include "module_base/global_variable.h"
#include "module_base/parallel_global.h" // for MPI
#include "module_base/timer.h"
Expand Down Expand Up @@ -34,163 +33,78 @@ HSolverPW<T, Device>::HSolverPW(ModulePW::PW_Basis_K* wfc_basis_in, wavefunc* pw
/*this->init(pbas_in);*/
}

#ifdef USE_PAW
template <typename T, typename Device>
void HSolverPW<T, Device>::solve(hamilt::Hamilt<T, Device>* pHamilt,
psi::Psi<T, Device>& psi,
elecstate::ElecState* pes,
const std::string method_in,
const bool skip_charge)
void HSolverPW<T, Device>::paw_func_in_kloop(const int ik)
{
ModuleBase::TITLE("HSolverPW", "solve");
ModuleBase::timer::tick("HSolverPW", "solve");
// prepare for the precondition of diagonalization
this->precondition.resize(psi.get_nbasis());
this->hamilt_ = pHamilt;
// select the method of diagonalization
this->method = method_in;
// report if the specified diagonalization method is not supported
const std::initializer_list<std::string> _methods = {"cg", "dav", "dav_subspace", "bpcg"};
if (std::find(std::begin(_methods), std::end(_methods), this->method) == std::end(_methods))
{
ModuleBase::WARNING_QUIT("HSolverPW::solve", "This method of DiagH is not supported!");
}

std::vector<Real> eigenvalues(pes->ekb.nr * pes->ekb.nc, 0);

if (this->is_first_scf)
{
is_occupied.resize(psi.get_nk() * psi.get_nbands(), true);
}
else
if (GlobalV::use_paw)
{
if (this->diago_full_acc)
{
is_occupied.assign(is_occupied.size(), true);
}
else
const int npw = this->wfc_basis->npwk[ik];
ModuleBase::Vector3<double>* _gk = new ModuleBase::Vector3<double>[npw];
for (int ig = 0; ig < npw; ig++)
{
for (int i = 0; i < psi.get_nk(); i++)
{
if (pes->klist->wk[i] > 0.0)
{
for (int j = 0; j < psi.get_nbands(); j++)
{
if (pes->wg(i, j) / pes->klist->wk[i] < 0.01)
{
is_occupied[i * psi.get_nbands() + j] = false;
}
}
}
}
_gk[ig] = this->wfc_basis->getgpluskcar(ik, ig);
}
}

/// Loop over k points for solve Hamiltonian to charge density
for (int ik = 0; ik < this->wfc_basis->nks; ++ik)
{
/// update H(k) for each k point
pHamilt->updateHk(ik);
std::vector<double> kpt(3, 0);
kpt[0] = this->wfc_basis->kvec_c[ik].x;
kpt[1] = this->wfc_basis->kvec_c[ik].y;
kpt[2] = this->wfc_basis->kvec_c[ik].z;

#ifdef USE_PAW
if (GlobalV::use_paw)
double** kpg;
double** gcar;
kpg = new double*[npw];
gcar = new double*[npw];
for (int ipw = 0; ipw < npw; ipw++)
{
const int npw = this->wfc_basis->npwk[ik];
ModuleBase::Vector3<double>* _gk = new ModuleBase::Vector3<double>[npw];
for (int ig = 0; ig < npw; ig++)
{
_gk[ig] = this->wfc_basis->getgpluskcar(ik, ig);
}

std::vector<double> kpt(3, 0);
kpt[0] = this->wfc_basis->kvec_c[ik].x;
kpt[1] = this->wfc_basis->kvec_c[ik].y;
kpt[2] = this->wfc_basis->kvec_c[ik].z;

double** kpg;
double** gcar;
kpg = new double*[npw];
gcar = new double*[npw];
for (int ipw = 0; ipw < npw; ipw++)
{
kpg[ipw] = new double[3];
kpg[ipw][0] = _gk[ipw].x;
kpg[ipw][1] = _gk[ipw].y;
kpg[ipw][2] = _gk[ipw].z;

gcar[ipw] = new double[3];
gcar[ipw][0] = this->wfc_basis->getgcar(ik, ipw).x;
gcar[ipw][1] = this->wfc_basis->getgcar(ik, ipw).y;
gcar[ipw][2] = this->wfc_basis->getgcar(ik, ipw).z;
}

GlobalC::paw_cell.set_paw_k(npw,
wfc_basis->npwk_max,
kpt.data(),
this->wfc_basis->get_ig2ix(ik).data(),
this->wfc_basis->get_ig2iy(ik).data(),
this->wfc_basis->get_ig2iz(ik).data(),
(const double**)kpg,
GlobalC::ucell.tpiba,
(const double**)gcar);

std::vector<double>().swap(kpt);
for (int ipw = 0; ipw < npw; ipw++)
{
delete[] kpg[ipw];
delete[] gcar[ipw];
}
delete[] kpg;
delete[] gcar;

GlobalC::paw_cell.get_vkb();

GlobalC::paw_cell.set_currentk(ik);
kpg[ipw] = new double[3];
kpg[ipw][0] = _gk[ipw].x;
kpg[ipw][1] = _gk[ipw].y;
kpg[ipw][2] = _gk[ipw].z;

gcar[ipw] = new double[3];
gcar[ipw][0] = this->wfc_basis->getgcar(ik, ipw).x;
gcar[ipw][1] = this->wfc_basis->getgcar(ik, ipw).y;
gcar[ipw][2] = this->wfc_basis->getgcar(ik, ipw).z;
}
#endif

this->updatePsiK(pHamilt, psi, ik);

// template add precondition calculating here
update_precondition(precondition, ik, this->wfc_basis->npwk[ik]);

#ifdef USE_PAW
GlobalC::paw_cell.set_currentk(ik);
#endif

/// solve eigenvector and eigenvalue for H(k)
this->hamiltSolvePsiK(pHamilt, psi, eigenvalues.data() + ik * pes->ekb.nc);

if (skip_charge)
GlobalC::paw_cell.set_paw_k(npw,
wfc_basis->npwk_max,
kpt.data(),
this->wfc_basis->get_ig2ix(ik).data(),
this->wfc_basis->get_ig2iy(ik).data(),
this->wfc_basis->get_ig2iz(ik).data(),
(const double**)kpg,
GlobalC::ucell.tpiba,
(const double**)gcar);

std::vector<double>().swap(kpt);
for (int ipw = 0; ipw < npw; ipw++)
{
GlobalV::ofs_running << "Average iterative diagonalization steps for k-points " << ik
<< " is: " << DiagoIterAssist<T, Device>::avg_iter
<< " ; where current threshold is: " << DiagoIterAssist<T, Device>::PW_DIAG_THR
<< " . " << std::endl;
DiagoIterAssist<T, Device>::avg_iter = 0.0;
delete[] kpg[ipw];
delete[] gcar[ipw];
}
/// calculate the contribution of Psi for charge density rho
}
// END Loop over k points
delete[] kpg;
delete[] gcar;

base_device::memory::cast_memory_op<double, Real, base_device::DEVICE_CPU, base_device::DEVICE_CPU>()(
cpu_ctx,
cpu_ctx,
pes->ekb.c,
eigenvalues.data(),
pes->ekb.nr * pes->ekb.nc);
GlobalC::paw_cell.get_vkb();

this->is_first_scf = false;

this->endDiagh();
GlobalC::paw_cell.set_currentk(ik);
}
}

if (skip_charge)
template <typename T, typename Device>
void HSolverPW<T, Device>::call_paw_cell_set_currentk(const int ik)
{
if (GlobalV::use_paw)
{
ModuleBase::timer::tick("HSolverPW", "solve");
return;
GlobalC::paw_cell.set_currentk(ik);
}
reinterpret_cast<elecstate::ElecStatePW<T, Device>*>(pes)->psiToRho(psi);
}

#ifdef USE_PAW
template <typename T, typename Device>
void HSolverPW<T, Device>::paw_func_after_kloop(psi::Psi<T, Device>& psi, elecstate::ElecState* pes)
{
if (GlobalV::use_paw)
{
if (typeid(Real) != typeid(double))
Expand Down Expand Up @@ -295,7 +209,117 @@ void HSolverPW<T, Device>::solve(hamilt::Hamilt<T, Device>* pHamilt,
double* nhatgr;
GlobalC::paw_cell.get_nhat(pes->charge->nhat, nhatgr);
}
}

#endif

template <typename T, typename Device>
void HSolverPW<T, Device>::solve(hamilt::Hamilt<T, Device>* pHamilt,
psi::Psi<T, Device>& psi,
elecstate::ElecState* pes,
const std::string method_in,
const bool skip_charge)
{
ModuleBase::TITLE("HSolverPW", "solve");
ModuleBase::timer::tick("HSolverPW", "solve");
// prepare for the precondition of diagonalization
this->precondition.resize(psi.get_nbasis());
this->hamilt_ = pHamilt;
// select the method of diagonalization
this->method = method_in;
// report if the specified diagonalization method is not supported
const std::initializer_list<std::string> _methods = {"cg", "dav", "dav_subspace", "bpcg"};
if (std::find(std::begin(_methods), std::end(_methods), this->method) == std::end(_methods))
{
ModuleBase::WARNING_QUIT("HSolverPW::solve", "This method of DiagH is not supported!");
}

std::vector<Real> eigenvalues(pes->ekb.nr * pes->ekb.nc, 0);

if (this->is_first_scf)
{
is_occupied.resize(psi.get_nk() * psi.get_nbands(), true);
}
else
{
if (this->diago_full_acc)
{
is_occupied.assign(is_occupied.size(), true);
}
else
{
for (int i = 0; i < psi.get_nk(); i++)
{
if (pes->klist->wk[i] > 0.0)
{
for (int j = 0; j < psi.get_nbands(); j++)
{
if (pes->wg(i, j) / pes->klist->wk[i] < 0.01)
{
is_occupied[i * psi.get_nbands() + j] = false;
}
}
}
}
}
}

/// Loop over k points for solve Hamiltonian to charge density
for (int ik = 0; ik < this->wfc_basis->nks; ++ik)
{
/// update H(k) for each k point
pHamilt->updateHk(ik);

#ifdef USE_PAW
this->paw_func_in_kloop(ik);
#endif

this->updatePsiK(pHamilt, psi, ik);

// template add precondition calculating here
update_precondition(precondition, ik, this->wfc_basis->npwk[ik]);

#ifdef USE_PAW
this->call_paw_cell_set_currentk(ik);
#endif

/// solve eigenvector and eigenvalue for H(k)
this->hamiltSolvePsiK(pHamilt, psi, eigenvalues.data() + ik * pes->ekb.nc);

if (skip_charge)
{
GlobalV::ofs_running << "Average iterative diagonalization steps for k-points " << ik
<< " is: " << DiagoIterAssist<T, Device>::avg_iter
<< " ; where current threshold is: " << DiagoIterAssist<T, Device>::PW_DIAG_THR
<< " . " << std::endl;
DiagoIterAssist<T, Device>::avg_iter = 0.0;
}
/// calculate the contribution of Psi for charge density rho
}
// END Loop over k points

base_device::memory::cast_memory_op<double, Real, base_device::DEVICE_CPU, base_device::DEVICE_CPU>()(
cpu_ctx,
cpu_ctx,
pes->ekb.c,
eigenvalues.data(),
pes->ekb.nr * pes->ekb.nc);

this->is_first_scf = false;

this->endDiagh();

if (skip_charge)
{
ModuleBase::timer::tick("HSolverPW", "solve");
return;
}
reinterpret_cast<elecstate::ElecStatePW<T, Device>*>(pes)->psiToRho(psi);

#ifdef USE_PAW
this->paw_func_after_kloop(psi, pes);
#endif

ModuleBase::timer::tick("HSolverPW", "solve");
return;
}
Expand Down
8 changes: 8 additions & 0 deletions source/module_hsolver/hsolver_pw.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,14 @@ class HSolverPW : public HSolver<T, Device>
hamilt::Hamilt<T, Device>* hamilt_ = nullptr;

Device* ctx = {};

#ifdef USE_PAW
void paw_func_in_kloop(const int ik);

void call_paw_cell_set_currentk(const int ik);

void paw_func_after_kloop(psi::Psi<T, Device>& psi, elecstate::ElecState* pes);
#endif
};

template <typename T, typename Device>
Expand Down

0 comments on commit d655c2a

Please sign in to comment.