Skip to content

Commit

Permalink
Merge branch 'develop' into rho-pcie
Browse files Browse the repository at this point in the history
  • Loading branch information
dzzz2001 authored Jul 3, 2024
2 parents b41cf13 + 3905226 commit 805086a
Show file tree
Hide file tree
Showing 26 changed files with 2,087 additions and 4,257 deletions.
3 changes: 3 additions & 0 deletions source/module_esolver/esolver_ks_lcao_elec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -762,6 +762,9 @@ void ESolver_KS_LCAO<TK, TR>::nscf(void) {
this->cal_mag(istep, true);
}

/// write potential
this->create_Output_Potential(0).write();

return;
}

Expand Down
826 changes: 436 additions & 390 deletions source/module_esolver/esolver_ks_pw.cpp

Large diffs are not rendered by default.

640 changes: 330 additions & 310 deletions source/module_hsolver/hsolver_pw.cpp

Large diffs are not rendered by default.

55 changes: 34 additions & 21 deletions source/module_hsolver/hsolver_pw.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,26 +9,26 @@
namespace hsolver {

template <typename T, typename Device = base_device::DEVICE_CPU>
class HSolverPW : public HSolver<T, Device>
{
class HSolverPW : public HSolver<T, Device> {
private:
bool is_first_scf = true;

// Note GetTypeReal<T>::type will
// return T if T is real type(float, double),
// Note GetTypeReal<T>::type will
// return T if T is real type(float, double),
// otherwise return the real type of T(complex<float>, complex<double>)
using Real = typename GetTypeReal<T>::type;

public:
/**
* @brief diago_full_acc
* If .TRUE. all the empty states are diagonalized at the same level of accuracy of the occupied ones.
* Otherwise the empty states are diagonalized using a larger threshold
* (this should not affect total energy, forces, and other ground-state properties).
*
* @brief diago_full_acc
* If .TRUE. all the empty states are diagonalized at the same level of
* accuracy of the occupied ones. Otherwise the empty states are
* diagonalized using a larger threshold (this should not affect total
* energy, forces, and other ground-state properties).
*
*/
static bool diago_full_acc;

HSolverPW(ModulePW::PW_Basis_K* wfc_basis_in, wavefunc* pwf_in);

/*void init(
Expand All @@ -37,7 +37,7 @@ class HSolverPW : public HSolver<T, Device>
) override;
void update(//Input &in
) override;*/

/// @brief solve function for pw
/// @param pHamilt interface to hamilt
/// @param psi reference to psi
Expand All @@ -54,19 +54,25 @@ class HSolverPW : public HSolver<T, Device>
/// @param psi reference to psi
/// @param pes interface to elecstate
/// @param transform transformation matrix between lcao and pw
/// @param skip_charge
/// @param skip_charge
void solve(hamilt::Hamilt<T, Device>* pHamilt,
psi::Psi<T, Device>& psi,
elecstate::ElecState* pes,
psi::Psi<T, Device>& transform,
const bool skip_charge) override;
virtual Real cal_hsolerror() override;
virtual Real set_diagethr(const int istep, const int iter, const Real drho) override;
virtual Real reset_diagethr(std::ofstream& ofs_running, const Real hsover_error, const Real drho) override;
virtual Real
set_diagethr(const int istep, const int iter, const Real drho) override;
virtual Real reset_diagethr(std::ofstream& ofs_running,
const Real hsover_error,
const Real drho) override;

protected:
void initDiagh(const psi::Psi<T, Device>& psi_in);
// void initDiagh(const psi::Psi<T, Device>& psi_in);
void endDiagh();
void hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm, psi::Psi<T, Device>& psi, Real* eigenvalue);
void hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
psi::Psi<T, Device>& psi,
Real* eigenvalue);

void updatePsiK(hamilt::Hamilt<T, Device>* pHamilt,
psi::Psi<T, Device>& psi,
Expand All @@ -76,7 +82,9 @@ class HSolverPW : public HSolver<T, Device>
wavefunc* pwf = nullptr;

// calculate the precondition array for diagonalization in PW base
void update_precondition(std::vector<Real> &h_diag, const int ik, const int npw);
void update_precondition(std::vector<Real>& h_diag,
const int ik,
const int npw);

std::vector<Real> precondition;
std::vector<Real> eigenvalues;
Expand All @@ -86,11 +94,16 @@ class HSolverPW : public HSolver<T, Device>

hamilt::Hamilt<T, Device>* hamilt_ = nullptr;

Device * ctx = {};
using resmem_var_op = base_device::memory::resize_memory_op<Real, base_device::DEVICE_CPU>;
using delmem_var_op = base_device::memory::delete_memory_op<Real, base_device::DEVICE_CPU>;
Device* ctx = {};
using resmem_var_op
= base_device::memory::resize_memory_op<Real, base_device::DEVICE_CPU>;
using delmem_var_op
= base_device::memory::delete_memory_op<Real, base_device::DEVICE_CPU>;
using castmem_2d_2h_op
= base_device::memory::cast_memory_op<double, Real, base_device::DEVICE_CPU, base_device::DEVICE_CPU>;
= base_device::memory::cast_memory_op<double,
Real,
base_device::DEVICE_CPU,
base_device::DEVICE_CPU>;
};

template <typename T, typename Device>
Expand Down
162 changes: 80 additions & 82 deletions source/module_hsolver/hsolver_pw_sdft.cpp
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
#include "hsolver_pw_sdft.h"

#include <algorithm>

#include "module_base/global_function.h"
#include "module_base/timer.h"
#include "module_base/tool_title.h"
#include "module_elecstate/module_charge/symmetry_rho.h"

namespace hsolver
{
#include <algorithm>

namespace hsolver {
void HSolverPW_SDFT::solve(hamilt::Hamilt<std::complex<double>>* pHamilt,
psi::Psi<std::complex<double>>& psi,
elecstate::ElecState* pes,
Expand All @@ -17,28 +16,32 @@ void HSolverPW_SDFT::solve(hamilt::Hamilt<std::complex<double>>* pHamilt,
const int istep,
const int iter,
const std::string method_in,
const bool skip_charge)
{
const bool skip_charge) {
ModuleBase::TITLE(this->classname, "solve");
ModuleBase::timer::tick(this->classname, "solve");
const int npwx = psi.get_nbasis();
const int nbands = psi.get_nbands();
const int nks = psi.get_nk();

this->hamilt_ = pHamilt;
this->hamilt_ = pHamilt;
// prepare for the precondition of diagonalization
this->precondition.resize(psi.get_nbasis());

// select the method of diagonalization
this->method = method_in;
this->initDiagh(psi);
// report if the specified diagonalization method is not supported
const std::initializer_list<std::string> _methods
= {"cg", "dav", "dav_subspace", "bpcg"};
if (std::find(std::begin(_methods), std::end(_methods), this->method)
== std::end(_methods)) {
ModuleBase::WARNING_QUIT("HSolverPW::solve",
"This method of DiagH is not supported!");
}

// part of KSDFT to get KS orbitals
for (int ik = 0; ik < nks; ++ik)
{
for (int ik = 0; ik < nks; ++ik) {
pHamilt->updateHk(ik);
if (nbands > 0 && GlobalV::MY_STOGROUP == 0)
{
if (nbands > 0 && GlobalV::MY_STOGROUP == 0) {
this->updatePsiK(pHamilt, psi, ik);
// template add precondition calculating here
update_precondition(precondition, ik, this->wfc_basis->npwk[ik]);
Expand All @@ -48,82 +51,77 @@ void HSolverPW_SDFT::solve(hamilt::Hamilt<std::complex<double>>* pHamilt,
}

stoiter.stohchi.current_ik = ik;

#ifdef __MPI
if(nbands > 0)
{
MPI_Bcast(&psi(ik,0,0), npwx*nbands, MPI_DOUBLE_COMPLEX , 0, PARAPW_WORLD);
MPI_Bcast(&(pes->ekb(ik, 0)), nbands, MPI_DOUBLE, 0, PARAPW_WORLD);
}
if (nbands > 0) {
MPI_Bcast(&psi(ik, 0, 0),
npwx * nbands,
MPI_DOUBLE_COMPLEX,
0,
PARAPW_WORLD);
MPI_Bcast(&(pes->ekb(ik, 0)), nbands, MPI_DOUBLE, 0, PARAPW_WORLD);
}
#endif
stoiter.orthog(ik,psi,stowf);
stoiter.checkemm(ik,istep, iter, stowf); //check and reset emax & emin
}
stoiter.orthog(ik, psi, stowf);
stoiter.checkemm(ik, istep, iter, stowf); // check and reset emax & emin
}

this->endDiagh();
this->endDiagh();

for (int ik = 0;ik < nks;ik++)
{
//init k
if(nks > 1) pHamilt->updateHk(ik);
stoiter.stohchi.current_ik = ik;
stoiter.calPn(ik, stowf);
}
for (int ik = 0; ik < nks; ik++) {
// init k
if (nks > 1)
pHamilt->updateHk(ik);
stoiter.stohchi.current_ik = ik;
stoiter.calPn(ik, stowf);
}

stoiter.itermu(iter,pes);
stoiter.calHsqrtchi(stowf);
if(skip_charge)
{
ModuleBase::timer::tick(this->classname, "solve");
return;
}
//(5) calculate new charge density
// calculate KS rho.
if(nbands > 0)
{
pes->psiToRho(psi);
stoiter.itermu(iter, pes);
stoiter.calHsqrtchi(stowf);
if (skip_charge) {
ModuleBase::timer::tick(this->classname, "solve");
return;
}
//(5) calculate new charge density
// calculate KS rho.
if (nbands > 0) {
pes->psiToRho(psi);
#ifdef __MPI
MPI_Bcast(&pes->f_en.eband, 1, MPI_DOUBLE, 0, PARAPW_WORLD);
MPI_Bcast(&pes->f_en.eband, 1, MPI_DOUBLE, 0, PARAPW_WORLD);
#endif
}
else
{
for(int is=0; is < GlobalV::NSPIN; is++)
{
ModuleBase::GlobalFunc::ZEROS(pes->charge->rho[is], pes->charge->nrxx);
}
}
// calculate stochastic rho
stoiter.sum_stoband(stowf,pes,pHamilt,wfc_basis);
} else {
for (int is = 0; is < GlobalV::NSPIN; is++) {
ModuleBase::GlobalFunc::ZEROS(pes->charge->rho[is],
pes->charge->nrxx);
}
}
// calculate stochastic rho
stoiter.sum_stoband(stowf, pes, pHamilt, wfc_basis);

//will do rho symmetry and energy calculation in esolver
ModuleBase::timer::tick(this->classname, "solve");
return;
// will do rho symmetry and energy calculation in esolver
ModuleBase::timer::tick(this->classname, "solve");
return;
}

double HSolverPW_SDFT::set_diagethr(const int istep,
const int iter,
const double drho) {
if (iter == 1) {
if (istep == 0) {
if (GlobalV::init_chg == "file") {
this->diag_ethr = 1.0e-5;
}
this->diag_ethr = std::max(this->diag_ethr, GlobalV::PW_DIAG_THR);
} else
this->diag_ethr = std::max(this->diag_ethr, 1.0e-5);
} else {
if (GlobalV::NBANDS > 0 && this->stoiter.KS_ne > 1e-6)
this->diag_ethr
= std::min(this->diag_ethr,
0.1 * drho / std::max(1.0, this->stoiter.KS_ne));
else
this->diag_ethr = 0.0;
}

double HSolverPW_SDFT::set_diagethr(const int istep, const int iter, const double drho)
{
if (iter == 1)
{
if(istep == 0)
{
if (GlobalV::init_chg == "file")
{
this->diag_ethr = 1.0e-5;
}
this->diag_ethr = std::max(this->diag_ethr, GlobalV::PW_DIAG_THR);
}
else
this->diag_ethr = std::max(this->diag_ethr, 1.0e-5);
}
else
{
if(GlobalV::NBANDS > 0 && this->stoiter.KS_ne > 1e-6)
this->diag_ethr = std::min(this->diag_ethr, 0.1 * drho / std::max(1.0, this->stoiter.KS_ne));
else
this->diag_ethr = 0.0;

}
return this->diag_ethr;
}
}
return this->diag_ethr;
}
} // namespace hsolver
Loading

0 comments on commit 805086a

Please sign in to comment.