Skip to content

Commit

Permalink
remove the use of paraV in most part of esolver
Browse files Browse the repository at this point in the history
  • Loading branch information
kirk0830 committed May 30, 2024
1 parent a05a71f commit d40f130
Show file tree
Hide file tree
Showing 8 changed files with 324 additions and 717 deletions.
12 changes: 11 additions & 1 deletion source/module_elecstate/elecstate_lcao.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,23 @@ class ElecStateLCAO : public ElecState
{
public:
ElecStateLCAO(){} // will be called by ElecStateLCAO_TDDFT
/*
Note: on the removal of LOWF
The entire instance of Local_Orbital_wfc is not really necessary, because
this class only need it to do 2dbcd wavefunction gathering. Therefore,
what is critical is the 2dbcd handle which stores information about the
wavefunction, and another free function to do the 2dbcd gathering.
A future work would be replace the Local_Orbital_wfc with a 2dbcd handle.
A free gathering function will also be needed.
*/
ElecStateLCAO(Charge* chg_in ,
const K_Vectors* klist_in ,
int nks_in,
Local_Orbital_Charge* loc_in ,
Gint_Gamma* gint_gamma_in, //mohan add 2024-04-01
Gint_k* gint_k_in, //mohan add 2024-04-01
Local_Orbital_wfc* lowf_in ,
Local_Orbital_wfc* lowf_in,
ModulePW::PW_Basis* rhopw_in ,
ModulePW::PW_Basis_Big* bigpw_in )
{
Expand Down
1 change: 0 additions & 1 deletion source/module_elecstate/elecstate_lcao_tddft.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ void ElecStateLCAO_TDDFT::psiToRho_td(const psi::Psi<std::complex<double>>& psi)
}
}

//this->loc->cal_dk_k(*this->lowf->gridt, this->wg, *(this->klist));
for (int is = 0; is < GlobalV::NSPIN; is++)
{
ModuleBase::GlobalFunc::ZEROS(this->charge->rho[is], this->charge->nrxx); // mohan 2009-11-10
Expand Down
61 changes: 41 additions & 20 deletions source/module_esolver/esolver_ks_lcao.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,10 +163,26 @@ void ESolver_KS_LCAO<TK, TR>::before_all_runners(Input& inp, UnitCell& ucell)
this->gen_h.LM = &this->LM;

//! pass basis-pointer to EState and Psi
this->LOC.ParaV = this->LOWF.ParaV = this->LM.ParaV = &(this->orb_con.ParaV);
/*
Inform: on getting rid of ORB_control and Parallel_Orbitals
Have to say it is all the stories start, the ORB_control instance pass its Parallel_Orbitals instance to
Local_Orbital_Charge, Local_Orbital_Wfc and LCAO_Matrix, which is actually for getting information
of 2D block-cyclic distribution.
To remove LOC, LOWF and LM use in functions, one must make sure there is no more information imported
to those classes. Then places where to get information from them can be substituted to orb_con
Plan:
1. Specifically for paraV, the thing to do first is to replace the use of ParaV to the oen of ORB_control.
Then remove ORB_control and place paraV somewhere.
*/
this->LOC.ParaV = &(this->orb_con.ParaV);
this->LOWF.ParaV = &(this->orb_con.ParaV);
this->LM.ParaV = &(this->orb_con.ParaV);

// 5) initialize Density Matrix
dynamic_cast<elecstate::ElecStateLCAO<TK>*>(this->pelec)->init_DM(&this->kv, this->LM.ParaV, GlobalV::NSPIN);
dynamic_cast<elecstate::ElecStateLCAO<TK>*>(this->pelec)->init_DM(&this->kv, &(this->orb_con.ParaV), GlobalV::NSPIN);


if (GlobalV::CALCULATION == "get_S")
Expand Down Expand Up @@ -232,7 +248,7 @@ void ESolver_KS_LCAO<TK, TR>::before_all_runners(Input& inp, UnitCell& ucell)
// 10) init HSolver
if (this->phsol == nullptr)
{
this->phsol = new hsolver::HSolverLCAO<TK>(this->LOWF.ParaV);
this->phsol = new hsolver::HSolverLCAO<TK>(&(this->orb_con.ParaV));
this->phsol->method = GlobalV::KS_SOLVER;
}

Expand Down Expand Up @@ -286,7 +302,14 @@ void ESolver_KS_LCAO<TK, TR>::init_after_vc(Input& inp, UnitCell& ucell)
ModuleBase::timer::tick("ESolver_KS_LCAO", "init_after_vc");

ESolver_KS<TK>::init_after_vc(inp, ucell);

/*
Notes: on the removal of LOWF
Following constructor of ElecStateLCAO requires LOWF. However, ElecState only need
LOWF to do wavefunction 2dbcd (2D BlockCyclicDistribution) gathering. So, a free
function is needed to replace the use of LOWF. The function indeed needs the information
about 2dbcd, therefore another instance storing the information is needed instead.
Then that instance will be the input of "the free function to gather".
*/
if (GlobalV::md_prec_level == 2)
{
delete this->pelec;
Expand All @@ -296,7 +319,7 @@ void ESolver_KS_LCAO<TK, TR>::init_after_vc(Input& inp, UnitCell& ucell)
&(this->LOC),
&(this->GG), // mohan add 2024-04-01
&(this->GK), // mohan add 2024-04-01
&(this->LOWF),
&(this->LOWF), // should be replaced by a 2dbcd handle, if insist the "print_psi" must be in ElecState class
this->pw_rho,
this->pw_big);

Expand Down Expand Up @@ -657,10 +680,10 @@ void ESolver_KS_LCAO<TK, TR>::iter_init(const int istep, const int iter)
// first need to calculate the weight according to
// electrons number.
if (istep == 0
&& this->wf.init_wfc == "file"
&& this->LOWF.error == 0)
{
if (iter == 1)
&& this->wf.init_wfc == "file" // Note: on the removal of LOWF
&& this->LOWF.error == 0) // this means the wavefunction is read without any error.
{ // However the I/O of wavefunction are nonsence to be implmented in different places.
if (iter == 1) // once the reading of wavefunction has any error, should exit immediately.
{
std::cout << " WAVEFUN -> CHARGE " << std::endl;

Expand Down Expand Up @@ -831,11 +854,11 @@ void ESolver_KS_LCAO<TK, TR>::hamilt2density(int istep, int iter, double ethr)
#ifdef __EXX
if (GlobalC::exx_info.info_ri.real_number)
{
this->exd->exx_hamilt2density(*this->pelec, *this->LOWF.ParaV, iter);
this->exd->exx_hamilt2density(*this->pelec, this->orb_con.ParaV, iter);
}
else
{
this->exc->exx_hamilt2density(*this->pelec, *this->LOWF.ParaV, iter);
this->exc->exx_hamilt2density(*this->pelec, this->orb_con.ParaV, iter);
}
#endif

Expand All @@ -861,8 +884,6 @@ void ESolver_KS_LCAO<TK, TR>::hamilt2density(int istep, int iter, double ethr)
#ifdef __DEEPKS
if (GlobalV::deepks_scf)
{
const Parallel_Orbitals* pv = this->LOWF.ParaV;

const std::vector<std::vector<TK>>& dm =
dynamic_cast<const elecstate::ElecStateLCAO<TK>*>(this->pelec)->get_DM()->get_DMK_vector();

Expand Down Expand Up @@ -940,7 +961,7 @@ void ESolver_KS_LCAO<TK, TR>::update_pot(const int istep, const int iter)
GlobalV::out_app_flag,
"H",
"data-" + std::to_string(ik),
*this->LOWF.ParaV,
this->orb_con.ParaV,
GlobalV::DRANK);
ModuleIO::save_mat(istep,
s_mat.p,
Expand All @@ -951,7 +972,7 @@ void ESolver_KS_LCAO<TK, TR>::update_pot(const int istep, const int iter)
GlobalV::out_app_flag,
"S",
"data-" + std::to_string(ik),
*this->LOWF.ParaV,
this->orb_con.ParaV,
GlobalV::DRANK);
}
}
Expand Down Expand Up @@ -1048,7 +1069,7 @@ void ESolver_KS_LCAO<TK, TR>::iter_finish(int iter)
if (GlobalC::restart.info_save.save_H && two_level_step > 0 &&
(!GlobalC::exx_info.info_global.separate_loop || iter == 1)) // to avoid saving the same value repeatedly
{
std::vector<TK> Hexxk_save(this->LOWF.ParaV->get_local_size());
std::vector<TK> Hexxk_save(this->orb_con.ParaV.get_local_size());
for (int ik = 0;ik < this->kv.nks;++ik)
{
ModuleBase::GlobalFunc::ZEROS(Hexxk_save.data(), Hexxk_save.size());
Expand All @@ -1057,7 +1078,7 @@ void ESolver_KS_LCAO<TK, TR>::iter_finish(int iter)

opexx_save.contributeHk(ik);

GlobalC::restart.save_disk("Hexx", ik, this->LOWF.ParaV->get_local_size(), Hexxk_save.data());
GlobalC::restart.save_disk("Hexx", ik, this->orb_con.ParaV.get_local_size(), Hexxk_save.data());
}
if (GlobalV::MY_RANK == 0)
{
Expand Down Expand Up @@ -1234,7 +1255,7 @@ void ESolver_KS_LCAO<TK, TR>::after_scf(const int istep)
GlobalC::ucell,
GlobalC::ORB,
GlobalC::GridD,
this->LOWF.ParaV,
&(this->orb_con.ParaV),
*(this->psi),
dynamic_cast<const elecstate::ElecStateLCAO<TK>*>(this->pelec)->get_DM());

Expand All @@ -1251,7 +1272,7 @@ void ESolver_KS_LCAO<TK, TR>::after_scf(const int istep)
RPA_LRI<TK, double> rpa_lri_double(GlobalC::exx_info.info_ri);
rpa_lri_double.cal_postSCF_exx(*dynamic_cast<const elecstate::ElecStateLCAO<TK>*>(this->pelec)->get_DM(), MPI_COMM_WORLD, this->kv);
rpa_lri_double.init(MPI_COMM_WORLD, this->kv);
rpa_lri_double.out_for_RPA(*(this->LOWF.ParaV), *(this->psi), this->pelec);
rpa_lri_double.out_for_RPA(this->orb_con.paraV, *(this->psi), this->pelec);
}
#endif

Expand Down Expand Up @@ -1418,7 +1439,7 @@ ModuleIO::Output_Mat_Sparse<TK> ESolver_KS_LCAO<TK, TR>::create_Output_Mat_Spars
INPUT.out_mat_r,
istep,
this->pelec->pot->get_effective_v(),
*this->LOWF.ParaV,
this->orb_con.ParaV,
this->gen_h, // mohan add 2024-04-06
this->GK, // mohan add 2024-04-01
this->LM,
Expand Down
14 changes: 6 additions & 8 deletions source/module_esolver/esolver_ks_lcao_elec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,23 +109,23 @@ void ESolver_KS_LCAO<TK, TR>::beforesolver(const int istep)
if (GlobalV::GAMMA_ONLY_LOCAL)
{
nsk = GlobalV::NSPIN;
ncol = this->LOWF.ParaV->ncol_bands;
ncol = this->orb_con.ParaV.ncol_bands;
if (GlobalV::KS_SOLVER == "genelpa" || GlobalV::KS_SOLVER == "lapack_gvx" || GlobalV::KS_SOLVER=="pexsi"
|| GlobalV::KS_SOLVER == "cusolver")
{
ncol = this->LOWF.ParaV->ncol;
ncol = this->orb_con.ParaV.ncol;
}
}
else
{
nsk = this->kv.nks;
#ifdef __MPI
ncol = this->LOWF.ParaV->ncol_bands;
ncol = this->orb_con.ParaV.ncol_bands;
#else
ncol = GlobalV::NBANDS;
#endif
}
this->psi = new psi::Psi<TK>(nsk, ncol, this->LOWF.ParaV->nrow, nullptr);
this->psi = new psi::Psi<TK>(nsk, ncol, this->orb_con.ParaV.nrow, nullptr);
}

// prepare grid in Gint
Expand Down Expand Up @@ -574,7 +574,6 @@ void ESolver_KS_LCAO<TK, TR>::nscf(void)
// Peize Lin add 2018-08-14
if (GlobalC::exx_info.info_global.cal_exx)
{
// GlobalC::exx_lcao.cal_exx_elec_nscf(this->LOWF.ParaV[0]);
const std::string file_name_exx = GlobalV::global_out_dir + "HexxR" + std::to_string(GlobalV::MY_RANK);
if (GlobalC::exx_info.info_ri.real_number)
{
Expand Down Expand Up @@ -670,7 +669,7 @@ void ESolver_KS_LCAO<TK, TR>::nscf(void)
this->sf,
this->kv,
this->psi,
this->LOWF.ParaV);
&(this->orb_con.ParaV));
}
else if (INPUT.wannier_method == 2)
{
Expand All @@ -682,7 +681,7 @@ void ESolver_KS_LCAO<TK, TR>::nscf(void)
INPUT.nnkpfile,
INPUT.wannier_spin);

myWannier.calculate(this->pelec->ekb, this->kv, *(this->psi), this->LOWF.ParaV);
myWannier.calculate(this->pelec->ekb, this->kv, *(this->psi), &(this->orb_con.ParaV));
}
#endif
}
Expand All @@ -696,7 +695,6 @@ void ESolver_KS_LCAO<TK, TR>::nscf(void)

// below is for DeePKS NSCF calculation
#ifdef __DEEPKS
const Parallel_Orbitals* pv = this->LOWF.ParaV;
if (GlobalV::deepks_out_labels || GlobalV::deepks_scf)
{
const elecstate::DensityMatrix<TK, double>* dm
Expand Down
13 changes: 7 additions & 6 deletions source/module_esolver/esolver_ks_lcao_tddft.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,8 @@ void ESolver_KS_LCAO_TDDFT::before_all_runners(Input& inp, UnitCell& ucell)
// pass Hamilt-pointer to Operator
this->gen_h.LM = &this->LM;
// pass basis-pointer to EState and Psi
this->LOC.ParaV = this->LOWF.ParaV = this->LM.ParaV;
this->LOC.ParaV = this->LM.ParaV;;
this->LOWF.ParaV = this->LM.ParaV;

// init DensityMatrix
dynamic_cast<elecstate::ElecStateLCAO<std::complex<double>>*>(this->pelec)
Expand All @@ -111,7 +112,7 @@ void ESolver_KS_LCAO_TDDFT::before_all_runners(Input& inp, UnitCell& ucell)
// init Psi, HSolver, ElecState, Hamilt
if (this->phsol == nullptr)
{
this->phsol = new hsolver::HSolverLCAO<std::complex<double>>(this->LOWF.ParaV);
this->phsol = new hsolver::HSolverLCAO<std::complex<double>>(this->LM.ParaV);
this->phsol->method = GlobalV::KS_SOLVER;
}

Expand Down Expand Up @@ -274,7 +275,7 @@ void ESolver_KS_LCAO_TDDFT::update_pot(const int istep, const int iter)
GlobalV::out_app_flag,
"H",
"data-" + std::to_string(ik),
*this->LOWF.ParaV,
*this->LM.ParaV,
GlobalV::DRANK);

ModuleIO::save_mat(istep,
Expand All @@ -286,7 +287,7 @@ void ESolver_KS_LCAO_TDDFT::update_pot(const int istep, const int iter)
GlobalV::out_app_flag,
"S",
"data-" + std::to_string(ik),
*this->LOWF.ParaV,
*this->LM.ParaV,
GlobalV::DRANK);
}
}
Expand Down Expand Up @@ -333,8 +334,8 @@ void ESolver_KS_LCAO_TDDFT::update_pot(const int istep, const int iter)
{
#ifdef __MPI
this->psi_laststep = new psi::Psi<std::complex<double>>(kv.nks,
this->LOWF.ParaV->ncol_bands,
this->LOWF.ParaV->nrow,
this->LM.ParaV->ncol_bands,
this->LM.ParaV->nrow,
nullptr);
#else
this->psi_laststep = new psi::Psi<std::complex<double>>(kv.nks, GlobalV::NBANDS, GlobalV::NLOCAL, nullptr);
Expand Down
2 changes: 1 addition & 1 deletion source/module_esolver/io_npz.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ void ESolver_KS_LCAO<TK, TR>::read_mat_npz(std::string& zipname, hamilt::HContai
{
ModuleBase::TITLE("LCAO_Hamilt","read_mat_npz");

const Parallel_Orbitals* paraV = this->LOWF.ParaV;
const Parallel_Orbitals* paraV = &(this->orb_con.ParaV);

#ifdef __USECNPY

Expand Down
Loading

0 comments on commit d40f130

Please sign in to comment.