Skip to content

Commit

Permalink
Merge branch 'develop' into hsolver_pw
Browse files Browse the repository at this point in the history
  • Loading branch information
haozhihan authored Sep 21, 2024
2 parents dc2c921 + 80b2c75 commit f91d70b
Show file tree
Hide file tree
Showing 12 changed files with 31 additions and 24 deletions.
1 change: 1 addition & 0 deletions source/module_esolver/esolver_fp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ void ESolver_FP::after_scf(const int istep)
PARAM.inp.nspin,
GlobalC::ucell.GT,
rhog_tot,
GlobalV::MY_POOL,
GlobalV::RANK_IN_POOL,
GlobalV::NPROC_IN_POOL);
}
Expand Down
6 changes: 3 additions & 3 deletions source/module_hsolver/diago_dav_subspace.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ int Diago_DavSubspace<T, Device>::diag_once(const HPsiFunc& hpsi_func,

// compute h*psi_in_iter
// NOTE: bands after the first n_band should yield zero
hpsi_func(this->hphi, this->psi_in_iter, this->nbase_x, this->dim, 0, this->nbase_x - 1);
hpsi_func(this->psi_in_iter, this->hphi, this->nbase_x, this->dim, 0, this->nbase_x - 1);

// at this stage, notconv = n_band and nbase = 0
// note that nbase of cal_elem is an inout parameter: nbase := nbase + notconv
Expand Down Expand Up @@ -421,7 +421,7 @@ void Diago_DavSubspace<T, Device>::cal_grad(const HPsiFunc& hpsi_func,
}

// update hpsi[:, nbase:nbase+notconv]
hpsi_func(&hphi[nbase * this->dim], psi_iter, this->nbase_x, this->dim, nbase, nbase + notconv - 1);
hpsi_func(psi_iter, &hphi[nbase * this->dim], this->nbase_x, this->dim, nbase, nbase + notconv - 1);

ModuleBase::timer::tick("Diago_DavSubspace", "cal_grad");
return;
Expand Down Expand Up @@ -886,7 +886,7 @@ void Diago_DavSubspace<T, Device>::diagH_subspace(T* psi_pointer, // [in] & [out

{
// do hPsi for all bands
hpsi_func(hphi, psi_pointer, n_band, dmax, 0, nstart - 1);
hpsi_func(psi_pointer, hphi, n_band, dmax, 0, nstart - 1);

gemm_op<T, Device>()(ctx,
'C',
Expand Down
4 changes: 2 additions & 2 deletions source/module_hsolver/diago_david.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ int DiagoDavid<T, Device>::diag_once(const HPsiFunc& hpsi_func,
// end of SchmidtOrth and calculate H|psi>
// hpsi_info dav_hpsi_in(&basis, psi::Range(true, 0, 0, nband - 1), this->hpsi);
// phm_in->ops->hPsi(dav_hpsi_in);
hpsi_func(this->hpsi, basis, nbase_x, dim, 0, nband - 1);
hpsi_func(basis, hpsi, nbase_x, dim, 0, nband - 1);

this->cal_elem(dim, nbase, nbase_x, this->notconv, this->hpsi, this->spsi, this->hcc, this->scc);

Expand Down Expand Up @@ -601,7 +601,7 @@ void DiagoDavid<T, Device>::cal_grad(const HPsiFunc& hpsi_func,
// psi::Range(true, 0, nbase, nbase + notconv - 1),
// &hpsi[nbase * dim]); // &hp(nbase, 0)
// phm_in->ops->hPsi(dav_hpsi_in);
hpsi_func(&hpsi[nbase * dim], basis, nbase_x, dim, nbase, nbase + notconv - 1);
hpsi_func(basis, &hpsi[nbase * dim], nbase_x, dim, nbase, nbase + notconv - 1);

delmem_complex_op()(this->ctx, lagrange);
delmem_complex_op()(this->ctx, vc_ev_vector);
Expand Down
4 changes: 2 additions & 2 deletions source/module_hsolver/diago_david.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ class DiagoDavid : public DiagH<T, Device>
* For eigenvalue problem HX = λX or generalized eigenvalue problem HX = λSX,
* this function computes the product of the Hamiltonian matrix H and a blockvector X.
*
* @param[out] HX Pointer to output blockvector of type `T*`.
* @param[in] X Pointer to input blockvector of type `T*`.
* @param[out] X Pointer to input blockvector of type `T*`.
* @param[in] HX Pointer to output blockvector of type `T*`.
* @param[in] neig Number of eigebpairs required.
* @param[in] dim Dimension of matrix.
* @param[in] id_start Start index of blockvector.
Expand Down
10 changes: 5 additions & 5 deletions source/module_hsolver/hsolver_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -434,8 +434,8 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
else if (this->method == "dav_subspace")
{
auto ngk_pointer = psi.get_ngk_pointer();
auto hpsi_func = [hm, ngk_pointer](T* hpsi_out,
T* psi_in,
auto hpsi_func = [hm, ngk_pointer](T* psi_in,
T* hpsi_out,
const int nband_in,
const int nbasis_in,
const int band_index1,
Expand Down Expand Up @@ -492,9 +492,9 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,

auto ngk_pointer = psi.get_ngk_pointer();
/// wrap hpsi into lambda function, Matrix \times blockvector
/// hpsi(HX, X, nband, dim, band_index1, band_index2)
auto hpsi_func = [hm, ngk_pointer](T* hpsi_out,
T* psi_in,
/// hpsi(X, HX, nband, dim, band_index1, band_index2)
auto hpsi_func = [hm, ngk_pointer](T* psi_in,
T* hpsi_out,
const int nband_in,
const int nbasis_in,
const int band_index1,
Expand Down
2 changes: 1 addition & 1 deletion source/module_hsolver/test/diago_david_float_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ class DiagoDavPrepare
#endif


auto hpsi_func = [phm](std::complex<float>* hpsi_out,std::complex<float>* psi_in,
auto hpsi_func = [phm](std::complex<float>* psi_in,std::complex<float>* hpsi_out,
const int nband_in, const int nbasis_in,
const int band_index1, const int band_index2)
{
Expand Down
2 changes: 1 addition & 1 deletion source/module_hsolver/test/diago_david_real_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ class DiagoDavPrepare
#endif


auto hpsi_func = [phm](double* hpsi_out,double* psi_in,
auto hpsi_func = [phm](double* psi_in,double* hpsi_out,
const int nband_in, const int nbasis_in,
const int band_index1, const int band_index2)
{
Expand Down
2 changes: 1 addition & 1 deletion source/module_hsolver/test/diago_david_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ class DiagoDavPrepare
#endif


auto hpsi_func = [phm](std::complex<double>* hpsi_out,std::complex<double>* psi_in,
auto hpsi_func = [phm](std::complex<double>* psi_in,std::complex<double>* hpsi_out,
const int nband_in, const int nbasis_in,
const int band_index1, const int band_index2)
{
Expand Down
15 changes: 8 additions & 7 deletions source/module_io/restart_exx_csr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,10 @@ namespace ModuleIO
const SparseMatrix<Tdata>& matrix = csr.getMatrix(iR);
for (auto& ijv : matrix.getElements())
{
const int& i = ijv.first.first;
const int& j = ijv.first.second;
Hexxs.at(is).at(ucell.iwt2iat[i]).at({ ucell.iwt2iat[j], { R[0], R[1], R[2] } })(ucell.iwt2iw[i], ucell.iwt2iw[j]) = ijv.second;
const int& npol = ucell.get_npol();
const int& i = ijv.first.first * npol;
const int& j = ijv.first.second * npol;
Hexxs.at(is).at(ucell.iwt2iat[i]).at({ ucell.iwt2iat[j], { R[0], R[1], R[2] } })(ucell.iwt2iw[i] / npol, ucell.iwt2iw[j] / npol) = ijv.second;
}
}
}
Expand All @@ -64,17 +65,17 @@ namespace ModuleIO
int iat2 = a2R_data.first.first;
int nw1 = ucell.atoms[ucell.iat2it[iat1]].nw;
int nw2 = ucell.atoms[ucell.iat2it[iat2]].nw;
int start1 = ucell.atoms[ucell.iat2it[iat1]].stapos_wf + ucell.iat2ia[iat1] * nw1;
int start2 = ucell.atoms[ucell.iat2it[iat2]].stapos_wf + ucell.iat2ia[iat2] * nw2;
int start1 = ucell.atoms[ucell.iat2it[iat1]].stapos_wf / ucell.get_npol() + ucell.iat2ia[iat1] * nw1;
int start2 = ucell.atoms[ucell.iat2it[iat2]].stapos_wf / ucell.get_npol() + ucell.iat2ia[iat2] * nw2;

const TC& R = a2R_data.first.second;
auto& matrix = a2R_data.second;
Abfs::Vector3_Order<int> dR(R[0], R[1], R[2]);
for (int i = 0;i < nw1;++i) {
for (int j = 0;j < nw2;++j) {
target[dR][start1 + i][start2 + j] = ((std::abs(matrix(i, j)) > sparse_threshold) ? matrix(i, j) : static_cast<Tdata>(0));
}
}
}
}
}
}
return target;
Expand Down
4 changes: 4 additions & 0 deletions source/module_io/rhog_io.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -204,11 +204,15 @@ bool ModuleIO::write_rhog(const std::string& fchg,
const int nspin, // GlobalV
const ModuleBase::Matrix3& GT, // from UnitCell, useful for calculating the miller
std::complex<double>** rhog,
const int ipool,
const int irank,
const int nrank)
{
ModuleBase::TITLE("ModuleIO", "write_rhog");
ModuleBase::timer::tick("ModuleIO", "write_rhog");
if (ipool != 0) { return true; }
// only one pool writes the rhog, because rhog in all pools are identical.

// for large-scale data, it is not wise to collect all distributed components to the
// master process and then write the data to the file. Instead, we can write the data
// processer by processer.
Expand Down
1 change: 1 addition & 0 deletions source/module_io/rhog_io.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ bool write_rhog(const std::string& fchg,
const int nspin, // GlobalV
const ModuleBase::Matrix3& GT, // from UnitCell, useful for calculating the miller
std::complex<double>** rhog,
const int ipool,
const int irank,
const int nrank);

Expand Down
4 changes: 2 additions & 2 deletions source/module_lr/hsolver_lrtd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,8 @@ namespace LR
// do diag and add davidson iteration counts up to avg_iter

auto hpsi_func = [pHamilt](
T* hpsi_out,
T* psi_in,
T* hpsi_out,
const int nband_in,
const int nbasis_in,
const int band_index1,
Expand Down Expand Up @@ -119,8 +119,8 @@ namespace LR
comm_info);

std::function<void(T*, T*, const int, const int, const int, const int)> hpsi_func = [pHamilt](
T* hpsi_out,
T* psi_in,
T* hpsi_out,
const int nband_in,
const int nbasis_in,
const int band_index1,
Expand Down

0 comments on commit f91d70b

Please sign in to comment.