Skip to content

Commit

Permalink
Refactor: Make Hsolver_sdft a local variable in hamilt_to_density fun…
Browse files Browse the repository at this point in the history
…ction (#4941)

* Make Hsolver_sdft a temporary obj in hamilt_to_density function

* fix bug

* store emin and emax

* [pre-commit.ci lite] apply automatic fixes

* fix bug in mDFT

* remove useless memory cost

* fix init_psi should be store in sdft

* update results

---------

Co-authored-by: pre-commit-ci-lite[bot] <117423508+pre-commit-ci-lite[bot]@users.noreply.github.com>
  • Loading branch information
Qianruipku and pre-commit-ci-lite[bot] authored Aug 11, 2024
1 parent 111073b commit b7ffff7
Show file tree
Hide file tree
Showing 23 changed files with 870 additions and 630 deletions.
1 change: 1 addition & 0 deletions source/Makefile.Objects
Original file line number Diff line number Diff line change
Expand Up @@ -638,6 +638,7 @@ OBJS_SRCPW=H_Ewald_pw.o\
soc.o\
sto_iter.o\
sto_hchi.o\
sto_che.o\
sto_wf.o\
sto_func.o\
sto_forces.o\
Expand Down
55 changes: 34 additions & 21 deletions source/module_esolver/esolver_sdft_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ namespace ModuleESolver
{

ESolver_SDFT_PW::ESolver_SDFT_PW()
: stoche(PARAM.inp.nche_sto, PARAM.inp.method_sto, PARAM.inp.emax_sto, PARAM.inp.emin_sto)
{
classname = "ESolver_SDFT_PW";
basisname = "PW";
Expand Down Expand Up @@ -92,7 +93,6 @@ void ESolver_SDFT_PW::before_all_runners(const Input_para& inp, UnitCell& ucell)

// 9) initialize the stochastic wave functions
stowf.init(&kv, pw_wfc->npwk_max);

if (inp.nbands_sto != 0)
{
if (inp.initsto_ecut < inp.ecutwfc)
Expand All @@ -108,6 +108,10 @@ void ESolver_SDFT_PW::before_all_runners(const Input_para& inp, UnitCell& ucell)
{
Init_Com_Orbitals(this->stowf);
}
if (this->method_sto == 2)
{
stowf.allocate_chiallorder(this->nche_sto);
}

size_t size = stowf.chi0->size();

Expand All @@ -123,7 +127,8 @@ void ESolver_SDFT_PW::before_all_runners(const Input_para& inp, UnitCell& ucell)
}

// 9) initialize the hsolver
this->phsol = new hsolver::HSolverPW_SDFT(&kv, pw_wfc, &wf, this->stowf, inp.method_sto);
// It should be removed after esolver_ks does not use phsol.
this->phsol = new hsolver::HSolverPW_SDFT(&this->kv, this->pw_wfc, &this->wf, this->stowf, this->stoche, this->init_psi);

return;
}
Expand Down Expand Up @@ -169,21 +174,27 @@ void ESolver_SDFT_PW::hamilt2density(int istep, int iter, double ethr)

hsolver::DiagoIterAssist<std::complex<double>>::PW_DIAG_NMAX = GlobalV::PW_DIAG_NMAX;

this->phsol->solve(this->p_hamilt,
this->psi[0],
this->pelec,
pw_wfc,
this->stowf,
istep,
iter,
GlobalV::KS_SOLVER,

hsolver::DiagoIterAssist<std::complex<double>>::SCF_ITER,
hsolver::DiagoIterAssist<std::complex<double>>::need_subspace,
hsolver::DiagoIterAssist<std::complex<double>>::PW_DIAG_NMAX,
hsolver::DiagoIterAssist<std::complex<double>>::PW_DIAG_THR,

false);
// hsolver only exists in this function
hsolver::HSolverPW_SDFT hsolver_pw_sdft_obj(&this->kv, this->pw_wfc, &this->wf, this->stowf, this->stoche, this->init_psi);

hsolver_pw_sdft_obj.solve(this->p_hamilt,
this->psi[0],
this->pelec,
pw_wfc,
this->stowf,
istep,
iter,
GlobalV::KS_SOLVER,
hsolver::DiagoIterAssist<std::complex<double>>::SCF_ITER,
hsolver::DiagoIterAssist<std::complex<double>>::need_subspace,
hsolver::DiagoIterAssist<std::complex<double>>::PW_DIAG_NMAX,
hsolver::DiagoIterAssist<std::complex<double>>::PW_DIAG_THR,
false);
this->init_psi = true;

// temporary
// set_diagethr need it
((hsolver::HSolverPW_SDFT*)phsol)->set_KS_ne(hsolver_pw_sdft_obj.stoiter.KS_ne);

if (GlobalV::MY_STOGROUP == 0)
{
Expand Down Expand Up @@ -243,16 +254,18 @@ void ESolver_SDFT_PW::after_all_runners()
GlobalV::ofs_running << " --------------------------------------------\n\n" << std::endl;
ModuleIO::write_istate_info(this->pelec->ekb, this->pelec->wg, kv, &(GlobalC::Pkpoints));

((hsolver::HSolverPW_SDFT*)phsol)->stoiter.cleanchiallorder(); // release lots of memories

if (this->method_sto == 2)
{
stowf.clean_chiallorder(); // release lots of memories
}
if (PARAM.inp.out_dos)
{
Sto_DOS sto_dos(this->pw_wfc,
&this->kv,
this->pelec,
this->psi,
this->p_hamilt,
(hsolver::HSolverPW_SDFT*)phsol,
this->stoche,
&stowf);
sto_dos.decide_param(PARAM.inp.dos_nche,
PARAM.inp.emin_sto,
Expand All @@ -275,7 +288,7 @@ void ESolver_SDFT_PW::after_all_runners()
this->psi,
&GlobalC::ppcell,
this->p_hamilt,
(hsolver::HSolverPW_SDFT*)phsol,
this->stoche,
&stowf);
sto_elecond.decide_nche(PARAM.inp.cond_dt, 1e-8, this->nche_sto, PARAM.inp.emin_sto, PARAM.inp.emax_sto);
sto_elecond.sKG(PARAM.inp.cond_smear,
Expand Down
2 changes: 2 additions & 0 deletions source/module_esolver/esolver_sdft_pw.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "module_hamilt_pw/hamilt_stodft/sto_hchi.h"
#include "module_hamilt_pw/hamilt_stodft/sto_iter.h"
#include "module_hamilt_pw/hamilt_stodft/sto_wf.h"
#include "module_hamilt_pw/hamilt_stodft/sto_che.h"

namespace ModuleESolver
{
Expand All @@ -25,6 +26,7 @@ class ESolver_SDFT_PW : public ESolver_KS_PW<std::complex<double>>

public:
Stochastic_WF stowf;
StoChe<double> stoche;

protected:
virtual void before_scf(const int istep) override;
Expand Down
1 change: 1 addition & 0 deletions source/module_hamilt_pw/hamilt_stodft/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
list(APPEND hamilt_stodft_srcs
sto_iter.cpp
sto_hchi.cpp
sto_che.cpp
sto_wf.cpp
sto_func.cpp
sto_forces.cpp
Expand Down
44 changes: 44 additions & 0 deletions source/module_hamilt_pw/hamilt_stodft/sto_che.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#include "sto_che.h"
#include "module_base/blas_connector.h"

template <typename REAL>
StoChe<REAL>::~StoChe()
{
delete p_che;
delete[] spolyv;
}

template <typename REAL>
StoChe<REAL>::StoChe(const int& nche, const int& method, const REAL& emax_sto, const REAL& emin_sto)
{
this->nche = nche;
this->method_sto = method;
p_che = new ModuleBase::Chebyshev<REAL>(nche);
if (method == 1)
{
spolyv = new REAL[nche];
}
else
{
spolyv = new REAL[nche * nche];
}

this->emax_sto = emax_sto;
this->emin_sto = emin_sto;
}

template class StoChe<double>;
// template class StoChe<float>;

double vTMv(const double* v, const double* M, const int n)
{
const char normal = 'N';
const double one = 1;
const int inc = 1;
const double zero = 0;
double* y = new double[n];
dgemv_(&normal, &n, &n, &one, M, &n, v, &inc, &zero, y, &inc);
double result = BlasConnector::dot(n, y, 1, v, 1);
delete[] y;
return result;
}
35 changes: 35 additions & 0 deletions source/module_hamilt_pw/hamilt_stodft/sto_che.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#ifndef STO_CHE_H
#define STO_CHE_H
#include "module_base/math_chebyshev.h"

template <typename REAL>
class StoChe
{
public:
StoChe(const int& nche, const int& method, const REAL& emax_sto, const REAL& emin_sto);
~StoChe();

public:
int nche = 0; ///< order of Chebyshev expansion
REAL* spolyv = nullptr; ///< coefficients of Chebyshev expansion
int method_sto = 0; ///< method for the stochastic calculation

// Chebyshev expansion
// It stores the plan of FFTW and should be initialized at the beginning of the calculation
ModuleBase::Chebyshev<REAL>* p_che = nullptr;

REAL emax_sto = 0.0; ///< maximum energy for normalization
REAL emin_sto = 0.0; ///< minimum energy for normalization
};

/**
* @brief calculate v^T*M*v
*
* @param v v
* @param M M
* @param n the dimension of v
* @return double
*/
double vTMv(const double* v, const double* M, const int n);

#endif
72 changes: 48 additions & 24 deletions source/module_hamilt_pw/hamilt_stodft/sto_dos.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,43 +3,63 @@
#include "module_base/timer.h"
#include "module_base/tool_title.h"
#include "sto_tool.h"
Sto_DOS::~Sto_DOS()
{
}

Sto_DOS::Sto_DOS(ModulePW::PW_Basis_K* p_wfcpw_in, K_Vectors* p_kv_in, elecstate::ElecState* p_elec_in,
psi::Psi<std::complex<double>>* p_psi_in, hamilt::Hamilt<std::complex<double>>* p_hamilt_in,
hsolver::HSolverPW_SDFT* p_hsol_in, Stochastic_WF* p_stowf_in)
Sto_DOS::Sto_DOS(ModulePW::PW_Basis_K* p_wfcpw_in,
K_Vectors* p_kv_in,
elecstate::ElecState* p_elec_in,
psi::Psi<std::complex<double>>* p_psi_in,
hamilt::Hamilt<std::complex<double>>* p_hamilt_in,
StoChe<double>& stoche,
Stochastic_WF* p_stowf_in)
{
this->p_wfcpw = p_wfcpw_in;
this->p_kv = p_kv_in;
this->p_elec = p_elec_in;
this->p_psi = p_psi_in;
this->p_hamilt = p_hamilt_in;
this->p_hsol = p_hsol_in;
this->p_stowf = p_stowf_in;
this->nbands_ks = p_psi_in->get_nbands();
this->nbands_sto = p_stowf_in->nchi;
this->method_sto = stoche.method_sto;
this->stohchi.init(p_wfcpw_in, p_kv_in, &stoche.emin_sto, &stoche.emax_sto);
this->stofunc.set_E_range(&stoche.emin_sto, &stoche.emax_sto);
}
void Sto_DOS::decide_param(const int& dos_nche, const double& emin_sto, const double& emax_sto, const bool& dos_setemin,
const bool& dos_setemax, const double& dos_emin_ev, const double& dos_emax_ev,
void Sto_DOS::decide_param(const int& dos_nche,
const double& emin_sto,
const double& emax_sto,
const bool& dos_setemin,
const bool& dos_setemax,
const double& dos_emin_ev,
const double& dos_emax_ev,
const double& dos_scale)
{
this->dos_nche = dos_nche;
check_che(this->dos_nche, emin_sto, emax_sto, this->nbands_sto, this->p_kv, this->p_stowf, this->p_hamilt,
this->p_hsol);
check_che(this->dos_nche,
emin_sto,
emax_sto,
this->nbands_sto,
this->p_kv,
this->p_stowf,
this->p_hamilt,
this->stohchi);
if (dos_setemax)
{
this->emax = dos_emax_ev;
}
else
{
this->emax = p_hsol->stoiter.stohchi.Emax * ModuleBase::Ry_to_eV;
this->emax = *stohchi.Emax * ModuleBase::Ry_to_eV;
}
if (dos_setemin)
{
this->emin = dos_emin_ev;
}
else
{
this->emin = p_hsol->stoiter.stohchi.Emin * ModuleBase::Ry_to_eV;
this->emin = *stohchi.Emin * ModuleBase::Ry_to_eV;
}

if (!dos_setemax && !dos_setemin)
Expand All @@ -59,13 +79,11 @@ void Sto_DOS::caldos(const double sigmain, const double de, const int npart)
std::cout << "=========================" << std::endl;
ModuleBase::Chebyshev<double> che(dos_nche);
const int nk = p_kv->get_nks();
Stochastic_Iter& stoiter = p_hsol->stoiter;
Stochastic_hchi& stohchi = stoiter.stohchi;
const int npwx = p_wfcpw->npwk_max;

std::vector<double> spolyv;
std::vector<std::complex<double>> allorderchi;
if (stoiter.method == 1)
if (this->method_sto == 1)
{
spolyv.resize(dos_nche, 0);
}
Expand Down Expand Up @@ -99,7 +117,7 @@ void Sto_DOS::caldos(const double sigmain, const double de, const int npart)
p_stowf->chi0->fix_k(ik);
pchi = p_stowf->chi0->get_pointer();
}
if (stoiter.method == 1)
if (this->method_sto == 1)
{
che.tracepolyA(&stohchi, &Stochastic_hchi::hchi_norm, pchi, npw, npwx, nchipk);
for (int i = 0; i < dos_nche; ++i)
Expand All @@ -125,7 +143,12 @@ void Sto_DOS::caldos(const double sigmain, const double de, const int npart)
}
ModuleBase::GlobalFunc::ZEROS(allorderchi.data(), nchipk_new * npwx * dos_nche);
std::complex<double>* tmpchi = pchi + start_nchipk * npwx;
che.calpolyvec_complex(&stohchi, &Stochastic_hchi::hchi_norm, tmpchi, allorderchi.data(), npw, npwx,
che.calpolyvec_complex(&stohchi,
&Stochastic_hchi::hchi_norm,
tmpchi,
allorderchi.data(),
npw,
npwx,
nchipk_new);
double* vec_all = (double*)allorderchi.data();
int LDA = npwx * nchipk_new * 2;
Expand All @@ -140,7 +163,7 @@ void Sto_DOS::caldos(const double sigmain, const double de, const int npart)

std::ofstream ofsdos;
int ndos = int((emax - emin) / de) + 1;
stoiter.stofunc.sigma = sigmain / ModuleBase::Ry_to_eV;
this->stofunc.sigma = sigmain / ModuleBase::Ry_to_eV;
ModuleBase::timer::tick("Sto_DOS", "Tracepoly");

std::cout << "2. Dos:" << std::endl;
Expand All @@ -154,16 +177,16 @@ void Sto_DOS::caldos(const double sigmain, const double de, const int npart)
{
double tmpks = 0;
double tmpsto = 0;
stoiter.stofunc.targ_e = (emin + ie * de) / ModuleBase::Ry_to_eV;
if (stoiter.method == 1)
this->stofunc.targ_e = (emin + ie * de) / ModuleBase::Ry_to_eV;
if (this->method_sto == 1)
{
che.calcoef_real(&stoiter.stofunc, &Sto_Func<double>::ngauss);
che.calcoef_real(&this->stofunc, &Sto_Func<double>::ngauss);
tmpsto = BlasConnector::dot(dos_nche, che.coef_real, 1, spolyv.data(), 1);
}
else
{
che.calcoef_real(&stoiter.stofunc, &Sto_Func<double>::nroot_gauss);
tmpsto = stoiter.vTMv(che.coef_real, spolyv.data(), dos_nche);
che.calcoef_real(&this->stofunc, &Sto_Func<double>::nroot_gauss);
tmpsto = vTMv(che.coef_real, spolyv.data(), dos_nche);
}
if (GlobalV::NBANDS > 0)
{
Expand All @@ -172,14 +195,14 @@ void Sto_DOS::caldos(const double sigmain, const double de, const int npart)
double* en = &(this->p_elec->ekb(ik, 0));
for (int ib = 0; ib < GlobalV::NBANDS; ++ib)
{
tmpks += stoiter.stofunc.gauss(en[ib]) * p_kv->wk[ik] / 2;
tmpks += this->stofunc.gauss(en[ib]) * p_kv->wk[ik] / 2;
}
}
}
tmpks /= GlobalV::NPROC_IN_POOL;

double tmperror = 0;
if (stoiter.method == 1)
if (this->method_sto == 1)
{
tmperror = che.coef_real[dos_nche - 1] * spolyv[dos_nche - 1];
}
Expand Down Expand Up @@ -220,8 +243,9 @@ void Sto_DOS::caldos(const double sigmain, const double de, const int npart)
for (int ie = 0; ie < ndos; ++ie)
{
double tmperror = 2.0 * std::abs(error[ie]);
if (maxerror < tmperror)
if (maxerror < tmperror) {
maxerror = tmperror;
}
double dos = 2.0 * (ks_dos[ie] + sto_dos[ie]) / ModuleBase::Ry_to_eV;
sum += dos;
ofsdos << std::setw(8) << emin + ie * de << std::setw(20) << dos << std::setw(20) << sum * de
Expand Down
Loading

0 comments on commit b7ffff7

Please sign in to comment.