Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor: consistent order of hpsi #5134

Merged
merged 4 commits into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions source/module_hsolver/diago_dav_subspace.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ int Diago_DavSubspace<T, Device>::diag_once(const HPsiFunc& hpsi_func,

// compute h*psi_in_iter
// NOTE: bands after the first n_band should yield zero
hpsi_func(this->hphi, this->psi_in_iter, this->nbase_x, this->dim, 0, this->nbase_x - 1);
hpsi_func(this->psi_in_iter, this->hphi, this->nbase_x, this->dim, 0, this->nbase_x - 1);

// at this stage, notconv = n_band and nbase = 0
// note that nbase of cal_elem is an inout parameter: nbase := nbase + notconv
Expand Down Expand Up @@ -421,7 +421,7 @@ void Diago_DavSubspace<T, Device>::cal_grad(const HPsiFunc& hpsi_func,
}

// update hpsi[:, nbase:nbase+notconv]
hpsi_func(&hphi[nbase * this->dim], psi_iter, this->nbase_x, this->dim, nbase, nbase + notconv - 1);
hpsi_func(psi_iter, &hphi[nbase * this->dim], this->nbase_x, this->dim, nbase, nbase + notconv - 1);

ModuleBase::timer::tick("Diago_DavSubspace", "cal_grad");
return;
Expand Down Expand Up @@ -886,7 +886,7 @@ void Diago_DavSubspace<T, Device>::diagH_subspace(T* psi_pointer, // [in] & [out

{
// do hPsi for all bands
hpsi_func(hphi, psi_pointer, n_band, dmax, 0, nstart - 1);
hpsi_func(psi_pointer, hphi, n_band, dmax, 0, nstart - 1);

gemm_op<T, Device>()(ctx,
'C',
Expand Down
4 changes: 2 additions & 2 deletions source/module_hsolver/diago_david.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ int DiagoDavid<T, Device>::diag_once(const HPsiFunc& hpsi_func,
// end of SchmidtOrth and calculate H|psi>
// hpsi_info dav_hpsi_in(&basis, psi::Range(true, 0, 0, nband - 1), this->hpsi);
// phm_in->ops->hPsi(dav_hpsi_in);
hpsi_func(this->hpsi, basis, nbase_x, dim, 0, nband - 1);
hpsi_func(basis, hpsi, nbase_x, dim, 0, nband - 1);

this->cal_elem(dim, nbase, nbase_x, this->notconv, this->hpsi, this->spsi, this->hcc, this->scc);

Expand Down Expand Up @@ -601,7 +601,7 @@ void DiagoDavid<T, Device>::cal_grad(const HPsiFunc& hpsi_func,
// psi::Range(true, 0, nbase, nbase + notconv - 1),
// &hpsi[nbase * dim]); // &hp(nbase, 0)
// phm_in->ops->hPsi(dav_hpsi_in);
hpsi_func(&hpsi[nbase * dim], basis, nbase_x, dim, nbase, nbase + notconv - 1);
hpsi_func(basis, &hpsi[nbase * dim], nbase_x, dim, nbase, nbase + notconv - 1);

delmem_complex_op()(this->ctx, lagrange);
delmem_complex_op()(this->ctx, vc_ev_vector);
Expand Down
4 changes: 2 additions & 2 deletions source/module_hsolver/diago_david.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ class DiagoDavid : public DiagH<T, Device>
* For eigenvalue problem HX = λX or generalized eigenvalue problem HX = λSX,
* this function computes the product of the Hamiltonian matrix H and a blockvector X.
*
* @param[out] HX Pointer to output blockvector of type `T*`.
* @param[in] X Pointer to input blockvector of type `T*`.
* @param[out] X Pointer to input blockvector of type `T*`.
* @param[in] HX Pointer to output blockvector of type `T*`.
* @param[in] neig Number of eigebpairs required.
* @param[in] dim Dimension of matrix.
* @param[in] id_start Start index of blockvector.
Expand Down
10 changes: 5 additions & 5 deletions source/module_hsolver/hsolver_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -434,8 +434,8 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
else if (this->method == "dav_subspace")
{
auto ngk_pointer = psi.get_ngk_pointer();
auto hpsi_func = [hm, ngk_pointer](T* hpsi_out,
T* psi_in,
auto hpsi_func = [hm, ngk_pointer](T* psi_in,
T* hpsi_out,
const int nband_in,
const int nbasis_in,
const int band_index1,
Expand Down Expand Up @@ -492,9 +492,9 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,

auto ngk_pointer = psi.get_ngk_pointer();
/// wrap hpsi into lambda function, Matrix \times blockvector
/// hpsi(HX, X, nband, dim, band_index1, band_index2)
auto hpsi_func = [hm, ngk_pointer](T* hpsi_out,
T* psi_in,
/// hpsi(X, HX, nband, dim, band_index1, band_index2)
auto hpsi_func = [hm, ngk_pointer](T* psi_in,
T* hpsi_out,
const int nband_in,
const int nbasis_in,
const int band_index1,
Expand Down
2 changes: 1 addition & 1 deletion source/module_hsolver/test/diago_david_float_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ class DiagoDavPrepare
#endif


auto hpsi_func = [phm](std::complex<float>* hpsi_out,std::complex<float>* psi_in,
auto hpsi_func = [phm](std::complex<float>* psi_in,std::complex<float>* hpsi_out,
const int nband_in, const int nbasis_in,
const int band_index1, const int band_index2)
{
Expand Down
2 changes: 1 addition & 1 deletion source/module_hsolver/test/diago_david_real_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ class DiagoDavPrepare
#endif


auto hpsi_func = [phm](double* hpsi_out,double* psi_in,
auto hpsi_func = [phm](double* psi_in,double* hpsi_out,
const int nband_in, const int nbasis_in,
const int band_index1, const int band_index2)
{
Expand Down
2 changes: 1 addition & 1 deletion source/module_hsolver/test/diago_david_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ class DiagoDavPrepare
#endif


auto hpsi_func = [phm](std::complex<double>* hpsi_out,std::complex<double>* psi_in,
auto hpsi_func = [phm](std::complex<double>* psi_in,std::complex<double>* hpsi_out,
const int nband_in, const int nbasis_in,
const int band_index1, const int band_index2)
{
Expand Down
4 changes: 2 additions & 2 deletions source/module_lr/hsolver_lrtd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,8 @@ namespace LR
// do diag and add davidson iteration counts up to avg_iter

auto hpsi_func = [pHamilt](
T* hpsi_out,
T* psi_in,
T* hpsi_out,
const int nband_in,
const int nbasis_in,
const int band_index1,
Expand Down Expand Up @@ -119,8 +119,8 @@ namespace LR
comm_info);

std::function<void(T*, T*, const int, const int, const int, const int)> hpsi_func = [pHamilt](
T* hpsi_out,
T* psi_in,
T* hpsi_out,
const int nband_in,
const int nbasis_in,
const int band_index1,
Expand Down
Loading