Skip to content

Commit

Permalink
Refactor: consistent order of hpsi (#5134)
Browse files Browse the repository at this point in the history
  • Loading branch information
Cstandardlib authored Sep 20, 2024
1 parent 320b07f commit c4a324d
Show file tree
Hide file tree
Showing 8 changed files with 17 additions and 17 deletions.
6 changes: 3 additions & 3 deletions source/module_hsolver/diago_dav_subspace.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ int Diago_DavSubspace<T, Device>::diag_once(const HPsiFunc& hpsi_func,

// compute h*psi_in_iter
// NOTE: bands after the first n_band should yield zero
hpsi_func(this->hphi, this->psi_in_iter, this->nbase_x, this->dim, 0, this->nbase_x - 1);
hpsi_func(this->psi_in_iter, this->hphi, this->nbase_x, this->dim, 0, this->nbase_x - 1);

// at this stage, notconv = n_band and nbase = 0
// note that nbase of cal_elem is an inout parameter: nbase := nbase + notconv
Expand Down Expand Up @@ -421,7 +421,7 @@ void Diago_DavSubspace<T, Device>::cal_grad(const HPsiFunc& hpsi_func,
}

// update hpsi[:, nbase:nbase+notconv]
hpsi_func(&hphi[nbase * this->dim], psi_iter, this->nbase_x, this->dim, nbase, nbase + notconv - 1);
hpsi_func(psi_iter, &hphi[nbase * this->dim], this->nbase_x, this->dim, nbase, nbase + notconv - 1);

ModuleBase::timer::tick("Diago_DavSubspace", "cal_grad");
return;
Expand Down Expand Up @@ -886,7 +886,7 @@ void Diago_DavSubspace<T, Device>::diagH_subspace(T* psi_pointer, // [in] & [out

{
// do hPsi for all bands
hpsi_func(hphi, psi_pointer, n_band, dmax, 0, nstart - 1);
hpsi_func(psi_pointer, hphi, n_band, dmax, 0, nstart - 1);

gemm_op<T, Device>()(ctx,
'C',
Expand Down
4 changes: 2 additions & 2 deletions source/module_hsolver/diago_david.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ int DiagoDavid<T, Device>::diag_once(const HPsiFunc& hpsi_func,
// end of SchmidtOrth and calculate H|psi>
// hpsi_info dav_hpsi_in(&basis, psi::Range(true, 0, 0, nband - 1), this->hpsi);
// phm_in->ops->hPsi(dav_hpsi_in);
hpsi_func(this->hpsi, basis, nbase_x, dim, 0, nband - 1);
hpsi_func(basis, hpsi, nbase_x, dim, 0, nband - 1);

this->cal_elem(dim, nbase, nbase_x, this->notconv, this->hpsi, this->spsi, this->hcc, this->scc);

Expand Down Expand Up @@ -601,7 +601,7 @@ void DiagoDavid<T, Device>::cal_grad(const HPsiFunc& hpsi_func,
// psi::Range(true, 0, nbase, nbase + notconv - 1),
// &hpsi[nbase * dim]); // &hp(nbase, 0)
// phm_in->ops->hPsi(dav_hpsi_in);
hpsi_func(&hpsi[nbase * dim], basis, nbase_x, dim, nbase, nbase + notconv - 1);
hpsi_func(basis, &hpsi[nbase * dim], nbase_x, dim, nbase, nbase + notconv - 1);

delmem_complex_op()(this->ctx, lagrange);
delmem_complex_op()(this->ctx, vc_ev_vector);
Expand Down
4 changes: 2 additions & 2 deletions source/module_hsolver/diago_david.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ class DiagoDavid : public DiagH<T, Device>
* For eigenvalue problem HX = λX or generalized eigenvalue problem HX = λSX,
* this function computes the product of the Hamiltonian matrix H and a blockvector X.
*
* @param[out] HX Pointer to output blockvector of type `T*`.
* @param[in] X Pointer to input blockvector of type `T*`.
* @param[out] X Pointer to input blockvector of type `T*`.
* @param[in] HX Pointer to output blockvector of type `T*`.
* @param[in] neig Number of eigebpairs required.
* @param[in] dim Dimension of matrix.
* @param[in] id_start Start index of blockvector.
Expand Down
10 changes: 5 additions & 5 deletions source/module_hsolver/hsolver_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -434,8 +434,8 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
else if (this->method == "dav_subspace")
{
auto ngk_pointer = psi.get_ngk_pointer();
auto hpsi_func = [hm, ngk_pointer](T* hpsi_out,
T* psi_in,
auto hpsi_func = [hm, ngk_pointer](T* psi_in,
T* hpsi_out,
const int nband_in,
const int nbasis_in,
const int band_index1,
Expand Down Expand Up @@ -492,9 +492,9 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,

auto ngk_pointer = psi.get_ngk_pointer();
/// wrap hpsi into lambda function, Matrix \times blockvector
/// hpsi(HX, X, nband, dim, band_index1, band_index2)
auto hpsi_func = [hm, ngk_pointer](T* hpsi_out,
T* psi_in,
/// hpsi(X, HX, nband, dim, band_index1, band_index2)
auto hpsi_func = [hm, ngk_pointer](T* psi_in,
T* hpsi_out,
const int nband_in,
const int nbasis_in,
const int band_index1,
Expand Down
2 changes: 1 addition & 1 deletion source/module_hsolver/test/diago_david_float_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ class DiagoDavPrepare
#endif


auto hpsi_func = [phm](std::complex<float>* hpsi_out,std::complex<float>* psi_in,
auto hpsi_func = [phm](std::complex<float>* psi_in,std::complex<float>* hpsi_out,
const int nband_in, const int nbasis_in,
const int band_index1, const int band_index2)
{
Expand Down
2 changes: 1 addition & 1 deletion source/module_hsolver/test/diago_david_real_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ class DiagoDavPrepare
#endif


auto hpsi_func = [phm](double* hpsi_out,double* psi_in,
auto hpsi_func = [phm](double* psi_in,double* hpsi_out,
const int nband_in, const int nbasis_in,
const int band_index1, const int band_index2)
{
Expand Down
2 changes: 1 addition & 1 deletion source/module_hsolver/test/diago_david_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ class DiagoDavPrepare
#endif


auto hpsi_func = [phm](std::complex<double>* hpsi_out,std::complex<double>* psi_in,
auto hpsi_func = [phm](std::complex<double>* psi_in,std::complex<double>* hpsi_out,
const int nband_in, const int nbasis_in,
const int band_index1, const int band_index2)
{
Expand Down
4 changes: 2 additions & 2 deletions source/module_lr/hsolver_lrtd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,8 @@ namespace LR
// do diag and add davidson iteration counts up to avg_iter

auto hpsi_func = [pHamilt](
T* hpsi_out,
T* psi_in,
T* hpsi_out,
const int nband_in,
const int nbasis_in,
const int band_index1,
Expand Down Expand Up @@ -119,8 +119,8 @@ namespace LR
comm_info);

std::function<void(T*, T*, const int, const int, const int, const int)> hpsi_func = [pHamilt](
T* hpsi_out,
T* psi_in,
T* hpsi_out,
const int nband_in,
const int nbasis_in,
const int band_index1,
Expand Down

0 comments on commit c4a324d

Please sign in to comment.