Skip to content

Commit

Permalink
Sorting out the calculation logic of pexsi in hsolver-lcao (#5299)
Browse files Browse the repository at this point in the history
  • Loading branch information
haozhihan authored Oct 18, 2024
1 parent 7cb7664 commit 255e104
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 47 deletions.
84 changes: 43 additions & 41 deletions source/module_hsolver/hsolver_lcao.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,50 @@ void HSolverLCAO<T, Device>::solve(hamilt::Hamilt<T>* pHamilt,
ModuleBase::TITLE("HSolverLCAO", "solve");
ModuleBase::timer::tick("HSolverLCAO", "solve");

#ifdef __PEXSI // other purification methods should follow this routine
// Zhang Xiaoyang : Please modify Pesxi usage later
if (this->method == "pexsi")
if (this->method != "pexsi")
{
if (GlobalV::KPAR_LCAO > 1
&& (this->method == "genelpa" || this->method == "elpa" || this->method == "scalapack_gvx"))
{
#ifdef __MPI
this->parakSolve(pHamilt, psi, pes, GlobalV::KPAR_LCAO);
#endif
}
else if (GlobalV::KPAR_LCAO == 1)
{
/// Loop over k points for solve Hamiltonian to eigenpairs(eigenvalues and eigenvectors).
for (int ik = 0; ik < psi.get_nk(); ++ik)
{
/// update H(k) for each k point
pHamilt->updateHk(ik);

/// find psi pointer for each k point
psi.fix_k(ik);

/// solve eigenvector and eigenvalue for H(k)
this->hamiltSolvePsiK(pHamilt, psi, &(pes->ekb(ik, 0)));
}
}
else
{
ModuleBase::WARNING_QUIT("HSolverLCAO::solve",
"This method and KPAR setting is not supported for lcao basis in ABACUS!");
}

if (!skip_charge)
{
// used in scf calculation
// calculate charge by eigenpairs(eigenvalues and eigenvectors)
pes->psiToRho(psi);
}
else
{
// used in nscf calculation
}
}
else if (this->method == "pexsi")
{
#ifdef __PEXSI // other purification methods should follow this routine
DiagoPexsi<T> pe(ParaV);
for (int ik = 0; ik < psi.get_nk(); ++ik)
{
Expand All @@ -60,41 +100,7 @@ void HSolverLCAO<T, Device>::solve(hamilt::Hamilt<T>* pHamilt,
pes->f_en.eband = pe.totalFreeEnergy;
// maybe eferm could be dealt with in the future
_pes->dmToRho(pe.DM, pe.EDM);
ModuleBase::timer::tick("HSolverLCAO", "solve");
return;
}
#endif

if (GlobalV::KPAR_LCAO > 1
&& (this->method == "genelpa" || this->method == "elpa" || this->method == "scalapack_gvx"))
{
#ifdef __MPI
this->parakSolve(pHamilt, psi, pes, GlobalV::KPAR_LCAO);
#endif
}
else if (GlobalV::KPAR_LCAO == 1)
{
/// Loop over k points for solve Hamiltonian to eigenpairs(eigenvalues and eigenvectors).
for (int ik = 0; ik < psi.get_nk(); ++ik)
{
/// update H(k) for each k point
pHamilt->updateHk(ik);

/// find psi pointer for each k point
psi.fix_k(ik);

/// solve eigenvector and eigenvalue for H(k)
this->hamiltSolvePsiK(pHamilt, psi, &(pes->ekb(ik, 0)));
}
}

if (!skip_charge) // used in scf calculation
{
// calculate charge by eigenpairs(eigenvalues and eigenvectors)
pes->psiToRho(psi);
}
else // used in nscf calculation
{
}

ModuleBase::timer::tick("HSolverLCAO", "solve");
Expand All @@ -114,7 +120,6 @@ void HSolverLCAO<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T>* hm, psi::Psi<T>&
sa.diag(hm, psi, eigenvalue);
#endif
}

#ifdef __ELPA
else if (this->method == "genelpa")
{
Expand All @@ -127,7 +132,6 @@ void HSolverLCAO<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T>* hm, psi::Psi<T>&
el.diag(hm, psi, eigenvalue);
}
#endif

#ifdef __CUDA
else if (this->method == "cusolver")
{
Expand All @@ -142,15 +146,13 @@ void HSolverLCAO<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T>* hm, psi::Psi<T>&
}
#endif
#endif

#ifndef __MPI
else if (this->method == "lapack") // only for single core
{
DiagoLapack<T> la;
la.diag(hm, psi, eigenvalue);
}
#endif

else
{
ModuleBase::WARNING_QUIT("HSolverLCAO::solve", "This method is not supported for lcao basis in ABACUS!");
Expand Down
8 changes: 2 additions & 6 deletions source/module_hsolver/hsolver_lcao.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,13 @@ class HSolverLCAO
const bool skip_charge);

private:
void hamiltSolvePsiK(hamilt::Hamilt<T>* hm, psi::Psi<T>& psi, double* eigenvalue);
void hamiltSolvePsiK(hamilt::Hamilt<T>* hm, psi::Psi<T>& psi, double* eigenvalue); // for kpar_lcao == 1

void parakSolve(hamilt::Hamilt<T>* pHamilt, psi::Psi<T>& psi, elecstate::ElecState* pes, int kpar);
void parakSolve(hamilt::Hamilt<T>* pHamilt, psi::Psi<T>& psi, elecstate::ElecState* pes, int kpar); // for kpar_lcao > 1

const Parallel_Orbitals* ParaV;

const std::string method;

// for cg_in_lcao
using Real = typename GetTypeReal<T>::type;
std::vector<Real> precondition_lcao;
};

} // namespace hsolver
Expand Down

0 comments on commit 255e104

Please sign in to comment.