Skip to content

Commit

Permalink
refactor paw func in hsolver
Browse files Browse the repository at this point in the history
  • Loading branch information
haozhihan committed Jul 11, 2024
1 parent e552848 commit 000f7ba
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 162 deletions.
272 changes: 111 additions & 161 deletions source/module_hsolver/hsolver_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,169 +102,9 @@ void HSolverPW<T, Device>::call_paw_cell_set_currentk(const int ik)
}
}

#endif

template <typename T, typename Device>
void HSolverPW<T, Device>::solve(hamilt::Hamilt<T, Device>* pHamilt,
psi::Psi<T, Device>& psi,
elecstate::ElecState* pes,
const std::string method_in,
const bool skip_charge)
void HSolverPW<T, Device>::paw_func_after_kloop(psi::Psi<T, Device>& psi, elecstate::ElecState* pes)
{
ModuleBase::TITLE("HSolverPW", "solve");
ModuleBase::timer::tick("HSolverPW", "solve");
// prepare for the precondition of diagonalization
this->precondition.resize(psi.get_nbasis());
this->hamilt_ = pHamilt;
// select the method of diagonalization
this->method = method_in;
// report if the specified diagonalization method is not supported
const std::initializer_list<std::string> _methods = {"cg", "dav", "dav_subspace", "bpcg"};
if (std::find(std::begin(_methods), std::end(_methods), this->method) == std::end(_methods))
{
ModuleBase::WARNING_QUIT("HSolverPW::solve", "This method of DiagH is not supported!");
}

std::vector<Real> eigenvalues(pes->ekb.nr * pes->ekb.nc, 0);

if (this->is_first_scf)
{
is_occupied.resize(psi.get_nk() * psi.get_nbands(), true);
}
else
{
if (this->diago_full_acc)
{
is_occupied.assign(is_occupied.size(), true);
}
else
{
for (int i = 0; i < psi.get_nk(); i++)
{
if (pes->klist->wk[i] > 0.0)
{
for (int j = 0; j < psi.get_nbands(); j++)
{
if (pes->wg(i, j) / pes->klist->wk[i] < 0.01)
{
is_occupied[i * psi.get_nbands() + j] = false;
}
}
}
}
}
}

/// Loop over k points for solve Hamiltonian to charge density
for (int ik = 0; ik < this->wfc_basis->nks; ++ik)
{
/// update H(k) for each k point
pHamilt->updateHk(ik);

#ifdef USE_PAW
// if (GlobalV::use_paw)
// {
// const int npw = this->wfc_basis->npwk[ik];
// ModuleBase::Vector3<double>* _gk = new ModuleBase::Vector3<double>[npw];
// for (int ig = 0; ig < npw; ig++)
// {
// _gk[ig] = this->wfc_basis->getgpluskcar(ik, ig);
// }

// std::vector<double> kpt(3, 0);
// kpt[0] = this->wfc_basis->kvec_c[ik].x;
// kpt[1] = this->wfc_basis->kvec_c[ik].y;
// kpt[2] = this->wfc_basis->kvec_c[ik].z;

// double** kpg;
// double** gcar;
// kpg = new double*[npw];
// gcar = new double*[npw];
// for (int ipw = 0; ipw < npw; ipw++)
// {
// kpg[ipw] = new double[3];
// kpg[ipw][0] = _gk[ipw].x;
// kpg[ipw][1] = _gk[ipw].y;
// kpg[ipw][2] = _gk[ipw].z;

// gcar[ipw] = new double[3];
// gcar[ipw][0] = this->wfc_basis->getgcar(ik, ipw).x;
// gcar[ipw][1] = this->wfc_basis->getgcar(ik, ipw).y;
// gcar[ipw][2] = this->wfc_basis->getgcar(ik, ipw).z;
// }

// GlobalC::paw_cell.set_paw_k(npw,
// wfc_basis->npwk_max,
// kpt.data(),
// this->wfc_basis->get_ig2ix(ik).data(),
// this->wfc_basis->get_ig2iy(ik).data(),
// this->wfc_basis->get_ig2iz(ik).data(),
// (const double**)kpg,
// GlobalC::ucell.tpiba,
// (const double**)gcar);

// std::vector<double>().swap(kpt);
// for (int ipw = 0; ipw < npw; ipw++)
// {
// delete[] kpg[ipw];
// delete[] gcar[ipw];
// }
// delete[] kpg;
// delete[] gcar;

// GlobalC::paw_cell.get_vkb();

// GlobalC::paw_cell.set_currentk(ik);
// }

this->paw_func_in_kloop(ik);

#endif

this->updatePsiK(pHamilt, psi, ik);

// template add precondition calculating here
update_precondition(precondition, ik, this->wfc_basis->npwk[ik]);

#ifdef USE_PAW
// GlobalC::paw_cell.set_currentk(ik);
this->call_paw_cell_set_currentk(ik);
#endif

/// solve eigenvector and eigenvalue for H(k)
this->hamiltSolvePsiK(pHamilt, psi, eigenvalues.data() + ik * pes->ekb.nc);

if (skip_charge)
{
GlobalV::ofs_running << "Average iterative diagonalization steps for k-points " << ik
<< " is: " << DiagoIterAssist<T, Device>::avg_iter
<< " ; where current threshold is: " << DiagoIterAssist<T, Device>::PW_DIAG_THR
<< " . " << std::endl;
DiagoIterAssist<T, Device>::avg_iter = 0.0;
}
/// calculate the contribution of Psi for charge density rho
}
// END Loop over k points

base_device::memory::cast_memory_op<double, Real, base_device::DEVICE_CPU, base_device::DEVICE_CPU>()(
cpu_ctx,
cpu_ctx,
pes->ekb.c,
eigenvalues.data(),
pes->ekb.nr * pes->ekb.nc);

this->is_first_scf = false;

this->endDiagh();

if (skip_charge)
{
ModuleBase::timer::tick("HSolverPW", "solve");
return;
}
reinterpret_cast<elecstate::ElecStatePW<T, Device>*>(pes)->psiToRho(psi);

#ifdef USE_PAW
if (GlobalV::use_paw)
{
if (typeid(Real) != typeid(double))
Expand Down Expand Up @@ -369,7 +209,117 @@ void HSolverPW<T, Device>::solve(hamilt::Hamilt<T, Device>* pHamilt,
double* nhatgr;
GlobalC::paw_cell.get_nhat(pes->charge->nhat, nhatgr);
}
}

#endif

template <typename T, typename Device>
void HSolverPW<T, Device>::solve(hamilt::Hamilt<T, Device>* pHamilt,
psi::Psi<T, Device>& psi,
elecstate::ElecState* pes,
const std::string method_in,
const bool skip_charge)
{
ModuleBase::TITLE("HSolverPW", "solve");
ModuleBase::timer::tick("HSolverPW", "solve");
// prepare for the precondition of diagonalization
this->precondition.resize(psi.get_nbasis());
this->hamilt_ = pHamilt;
// select the method of diagonalization
this->method = method_in;
// report if the specified diagonalization method is not supported
const std::initializer_list<std::string> _methods = {"cg", "dav", "dav_subspace", "bpcg"};
if (std::find(std::begin(_methods), std::end(_methods), this->method) == std::end(_methods))
{
ModuleBase::WARNING_QUIT("HSolverPW::solve", "This method of DiagH is not supported!");
}

std::vector<Real> eigenvalues(pes->ekb.nr * pes->ekb.nc, 0);

if (this->is_first_scf)
{
is_occupied.resize(psi.get_nk() * psi.get_nbands(), true);
}
else
{
if (this->diago_full_acc)
{
is_occupied.assign(is_occupied.size(), true);
}
else
{
for (int i = 0; i < psi.get_nk(); i++)
{
if (pes->klist->wk[i] > 0.0)
{
for (int j = 0; j < psi.get_nbands(); j++)
{
if (pes->wg(i, j) / pes->klist->wk[i] < 0.01)
{
is_occupied[i * psi.get_nbands() + j] = false;
}
}
}
}
}
}

/// Loop over k points for solve Hamiltonian to charge density
for (int ik = 0; ik < this->wfc_basis->nks; ++ik)
{
/// update H(k) for each k point
pHamilt->updateHk(ik);

#ifdef USE_PAW
this->paw_func_in_kloop(ik);
#endif

this->updatePsiK(pHamilt, psi, ik);

// template add precondition calculating here
update_precondition(precondition, ik, this->wfc_basis->npwk[ik]);

#ifdef USE_PAW
this->call_paw_cell_set_currentk(ik);
#endif

/// solve eigenvector and eigenvalue for H(k)
this->hamiltSolvePsiK(pHamilt, psi, eigenvalues.data() + ik * pes->ekb.nc);

if (skip_charge)
{
GlobalV::ofs_running << "Average iterative diagonalization steps for k-points " << ik
<< " is: " << DiagoIterAssist<T, Device>::avg_iter
<< " ; where current threshold is: " << DiagoIterAssist<T, Device>::PW_DIAG_THR
<< " . " << std::endl;
DiagoIterAssist<T, Device>::avg_iter = 0.0;
}
/// calculate the contribution of Psi for charge density rho
}
// END Loop over k points

base_device::memory::cast_memory_op<double, Real, base_device::DEVICE_CPU, base_device::DEVICE_CPU>()(
cpu_ctx,
cpu_ctx,
pes->ekb.c,
eigenvalues.data(),
pes->ekb.nr * pes->ekb.nc);

this->is_first_scf = false;

this->endDiagh();

if (skip_charge)
{
ModuleBase::timer::tick("HSolverPW", "solve");
return;
}
reinterpret_cast<elecstate::ElecStatePW<T, Device>*>(pes)->psiToRho(psi);

#ifdef USE_PAW
this->paw_func_after_kloop(psi, pes);
#endif

ModuleBase::timer::tick("HSolverPW", "solve");
return;
}
Expand Down
3 changes: 2 additions & 1 deletion source/module_hsolver/hsolver_pw.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,9 @@ class HSolverPW : public HSolver<T, Device>
void paw_func_in_kloop(const int ik);

void call_paw_cell_set_currentk(const int ik);
#endif

void paw_func_after_kloop(psi::Psi<T, Device>& psi, elecstate::ElecState* pes);
#endif
};

template <typename T, typename Device>
Expand Down

0 comments on commit 000f7ba

Please sign in to comment.