Skip to content

Commit

Permalink
Refactor: remove template<class T> in Chebyshev (deepmodeling#4972)
Browse files Browse the repository at this point in the history
* Refactor: remove template<class T> in Chebyshev

* [pre-commit.ci lite] apply automatic fixes

---------

Co-authored-by: pre-commit-ci-lite[bot] <117423508+pre-commit-ci-lite[bot]@users.noreply.github.com>
  • Loading branch information
Qianruipku and pre-commit-ci-lite[bot] authored Aug 17, 2024
1 parent d0b25ea commit ffdc617
Show file tree
Hide file tree
Showing 8 changed files with 1,239 additions and 1,127 deletions.
619 changes: 577 additions & 42 deletions source/module_base/math_chebyshev.cpp

Large diffs are not rendered by default.

353 changes: 180 additions & 173 deletions source/module_base/math_chebyshev.h

Large diffs are not rendered by default.

526 changes: 0 additions & 526 deletions source/module_base/math_chebyshev_def.h

This file was deleted.

685 changes: 397 additions & 288 deletions source/module_base/test/math_chebyshev_test.cpp

Large diffs are not rendered by default.

21 changes: 14 additions & 7 deletions source/module_hamilt_pw/hamilt_stodft/sto_dos.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,14 @@ void Sto_DOS::caldos(const double sigmain, const double de, const int npart)
p_stowf->chi0->fix_k(ik);
pchi = p_stowf->chi0->get_pointer();
}
auto hchi_norm = std::bind(&Stochastic_hchi::hchi_norm,
&stohchi,
std::placeholders::_1,
std::placeholders::_2,
std::placeholders::_3);
if (this->method_sto == 1)
{
che.tracepolyA(&stohchi, &Stochastic_hchi::hchi_norm, pchi, npw, npwx, nchipk);
che.tracepolyA(hchi_norm, pchi, npw, npwx, nchipk);
for (int i = 0; i < dos_nche; ++i)
{
spolyv[i] += che.polytrace[i] * p_kv->wk[ik] / 2;
Expand All @@ -143,8 +148,7 @@ void Sto_DOS::caldos(const double sigmain, const double de, const int npart)
}
ModuleBase::GlobalFunc::ZEROS(allorderchi.data(), nchipk_new * npwx * dos_nche);
std::complex<double>* tmpchi = pchi + start_nchipk * npwx;
che.calpolyvec_complex(&stohchi,
&Stochastic_hchi::hchi_norm,
che.calpolyvec_complex(hchi_norm,
tmpchi,
allorderchi.data(),
npw,
Expand Down Expand Up @@ -180,12 +184,14 @@ void Sto_DOS::caldos(const double sigmain, const double de, const int npart)
this->stofunc.targ_e = (emin + ie * de) / ModuleBase::Ry_to_eV;
if (this->method_sto == 1)
{
che.calcoef_real(&this->stofunc, &Sto_Func<double>::ngauss);
auto ngauss = std::bind(&Sto_Func<double>::ngauss, &this->stofunc, std::placeholders::_1);
che.calcoef_real(ngauss);
tmpsto = BlasConnector::dot(dos_nche, che.coef_real, 1, spolyv.data(), 1);
}
else
{
che.calcoef_real(&this->stofunc, &Sto_Func<double>::nroot_gauss);
auto nroot_gauss = std::bind(&Sto_Func<double>::nroot_gauss, &this->stofunc, std::placeholders::_1);
che.calcoef_real(nroot_gauss);
tmpsto = vTMv(che.coef_real, spolyv.data(), dos_nche);
}
if (GlobalV::NBANDS > 0)
Expand Down Expand Up @@ -243,9 +249,10 @@ void Sto_DOS::caldos(const double sigmain, const double de, const int npart)
for (int ie = 0; ie < ndos; ++ie)
{
double tmperror = 2.0 * std::abs(error[ie]);
if (maxerror < tmperror) {
if (maxerror < tmperror)
{
maxerror = tmperror;
}
}
double dos = 2.0 * (ks_dos[ie] + sto_dos[ie]) / ModuleBase::Ry_to_eV;
sum += dos;
ofsdos << std::setw(8) << emin + ie * de << std::setw(20) << dos << std::setw(20) << sum * de
Expand Down
92 changes: 32 additions & 60 deletions source/module_hamilt_pw/hamilt_stodft/sto_elecond.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,16 @@ void Sto_EleCond::decide_nche(const double dt,
const double mu = this->p_elec->eferm.ef;
this->stofunc.mu = mu;
int& nbatch = this->cond_dtbatch;
auto ncos = std::bind(&Sto_Func<double>::ncos, &this->stofunc, std::placeholders::_1);
auto n_sin = std::bind(&Sto_Func<double>::n_sin, &this->stofunc, std::placeholders::_1);
// try to find nbatch
if (nbatch == 0)
{
for (int test_nbatch = 128; test_nbatch >= 1; test_nbatch /= 2)
{
nbatch = test_nbatch;
this->stofunc.t = 0.5 * dt * nbatch;
chet.calcoef_pair(&this->stofunc, &Sto_Func<double>::ncos, &Sto_Func<double>::n_sin);
chet.calcoef_pair(ncos, n_sin);
double minerror = std::abs(chet.coef_complex[nche_guess - 1] / chet.coef_complex[0]);
if (minerror < cond_thr)
{
Expand All @@ -72,7 +74,7 @@ void Sto_EleCond::decide_nche(const double dt,
// first try to find nche
this->stofunc.t = 0.5 * dt * nbatch;
auto getnche = [&](int& nche) {
chet.calcoef_pair(&this->stofunc, &Sto_Func<double>::ncos, &Sto_Func<double>::n_sin);
chet.calcoef_pair(ncos, n_sin);
for (int i = 1; i < nche_guess; ++i)
{
double error = std::abs(chet.coef_complex[i] / chet.coef_complex[0]);
Expand Down Expand Up @@ -541,8 +543,11 @@ void Sto_EleCond::sKG(const int& smear_type,
const double mu = this->p_elec->eferm.ef;
this->stofunc.mu = mu;
this->stofunc.t = 0.5 * dt * nbatch;
chet.calcoef_pair(&this->stofunc, &Sto_Func<double>::ncos, &Sto_Func<double>::nsin);
chemt.calcoef_pair(&this->stofunc, &Sto_Func<double>::ncos, &Sto_Func<double>::n_sin);
auto ncos = std::bind(&Sto_Func<double>::ncos, &this->stofunc, std::placeholders::_1);
auto nsin = std::bind(&Sto_Func<double>::nsin, &this->stofunc, std::placeholders::_1);
auto n_sin = std::bind(&Sto_Func<double>::n_sin, &this->stofunc, std::placeholders::_1);
chet.calcoef_pair(ncos, nsin);
chemt.calcoef_pair(ncos, n_sin);
std::vector<std::complex<double>> batchcoef, batchmcoef;
if (nbatch > 1)
{
Expand All @@ -560,8 +565,8 @@ void Sto_EleCond::sKG(const int& smear_type,
tmpcoef = batchcoef.data() + ib * cond_nche;
tmpmcoef = batchmcoef.data() + ib * cond_nche;
this->stofunc.t = 0.5 * dt * (ib + 1);
chet.calcoef_pair(&this->stofunc, &Sto_Func<double>::ncos, &Sto_Func<double>::nsin);
chemt.calcoef_pair(&this->stofunc, &Sto_Func<double>::ncos, &Sto_Func<double>::n_sin);
chet.calcoef_pair(ncos, nsin);
chemt.calcoef_pair(ncos, n_sin);
for (int i = 0; i < cond_nche; ++i)
{
tmpcoef[i] = chet.coef_complex[i];
Expand Down Expand Up @@ -725,24 +730,19 @@ void Sto_EleCond::sKG(const int& smear_type,
vkspsi.resize(1, 1, 1);
}

che.calcoef_real(&this->stofunc, &Sto_Func<double>::nroot_fd);
che.calfinalvec_real(&stohchi,
&Stochastic_hchi::hchi_norm,
stopsi->get_pointer(),
sfchi.get_pointer(),
npw,
npwx,
perbands_sto);

che.calcoef_real(&this->stofunc, &Sto_Func<double>::nroot_mfd);

che.calfinalvec_real(&stohchi,
&Stochastic_hchi::hchi_norm,
stopsi->get_pointer(),
smfchi.get_pointer(),
npw,
npwx,
perbands_sto);
auto nroot_fd = std::bind(&Sto_Func<double>::nroot_fd, &this->stofunc, std::placeholders::_1);
che.calcoef_real(nroot_fd);
auto hchi_norm = std::bind(&Stochastic_hchi::hchi_norm,
&stohchi,
std::placeholders::_1,
std::placeholders::_2,
std::placeholders::_3);
che.calfinalvec_real(hchi_norm, stopsi->get_pointer(), sfchi.get_pointer(), npw, npwx, perbands_sto);

auto nroot_mfd = std::bind(&Sto_Func<double>::nroot_mfd, &this->stofunc, std::placeholders::_1);
che.calcoef_real(nroot_mfd);

che.calfinalvec_real(hchi_norm, stopsi->get_pointer(), smfchi.get_pointer(), npw, npwx, perbands_sto);

//------------------------ allocate ------------------------
psi::Psi<std::complex<double>>& expmtsfchi = sfchi;
Expand Down Expand Up @@ -821,29 +821,25 @@ void Sto_EleCond::sKG(const int& smear_type,
// Sto
if (nbatch == 1)
{
chemt.calfinalvec_complex(&stohchi,
&Stochastic_hchi::hchi_norm,
chemt.calfinalvec_complex(hchi_norm,
expmtsfchi.get_pointer(),
expmtsfchi.get_pointer(),
npw,
npwx,
perbands_sto);
chemt.calfinalvec_complex(&stohchi,
&Stochastic_hchi::hchi_norm,
chemt.calfinalvec_complex(hchi_norm,
expmtsmfchi.get_pointer(),
expmtsmfchi.get_pointer(),
npw,
npwx,
perbands_sto);
chet.calfinalvec_complex(&stohchi,
&Stochastic_hchi::hchi_norm,
chet.calfinalvec_complex(hchi_norm,
exptsfchi.get_pointer(),
exptsfchi.get_pointer(),
npw,
npwx,
perbands_sto);
chet.calfinalvec_complex(&stohchi,
&Stochastic_hchi::hchi_norm,
chet.calfinalvec_complex(hchi_norm,
exptsmfchi.get_pointer(),
exptsmfchi.get_pointer(),
npw,
Expand All @@ -862,34 +858,10 @@ void Sto_EleCond::sKG(const int& smear_type,
std::complex<double>* stoexptsmfchi = exptsmfchi.get_pointer();
if ((it - 1) % nbatch == 0)
{
chet.calpolyvec_complex(&stohchi,
&Stochastic_hchi::hchi_norm,
stoexptsfchi,
tmppolyexptsfchi,
npw,
npwx,
perbands_sto);
chet.calpolyvec_complex(&stohchi,
&Stochastic_hchi::hchi_norm,
stoexptsmfchi,
tmppolyexptsmfchi,
npw,
npwx,
perbands_sto);
chemt.calpolyvec_complex(&stohchi,
&Stochastic_hchi::hchi_norm,
stoexpmtsfchi,
tmppolyexpmtsfchi,
npw,
npwx,
perbands_sto);
chemt.calpolyvec_complex(&stohchi,
&Stochastic_hchi::hchi_norm,
stoexpmtsmfchi,
tmppolyexpmtsmfchi,
npw,
npwx,
perbands_sto);
chet.calpolyvec_complex(hchi_norm, stoexptsfchi, tmppolyexptsfchi, npw, npwx, perbands_sto);
chet.calpolyvec_complex(hchi_norm, stoexptsmfchi, tmppolyexptsmfchi, npw, npwx, perbands_sto);
chemt.calpolyvec_complex(hchi_norm, stoexpmtsfchi, tmppolyexpmtsfchi, npw, npwx, perbands_sto);
chemt.calpolyvec_complex(hchi_norm, stoexpmtsmfchi, tmppolyexpmtsmfchi, npw, npwx, perbands_sto);
}

std::complex<double>* tmpcoef = batchcoef.data() + (it - 1) % nbatch * cond_nche;
Expand Down
53 changes: 31 additions & 22 deletions source/module_hamilt_pw/hamilt_stodft/sto_iter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,13 +134,12 @@ void Stochastic_Iter::checkemm(const int& ik, const int istep, const int iter, S
while (true)
{
bool converge;
converge = p_che->checkconverge(&stohchi,
&Stochastic_hchi::hchi_norm,
pchi,
npw,
*stohchi.Emax,
*stohchi.Emin,
5.0);
auto hchi_norm = std::bind(&Stochastic_hchi::hchi_norm,
&stohchi,
std::placeholders::_1,
std::placeholders::_2,
std::placeholders::_3);
converge = p_che->checkconverge(hchi_norm, pchi, npw, *stohchi.Emax, *stohchi.Emin, 5.0);

if (!converge)
{
Expand Down Expand Up @@ -336,23 +335,22 @@ void Stochastic_Iter::calPn(const int& ik, Stochastic_WF& stowf)
pchi = stowf.chi0->get_pointer();
}

auto hchi_norm = std::bind(&Stochastic_hchi::hchi_norm,
&stohchi,
std::placeholders::_1,
std::placeholders::_2,
std::placeholders::_3);
if (this->method == 1)
{
p_che->tracepolyA(&stohchi, &Stochastic_hchi::hchi_norm, pchi, npw, npwx, nchip_ik);
p_che->tracepolyA(hchi_norm, pchi, npw, npwx, nchip_ik);
for (int i = 0; i < norder; ++i)
{
spolyv[i] += p_che->polytrace[i] * this->pkv->wk[ik];
}
}
else
{
p_che->calpolyvec_complex(&stohchi,
&Stochastic_hchi::hchi_norm,
pchi,
stowf.chiallorder[ik].c,
npw,
npwx,
nchip_ik);
p_che->calpolyvec_complex(hchi_norm, pchi, stowf.chiallorder[ik].c, npw, npwx, nchip_ik);
double* vec_all = (double*)stowf.chiallorder[ik].c;
char trans = 'T';
char normal = 'N';
Expand All @@ -377,12 +375,14 @@ double Stochastic_Iter::calne(elecstate::ElecState* pes)
if (this->method == 1)
{
// Note: spolyv contains kv.wk[ik]
p_che->calcoef_real(&stofunc, &Sto_Func<double>::nfd);
auto nfd = std::bind(&Sto_Func<double>::nfd, &this->stofunc, std::placeholders::_1);
p_che->calcoef_real(nfd);
sto_ne = BlasConnector::dot(norder, p_che->coef_real, 1, spolyv, 1);
}
else
{
p_che->calcoef_real(&stofunc, &Sto_Func<double>::nroot_fd);
auto nroot_fd = std::bind(&Sto_Func<double>::nroot_fd, &this->stofunc, std::placeholders::_1);
p_che->calcoef_real(nroot_fd);
sto_ne = vTMv(p_che->coef_real, spolyv, norder);
}
if (GlobalV::NBANDS > 0)
Expand All @@ -409,7 +409,8 @@ double Stochastic_Iter::calne(elecstate::ElecState* pes)

void Stochastic_Iter::calHsqrtchi(Stochastic_WF& stowf)
{
p_che->calcoef_real(&stofunc, &Sto_Func<double>::nroot_fd);
auto nroot_fd = std::bind(&Sto_Func<double>::nroot_fd, &this->stofunc, std::placeholders::_1);
p_che->calcoef_real(nroot_fd);
for (int ik = 0; ik < this->pkv->get_nks(); ++ik)
{
// init k
Expand Down Expand Up @@ -446,12 +447,14 @@ void Stochastic_Iter::sum_stoband(Stochastic_WF& stowf,
double stodemet;
if (this->method == 1)
{
p_che->calcoef_real(&stofunc, &Sto_Func<double>::nfdlnfd);
auto nfdlnfd = std::bind(&Sto_Func<double>::nfdlnfd, &this->stofunc, std::placeholders::_1);
p_che->calcoef_real(nfdlnfd);
stodemet = BlasConnector::dot(norder, p_che->coef_real, 1, spolyv, 1);
}
else
{
p_che->calcoef_real(&stofunc, &Sto_Func<double>::n_root_fdlnfd);
auto nroot_fdlnfd = std::bind(&Sto_Func<double>::n_root_fdlnfd, &this->stofunc, std::placeholders::_1);
p_che->calcoef_real(nroot_fdlnfd);
stodemet = -vTMv(p_che->coef_real, spolyv, norder);
}

Expand Down Expand Up @@ -480,7 +483,8 @@ void Stochastic_Iter::sum_stoband(Stochastic_WF& stowf,
double sto_eband = 0;
if (this->method == 1)
{
p_che->calcoef_real(&stofunc, &Sto_Func<double>::nxfd);
auto nxfd = std::bind(&Sto_Func<double>::nxfd, &this->stofunc, std::placeholders::_1);
p_che->calcoef_real(nxfd);
sto_eband = BlasConnector::dot(norder, p_che->coef_real, 1, spolyv, 1);
}
else
Expand Down Expand Up @@ -642,6 +646,11 @@ void Stochastic_Iter::calTnchi_ik(const int& ik, Stochastic_WF& stowf)
}
else
{
p_che->calfinalvec_real(&stohchi, &Stochastic_hchi::hchi_norm, pchi, out, npw, npwx, nchip[ik]);
auto hchi_norm = std::bind(&Stochastic_hchi::hchi_norm,
&stohchi,
std::placeholders::_1,
std::placeholders::_2,
std::placeholders::_3);
p_che->calfinalvec_real(hchi_norm, pchi, out, npw, npwx, nchip[ik]);
}
}
17 changes: 8 additions & 9 deletions source/module_hamilt_pw/hamilt_stodft/sto_tool.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#include "sto_tool.h"

#include "module_base/timer.h"
#include "module_base/math_chebyshev.h"
#include "module_base/timer.h"
#ifdef __MPI
#include "mpi.h"
#endif
Expand Down Expand Up @@ -67,16 +67,15 @@ void check_che(const int& nche_in,
{
pchi = &p_stowf->chi0[0](ik, i, 0);
}
while (1)
while (true)
{
bool converge;
converge = chetest.checkconverge(&stohchi,
&Stochastic_hchi::hchi_norm,
pchi,
npw,
*stohchi.Emax,
*stohchi.Emin,
2.0);
auto hchi_norm = std::bind(&Stochastic_hchi::hchi_norm,
&stohchi,
std::placeholders::_1,
std::placeholders::_2,
std::placeholders::_3);
converge = chetest.checkconverge(hchi_norm, pchi, npw, *stohchi.Emax, *stohchi.Emin, 2.0);

if (!converge)
{
Expand Down

0 comments on commit ffdc617

Please sign in to comment.