From b7ffff7caa794cf7d1cbd1ef65d181f70693fa5c Mon Sep 17 00:00:00 2001 From: Qianrui Liu <76200646+Qianruipku@users.noreply.github.com> Date: Sun, 11 Aug 2024 21:30:39 +0800 Subject: [PATCH] Refactor: Make Hsolver_sdft a local variable in hamilt_to_density function (#4941) * Make Hsolver_sdft a temporary obj in hamilt_to_density function * fix bug * store emin and emax * [pre-commit.ci lite] apply automatic fixes * fix bug in mDFT * remove useless memory cost * fix init_psi should be store in sdft * update results --------- Co-authored-by: pre-commit-ci-lite[bot] <117423508+pre-commit-ci-lite[bot]@users.noreply.github.com> --- source/Makefile.Objects | 1 + source/module_esolver/esolver_sdft_pw.cpp | 55 ++- source/module_esolver/esolver_sdft_pw.h | 2 + .../hamilt_stodft/CMakeLists.txt | 1 + .../hamilt_stodft/sto_che.cpp | 44 ++ .../module_hamilt_pw/hamilt_stodft/sto_che.h | 35 ++ .../hamilt_stodft/sto_dos.cpp | 72 ++- .../module_hamilt_pw/hamilt_stodft/sto_dos.h | 31 +- .../hamilt_stodft/sto_elecond.cpp | 67 ++- .../hamilt_stodft/sto_elecond.h | 5 +- .../hamilt_stodft/sto_func.cpp | 158 +++---- .../module_hamilt_pw/hamilt_stodft/sto_func.h | 17 +- .../hamilt_stodft/sto_hchi.cpp | 350 ++++++++------- .../module_hamilt_pw/hamilt_stodft/sto_hchi.h | 48 +- .../hamilt_stodft/sto_iter.cpp | 411 +++++++++--------- .../module_hamilt_pw/hamilt_stodft/sto_iter.h | 55 ++- .../hamilt_stodft/sto_tool.cpp | 55 ++- .../module_hamilt_pw/hamilt_stodft/sto_tool.h | 21 +- .../module_hamilt_pw/hamilt_stodft/sto_wf.cpp | 16 + .../module_hamilt_pw/hamilt_stodft/sto_wf.h | 9 +- source/module_hsolver/hsolver_pw_sdft.h | 12 +- .../module_hsolver/test/test_hsolver_sdft.cpp | 25 +- .../184_PW_BNDKPAR_SDFT_MALL/result.ref | 10 +- 23 files changed, 870 insertions(+), 630 deletions(-) create mode 100644 source/module_hamilt_pw/hamilt_stodft/sto_che.cpp create mode 100644 source/module_hamilt_pw/hamilt_stodft/sto_che.h diff --git a/source/Makefile.Objects b/source/Makefile.Objects index f34e6fa049..3cdb90f7eb 100644 --- a/source/Makefile.Objects +++ b/source/Makefile.Objects @@ -638,6 +638,7 @@ OBJS_SRCPW=H_Ewald_pw.o\ soc.o\ sto_iter.o\ sto_hchi.o\ + sto_che.o\ sto_wf.o\ sto_func.o\ sto_forces.o\ diff --git a/source/module_esolver/esolver_sdft_pw.cpp b/source/module_esolver/esolver_sdft_pw.cpp index 5a0393803d..c504636771 100644 --- a/source/module_esolver/esolver_sdft_pw.cpp +++ b/source/module_esolver/esolver_sdft_pw.cpp @@ -29,6 +29,7 @@ namespace ModuleESolver { ESolver_SDFT_PW::ESolver_SDFT_PW() + : stoche(PARAM.inp.nche_sto, PARAM.inp.method_sto, PARAM.inp.emax_sto, PARAM.inp.emin_sto) { classname = "ESolver_SDFT_PW"; basisname = "PW"; @@ -92,7 +93,6 @@ void ESolver_SDFT_PW::before_all_runners(const Input_para& inp, UnitCell& ucell) // 9) initialize the stochastic wave functions stowf.init(&kv, pw_wfc->npwk_max); - if (inp.nbands_sto != 0) { if (inp.initsto_ecut < inp.ecutwfc) @@ -108,6 +108,10 @@ void ESolver_SDFT_PW::before_all_runners(const Input_para& inp, UnitCell& ucell) { Init_Com_Orbitals(this->stowf); } + if (this->method_sto == 2) + { + stowf.allocate_chiallorder(this->nche_sto); + } size_t size = stowf.chi0->size(); @@ -123,7 +127,8 @@ void ESolver_SDFT_PW::before_all_runners(const Input_para& inp, UnitCell& ucell) } // 9) initialize the hsolver - this->phsol = new hsolver::HSolverPW_SDFT(&kv, pw_wfc, &wf, this->stowf, inp.method_sto); + // It should be removed after esolver_ks does not use phsol. + this->phsol = new hsolver::HSolverPW_SDFT(&this->kv, this->pw_wfc, &this->wf, this->stowf, this->stoche, this->init_psi); return; } @@ -169,21 +174,27 @@ void ESolver_SDFT_PW::hamilt2density(int istep, int iter, double ethr) hsolver::DiagoIterAssist>::PW_DIAG_NMAX = GlobalV::PW_DIAG_NMAX; - this->phsol->solve(this->p_hamilt, - this->psi[0], - this->pelec, - pw_wfc, - this->stowf, - istep, - iter, - GlobalV::KS_SOLVER, - - hsolver::DiagoIterAssist>::SCF_ITER, - hsolver::DiagoIterAssist>::need_subspace, - hsolver::DiagoIterAssist>::PW_DIAG_NMAX, - hsolver::DiagoIterAssist>::PW_DIAG_THR, - - false); + // hsolver only exists in this function + hsolver::HSolverPW_SDFT hsolver_pw_sdft_obj(&this->kv, this->pw_wfc, &this->wf, this->stowf, this->stoche, this->init_psi); + + hsolver_pw_sdft_obj.solve(this->p_hamilt, + this->psi[0], + this->pelec, + pw_wfc, + this->stowf, + istep, + iter, + GlobalV::KS_SOLVER, + hsolver::DiagoIterAssist>::SCF_ITER, + hsolver::DiagoIterAssist>::need_subspace, + hsolver::DiagoIterAssist>::PW_DIAG_NMAX, + hsolver::DiagoIterAssist>::PW_DIAG_THR, + false); + this->init_psi = true; + + // temporary + // set_diagethr need it + ((hsolver::HSolverPW_SDFT*)phsol)->set_KS_ne(hsolver_pw_sdft_obj.stoiter.KS_ne); if (GlobalV::MY_STOGROUP == 0) { @@ -243,8 +254,10 @@ void ESolver_SDFT_PW::after_all_runners() GlobalV::ofs_running << " --------------------------------------------\n\n" << std::endl; ModuleIO::write_istate_info(this->pelec->ekb, this->pelec->wg, kv, &(GlobalC::Pkpoints)); - ((hsolver::HSolverPW_SDFT*)phsol)->stoiter.cleanchiallorder(); // release lots of memories - + if (this->method_sto == 2) + { + stowf.clean_chiallorder(); // release lots of memories + } if (PARAM.inp.out_dos) { Sto_DOS sto_dos(this->pw_wfc, @@ -252,7 +265,7 @@ void ESolver_SDFT_PW::after_all_runners() this->pelec, this->psi, this->p_hamilt, - (hsolver::HSolverPW_SDFT*)phsol, + this->stoche, &stowf); sto_dos.decide_param(PARAM.inp.dos_nche, PARAM.inp.emin_sto, @@ -275,7 +288,7 @@ void ESolver_SDFT_PW::after_all_runners() this->psi, &GlobalC::ppcell, this->p_hamilt, - (hsolver::HSolverPW_SDFT*)phsol, + this->stoche, &stowf); sto_elecond.decide_nche(PARAM.inp.cond_dt, 1e-8, this->nche_sto, PARAM.inp.emin_sto, PARAM.inp.emax_sto); sto_elecond.sKG(PARAM.inp.cond_smear, diff --git a/source/module_esolver/esolver_sdft_pw.h b/source/module_esolver/esolver_sdft_pw.h index fef542e872..2d2d0231c8 100644 --- a/source/module_esolver/esolver_sdft_pw.h +++ b/source/module_esolver/esolver_sdft_pw.h @@ -5,6 +5,7 @@ #include "module_hamilt_pw/hamilt_stodft/sto_hchi.h" #include "module_hamilt_pw/hamilt_stodft/sto_iter.h" #include "module_hamilt_pw/hamilt_stodft/sto_wf.h" +#include "module_hamilt_pw/hamilt_stodft/sto_che.h" namespace ModuleESolver { @@ -25,6 +26,7 @@ class ESolver_SDFT_PW : public ESolver_KS_PW> public: Stochastic_WF stowf; + StoChe stoche; protected: virtual void before_scf(const int istep) override; diff --git a/source/module_hamilt_pw/hamilt_stodft/CMakeLists.txt b/source/module_hamilt_pw/hamilt_stodft/CMakeLists.txt index ff750f3399..60279531da 100644 --- a/source/module_hamilt_pw/hamilt_stodft/CMakeLists.txt +++ b/source/module_hamilt_pw/hamilt_stodft/CMakeLists.txt @@ -1,6 +1,7 @@ list(APPEND hamilt_stodft_srcs sto_iter.cpp sto_hchi.cpp + sto_che.cpp sto_wf.cpp sto_func.cpp sto_forces.cpp diff --git a/source/module_hamilt_pw/hamilt_stodft/sto_che.cpp b/source/module_hamilt_pw/hamilt_stodft/sto_che.cpp new file mode 100644 index 0000000000..2f404b9c9c --- /dev/null +++ b/source/module_hamilt_pw/hamilt_stodft/sto_che.cpp @@ -0,0 +1,44 @@ +#include "sto_che.h" +#include "module_base/blas_connector.h" + +template +StoChe::~StoChe() +{ + delete p_che; + delete[] spolyv; +} + +template +StoChe::StoChe(const int& nche, const int& method, const REAL& emax_sto, const REAL& emin_sto) +{ + this->nche = nche; + this->method_sto = method; + p_che = new ModuleBase::Chebyshev(nche); + if (method == 1) + { + spolyv = new REAL[nche]; + } + else + { + spolyv = new REAL[nche * nche]; + } + + this->emax_sto = emax_sto; + this->emin_sto = emin_sto; +} + +template class StoChe; +// template class StoChe; + +double vTMv(const double* v, const double* M, const int n) +{ + const char normal = 'N'; + const double one = 1; + const int inc = 1; + const double zero = 0; + double* y = new double[n]; + dgemv_(&normal, &n, &n, &one, M, &n, v, &inc, &zero, y, &inc); + double result = BlasConnector::dot(n, y, 1, v, 1); + delete[] y; + return result; +} \ No newline at end of file diff --git a/source/module_hamilt_pw/hamilt_stodft/sto_che.h b/source/module_hamilt_pw/hamilt_stodft/sto_che.h new file mode 100644 index 0000000000..e2ffc1baca --- /dev/null +++ b/source/module_hamilt_pw/hamilt_stodft/sto_che.h @@ -0,0 +1,35 @@ +#ifndef STO_CHE_H +#define STO_CHE_H +#include "module_base/math_chebyshev.h" + +template +class StoChe +{ + public: + StoChe(const int& nche, const int& method, const REAL& emax_sto, const REAL& emin_sto); + ~StoChe(); + + public: + int nche = 0; ///< order of Chebyshev expansion + REAL* spolyv = nullptr; ///< coefficients of Chebyshev expansion + int method_sto = 0; ///< method for the stochastic calculation + + // Chebyshev expansion + // It stores the plan of FFTW and should be initialized at the beginning of the calculation + ModuleBase::Chebyshev* p_che = nullptr; + + REAL emax_sto = 0.0; ///< maximum energy for normalization + REAL emin_sto = 0.0; ///< minimum energy for normalization +}; + +/** + * @brief calculate v^T*M*v + * + * @param v v + * @param M M + * @param n the dimension of v + * @return double + */ +double vTMv(const double* v, const double* M, const int n); + +#endif \ No newline at end of file diff --git a/source/module_hamilt_pw/hamilt_stodft/sto_dos.cpp b/source/module_hamilt_pw/hamilt_stodft/sto_dos.cpp index a125f5cef8..5495eecf3c 100644 --- a/source/module_hamilt_pw/hamilt_stodft/sto_dos.cpp +++ b/source/module_hamilt_pw/hamilt_stodft/sto_dos.cpp @@ -3,35 +3,55 @@ #include "module_base/timer.h" #include "module_base/tool_title.h" #include "sto_tool.h" +Sto_DOS::~Sto_DOS() +{ +} -Sto_DOS::Sto_DOS(ModulePW::PW_Basis_K* p_wfcpw_in, K_Vectors* p_kv_in, elecstate::ElecState* p_elec_in, - psi::Psi>* p_psi_in, hamilt::Hamilt>* p_hamilt_in, - hsolver::HSolverPW_SDFT* p_hsol_in, Stochastic_WF* p_stowf_in) +Sto_DOS::Sto_DOS(ModulePW::PW_Basis_K* p_wfcpw_in, + K_Vectors* p_kv_in, + elecstate::ElecState* p_elec_in, + psi::Psi>* p_psi_in, + hamilt::Hamilt>* p_hamilt_in, + StoChe& stoche, + Stochastic_WF* p_stowf_in) { this->p_wfcpw = p_wfcpw_in; this->p_kv = p_kv_in; this->p_elec = p_elec_in; this->p_psi = p_psi_in; this->p_hamilt = p_hamilt_in; - this->p_hsol = p_hsol_in; this->p_stowf = p_stowf_in; this->nbands_ks = p_psi_in->get_nbands(); this->nbands_sto = p_stowf_in->nchi; + this->method_sto = stoche.method_sto; + this->stohchi.init(p_wfcpw_in, p_kv_in, &stoche.emin_sto, &stoche.emax_sto); + this->stofunc.set_E_range(&stoche.emin_sto, &stoche.emax_sto); } -void Sto_DOS::decide_param(const int& dos_nche, const double& emin_sto, const double& emax_sto, const bool& dos_setemin, - const bool& dos_setemax, const double& dos_emin_ev, const double& dos_emax_ev, +void Sto_DOS::decide_param(const int& dos_nche, + const double& emin_sto, + const double& emax_sto, + const bool& dos_setemin, + const bool& dos_setemax, + const double& dos_emin_ev, + const double& dos_emax_ev, const double& dos_scale) { this->dos_nche = dos_nche; - check_che(this->dos_nche, emin_sto, emax_sto, this->nbands_sto, this->p_kv, this->p_stowf, this->p_hamilt, - this->p_hsol); + check_che(this->dos_nche, + emin_sto, + emax_sto, + this->nbands_sto, + this->p_kv, + this->p_stowf, + this->p_hamilt, + this->stohchi); if (dos_setemax) { this->emax = dos_emax_ev; } else { - this->emax = p_hsol->stoiter.stohchi.Emax * ModuleBase::Ry_to_eV; + this->emax = *stohchi.Emax * ModuleBase::Ry_to_eV; } if (dos_setemin) { @@ -39,7 +59,7 @@ void Sto_DOS::decide_param(const int& dos_nche, const double& emin_sto, const do } else { - this->emin = p_hsol->stoiter.stohchi.Emin * ModuleBase::Ry_to_eV; + this->emin = *stohchi.Emin * ModuleBase::Ry_to_eV; } if (!dos_setemax && !dos_setemin) @@ -59,13 +79,11 @@ void Sto_DOS::caldos(const double sigmain, const double de, const int npart) std::cout << "=========================" << std::endl; ModuleBase::Chebyshev che(dos_nche); const int nk = p_kv->get_nks(); - Stochastic_Iter& stoiter = p_hsol->stoiter; - Stochastic_hchi& stohchi = stoiter.stohchi; const int npwx = p_wfcpw->npwk_max; std::vector spolyv; std::vector> allorderchi; - if (stoiter.method == 1) + if (this->method_sto == 1) { spolyv.resize(dos_nche, 0); } @@ -99,7 +117,7 @@ void Sto_DOS::caldos(const double sigmain, const double de, const int npart) p_stowf->chi0->fix_k(ik); pchi = p_stowf->chi0->get_pointer(); } - if (stoiter.method == 1) + if (this->method_sto == 1) { che.tracepolyA(&stohchi, &Stochastic_hchi::hchi_norm, pchi, npw, npwx, nchipk); for (int i = 0; i < dos_nche; ++i) @@ -125,7 +143,12 @@ void Sto_DOS::caldos(const double sigmain, const double de, const int npart) } ModuleBase::GlobalFunc::ZEROS(allorderchi.data(), nchipk_new * npwx * dos_nche); std::complex* tmpchi = pchi + start_nchipk * npwx; - che.calpolyvec_complex(&stohchi, &Stochastic_hchi::hchi_norm, tmpchi, allorderchi.data(), npw, npwx, + che.calpolyvec_complex(&stohchi, + &Stochastic_hchi::hchi_norm, + tmpchi, + allorderchi.data(), + npw, + npwx, nchipk_new); double* vec_all = (double*)allorderchi.data(); int LDA = npwx * nchipk_new * 2; @@ -140,7 +163,7 @@ void Sto_DOS::caldos(const double sigmain, const double de, const int npart) std::ofstream ofsdos; int ndos = int((emax - emin) / de) + 1; - stoiter.stofunc.sigma = sigmain / ModuleBase::Ry_to_eV; + this->stofunc.sigma = sigmain / ModuleBase::Ry_to_eV; ModuleBase::timer::tick("Sto_DOS", "Tracepoly"); std::cout << "2. Dos:" << std::endl; @@ -154,16 +177,16 @@ void Sto_DOS::caldos(const double sigmain, const double de, const int npart) { double tmpks = 0; double tmpsto = 0; - stoiter.stofunc.targ_e = (emin + ie * de) / ModuleBase::Ry_to_eV; - if (stoiter.method == 1) + this->stofunc.targ_e = (emin + ie * de) / ModuleBase::Ry_to_eV; + if (this->method_sto == 1) { - che.calcoef_real(&stoiter.stofunc, &Sto_Func::ngauss); + che.calcoef_real(&this->stofunc, &Sto_Func::ngauss); tmpsto = BlasConnector::dot(dos_nche, che.coef_real, 1, spolyv.data(), 1); } else { - che.calcoef_real(&stoiter.stofunc, &Sto_Func::nroot_gauss); - tmpsto = stoiter.vTMv(che.coef_real, spolyv.data(), dos_nche); + che.calcoef_real(&this->stofunc, &Sto_Func::nroot_gauss); + tmpsto = vTMv(che.coef_real, spolyv.data(), dos_nche); } if (GlobalV::NBANDS > 0) { @@ -172,14 +195,14 @@ void Sto_DOS::caldos(const double sigmain, const double de, const int npart) double* en = &(this->p_elec->ekb(ik, 0)); for (int ib = 0; ib < GlobalV::NBANDS; ++ib) { - tmpks += stoiter.stofunc.gauss(en[ib]) * p_kv->wk[ik] / 2; + tmpks += this->stofunc.gauss(en[ib]) * p_kv->wk[ik] / 2; } } } tmpks /= GlobalV::NPROC_IN_POOL; double tmperror = 0; - if (stoiter.method == 1) + if (this->method_sto == 1) { tmperror = che.coef_real[dos_nche - 1] * spolyv[dos_nche - 1]; } @@ -220,8 +243,9 @@ void Sto_DOS::caldos(const double sigmain, const double de, const int npart) for (int ie = 0; ie < ndos; ++ie) { double tmperror = 2.0 * std::abs(error[ie]); - if (maxerror < tmperror) + if (maxerror < tmperror) { maxerror = tmperror; +} double dos = 2.0 * (ks_dos[ie] + sto_dos[ie]) / ModuleBase::Ry_to_eV; sum += dos; ofsdos << std::setw(8) << emin + ie * de << std::setw(20) << dos << std::setw(20) << sum * de diff --git a/source/module_hamilt_pw/hamilt_stodft/sto_dos.h b/source/module_hamilt_pw/hamilt_stodft/sto_dos.h index 407ec3cad1..7941bc6bd9 100644 --- a/source/module_hamilt_pw/hamilt_stodft/sto_dos.h +++ b/source/module_hamilt_pw/hamilt_stodft/sto_dos.h @@ -1,14 +1,24 @@ #ifndef STO_DOS #define STO_DOS +#include "module_elecstate/elecstate.h" +#include "module_hamilt_general/hamilt.h" +#include "module_hamilt_pw/hamilt_stodft/sto_che.h" +#include "module_hamilt_pw/hamilt_stodft/sto_func.h" +#include "module_hamilt_pw/hamilt_stodft/sto_hchi.h" #include "module_hamilt_pw/hamilt_stodft/sto_wf.h" -#include "module_hsolver/hsolver_pw_sdft.h" class Sto_DOS { public: - Sto_DOS(ModulePW::PW_Basis_K* p_wfcpw_in, K_Vectors* p_kv_in, elecstate::ElecState* p_elec_in, - psi::Psi>* p_psi_in, hamilt::Hamilt>* p_hamilt_in, - hsolver::HSolverPW_SDFT* p_hsol_in, Stochastic_WF* p_stowf_in); + Sto_DOS(ModulePW::PW_Basis_K* p_wfcpw_in, + K_Vectors* p_kv_in, + elecstate::ElecState* p_elec_in, + psi::Psi>* p_psi_in, + hamilt::Hamilt>* p_hamilt_in, + StoChe& stoche, + Stochastic_WF* p_stowf_in); + ~Sto_DOS(); + /** * @brief decide the parameters for the DOS calculation * @@ -21,8 +31,13 @@ class Sto_DOS * @param dos_emax_ev Emax input for DOS * @param dos_scale dos_scale input for DOS */ - void decide_param(const int& dos_nche, const double& emin_sto, const double& emax_sto, const bool& dos_setemin, - const bool& dos_setemax, const double& dos_emin_ev, const double& dos_emax_ev, + void decide_param(const int& dos_nche, + const double& emin_sto, + const double& emax_sto, + const bool& dos_setemin, + const bool& dos_setemax, + const double& dos_emin_ev, + const double& dos_emax_ev, const double& dos_scale); /** * @brief Calculate DOS using stochastic wavefunctions @@ -37,6 +52,7 @@ class Sto_DOS int nbands_ks = 0; ///< number of KS bands int nbands_sto = 0; ///< number of stochastic bands int dos_nche = 0; ///< number of Chebyshev orders for DOS + int method_sto = 1; ///< method for sDFT double emax = 0.0; ///< maximum energy double emin = 0.0; ///< minimum energy ModulePW::PW_Basis_K* p_wfcpw = nullptr; ///< pointer to the plane wave basis @@ -44,8 +60,9 @@ class Sto_DOS elecstate::ElecState* p_elec = nullptr; ///< pointer to the electronic state psi::Psi>* p_psi = nullptr; ///< pointer to the wavefunction hamilt::Hamilt>* p_hamilt; ///< pointer to the Hamiltonian - hsolver::HSolverPW_SDFT* p_hsol = nullptr; ///< pointer to the Hamiltonian solver Stochastic_WF* p_stowf = nullptr; ///< pointer to the stochastic wavefunctions + Stochastic_hchi stohchi; ///< stochastic hchi + Sto_Func stofunc; ///< functions }; #endif // STO_DOS \ No newline at end of file diff --git a/source/module_hamilt_pw/hamilt_stodft/sto_elecond.cpp b/source/module_hamilt_pw/hamilt_stodft/sto_elecond.cpp index a3c21f43e5..bb25ef7062 100644 --- a/source/module_hamilt_pw/hamilt_stodft/sto_elecond.cpp +++ b/source/module_hamilt_pw/hamilt_stodft/sto_elecond.cpp @@ -19,15 +19,16 @@ Sto_EleCond::Sto_EleCond(UnitCell* p_ucell_in, psi::Psi>* p_psi_in, pseudopot_cell_vnl* p_ppcell_in, hamilt::Hamilt>* p_hamilt_in, - hsolver::HSolverPW_SDFT* p_hsol_in, + StoChe& stoche, Stochastic_WF* p_stowf_in) : EleCond(p_ucell_in, p_kv_in, p_elec_in, p_wfcpw_in, p_psi_in, p_ppcell_in) { this->p_hamilt = p_hamilt_in; - this->p_hsol = p_hsol_in; this->p_stowf = p_stowf_in; this->nbands_ks = p_psi_in->get_nbands(); this->nbands_sto = p_stowf_in->nchi; + this->stohchi.init(p_wfcpw_in, p_kv_in, &stoche.emin_sto, &stoche.emax_sto); + this->stofunc.set_E_range(&stoche.emin_sto, &stoche.emax_sto); } void Sto_EleCond::decide_nche(const double dt, @@ -38,9 +39,8 @@ void Sto_EleCond::decide_nche(const double dt, { int nche_guess = 1000; ModuleBase::Chebyshev chet(nche_guess); - Stochastic_Iter& stoiter = p_hsol->stoiter; const double mu = this->p_elec->eferm.ef; - stoiter.stofunc.mu = mu; + this->stofunc.mu = mu; int& nbatch = this->cond_dtbatch; // try to find nbatch if (nbatch == 0) @@ -48,8 +48,8 @@ void Sto_EleCond::decide_nche(const double dt, for (int test_nbatch = 128; test_nbatch >= 1; test_nbatch /= 2) { nbatch = test_nbatch; - stoiter.stofunc.t = 0.5 * dt * nbatch; - chet.calcoef_pair(&stoiter.stofunc, &Sto_Func::ncos, &Sto_Func::n_sin); + this->stofunc.t = 0.5 * dt * nbatch; + chet.calcoef_pair(&this->stofunc, &Sto_Func::ncos, &Sto_Func::n_sin); double minerror = std::abs(chet.coef_complex[nche_guess - 1] / chet.coef_complex[0]); if (minerror < cond_thr) { @@ -70,9 +70,9 @@ void Sto_EleCond::decide_nche(const double dt, } // first try to find nche - stoiter.stofunc.t = 0.5 * dt * nbatch; + this->stofunc.t = 0.5 * dt * nbatch; auto getnche = [&](int& nche) { - chet.calcoef_pair(&stoiter.stofunc, &Sto_Func::ncos, &Sto_Func::n_sin); + chet.calcoef_pair(&this->stofunc, &Sto_Func::ncos, &Sto_Func::n_sin); for (int i = 1; i < nche_guess; ++i) { double error = std::abs(chet.coef_complex[i] / chet.coef_complex[0]); @@ -96,7 +96,7 @@ void Sto_EleCond::decide_nche(const double dt, this->p_kv, this->p_stowf, this->p_hamilt, - this->p_hsol); + this->stohchi); // second try to find nche with new Emin & Emax getnche(nche_new); @@ -104,8 +104,8 @@ void Sto_EleCond::decide_nche(const double dt, if (nche_new > nche_old * 2) { nche_old = nche_new; - try_emin = stoiter.stohchi.Emin; - try_emax = stoiter.stohchi.Emax; + try_emin = *stohchi.Emin; + try_emax = *stohchi.Emax; goto loop; } @@ -169,14 +169,13 @@ void Sto_EleCond::cal_jmatrix(const psi::Psi>& kspsi_all, const int allbands_sto = bandinfo[4]; const int allbands = bandinfo[5]; const int dim_jmatrix = perbands_ks * allbands_sto + perbands_sto * allbands; - Stochastic_Iter& stoiter = p_hsol->stoiter; psi::Psi> right_hchi(1, perbands_sto, npwx, p_kv->ngk.data()); psi::Psi> f_rightchi(1, perbands_sto, npwx, p_kv->ngk.data()); psi::Psi> f_right_hchi(1, perbands_sto, npwx, p_kv->ngk.data()); - stoiter.stohchi.hchi(leftchi.get_pointer(), left_hchi.get_pointer(), perbands_sto); - stoiter.stohchi.hchi(rightchi.get_pointer(), right_hchi.get_pointer(), perbands_sto); + this->stohchi.hchi(leftchi.get_pointer(), left_hchi.get_pointer(), perbands_sto); + this->stohchi.hchi(rightchi.get_pointer(), right_hchi.get_pointer(), perbands_sto); convert_psi(rightchi, f_rightchi); convert_psi(right_hchi, f_right_hchi); right_hchi.resize(1, 1, 1); @@ -533,8 +532,6 @@ void Sto_EleCond::sKG(const int& smear_type, ModuleBase::Chebyshev che(fd_nche); ModuleBase::Chebyshev chet(cond_nche); ModuleBase::Chebyshev chemt(cond_nche); - Stochastic_Iter& stoiter = p_hsol->stoiter; - Stochastic_hchi& stohchi = stoiter.stohchi; //------------------------------------------------------------------ // Calculate @@ -542,10 +539,10 @@ void Sto_EleCond::sKG(const int& smear_type, // Prepare Chebyshev coefficients for exp(-i H/\hbar t) const double mu = this->p_elec->eferm.ef; - stoiter.stofunc.mu = mu; - stoiter.stofunc.t = 0.5 * dt * nbatch; - chet.calcoef_pair(&stoiter.stofunc, &Sto_Func::ncos, &Sto_Func::nsin); - chemt.calcoef_pair(&stoiter.stofunc, &Sto_Func::ncos, &Sto_Func::n_sin); + this->stofunc.mu = mu; + this->stofunc.t = 0.5 * dt * nbatch; + chet.calcoef_pair(&this->stofunc, &Sto_Func::ncos, &Sto_Func::nsin); + chemt.calcoef_pair(&this->stofunc, &Sto_Func::ncos, &Sto_Func::n_sin); std::vector> batchcoef, batchmcoef; if (nbatch > 1) { @@ -562,16 +559,16 @@ void Sto_EleCond::sKG(const int& smear_type, { tmpcoef = batchcoef.data() + ib * cond_nche; tmpmcoef = batchmcoef.data() + ib * cond_nche; - stoiter.stofunc.t = 0.5 * dt * (ib + 1); - chet.calcoef_pair(&stoiter.stofunc, &Sto_Func::ncos, &Sto_Func::nsin); - chemt.calcoef_pair(&stoiter.stofunc, &Sto_Func::ncos, &Sto_Func::n_sin); + this->stofunc.t = 0.5 * dt * (ib + 1); + chet.calcoef_pair(&this->stofunc, &Sto_Func::ncos, &Sto_Func::nsin); + chemt.calcoef_pair(&this->stofunc, &Sto_Func::ncos, &Sto_Func::n_sin); for (int i = 0; i < cond_nche; ++i) { tmpcoef[i] = chet.coef_complex[i]; tmpmcoef[i] = chemt.coef_complex[i]; } } - stoiter.stofunc.t = 0.5 * dt * nbatch; + this->stofunc.t = 0.5 * dt * nbatch; } // ik loop @@ -586,14 +583,14 @@ void Sto_EleCond::sKG(const int& smear_type, { this->p_hamilt->updateHk(ik); } - stoiter.stohchi.current_ik = ik; + this->stohchi.current_ik = ik; const int npw = p_kv->ngk[ik]; // get allbands_ks int cutib0 = 0; if (this->nbands_ks > 0) { - double Emax_KS = std::max(stoiter.stofunc.Emin, this->p_elec->ekb(ik, this->nbands_ks - 1)); + double Emax_KS = std::max(*this->stofunc.Emin, this->p_elec->ekb(ik, this->nbands_ks - 1)); for (cutib0 = this->nbands_ks - 1; cutib0 >= 0; --cutib0) { if (Emax_KS - this->p_elec->ekb(ik, cutib0) > dEcut) @@ -602,17 +599,17 @@ void Sto_EleCond::sKG(const int& smear_type, } } ++cutib0; - double Emin_KS = (cutib0 < this->nbands_ks) ? this->p_elec->ekb(ik, cutib0) : stoiter.stofunc.Emin; - double dE = stoiter.stofunc.Emax - Emin_KS + wcut / ModuleBase::Ry_to_eV; + double Emin_KS = (cutib0 < this->nbands_ks) ? this->p_elec->ekb(ik, cutib0) : *this->stofunc.Emin; + double dE = *this->stofunc.Emax - Emin_KS + wcut / ModuleBase::Ry_to_eV; std::cout << "Emin_KS(" << cutib0 + 1 << "): " << Emin_KS * ModuleBase::Ry_to_eV - << " eV; Emax: " << stoiter.stofunc.Emax * ModuleBase::Ry_to_eV + << " eV; Emax: " << *this->stofunc.Emax * ModuleBase::Ry_to_eV << " eV; Recommended max dt: " << 2 * M_PI / dE << " a.u." << std::endl; } else { - double dE = stoiter.stofunc.Emax - stoiter.stofunc.Emin + wcut / ModuleBase::Ry_to_eV; - std::cout << "Emin: " << stoiter.stofunc.Emin * ModuleBase::Ry_to_eV - << " eV; Emax: " << stoiter.stofunc.Emax * ModuleBase::Ry_to_eV + double dE = *this->stofunc.Emax - *this->stofunc.Emin + wcut / ModuleBase::Ry_to_eV; + std::cout << "Emin: " << *this->stofunc.Emin * ModuleBase::Ry_to_eV + << " eV; Emax: " << *this->stofunc.Emax * ModuleBase::Ry_to_eV << " eV; Recommended max dt: " << 2 * M_PI / dE << " a.u." << std::endl; } // Parallel for bands @@ -712,7 +709,7 @@ void Sto_EleCond::sKG(const int& smear_type, { kspsi(0, ib, ig) = p_psi[0](ib0_ks + ib, ig); } - double fi = stoiter.stofunc.fd(en[ib]); + double fi = this->stofunc.fd(en[ib]); expmtmf_fact[ib] = 1 - fi; expmtf_fact[ib] = fi; } @@ -728,7 +725,7 @@ void Sto_EleCond::sKG(const int& smear_type, vkspsi.resize(1, 1, 1); } - che.calcoef_real(&stoiter.stofunc, &Sto_Func::nroot_fd); + che.calcoef_real(&this->stofunc, &Sto_Func::nroot_fd); che.calfinalvec_real(&stohchi, &Stochastic_hchi::hchi_norm, stopsi->get_pointer(), @@ -737,7 +734,7 @@ void Sto_EleCond::sKG(const int& smear_type, npwx, perbands_sto); - che.calcoef_real(&stoiter.stofunc, &Sto_Func::nroot_mfd); + che.calcoef_real(&this->stofunc, &Sto_Func::nroot_mfd); che.calfinalvec_real(&stohchi, &Stochastic_hchi::hchi_norm, diff --git a/source/module_hamilt_pw/hamilt_stodft/sto_elecond.h b/source/module_hamilt_pw/hamilt_stodft/sto_elecond.h index 6f06aca238..a7a4e0f7a6 100644 --- a/source/module_hamilt_pw/hamilt_stodft/sto_elecond.h +++ b/source/module_hamilt_pw/hamilt_stodft/sto_elecond.h @@ -16,7 +16,7 @@ class Sto_EleCond : protected EleCond psi::Psi>* p_psi_in, pseudopot_cell_vnl* p_ppcell_in, hamilt::Hamilt>* p_hamilt_in, - hsolver::HSolverPW_SDFT* p_hsol_in, + StoChe& stoche, Stochastic_WF* p_stowf_in); ~Sto_EleCond(){}; /** @@ -59,8 +59,9 @@ class Sto_EleCond : protected EleCond int fd_nche = 0; ///< number of Chebyshev orders for Fermi-Dirac function int cond_dtbatch = 0; ///< number of time steps in a batch hamilt::Hamilt>* p_hamilt; ///< pointer to the Hamiltonian - hsolver::HSolverPW_SDFT* p_hsol = nullptr; ///< pointer to the Hamiltonian solver Stochastic_WF* p_stowf = nullptr; ///< pointer to the stochastic wavefunctions + Stochastic_hchi stohchi; ///< stochastic hchi + Sto_Func stofunc; ///< functions protected: /** diff --git a/source/module_hamilt_pw/hamilt_stodft/sto_func.cpp b/source/module_hamilt_pw/hamilt_stodft/sto_func.cpp index 0ad3f7f0fd..581e5ab7ef 100644 --- a/source/module_hamilt_pw/hamilt_stodft/sto_func.cpp +++ b/source/module_hamilt_pw/hamilt_stodft/sto_func.cpp @@ -1,72 +1,80 @@ #include "sto_func.h" + #include "module_elecstate/occupy.h" #define TWOPI 6.283185307179586 -template +template Sto_Func::Sto_Func() { this->tem = Occupy::gaussian_parameter; } -template -REAL Sto_Func:: root_fd(REAL e) +template +void Sto_Func::set_E_range(REAL* Emin_in, REAL* Emax_in) +{ + Emin = Emin_in; + Emax = Emax_in; +} + +template +REAL Sto_Func::root_fd(REAL e) { - REAL e_mu = (e - mu) / this->tem ; - if(e_mu > 72) + REAL e_mu = (e - mu) / this->tem; + if (e_mu > 72) return 0; else return 1 / sqrt(1 + exp(e_mu)); } -template -REAL Sto_Func:: nroot_fd(REAL e) +template +REAL Sto_Func::nroot_fd(REAL e) { - REAL Ebar = (Emin + Emax)/2; - REAL DeltaE = (Emax - Emin)/2; - REAL ne_mu = (e * DeltaE + Ebar - mu) / this->tem ; - if(ne_mu > 72) + REAL Ebar = (*Emin + *Emax) / 2; + REAL DeltaE = (*Emax - *Emin) / 2; + REAL ne_mu = (e * DeltaE + Ebar - mu) / this->tem; + if (ne_mu > 72) return 0; else return 1 / sqrt(1 + exp(ne_mu)); } -template -REAL Sto_Func:: fd(REAL e) +template +REAL Sto_Func::fd(REAL e) { - REAL e_mu = (e - mu) / this->tem ; - if(e_mu > 36) + REAL e_mu = (e - mu) / this->tem; + if (e_mu > 36) return 0; else return 1 / (1 + exp(e_mu)); } -template -REAL Sto_Func:: nfd(REAL e) +template +REAL Sto_Func::nfd(REAL e) { - REAL Ebar = (Emin + Emax)/2; - REAL DeltaE = (Emax - Emin)/2; - REAL ne_mu = (e * DeltaE + Ebar - mu) / this->tem ; - if(ne_mu > 36) + REAL Ebar = (*Emin + *Emax) / 2; + REAL DeltaE = (*Emax - *Emin) / 2; + REAL ne_mu = (e * DeltaE + Ebar - mu) / this->tem; + if (ne_mu > 36) return 0; else return 1 / (1 + exp(ne_mu)); } -template -REAL Sto_Func:: nxfd(REAL rawe) +template +REAL Sto_Func::nxfd(REAL rawe) { - REAL Ebar = (Emin + Emax)/2; - REAL DeltaE = (Emax - Emin)/2; + REAL Ebar = (*Emin + *Emax) / 2; + REAL DeltaE = (*Emax - *Emin) / 2; REAL e = rawe * DeltaE + Ebar; - REAL ne_mu = (e - mu) / this->tem ; - if(ne_mu > 36) + REAL ne_mu = (e - mu) / this->tem; + if (ne_mu > 36) return 0; else return e / (1 + exp(ne_mu)); } -template -REAL Sto_Func:: fdlnfd(REAL e) +template +REAL Sto_Func::fdlnfd(REAL e) { REAL e_mu = (e - mu) / this->tem; if (e_mu > 36) @@ -83,11 +91,11 @@ REAL Sto_Func:: fdlnfd(REAL e) } } -template -REAL Sto_Func:: nfdlnfd(REAL rawe) +template +REAL Sto_Func::nfdlnfd(REAL rawe) { - REAL Ebar = (Emin + Emax) / 2; - REAL DeltaE = (Emax - Emin) / 2; + REAL Ebar = (*Emin + *Emax) / 2; + REAL DeltaE = (*Emax - *Emin) / 2; REAL ne_mu = (rawe * DeltaE + Ebar - mu) / this->tem; if (ne_mu > 36) return 0; @@ -103,15 +111,15 @@ REAL Sto_Func:: nfdlnfd(REAL rawe) } } -template -REAL Sto_Func:: n_root_fdlnfd(REAL rawe) +template +REAL Sto_Func::n_root_fdlnfd(REAL rawe) { - REAL Ebar = (Emin + Emax)/2; - REAL DeltaE = (Emax - Emin)/2; - REAL ne_mu = (rawe * DeltaE + Ebar - mu) / this->tem ; - if(ne_mu > 36) + REAL Ebar = (*Emin + *Emax) / 2; + REAL DeltaE = (*Emax - *Emin) / 2; + REAL ne_mu = (rawe * DeltaE + Ebar - mu) / this->tem; + if (ne_mu > 36) return 0; - else if(ne_mu < -36) + else if (ne_mu < -36) return 0; else { @@ -119,86 +127,86 @@ REAL Sto_Func:: n_root_fdlnfd(REAL rawe) if (f == 0 || f == 1) return 0; else - return sqrt(-f * log(f) - (1-f) * log(1-f)); + return sqrt(-f * log(f) - (1 - f) * log(1 - f)); } } -template +template REAL Sto_Func::nroot_mfd(REAL rawe) { - REAL Ebar = (Emin + Emax)/2; - REAL DeltaE = (Emax - Emin)/2; - REAL ne_mu = (rawe * DeltaE + Ebar - mu) / this->tem ; - if(ne_mu < -72) + REAL Ebar = (*Emin + *Emax) / 2; + REAL DeltaE = (*Emax - *Emin) / 2; + REAL ne_mu = (rawe * DeltaE + Ebar - mu) / this->tem; + if (ne_mu < -72) return 0; else return 1 / sqrt(1 + exp(-ne_mu)); } -template -REAL Sto_Func:: ncos(REAL rawe) +template +REAL Sto_Func::ncos(REAL rawe) { - REAL Ebar = (Emin + Emax)/2; - REAL DeltaE = (Emax - Emin)/2; + REAL Ebar = (*Emin + *Emax) / 2; + REAL DeltaE = (*Emax - *Emin) / 2; REAL e = rawe * DeltaE + Ebar; return cos(e * t); } -template -REAL Sto_Func:: nsin(REAL rawe) +template +REAL Sto_Func::nsin(REAL rawe) { - REAL Ebar = (Emin + Emax)/2; - REAL DeltaE = (Emax - Emin)/2; + REAL Ebar = (*Emin + *Emax) / 2; + REAL DeltaE = (*Emax - *Emin) / 2; REAL e = rawe * DeltaE + Ebar; return sin(e * t); } -template -REAL Sto_Func:: n_sin(REAL rawe) +template +REAL Sto_Func::n_sin(REAL rawe) { - REAL Ebar = (Emin + Emax)/2; - REAL DeltaE = (Emax - Emin)/2; + REAL Ebar = (*Emin + *Emax) / 2; + REAL DeltaE = (*Emax - *Emin) / 2; REAL e = rawe * DeltaE + Ebar; return -sin(e * t); } -template +template REAL Sto_Func::gauss(REAL e) { - REAL a = pow((targ_e-e),2)/2.0/pow(sigma,2); - if(a > 72) + REAL a = pow((targ_e - e), 2) / 2.0 / pow(sigma, 2); + if (a > 72) return 0; else - return exp(-a) /sqrt(TWOPI) / sigma ; + return exp(-a) / sqrt(TWOPI) / sigma; } -template +template REAL Sto_Func::ngauss(REAL rawe) { - REAL Ebar = (Emin + Emax)/2; - REAL DeltaE = (Emax - Emin)/2; + REAL Ebar = (*Emin + *Emax) / 2; + REAL DeltaE = (*Emax - *Emin) / 2; REAL e = rawe * DeltaE + Ebar; - REAL a = pow((targ_e-e),2)/2.0/pow(sigma,2); - if(a > 72) + REAL a = pow((targ_e - e), 2) / 2.0 / pow(sigma, 2); + if (a > 72) return 0; else - return exp(-a) /sqrt(TWOPI) / sigma ; + return exp(-a) / sqrt(TWOPI) / sigma; } -template +template REAL Sto_Func::nroot_gauss(REAL rawe) { - REAL Ebar = (Emin + Emax)/2; - REAL DeltaE = (Emax - Emin)/2; + REAL Ebar = (*Emin + *Emax) / 2; + REAL DeltaE = (*Emax - *Emin) / 2; REAL e = rawe * DeltaE + Ebar; - REAL a = pow((targ_e-e),2)/4.0/pow(sigma,2); - if(a > 72) + REAL a = pow((targ_e - e), 2) / 4.0 / pow(sigma, 2); + if (a > 72) return 0; else - return exp(-a) /sqrt(sqrt(TWOPI) * sigma) ; + return exp(-a) / sqrt(sqrt(TWOPI) * sigma); } -//we only have two examples: double and float. +// we only have two examples: double and float. template class Sto_Func; #ifdef __ENABLE_FLOAT_FFTW template class Sto_Func; diff --git a/source/module_hamilt_pw/hamilt_stodft/sto_func.h b/source/module_hamilt_pw/hamilt_stodft/sto_func.h index 3d521d3964..2eb67dd50e 100644 --- a/source/module_hamilt_pw/hamilt_stodft/sto_func.h +++ b/source/module_hamilt_pw/hamilt_stodft/sto_func.h @@ -4,14 +4,16 @@ template class Sto_Func { -public: + public: Sto_Func(); ~Sto_Func(){}; - REAL tem; //temperature - REAL mu; //chemical potential - REAL Emin, Emax; + REAL tem; // temperature + REAL mu; // chemical potential + REAL* Emin = nullptr; + REAL* Emax = nullptr; + void set_E_range(REAL* Emin_in, REAL* Emax_in); -public: + public: REAL root_fd(REAL e); REAL fd(REAL e); REAL nroot_fd(REAL e); @@ -22,19 +24,18 @@ class Sto_Func REAL n_root_fdlnfd(REAL e); REAL nroot_mfd(REAL e); -public: + public: REAL t; REAL ncos(REAL e); REAL nsin(REAL e); REAL n_sin(REAL e); -public: + public: REAL sigma; REAL targ_e; REAL gauss(REAL e); REAL ngauss(REAL e); REAL nroot_gauss(REAL e); - }; #endif \ No newline at end of file diff --git a/source/module_hamilt_pw/hamilt_stodft/sto_hchi.cpp b/source/module_hamilt_pw/hamilt_stodft/sto_hchi.cpp index f402630072..29a337f1ab 100644 --- a/source/module_hamilt_pw/hamilt_stodft/sto_hchi.cpp +++ b/source/module_hamilt_pw/hamilt_stodft/sto_hchi.cpp @@ -1,176 +1,214 @@ -#include "module_hamilt_pw/hamilt_pwdft/global.h" -#include "sto_hchi.h" -#include "module_base/tool_title.h" -#include "module_base/timer.h" +#include "sto_hchi.h" + #include "module_base/parallel_reduce.h" +#include "module_base/timer.h" +#include "module_base/tool_title.h" #include "module_esolver/esolver_sdft_pw.h" - +#include "module_hamilt_pw/hamilt_pwdft/global.h" Stochastic_hchi::Stochastic_hchi() { - Emin = PARAM.inp.emin_sto; - Emax = PARAM.inp.emax_sto; } Stochastic_hchi::~Stochastic_hchi() { } -void Stochastic_hchi:: init(ModulePW::PW_Basis_K* wfc_basis, K_Vectors* pkv_in) +void Stochastic_hchi::init(ModulePW::PW_Basis_K* wfc_basis, K_Vectors* pkv_in, double* Emin_in, double* Emax_in) { - wfcpw = wfc_basis; - pkv = pkv_in; + wfcpw = wfc_basis; + pkv = pkv_in; + Emin = Emin_in; + Emax = Emax_in; } - -void Stochastic_hchi:: hchi(complex *chig, complex *hchig, const int m) +void Stochastic_hchi::hchi(complex* chig, complex* hchig, const int m) { - - - //--------------------------------------------------- - - const int ik = this->current_ik; - const int current_spin = pkv->isk[ik]; - const int npwx = this->wfcpw->npwk_max; - const int npw = this->wfcpw->npwk[ik]; - const int npm = GlobalV::NPOL * m; - const int inc = 1; - const double tpiba2 = GlobalC::ucell.tpiba2; - const int nrxx = this->wfcpw->nrxx; - //------------------------------------ - //(1) the kinetical energy. - //------------------------------------ - complex *chibg = chig; - complex *hchibg = hchig; - if(PARAM.inp.t_in_h) - { - for (int ib = 0; ib < m ; ++ib) - { - for (int ig = 0; ig < npw; ++ig) - { - hchibg[ig] = this->wfcpw->getgk2(ik,ig) * tpiba2 * chibg[ig]; - } - chibg += npwx; - hchibg += npwx; - } - } - - //------------------------------------ - //(2) the local potential. - //------------------------------------ - ModuleBase::timer::tick("Stochastic_hchi","vloc"); - std::complex* porter = new std::complex[nrxx]; - if(PARAM.inp.vl_in_h) - { - chibg = chig; - hchibg = hchig; - const double* pveff = &((*GlobalTemp::veff)(current_spin, 0)); - for(int ib = 0 ; ib < m ; ++ib) - { - this->wfcpw->recip2real(chibg, porter, ik); - for (int ir=0; ir< nrxx; ir++) - { - porter[ir] *= pveff[ir]; - } - this->wfcpw->real2recip(porter, hchibg, ik, true); - - chibg += npwx; - hchibg += npwx; - } - - } - delete[] porter; - ModuleBase::timer::tick("Stochastic_hchi","vloc"); - - - //------------------------------------ - // (3) the nonlocal pseudopotential. - //------------------------------------ - ModuleBase::timer::tick("Stochastic_hchi","vnl"); - if(PARAM.inp.vnl_in_h) - { - if ( GlobalC::ppcell.nkb > 0) - { - int nkb = GlobalC::ppcell.nkb; - complex *becp = new complex[ nkb * GlobalV::NPOL * m ]; - char transc = 'C'; - char transn = 'N'; - char transt = 'T'; - if(m==1 && GlobalV::NPOL ==1) - { - zgemv_(&transc, &npw, &nkb, &ModuleBase::ONE, GlobalC::ppcell.vkb.c, &npwx, chig, &inc, &ModuleBase::ZERO, becp, &inc); - } - else - { - zgemm_(&transc,&transn,&nkb,&npm,&npw,&ModuleBase::ONE,GlobalC::ppcell.vkb.c,&npwx,chig,&npwx,&ModuleBase::ZERO,becp,&nkb); - } + + //--------------------------------------------------- + + const int ik = this->current_ik; + const int current_spin = pkv->isk[ik]; + const int npwx = this->wfcpw->npwk_max; + const int npw = this->wfcpw->npwk[ik]; + const int npm = GlobalV::NPOL * m; + const int inc = 1; + const double tpiba2 = GlobalC::ucell.tpiba2; + const int nrxx = this->wfcpw->nrxx; + //------------------------------------ + //(1) the kinetical energy. + //------------------------------------ + complex* chibg = chig; + complex* hchibg = hchig; + if (PARAM.inp.t_in_h) + { + for (int ib = 0; ib < m; ++ib) + { + for (int ig = 0; ig < npw; ++ig) + { + hchibg[ig] = this->wfcpw->getgk2(ik, ig) * tpiba2 * chibg[ig]; + } + chibg += npwx; + hchibg += npwx; + } + } + + //------------------------------------ + //(2) the local potential. + //------------------------------------ + ModuleBase::timer::tick("Stochastic_hchi", "vloc"); + std::complex* porter = new std::complex[nrxx]; + if (PARAM.inp.vl_in_h) + { + chibg = chig; + hchibg = hchig; + const double* pveff = &((*GlobalTemp::veff)(current_spin, 0)); + for (int ib = 0; ib < m; ++ib) + { + this->wfcpw->recip2real(chibg, porter, ik); + for (int ir = 0; ir < nrxx; ir++) + { + porter[ir] *= pveff[ir]; + } + this->wfcpw->real2recip(porter, hchibg, ik, true); + + chibg += npwx; + hchibg += npwx; + } + } + delete[] porter; + ModuleBase::timer::tick("Stochastic_hchi", "vloc"); + + //------------------------------------ + // (3) the nonlocal pseudopotential. + //------------------------------------ + ModuleBase::timer::tick("Stochastic_hchi", "vnl"); + if (PARAM.inp.vnl_in_h) + { + if (GlobalC::ppcell.nkb > 0) + { + int nkb = GlobalC::ppcell.nkb; + complex* becp = new complex[nkb * GlobalV::NPOL * m]; + char transc = 'C'; + char transn = 'N'; + char transt = 'T'; + if (m == 1 && GlobalV::NPOL == 1) + { + zgemv_(&transc, + &npw, + &nkb, + &ModuleBase::ONE, + GlobalC::ppcell.vkb.c, + &npwx, + chig, + &inc, + &ModuleBase::ZERO, + becp, + &inc); + } + else + { + zgemm_(&transc, + &transn, + &nkb, + &npm, + &npw, + &ModuleBase::ONE, + GlobalC::ppcell.vkb.c, + &npwx, + chig, + &npwx, + &ModuleBase::ZERO, + becp, + &nkb); + } Parallel_Reduce::reduce_pool(becp, nkb * GlobalV::NPOL * m); - complex *Ps = new complex [nkb * GlobalV::NPOL * m]; - ModuleBase::GlobalFunc::ZEROS( Ps, GlobalV::NPOL * m * nkb); - - int sum = 0; - int iat = 0; - for (int it=0; it* Ps = new complex[nkb * GlobalV::NPOL * m]; + ModuleBase::GlobalFunc::ZEROS(Ps, GlobalV::NPOL * m * nkb); + + int sum = 0; + int iat = 0; + for (int it = 0; it < GlobalC::ucell.ntype; it++) + { + const int Nprojs = GlobalC::ucell.atoms[it].ncpp.nh; + for (int ia = 0; ia < GlobalC::ucell.atoms[it].na; ia++) + { + // each atom has Nprojs, means this is with structure factor; + // each projector (each atom) must multiply coefficient + // with all the other projectors. + for (int ip = 0; ip < Nprojs; ip++) + { + for (int ip2 = 0; ip2 < Nprojs; ip2++) + { + for (int ib = 0; ib < m; ++ib) + { + Ps[(sum + ip2) * m + ib] + += GlobalC::ppcell.deeq(current_spin, iat, ip, ip2) * becp[ib * nkb + sum + ip]; + } // end ib + } // end ih + } // end jh + sum += Nprojs; + ++iat; + } // end na + } // end nt + + if (GlobalV::NPOL == 1 && m == 1) + { + zgemv_(&transn, + &npw, + &nkb, + &ModuleBase::ONE, + GlobalC::ppcell.vkb.c, + &npwx, + Ps, + &inc, + &ModuleBase::ONE, + hchig, + &inc); + } + else + { + zgemm_(&transn, + &transt, + &npw, + &npm, + &nkb, + &ModuleBase::ONE, + GlobalC::ppcell.vkb.c, + &npwx, + Ps, + &npm, + &ModuleBase::ONE, + hchig, + &npwx); + } + + delete[] becp; + delete[] Ps; + } + } + ModuleBase::timer::tick("Stochastic_hchi", "vnl"); + + return; } -void Stochastic_hchi:: hchi_norm(complex *chig, complex *hchig, const int m) +void Stochastic_hchi::hchi_norm(complex* chig, complex* hchig, const int m) { - ModuleBase::timer::tick("Stochastic_hchi","hchi_norm"); - - this->hchi(chig,hchig,m); - - const int ik = this->current_ik; - const int npwx = this->wfcpw->npwk_max; - const int npw = this->wfcpw->npwk[ik]; - const double Ebar = (Emin + Emax)/2; - const double DeltaE = (Emax - Emin)/2; - for(int ib = 0 ; ib < m ; ++ib) - { - for(int ig = 0; ig < npw; ++ig) - { - hchig[ib*npwx+ig] = (hchig[ib*npwx+ig] - Ebar * chig[ib*npwx+ig]) / DeltaE; - } - } - ModuleBase::timer::tick("Stochastic_hchi","hchi_norm"); + ModuleBase::timer::tick("Stochastic_hchi", "hchi_norm"); + + this->hchi(chig, hchig, m); + + const int ik = this->current_ik; + const int npwx = this->wfcpw->npwk_max; + const int npw = this->wfcpw->npwk[ik]; + const double Ebar = (*Emin + *Emax) / 2; + const double DeltaE = (*Emax - *Emin) / 2; + for (int ib = 0; ib < m; ++ib) + { + for (int ig = 0; ig < npw; ++ig) + { + hchig[ib * npwx + ig] = (hchig[ib * npwx + ig] - Ebar * chig[ib * npwx + ig]) / DeltaE; + } + } + ModuleBase::timer::tick("Stochastic_hchi", "hchi_norm"); } \ No newline at end of file diff --git a/source/module_hamilt_pw/hamilt_stodft/sto_hchi.h b/source/module_hamilt_pw/hamilt_stodft/sto_hchi.h index 44a3f4f790..2f195fc7c4 100644 --- a/source/module_hamilt_pw/hamilt_stodft/sto_hchi.h +++ b/source/module_hamilt_pw/hamilt_stodft/sto_hchi.h @@ -16,42 +16,34 @@ class Stochastic_hchi { - public: - + public: // constructor and deconstructor Stochastic_hchi(); ~Stochastic_hchi(); - void init(ModulePW::PW_Basis_K* wfc_basis, K_Vectors* pkv); - - - - double Emin; - double Emax; + void init(ModulePW::PW_Basis_K* wfc_basis, K_Vectors* pkv, double* Emin_in, double* Emax_in); - void orthogonal_to_psi_reciprocal( - std::complex* wfin, - std::complex *wfout, - const int& ikk); //wfin & wfout are wavefunctions in reciprocal space - void hchi( - std::complex *wfin, - std::complex *wfout, - const int m = 1); //wfin & wfout are wavefunctions in reciprocal space + double* Emin = nullptr; + double* Emax = nullptr; - void hchi_norm( - std::complex *wfin, - std::complex *wfout, - const int m = 1); //wfin & wfout are wavefunctions in reciprocal space + void orthogonal_to_psi_reciprocal(std::complex* wfin, + std::complex* wfout, + const int& ikk); // wfin & wfout are wavefunctions in reciprocal space + void hchi(std::complex* wfin, + std::complex* wfout, + const int m = 1); // wfin & wfout are wavefunctions in reciprocal space - public: - - int current_ik = 0; - ModulePW::PW_Basis_K* wfcpw = nullptr; - K_Vectors* pkv = nullptr; + void hchi_norm(std::complex* wfin, + std::complex* wfout, + const int m = 1); // wfin & wfout are wavefunctions in reciprocal space - // chi should be orthogonal to psi (generated by diaganolization methods, - // such as CG) + public: + int current_ik = 0; + ModulePW::PW_Basis_K* wfcpw = nullptr; + K_Vectors* pkv = nullptr; + // chi should be orthogonal to psi (generated by diaganolization methods, + // such as CG) }; -#endif// Eelectrons_hchi +#endif // Eelectrons_hchi diff --git a/source/module_hamilt_pw/hamilt_stodft/sto_iter.cpp b/source/module_hamilt_pw/hamilt_stodft/sto_iter.cpp index c4185181e2..b07b80803a 100644 --- a/source/module_hamilt_pw/hamilt_stodft/sto_iter.cpp +++ b/source/module_hamilt_pw/hamilt_stodft/sto_iter.cpp @@ -1,26 +1,11 @@ #include "sto_iter.h" -#include "module_base/blas_connector.h" + +#include "module_base/parallel_reduce.h" #include "module_base/timer.h" #include "module_base/tool_quit.h" #include "module_base/tool_title.h" -#include "module_base/parallel_reduce.h" -#include "module_base/blas_connector.h" -#include "module_hamilt_pw/hamilt_pwdft/global.h" #include "module_elecstate/occupy.h" - -double Stochastic_Iter::vTMv(const double *v, const double * M, const int n) -{ - const char normal = 'N'; - const double one = 1; - const int inc = 1; - const double zero = 0; - double *y = new double [n]; - dgemv_(&normal,&n,&n,&one,M,&n,v,&inc,&zero,y,&inc); - double result = BlasConnector::dot(n,y,1,v,1); - delete[] y; - return result; -} - +#include "module_hamilt_pw/hamilt_pwdft/global.h" Stochastic_Iter::Stochastic_Iter() { @@ -31,37 +16,22 @@ Stochastic_Iter::Stochastic_Iter() Stochastic_Iter::~Stochastic_Iter() { - delete p_che; - delete[] spolyv; - delete[] chiallorder; } -void Stochastic_Iter::init(const int method_in, K_Vectors* pkv_in, ModulePW::PW_Basis_K* wfc_basis, Stochastic_WF& stowf) +void Stochastic_Iter::init(K_Vectors* pkv_in, + ModulePW::PW_Basis_K* wfc_basis, + Stochastic_WF& stowf, + StoChe& stoche) { - p_che = new ModuleBase::Chebyshev(PARAM.inp.nche_sto); + p_che = stoche.p_che; + spolyv = stoche.spolyv; nchip = stowf.nchip; targetne = GlobalV::nelec; this->pkv = pkv_in; - stohchi.init(wfc_basis, pkv); - delete[] spolyv; - const int norder = p_che->norder; - const int nks = wfc_basis->nks; - this->method = method_in; - if(method == 1) { spolyv = new double [norder]; - } else { spolyv = new double [norder*norder]; -} - stofunc.Emin = PARAM.inp.emin_sto; - stofunc.Emax = PARAM.inp.emax_sto; - - if(this->method == 2) - { - this->chiallorder = new ModuleBase::ComplexMatrix[stowf.nks]; - for (int ik = 0; ik < nks; ++ik) - { - const int npwx = stowf.npwx; - chiallorder[ik].create(nchip[ik] * npwx, norder,true); - } - } + this->method = stoche.method_sto; + + this->stohchi.init(wfc_basis, pkv, &stoche.emin_sto, &stoche.emax_sto); + this->stofunc.set_E_range(&stoche.emin_sto, &stoche.emax_sto); } void Stochastic_Iter::orthog(const int& ik, psi::Psi>& psi, Stochastic_WF& stowf) @@ -70,54 +40,80 @@ void Stochastic_Iter::orthog(const int& ik, psi::Psi>& psi, // orthogonal part if (GlobalV::NBANDS > 0) { - const int nchipk=stowf.nchip[ik]; - const int npw = psi.get_current_nbas(); - const int npwx = psi.get_nbasis(); + const int nchipk = stowf.nchip[ik]; + const int npw = psi.get_current_nbas(); + const int npwx = psi.get_nbasis(); stowf.chi0->fix_k(ik); stowf.chiortho->fix_k(ik); - std::complex *wfgin = stowf.chi0->get_pointer(), *wfgout = stowf.chiortho->get_pointer(); - for(int ig = 0 ; ig < npwx * nchipk; ++ig) - { - wfgout[ig] = wfgin[ig]; - } - - //orthogonal part - std::complex *sum = new std::complex [GlobalV::NBANDS * nchipk]; - char transC='C'; - char transN='N'; - - //sum(b - zgemm_(&transC, &transN, &GlobalV::NBANDS, &nchipk, &npw, &ModuleBase::ONE, - &psi(ik,0,0), &npwx, wfgout, &npwx, &ModuleBase::ZERO, sum, &GlobalV::NBANDS); + std::complex*wfgin = stowf.chi0->get_pointer(), *wfgout = stowf.chiortho->get_pointer(); + for (int ig = 0; ig < npwx * nchipk; ++ig) + { + wfgout[ig] = wfgin[ig]; + } + + // orthogonal part + std::complex* sum = new std::complex[GlobalV::NBANDS * nchipk]; + char transC = 'C'; + char transN = 'N'; + + // sum(b + zgemm_(&transC, + &transN, + &GlobalV::NBANDS, + &nchipk, + &npw, + &ModuleBase::ONE, + &psi(ik, 0, 0), + &npwx, + wfgout, + &npwx, + &ModuleBase::ZERO, + sum, + &GlobalV::NBANDS); Parallel_Reduce::reduce_pool(sum, GlobalV::NBANDS * nchipk); - - //psi -= psi * sum - zgemm_(&transN, &transN, &npw, &nchipk, &GlobalV::NBANDS, &ModuleBase::NEG_ONE, - &psi(ik,0,0), &npwx, sum, &GlobalV::NBANDS, &ModuleBase::ONE, wfgout, &npwx); - delete[] sum; + + // psi -= psi * sum + zgemm_(&transN, + &transN, + &npw, + &nchipk, + &GlobalV::NBANDS, + &ModuleBase::NEG_ONE, + &psi(ik, 0, 0), + &npwx, + sum, + &GlobalV::NBANDS, + &ModuleBase::ONE, + wfgout, + &npwx); + delete[] sum; } } void Stochastic_Iter::checkemm(const int& ik, const int istep, const int iter, Stochastic_WF& stowf) { - ModuleBase::TITLE("Stochastic_Iter","checkemm"); - //iter = 1,2,... istep = 0,1,2,... - // if( istep%PARAM.inp.initsto_freq != 0 ) return; + ModuleBase::TITLE("Stochastic_Iter", "checkemm"); + // iter = 1,2,... istep = 0,1,2,... + // if( istep%PARAM.inp.initsto_freq != 0 ) return; const int npw = stowf.ngk[ik]; const int nks = stowf.nks; - if(istep == 0) + if (istep == 0) { - if(iter > 5) { return; -} + if (iter > 5) + { + return; + } } else { - if(iter > 1) { return; -} + if (iter > 1) + { + return; + } } - + const int norder = p_che->norder; - std::complex * pchi; + std::complex* pchi; int ntest = 1; if (nchip[ik] < ntest) @@ -138,15 +134,16 @@ void Stochastic_Iter::checkemm(const int& ik, const int istep, const int iter, S while (true) { bool converge; - converge = p_che->checkconverge( - &stohchi, &Stochastic_hchi::hchi_norm, - pchi, npw, - stohchi.Emax, - stohchi.Emin, - 5.0); - - if(!converge) - { + converge = p_che->checkconverge(&stohchi, + &Stochastic_hchi::hchi_norm, + pchi, + npw, + *stohchi.Emax, + *stohchi.Emin, + 5.0); + + if (!converge) + { change = true; } else @@ -158,15 +155,14 @@ void Stochastic_Iter::checkemm(const int& ik, const int istep, const int iter, S if (ik == nks - 1) { #ifdef __MPI - MPI_Allreduce(MPI_IN_PLACE, &stohchi.Emax, 1, MPI_DOUBLE, MPI_MAX, MPI_COMM_WORLD); - MPI_Allreduce(MPI_IN_PLACE, &stohchi.Emin, 1, MPI_DOUBLE, MPI_MIN, MPI_COMM_WORLD); + MPI_Allreduce(MPI_IN_PLACE, stohchi.Emax, 1, MPI_DOUBLE, MPI_MAX, MPI_COMM_WORLD); + MPI_Allreduce(MPI_IN_PLACE, stohchi.Emin, 1, MPI_DOUBLE, MPI_MIN, MPI_COMM_WORLD); MPI_Allreduce(MPI_IN_PLACE, &change, 1, MPI_CHAR, MPI_LOR, MPI_COMM_WORLD); #endif - stofunc.Emax = stohchi.Emax; - stofunc.Emin = stohchi.Emin; if (change) { - GlobalV::ofs_running << "New Emax Ry" << stohchi.Emax << " ; new Emin " << stohchi.Emin <<" Ry" << std::endl; + GlobalV::ofs_running << "New Emax Ry" << *stohchi.Emax << " ; new Emin " << *stohchi.Emin << " Ry" + << std::endl; } change = false; } @@ -175,43 +171,46 @@ void Stochastic_Iter::checkemm(const int& ik, const int istep, const int iter, S void Stochastic_Iter::check_precision(const double ref, const double thr, const std::string info) { //============================== - //precision check + // precision check //============================== double error = 0; - if(this->method == 1) + if (this->method == 1) { - error = p_che->coef_real[p_che->norder-1] * spolyv[p_che->norder-1]; + error = p_che->coef_real[p_che->norder - 1] * spolyv[p_che->norder - 1]; } else { const int norder = p_che->norder; - double last_coef = p_che->coef_real[norder-1]; - double last_spolyv = spolyv[norder*norder - 1]; - error = last_coef *(BlasConnector::dot(norder,p_che->coef_real,1,spolyv+norder*(norder-1),1) - + BlasConnector::dot(norder,p_che->coef_real,1,spolyv+norder-1,norder)-last_coef*last_spolyv); + double last_coef = p_che->coef_real[norder - 1]; + double last_spolyv = spolyv[norder * norder - 1]; + error = last_coef + * (BlasConnector::dot(norder, p_che->coef_real, 1, spolyv + norder * (norder - 1), 1) + + BlasConnector::dot(norder, p_che->coef_real, 1, spolyv + norder - 1, norder) + - last_coef * last_spolyv); } #ifdef __MPI - MPI_Allreduce(MPI_IN_PLACE, &error, 1, MPI_DOUBLE, MPI_SUM , MPI_COMM_WORLD); + MPI_Allreduce(MPI_IN_PLACE, &error, 1, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD); #endif - double relative_error = std::abs(error/ref); - GlobalV::ofs_running< thr) + double relative_error = std::abs(error / ref); + GlobalV::ofs_running << info << "Relative Chebyshev Precision: " << relative_error * 1e9 << "E-09" << std::endl; + if (relative_error > thr) { std::stringstream ss; - ss<>fractxt; + ss << relative_error; + std::string fractxt, tartxt; + ss >> fractxt; ss.clear(); - ss<>tartxt; - std::string warningtxt = "( "+info+" relative Chebyshev error = "+fractxt+" > threshold = "+tartxt+" ) Maybe you should increase the parameter \"nche_sto\" for more accuracy."; + ss << thr; + ss >> tartxt; + std::string warningtxt = "( " + info + " relative Chebyshev error = " + fractxt + " > threshold = " + tartxt + + " ) Maybe you should increase the parameter \"nche_sto\" for more accuracy."; ModuleBase::WARNING("Stochastic_Chebychev", warningtxt); } //=============================== } -void Stochastic_Iter::itermu(const int iter, elecstate::ElecState* pes) +void Stochastic_Iter::itermu(const int iter, elecstate::ElecState* pes) { ModuleBase::TITLE("Stochastic_Iter", "itermu"); ModuleBase::timer::tick("Stochastic_Iter", "itermu"); @@ -219,13 +218,13 @@ void Stochastic_Iter::itermu(const int iter, elecstate::ElecState* pes) if (iter == 1) { dmu = 2; - th_ne = 0.1 *PARAM.inp.scf_thr * GlobalV::nelec; + th_ne = 0.1 * PARAM.inp.scf_thr * GlobalV::nelec; // std::cout<<"th_ne "<stofunc.mu = mu0 - dmu; @@ -278,18 +277,19 @@ void Stochastic_Iter::itermu(const int iter, elecstate::ElecState* pes) { std::cout << "Fermi energy cannot be converged. Set THNE to " << th_ne << std::endl; th_ne *= 1e1; - if (th_ne > 1e1) { + if (th_ne > 1e1) + { ModuleBase::WARNING_QUIT("Stochastic_Iter", "Cannot converge feimi energy. Please retry with different random number"); -} + } } } pes->eferm.ef = this->stofunc.mu = mu0 = mu3; - GlobalV::ofs_running<<"Converge fermi energy = "<check_precision(targetne,10*PARAM.inp.scf_thr,"Ne"); - - //Set wf.wg - if(GlobalV::NBANDS > 0) + GlobalV::ofs_running << "Converge fermi energy = " << mu3 << " Ry in " << count << " steps." << std::endl; + this->check_precision(targetne, 10 * PARAM.inp.scf_thr, "Ne"); + + // Set wf.wg + if (GlobalV::NBANDS > 0) { for (int ikk = 0; ikk < this->pkv->get_nks(); ++ikk) { @@ -313,70 +313,79 @@ void Stochastic_Iter::calPn(const int& ik, Stochastic_WF& stowf) const int nchip_ik = nchip[ik]; const int npw = stowf.ngk[ik]; const int npwx = stowf.npwx; - if(ik==0) + if (ik == 0) { - if(this->method == 1) { + if (this->method == 1) + { ModuleBase::GlobalFunc::ZEROS(spolyv, norder); - } else { - ModuleBase::GlobalFunc::ZEROS(spolyv, norder*norder); -} + } + else + { + ModuleBase::GlobalFunc::ZEROS(spolyv, norder * norder); + } } - std::complex * pchi; - if(GlobalV::NBANDS > 0) + std::complex* pchi; + if (GlobalV::NBANDS > 0) { stowf.chiortho->fix_k(ik); pchi = stowf.chiortho->get_pointer(); - } + } else - { + { stowf.chi0->fix_k(ik); pchi = stowf.chi0->get_pointer(); } - - if(this->method == 1) + + if (this->method == 1) { p_che->tracepolyA(&stohchi, &Stochastic_hchi::hchi_norm, pchi, npw, npwx, nchip_ik); - for(int i = 0 ; i < norder ; ++i) + for (int i = 0; i < norder; ++i) { spolyv[i] += p_che->polytrace[i] * this->pkv->wk[ik]; } } else { - p_che->calpolyvec_complex(&stohchi, &Stochastic_hchi::hchi_norm, pchi, this->chiallorder[ik].c, npw, npwx, nchip_ik); - double* vec_all= (double *) this->chiallorder[ik].c; + p_che->calpolyvec_complex(&stohchi, + &Stochastic_hchi::hchi_norm, + pchi, + stowf.chiallorder[ik].c, + npw, + npwx, + nchip_ik); + double* vec_all = (double*)stowf.chiallorder[ik].c; char trans = 'T'; char normal = 'N'; double one = 1; int LDA = npwx * nchip_ik * 2; - int M = npwx * nchip_ik * 2; //Do not use kv.ngk[ik] + int M = npwx * nchip_ik * 2; // Do not use kv.ngk[ik] int N = norder; double kweight = this->pkv->wk[ik]; - dgemm_(&trans,&normal, &N,&N,&M,&kweight,vec_all,&LDA,vec_all,&LDA,&one,spolyv,&N); + dgemm_(&trans, &normal, &N, &N, &M, &kweight, vec_all, &LDA, vec_all, &LDA, &one, spolyv, &N); } ModuleBase::timer::tick("Stochastic_Iter", "calPn"); return; } double Stochastic_Iter::calne(elecstate::ElecState* pes) -{ - ModuleBase::timer::tick("Stochastic_Iter","calne"); +{ + ModuleBase::timer::tick("Stochastic_Iter", "calne"); double totne = 0; KS_ne = 0; const int norder = p_che->norder; double sto_ne; - if(this->method == 1) + if (this->method == 1) { - //Note: spolyv contains kv.wk[ik] - p_che->calcoef_real(&stofunc,&Sto_Func::nfd); - sto_ne = BlasConnector::dot(norder,p_che->coef_real,1,spolyv,1); + // Note: spolyv contains kv.wk[ik] + p_che->calcoef_real(&stofunc, &Sto_Func::nfd); + sto_ne = BlasConnector::dot(norder, p_che->coef_real, 1, spolyv, 1); } else { - p_che->calcoef_real(&stofunc,&Sto_Func::nroot_fd); - sto_ne = vTMv(p_che->coef_real,spolyv,norder); + p_che->calcoef_real(&stofunc, &Sto_Func::nroot_fd); + sto_ne = vTMv(p_che->coef_real, spolyv, norder); } - if(GlobalV::NBANDS > 0) + if (GlobalV::NBANDS > 0) { for (int ikk = 0; ikk < this->pkv->get_nks(); ++ikk) { @@ -389,8 +398,8 @@ double Stochastic_Iter::calne(elecstate::ElecState* pes) } KS_ne /= GlobalV::NPROC_IN_POOL; #ifdef __MPI - MPI_Allreduce(MPI_IN_PLACE, &KS_ne, 1, MPI_DOUBLE, MPI_SUM , STO_WORLD); - MPI_Allreduce(MPI_IN_PLACE, &sto_ne, 1, MPI_DOUBLE, MPI_SUM , MPI_COMM_WORLD); + MPI_Allreduce(MPI_IN_PLACE, &KS_ne, 1, MPI_DOUBLE, MPI_SUM, STO_WORLD); + MPI_Allreduce(MPI_IN_PLACE, &sto_ne, 1, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD); #endif totne = KS_ne + sto_ne; @@ -400,20 +409,21 @@ double Stochastic_Iter::calne(elecstate::ElecState* pes) void Stochastic_Iter::calHsqrtchi(Stochastic_WF& stowf) { - p_che->calcoef_real(&stofunc,&Sto_Func::nroot_fd); - for(int ik = 0; ik < this->pkv->get_nks(); ++ik) + p_che->calcoef_real(&stofunc, &Sto_Func::nroot_fd); + for (int ik = 0; ik < this->pkv->get_nks(); ++ik) { - //init k - if(this->pkv->get_nks() > 1) + // init k + if (this->pkv->get_nks() > 1) { - if(GlobalC::ppcell.nkb > 0 && (GlobalV::BASIS_TYPE=="pw" || GlobalV::BASIS_TYPE=="lcao_in_pw")) //xiaohui add 2013-09-02. Attention... + if (GlobalC::ppcell.nkb > 0 + && (GlobalV::BASIS_TYPE == "pw" + || GlobalV::BASIS_TYPE == "lcao_in_pw")) // xiaohui add 2013-09-02. Attention... { GlobalC::ppcell.getvnl(ik, GlobalC::ppcell.vkb); } GlobalV::CURRENT_K = ik; - } stohchi.current_ik = ik; @@ -421,25 +431,28 @@ void Stochastic_Iter::calHsqrtchi(Stochastic_WF& stowf) } } -void Stochastic_Iter::sum_stoband(Stochastic_WF& stowf, elecstate::ElecState* pes,hamilt::Hamilt>* pHamilt, ModulePW::PW_Basis_K* wfc_basis) -{ - ModuleBase::TITLE("Stochastic_Iter","sum_stoband"); - ModuleBase::timer::tick("Stochastic_Iter","sum_stoband"); +void Stochastic_Iter::sum_stoband(Stochastic_WF& stowf, + elecstate::ElecState* pes, + hamilt::Hamilt>* pHamilt, + ModulePW::PW_Basis_K* wfc_basis) +{ + ModuleBase::TITLE("Stochastic_Iter", "sum_stoband"); + ModuleBase::timer::tick("Stochastic_Iter", "sum_stoband"); int nrxx = wfc_basis->nrxx; int npwx = wfc_basis->npwk_max; const int norder = p_che->norder; //---------------cal demet----------------------- double stodemet; - if(this->method == 1) + if (this->method == 1) { - p_che->calcoef_real(&stofunc,&Sto_Func::nfdlnfd); - stodemet = BlasConnector::dot(norder,p_che->coef_real,1,spolyv,1); + p_che->calcoef_real(&stofunc, &Sto_Func::nfdlnfd); + stodemet = BlasConnector::dot(norder, p_che->coef_real, 1, spolyv, 1); } else { - p_che->calcoef_real(&stofunc,&Sto_Func::n_root_fdlnfd); - stodemet = -vTMv(p_che->coef_real,spolyv,norder); + p_che->calcoef_real(&stofunc, &Sto_Func::n_root_fdlnfd); + stodemet = -vTMv(p_che->coef_real, spolyv, norder); } if (GlobalV::NBANDS > 0) @@ -457,7 +470,7 @@ void Stochastic_Iter::sum_stoband(Stochastic_WF& stowf, elecstate::ElecState* pe pes->f_en.demet /= GlobalV::NPROC_IN_POOL; #ifdef __MPI MPI_Allreduce(MPI_IN_PLACE, &pes->f_en.demet, 1, MPI_DOUBLE, MPI_SUM, STO_WORLD); - MPI_Allreduce(MPI_IN_PLACE, &stodemet,1,MPI_DOUBLE,MPI_SUM,MPI_COMM_WORLD); + MPI_Allreduce(MPI_IN_PLACE, &stodemet, 1, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD); #endif pes->f_en.demet += stodemet; this->check_precision(pes->f_en.demet, 1e-4, "TS"); @@ -465,17 +478,17 @@ void Stochastic_Iter::sum_stoband(Stochastic_WF& stowf, elecstate::ElecState* pe //--------------------cal eband------------------------ double sto_eband = 0; - if(this->method == 1) + if (this->method == 1) { - p_che->calcoef_real(&stofunc,&Sto_Func::nxfd); - sto_eband = BlasConnector::dot(norder,p_che->coef_real,1,spolyv,1); + p_che->calcoef_real(&stofunc, &Sto_Func::nxfd); + sto_eband = BlasConnector::dot(norder, p_che->coef_real, 1, spolyv, 1); } else { - for(int ik = 0; ik < this->pkv->get_nks(); ++ik) + for (int ik = 0; ik < this->pkv->get_nks(); ++ik) { const int nchip_ik = nchip[ik]; - if(this->pkv->get_nks() > 1) + if (this->pkv->get_nks() > 1) { pHamilt->updateHk(ik); stowf.shchi->fix_k(ik); @@ -483,25 +496,25 @@ void Stochastic_Iter::sum_stoband(Stochastic_WF& stowf, elecstate::ElecState* pe stohchi.current_ik = ik; const int npw = this->pkv->ngk[ik]; const double kweight = this->pkv->wk[ik]; - std::complex *hshchi = new std::complex [nchip_ik * npwx]; + std::complex* hshchi = new std::complex[nchip_ik * npwx]; std::complex* tmpin = stowf.shchi->get_pointer(); - std::complex *tmpout = hshchi; - stohchi.hchi(tmpin,tmpout,nchip_ik); - for(int ichi = 0; ichi < nchip_ik ; ++ichi) + std::complex* tmpout = hshchi; + stohchi.hchi(tmpin, tmpout, nchip_ik); + for (int ichi = 0; ichi < nchip_ik; ++ichi) { - sto_eband += kweight * ModuleBase::GlobalFunc::ddot_real(npw,tmpin,tmpout,false); - tmpin+=npwx; - tmpout+=npwx; + sto_eband += kweight * ModuleBase::GlobalFunc::ddot_real(npw, tmpin, tmpout, false); + tmpin += npwx; + tmpout += npwx; } delete[] hshchi; } } #ifdef __MPI - MPI_Allreduce(MPI_IN_PLACE, &sto_eband,1,MPI_DOUBLE,MPI_SUM,MPI_COMM_WORLD); + MPI_Allreduce(MPI_IN_PLACE, &sto_eband, 1, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD); #endif pes->f_en.eband += sto_eband; //---------------------cal rho------------------------- - double *sto_rho = new double [nrxx]; + double* sto_rho = new double[nrxx]; double dr3 = GlobalC::ucell.omega / wfc_basis->nxyz; double tmprho, tmpne; @@ -524,8 +537,8 @@ void Stochastic_Iter::sum_stoband(Stochastic_WF& stowf, elecstate::ElecState* pe { const int nchip_ik = nchip[ik]; stowf.shchi->fix_k(ik); - std::complex *tmpout = stowf.shchi->get_pointer(); - for(int ichi = 0; ichi < nchip_ik ; ++ichi) + std::complex* tmpout = stowf.shchi->get_pointer(); + for (int ichi = 0; ichi < nchip_ik; ++ichi) { wfc_basis->recip2real(tmpout, porter, ik); for (int ir = 0; ir < nrxx; ++ir) @@ -537,7 +550,7 @@ void Stochastic_Iter::sum_stoband(Stochastic_WF& stowf, elecstate::ElecState* pe } delete[] porter; #ifdef __MPI - //temporary, rho_mpi should be rewrite as a tool function! Now it only treats pes->charge->rho + // temporary, rho_mpi should be rewrite as a tool function! Now it only treats pes->charge->rho pes->charge->rho_mpi(); #endif for (int ir = 0; ir < nrxx; ++ir) @@ -549,30 +562,34 @@ void Stochastic_Iter::sum_stoband(Stochastic_WF& stowf, elecstate::ElecState* pe sto_ne *= dr3; #ifdef __MPI - MPI_Allreduce(MPI_IN_PLACE,&sto_ne,1,MPI_DOUBLE,MPI_SUM,POOL_WORLD); - MPI_Allreduce(MPI_IN_PLACE,&sto_ne,1,MPI_DOUBLE,MPI_SUM,PARAPW_WORLD); - MPI_Allreduce(MPI_IN_PLACE,sto_rho,nrxx,MPI_DOUBLE,MPI_SUM,PARAPW_WORLD); + MPI_Allreduce(MPI_IN_PLACE, &sto_ne, 1, MPI_DOUBLE, MPI_SUM, POOL_WORLD); + MPI_Allreduce(MPI_IN_PLACE, &sto_ne, 1, MPI_DOUBLE, MPI_SUM, PARAPW_WORLD); + MPI_Allreduce(MPI_IN_PLACE, sto_rho, nrxx, MPI_DOUBLE, MPI_SUM, PARAPW_WORLD); #endif - double factor = targetne/(KS_ne+sto_ne); - if(std::abs(factor-1) > 1e-10) + double factor = targetne / (KS_ne + sto_ne); + if (std::abs(factor - 1) > 1e-10) { - GlobalV::ofs_running<<"Renormalize rho from ne = "<& stoche); void sum_stoband(Stochastic_WF& stowf, elecstate::ElecState* pes, @@ -44,44 +46,39 @@ class Stochastic_Iter void itermu(const int iter, elecstate::ElecState* pes); - void orthog(const int &ik, psi::Psi>& psi, Stochastic_WF& stowf); + void orthog(const int& ik, psi::Psi>& psi, Stochastic_WF& stowf); - void checkemm(const int &ik, const int istep, const int iter, Stochastic_WF& stowf); + void checkemm(const int& ik, const int istep, const int iter, Stochastic_WF& stowf); - void check_precision(const double ref,const double thr, const std::string info); + void check_precision(const double ref, const double thr, const std::string info); ModuleBase::Chebyshev* p_che = nullptr; Stochastic_hchi stohchi; Sto_Func stofunc; - double mu0; // chemical potential; unit in Ry + double mu0; // chemical potential; unit in Ry bool change; double targetne; - double *spolyv = nullptr; + double* spolyv = nullptr; - public: - - int * nchip = nullptr; + public: + int* nchip = nullptr; bool check = false; double th_ne; double KS_ne; - public: - int method; //different methods 1: slow, less memory 2: fast, more memory - ModuleBase::ComplexMatrix* chiallorder = nullptr; - //chiallorder cost too much memories and should be cleaned after scf. - void cleanchiallorder(); - //cal shchi = \sqrt{f(\hat{H})}|\chi> + + public: + int method; // different methods 1: slow, less memory 2: fast, more memory + // cal shchi = \sqrt{f(\hat{H})}|\chi> void calHsqrtchi(Stochastic_WF& stowf); - //cal Pn = \sum_\chi <\chi|Tn(\hat{h})|\chi> + // cal Pn = \sum_\chi <\chi|Tn(\hat{h})|\chi> void calPn(const int& ik, Stochastic_WF& stowf); - //cal Tnchi = \sum_n C_n*T_n(\hat{h})|\chi> + // cal Tnchi = \sum_n C_n*T_n(\hat{h})|\chi> void calTnchi_ik(const int& ik, Stochastic_WF& stowf); - //cal v^T*M*v - double vTMv(const double *v, const double * M, const int n); + private: K_Vectors* pkv; - }; -#endif// Eelectrons_Iter +#endif // Eelectrons_Iter diff --git a/source/module_hamilt_pw/hamilt_stodft/sto_tool.cpp b/source/module_hamilt_pw/hamilt_stodft/sto_tool.cpp index 091092e953..2e5aeafbb4 100644 --- a/source/module_hamilt_pw/hamilt_stodft/sto_tool.cpp +++ b/source/module_hamilt_pw/hamilt_stodft/sto_tool.cpp @@ -1,14 +1,20 @@ #include "sto_tool.h" #include "module_base/timer.h" +#include "module_base/math_chebyshev.h" #ifdef __MPI #include "mpi.h" #endif #include -void check_che(const int& nche_in, const double& try_emin, const double& try_emax, const int& nbands_sto, - K_Vectors* p_kv, Stochastic_WF* p_stowf, hamilt::Hamilt>* p_hamilt, - hsolver::HSolverPW_SDFT* p_hsol) +void check_che(const int& nche_in, + const double& try_emin, + const double& try_emax, + const int& nbands_sto, + K_Vectors* p_kv, + Stochastic_WF* p_stowf, + hamilt::Hamilt>* p_hamilt, + Stochastic_hchi& stohchi) { //------------------------------ // Convergence test @@ -16,11 +22,9 @@ void check_che(const int& nche_in, const double& try_emin, const double& try_ema bool change = false; const int nk = p_kv->get_nks(); ModuleBase::Chebyshev chetest(nche_in); - Stochastic_Iter& stoiter = p_hsol->stoiter; - Stochastic_hchi& stohchi = stoiter.stohchi; int ntest0 = 5; - stohchi.Emax = try_emax; - stohchi.Emin = try_emin; + *stohchi.Emax = try_emax; + *stohchi.Emin = try_emin; // if (GlobalV::NBANDS > 0) // { // double tmpemin = 1e10; @@ -37,7 +41,7 @@ void check_che(const int& nche_in, const double& try_emin, const double& try_ema for (int ik = 0; ik < nk; ++ik) { p_hamilt->updateHk(ik); - stoiter.stohchi.current_ik = ik; + stohchi.current_ik = ik; const int npw = p_kv->ngk[ik]; std::complex* pchi = nullptr; std::vector> randchi; @@ -66,8 +70,13 @@ void check_che(const int& nche_in, const double& try_emin, const double& try_ema while (1) { bool converge; - converge = chetest.checkconverge(&stohchi, &Stochastic_hchi::hchi_norm, pchi, npw, stohchi.Emax, - stohchi.Emin, 2.0); + converge = chetest.checkconverge(&stohchi, + &Stochastic_hchi::hchi_norm, + pchi, + npw, + *stohchi.Emax, + *stohchi.Emin, + 2.0); if (!converge) { @@ -83,12 +92,10 @@ void check_che(const int& nche_in, const double& try_emin, const double& try_ema if (ik == nk - 1) { #ifdef __MPI - MPI_Allreduce(MPI_IN_PLACE, &stohchi.Emax, 1, MPI_DOUBLE, MPI_MAX, MPI_COMM_WORLD); - MPI_Allreduce(MPI_IN_PLACE, &stohchi.Emin, 1, MPI_DOUBLE, MPI_MIN, MPI_COMM_WORLD); + MPI_Allreduce(MPI_IN_PLACE, stohchi.Emax, 1, MPI_DOUBLE, MPI_MAX, MPI_COMM_WORLD); + MPI_Allreduce(MPI_IN_PLACE, stohchi.Emin, 1, MPI_DOUBLE, MPI_MIN, MPI_COMM_WORLD); #endif - stoiter.stofunc.Emax = stohchi.Emax; - stoiter.stofunc.Emin = stohchi.Emin; - GlobalV::ofs_running << "New Emax " << stohchi.Emax << " Ry; new Emin " << stohchi.Emin << " Ry" + GlobalV::ofs_running << "New Emax " << *stohchi.Emax << " Ry; new Emin " << *stohchi.Emin << " Ry" << std::endl; change = false; } @@ -106,8 +113,12 @@ void convert_psi(const psi::Psi>& psi_in, psi::Psi>* gatherchi(psi::Psi>& chi, psi::Psi>& chi_all, - const int& npwx, int* nrecv_sto, int* displs_sto, const int perbands_sto) +psi::Psi>* gatherchi(psi::Psi>& chi, + psi::Psi>& chi_all, + const int& npwx, + int* nrecv_sto, + int* displs_sto, + const int perbands_sto) { psi::Psi>* p_chi; p_chi = χ @@ -116,8 +127,14 @@ psi::Psi>* gatherchi(psi::Psi>& chi, psi { p_chi = &chi_all; ModuleBase::timer::tick("sKG", "bands_gather"); - MPI_Allgatherv(chi.get_pointer(), perbands_sto * npwx, MPI_COMPLEX, chi_all.get_pointer(), nrecv_sto, - displs_sto, MPI_COMPLEX, PARAPW_WORLD); + MPI_Allgatherv(chi.get_pointer(), + perbands_sto * npwx, + MPI_COMPLEX, + chi_all.get_pointer(), + nrecv_sto, + displs_sto, + MPI_COMPLEX, + PARAPW_WORLD); ModuleBase::timer::tick("sKG", "bands_gather"); } #endif diff --git a/source/module_hamilt_pw/hamilt_stodft/sto_tool.h b/source/module_hamilt_pw/hamilt_stodft/sto_tool.h index 90fa78a818..69e501cb53 100644 --- a/source/module_hamilt_pw/hamilt_stodft/sto_tool.h +++ b/source/module_hamilt_pw/hamilt_stodft/sto_tool.h @@ -1,7 +1,7 @@ #include "module_cell/klist.h" #include "module_hamilt_general/hamilt.h" +#include "module_hamilt_pw/hamilt_stodft/sto_hchi.h" #include "module_hamilt_pw/hamilt_stodft/sto_wf.h" -#include "module_hsolver/hsolver_pw_sdft.h" #include "module_psi/psi.h" /** * @brief Check if Emin and Emax are converged @@ -11,9 +11,14 @@ * @param try_emax trial Emax * @param nbands_sto number of stochastic bands */ -void check_che(const int& nche_in, const double& try_emin, const double& try_emax, const int& nbands_sto, - K_Vectors* p_kv, Stochastic_WF* p_stowf, hamilt::Hamilt>* p_hamilt, - hsolver::HSolverPW_SDFT* p_hsol_in); +void check_che(const int& nche_in, + const double& try_emin, + const double& try_emax, + const int& nbands_sto, + K_Vectors* p_kv, + Stochastic_WF* p_stowf, + hamilt::Hamilt>* p_hamilt, + Stochastic_hchi& stohchi); #ifndef PARALLEL_DISTRIBUTION #define PARALLEL_DISTRIBUTION @@ -104,5 +109,9 @@ void convert_psi(const psi::Psi>& psi_in, psi::Psi> pointer to gathered stochastic wave function * */ -psi::Psi>* gatherchi(psi::Psi>& chi, psi::Psi>& chi_all, - const int& npwx, int* nrecv_sto, int* displs_sto, const int perbands_sto); +psi::Psi>* gatherchi(psi::Psi>& chi, + psi::Psi>& chi_all, + const int& npwx, + int* nrecv_sto, + int* displs_sto, + const int perbands_sto); diff --git a/source/module_hamilt_pw/hamilt_stodft/sto_wf.cpp b/source/module_hamilt_pw/hamilt_stodft/sto_wf.cpp index 2067d0d16a..e5f96b6445 100644 --- a/source/module_hamilt_pw/hamilt_stodft/sto_wf.cpp +++ b/source/module_hamilt_pw/hamilt_stodft/sto_wf.cpp @@ -21,6 +21,7 @@ Stochastic_WF::~Stochastic_WF() delete shchi; delete chiortho; delete[] nchip; + delete[] chiallorder; } void Stochastic_WF::init(K_Vectors* p_kv, const int npwx_in) @@ -36,6 +37,21 @@ void Stochastic_WF::init(K_Vectors* p_kv, const int npwx_in) } } +void Stochastic_WF::allocate_chiallorder(const int& norder) +{ + this->chiallorder = new ModuleBase::ComplexMatrix[this->nks]; + for (int ik = 0; ik < this->nks; ++ik) + { + chiallorder[ik].create(this->nchip[ik] * this->npwx, norder,true); + } +} + +void Stochastic_WF::clean_chiallorder() +{ + delete[] chiallorder; + chiallorder = nullptr; +} + void Init_Sto_Orbitals(Stochastic_WF& stowf, const int seed_in) { if (seed_in == 0 || seed_in == -1) diff --git a/source/module_hamilt_pw/hamilt_stodft/sto_wf.h b/source/module_hamilt_pw/hamilt_stodft/sto_wf.h index f198839f2d..f8811f8bb0 100644 --- a/source/module_hamilt_pw/hamilt_stodft/sto_wf.h +++ b/source/module_hamilt_pw/hamilt_stodft/sto_wf.h @@ -2,9 +2,9 @@ #define STOCHASTIC_WF_H #include "module_base/complexmatrix.h" +#include "module_basis/module_pw/pw_basis_k.h" #include "module_cell/klist.h" #include "module_psi/psi.h" -#include "module_basis/module_pw/pw_basis_k.h" //---------------------------------------------- // Generate stochastic wave functions @@ -32,6 +32,13 @@ class Stochastic_WF int npwx = 0; ///< max ngk[ik] in all processors int nbands_diag = 0; ///< number of bands obtained from diagonalization int nbands_total = 0; ///< number of bands in total, nbands_total=nchi+nbands_diag; + public: + // Tn(H)|chi> + ModuleBase::ComplexMatrix* chiallorder = nullptr; + // allocate chiallorder + void allocate_chiallorder(const int& norder); + // chiallorder cost too much memories and should be cleaned after scf. + void clean_chiallorder(); }; // init stochastic orbitals void Init_Sto_Orbitals(Stochastic_WF& stowf, const int seed_in); diff --git a/source/module_hsolver/hsolver_pw_sdft.h b/source/module_hsolver/hsolver_pw_sdft.h index 677aa43275..595d8553d5 100644 --- a/source/module_hsolver/hsolver_pw_sdft.h +++ b/source/module_hsolver/hsolver_pw_sdft.h @@ -11,10 +11,11 @@ class HSolverPW_SDFT : public HSolverPW> ModulePW::PW_Basis_K* wfc_basis_in, wavefunc* pwf_in, Stochastic_WF& stowf, - const int method_sto) - : HSolverPW(wfc_basis_in, pwf_in, false) + StoChe& stoche, + const bool initialed_psi_in) + : HSolverPW(wfc_basis_in, pwf_in, initialed_psi_in) { - stoiter.init(method_sto, pkv, wfc_basis_in, stowf); + stoiter.init(pkv, wfc_basis_in, stowf, stoche); } virtual void solve(hamilt::Hamilt>* pHamilt, @@ -38,6 +39,11 @@ class HSolverPW_SDFT : public HSolverPW> return 0.0; } + void set_KS_ne(const double& KS_ne_in) + { + stoiter.KS_ne = KS_ne_in; + } + Stochastic_Iter stoiter; }; } // namespace hsolver diff --git a/source/module_hsolver/test/test_hsolver_sdft.cpp b/source/module_hsolver/test/test_hsolver_sdft.cpp index d90b4038e4..a7dfd180bd 100644 --- a/source/module_hsolver/test/test_hsolver_sdft.cpp +++ b/source/module_hsolver/test/test_hsolver_sdft.cpp @@ -18,6 +18,16 @@ Sto_Func::Sto_Func(){} template class Sto_Func; +template +StoChe::StoChe(const int& nche, const int& method, const REAL& emax_sto, const REAL& emin_sto) +{ + this->nche = nche; +} +template +StoChe::~StoChe(){} + +template class StoChe; + Stochastic_hchi::Stochastic_hchi(){}; Stochastic_hchi::~Stochastic_hchi(){}; @@ -28,18 +38,13 @@ Stochastic_Iter::Stochastic_Iter() method = 2; } -Stochastic_Iter::~Stochastic_Iter() -{ - delete p_che; - delete[] spolyv; - delete[] chiallorder; -} +Stochastic_Iter::~Stochastic_Iter(){}; -void Stochastic_Iter::init(const int method_in, K_Vectors* pkv, ModulePW::PW_Basis_K *wfc_basis, Stochastic_WF &stowf) +void Stochastic_Iter::init(K_Vectors* pkv, ModulePW::PW_Basis_K *wfc_basis, Stochastic_WF &stowf, StoChe &stoche) { this->nchip = stowf.nchip;; this->targetne = 1; - this->method = method_in; + this->method = stoche.method_sto; } void Stochastic_Iter::orthog(const int& ik, @@ -123,11 +128,13 @@ Charge::~Charge(){}; class TestHSolverPW_SDFT : public ::testing::Test { public: + TestHSolverPW_SDFT():stoche(8,1,0,0){} ModulePW::PW_Basis_K pwbk; Stochastic_WF stowf; K_Vectors kv; wavefunc wf; - hsolver::HSolverPW_SDFT hs_d = hsolver::HSolverPW_SDFT(&kv, &pwbk, &wf, stowf, 0); + StoChe stoche; + hsolver::HSolverPW_SDFT hs_d = hsolver::HSolverPW_SDFT(&kv, &pwbk, &wf, stowf, stoche, false); hamilt::Hamilt> hamilt_test_d; diff --git a/tests/integrate/184_PW_BNDKPAR_SDFT_MALL/result.ref b/tests/integrate/184_PW_BNDKPAR_SDFT_MALL/result.ref index 662394eb5c..7df4ad8afa 100644 --- a/tests/integrate/184_PW_BNDKPAR_SDFT_MALL/result.ref +++ b/tests/integrate/184_PW_BNDKPAR_SDFT_MALL/result.ref @@ -1,5 +1,5 @@ -etotref -103.9857316470281745 -etotperatomref -51.9928658235 -totalforceref 197.981112 -totalstressref 257669.179265 -totaltimeref 31.18 +etotref -103.9857254497947423 +etotperatomref -51.9928627249 +totalforceref 197.981114 +totalstressref 257669.185397 +totaltimeref 27.90