Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor: eliminate most use of LOWF instance in esolver - step 1 #4273

Merged
merged 13 commits into from
Jun 3, 2024
Merged
12 changes: 11 additions & 1 deletion source/module_elecstate/elecstate_lcao.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,23 @@ class ElecStateLCAO : public ElecState
{
public:
ElecStateLCAO(){} // will be called by ElecStateLCAO_TDDFT
/*
Note: on the removal of LOWF
The entire instance of Local_Orbital_wfc is not really necessary, because
this class only need it to do 2dbcd wavefunction gathering. Therefore,
what is critical is the 2dbcd handle which stores information about the
wavefunction, and another free function to do the 2dbcd gathering.

A future work would be replace the Local_Orbital_wfc with a 2dbcd handle.
A free gathering function will also be needed.
*/
ElecStateLCAO(Charge* chg_in ,
const K_Vectors* klist_in ,
int nks_in,
Local_Orbital_Charge* loc_in ,
Gint_Gamma* gint_gamma_in, //mohan add 2024-04-01
Gint_k* gint_k_in, //mohan add 2024-04-01
Local_Orbital_wfc* lowf_in ,
Local_Orbital_wfc* lowf_in,
ModulePW::PW_Basis* rhopw_in ,
ModulePW::PW_Basis_Big* bigpw_in )
{
Expand Down
1 change: 0 additions & 1 deletion source/module_elecstate/elecstate_lcao_tddft.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ void ElecStateLCAO_TDDFT::psiToRho_td(const psi::Psi<std::complex<double>>& psi)
}
}

//this->loc->cal_dk_k(*this->lowf->gridt, this->wg, *(this->klist));
for (int is = 0; is < GlobalV::NSPIN; is++)
{
ModuleBase::GlobalFunc::ZEROS(this->charge->rho[is], this->charge->nrxx); // mohan 2009-11-10
Expand Down
63 changes: 42 additions & 21 deletions source/module_esolver/esolver_ks_lcao.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,10 +163,26 @@ void ESolver_KS_LCAO<TK, TR>::before_all_runners(Input& inp, UnitCell& ucell)
this->gen_h.LM = &this->LM;

//! pass basis-pointer to EState and Psi
this->LOC.ParaV = this->LOWF.ParaV = this->LM.ParaV = &(this->orb_con.ParaV);
/*
Inform: on getting rid of ORB_control and Parallel_Orbitals

Have to say it is all the stories start, the ORB_control instance pass its Parallel_Orbitals instance to
Local_Orbital_Charge, Local_Orbital_Wfc and LCAO_Matrix, which is actually for getting information
of 2D block-cyclic distribution.

To remove LOC, LOWF and LM use in functions, one must make sure there is no more information imported
to those classes. Then places where to get information from them can be substituted to orb_con

Plan:
1. Specifically for paraV, the thing to do first is to replace the use of ParaV to the oen of ORB_control.
Then remove ORB_control and place paraV somewhere.
*/
this->LOC.ParaV = &(this->orb_con.ParaV);
this->LOWF.ParaV = &(this->orb_con.ParaV);
this->LM.ParaV = &(this->orb_con.ParaV);

// 5) initialize Density Matrix
dynamic_cast<elecstate::ElecStateLCAO<TK>*>(this->pelec)->init_DM(&this->kv, this->LM.ParaV, GlobalV::NSPIN);
dynamic_cast<elecstate::ElecStateLCAO<TK>*>(this->pelec)->init_DM(&this->kv, &(this->orb_con.ParaV), GlobalV::NSPIN);


if (GlobalV::CALCULATION == "get_S")
Expand Down Expand Up @@ -232,7 +248,7 @@ void ESolver_KS_LCAO<TK, TR>::before_all_runners(Input& inp, UnitCell& ucell)
// 10) init HSolver
if (this->phsol == nullptr)
{
this->phsol = new hsolver::HSolverLCAO<TK>(this->LOWF.ParaV);
this->phsol = new hsolver::HSolverLCAO<TK>(&(this->orb_con.ParaV));
this->phsol->method = GlobalV::KS_SOLVER;
}

Expand Down Expand Up @@ -286,7 +302,14 @@ void ESolver_KS_LCAO<TK, TR>::init_after_vc(Input& inp, UnitCell& ucell)
ModuleBase::timer::tick("ESolver_KS_LCAO", "init_after_vc");

ESolver_KS<TK>::init_after_vc(inp, ucell);

/*
Notes: on the removal of LOWF
Following constructor of ElecStateLCAO requires LOWF. However, ElecState only need
LOWF to do wavefunction 2dbcd (2D BlockCyclicDistribution) gathering. So, a free
function is needed to replace the use of LOWF. The function indeed needs the information
about 2dbcd, therefore another instance storing the information is needed instead.
Then that instance will be the input of "the free function to gather".
*/
if (GlobalV::md_prec_level == 2)
{
delete this->pelec;
Expand All @@ -296,7 +319,7 @@ void ESolver_KS_LCAO<TK, TR>::init_after_vc(Input& inp, UnitCell& ucell)
&(this->LOC),
&(this->GG), // mohan add 2024-04-01
&(this->GK), // mohan add 2024-04-01
&(this->LOWF),
&(this->LOWF), // should be replaced by a 2dbcd handle, if insist the "print_psi" must be in ElecState class
this->pw_rho,
this->pw_big);

Expand Down Expand Up @@ -657,10 +680,10 @@ void ESolver_KS_LCAO<TK, TR>::iter_init(const int istep, const int iter)
// first need to calculate the weight according to
// electrons number.
if (istep == 0
&& this->wf.init_wfc == "file"
&& this->LOWF.error == 0)
{
if (iter == 1)
&& this->wf.init_wfc == "file" // Note: on the removal of LOWF
&& this->LOWF.error == 0) // this means the wavefunction is read without any error.
{ // However the I/O of wavefunction are nonsence to be implmented in different places.
if (iter == 1) // once the reading of wavefunction has any error, should exit immediately.
{
std::cout << " WAVEFUN -> CHARGE " << std::endl;

Expand Down Expand Up @@ -831,11 +854,11 @@ void ESolver_KS_LCAO<TK, TR>::hamilt2density(int istep, int iter, double ethr)
#ifdef __EXX
if (GlobalC::exx_info.info_ri.real_number)
{
this->exd->exx_hamilt2density(*this->pelec, *this->LOWF.ParaV, iter);
this->exd->exx_hamilt2density(*this->pelec, this->orb_con.ParaV, iter);
}
else
{
this->exc->exx_hamilt2density(*this->pelec, *this->LOWF.ParaV, iter);
this->exc->exx_hamilt2density(*this->pelec, this->orb_con.ParaV, iter);
}
#endif

Expand All @@ -861,8 +884,6 @@ void ESolver_KS_LCAO<TK, TR>::hamilt2density(int istep, int iter, double ethr)
#ifdef __DEEPKS
if (GlobalV::deepks_scf)
{
const Parallel_Orbitals* pv = this->LOWF.ParaV;

const std::vector<std::vector<TK>>& dm =
dynamic_cast<const elecstate::ElecStateLCAO<TK>*>(this->pelec)->get_DM()->get_DMK_vector();

Expand Down Expand Up @@ -940,7 +961,7 @@ void ESolver_KS_LCAO<TK, TR>::update_pot(const int istep, const int iter)
GlobalV::out_app_flag,
"H",
"data-" + std::to_string(ik),
*this->LOWF.ParaV,
this->orb_con.ParaV,
GlobalV::DRANK);
ModuleIO::save_mat(istep,
s_mat.p,
Expand All @@ -951,7 +972,7 @@ void ESolver_KS_LCAO<TK, TR>::update_pot(const int istep, const int iter)
GlobalV::out_app_flag,
"S",
"data-" + std::to_string(ik),
*this->LOWF.ParaV,
this->orb_con.ParaV,
GlobalV::DRANK);
}
}
Expand Down Expand Up @@ -1048,16 +1069,16 @@ void ESolver_KS_LCAO<TK, TR>::iter_finish(int iter)
if (GlobalC::restart.info_save.save_H && two_level_step > 0 &&
(!GlobalC::exx_info.info_global.separate_loop || iter == 1)) // to avoid saving the same value repeatedly
{
std::vector<TK> Hexxk_save(this->LOWF.ParaV->get_local_size());
for (int ik = 0;ik < this->kv.get_nks();++ik)
std::vector<TK> Hexxk_save(this->orb_con.ParaV.get_local_size());
for (int ik = 0; ik < this->kv.get_nks(); ++ik)
{
ModuleBase::GlobalFunc::ZEROS(Hexxk_save.data(), Hexxk_save.size());

hamilt::OperatorEXX<hamilt::OperatorLCAO<TK, TR>> opexx_save(&this->LM, nullptr, &Hexxk_save, this->kv);

opexx_save.contributeHk(ik);

GlobalC::restart.save_disk("Hexx", ik, this->LOWF.ParaV->get_local_size(), Hexxk_save.data());
GlobalC::restart.save_disk("Hexx", ik, this->orb_con.ParaV.get_local_size(), Hexxk_save.data());
}
if (GlobalV::MY_RANK == 0)
{
Expand Down Expand Up @@ -1234,7 +1255,7 @@ void ESolver_KS_LCAO<TK, TR>::after_scf(const int istep)
GlobalC::ucell,
GlobalC::ORB,
GlobalC::GridD,
this->LOWF.ParaV,
&(this->orb_con.ParaV),
*(this->psi),
dynamic_cast<const elecstate::ElecStateLCAO<TK>*>(this->pelec)->get_DM());

Expand All @@ -1251,7 +1272,7 @@ void ESolver_KS_LCAO<TK, TR>::after_scf(const int istep)
RPA_LRI<TK, double> rpa_lri_double(GlobalC::exx_info.info_ri);
rpa_lri_double.cal_postSCF_exx(*dynamic_cast<const elecstate::ElecStateLCAO<TK>*>(this->pelec)->get_DM(), MPI_COMM_WORLD, this->kv);
rpa_lri_double.init(MPI_COMM_WORLD, this->kv);
rpa_lri_double.out_for_RPA(*(this->LOWF.ParaV), *(this->psi), this->pelec);
rpa_lri_double.out_for_RPA(this->orb_con.ParaV, *(this->psi), this->pelec);
}
#endif

Expand Down Expand Up @@ -1418,7 +1439,7 @@ ModuleIO::Output_Mat_Sparse<TK> ESolver_KS_LCAO<TK, TR>::create_Output_Mat_Spars
INPUT.out_mat_r,
istep,
this->pelec->pot->get_effective_v(),
*this->LOWF.ParaV,
this->orb_con.ParaV,
this->gen_h, // mohan add 2024-04-06
this->GK, // mohan add 2024-04-01
this->LM,
Expand Down
14 changes: 6 additions & 8 deletions source/module_esolver/esolver_ks_lcao_elec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,23 +109,23 @@ void ESolver_KS_LCAO<TK, TR>::beforesolver(const int istep)
if (GlobalV::GAMMA_ONLY_LOCAL)
{
nsk = GlobalV::NSPIN;
ncol = this->LOWF.ParaV->ncol_bands;
ncol = this->orb_con.ParaV.ncol_bands;
if (GlobalV::KS_SOLVER == "genelpa" || GlobalV::KS_SOLVER == "lapack_gvx" || GlobalV::KS_SOLVER == "pexsi"
|| GlobalV::KS_SOLVER == "cusolver")
{
ncol = this->LOWF.ParaV->ncol;
ncol = this->orb_con.ParaV.ncol;
}
}
else
{
nsk = this->kv.get_nks();
#ifdef __MPI
ncol = this->LOWF.ParaV->ncol_bands;
ncol = this->orb_con.ParaV.ncol_bands;
#else
ncol = GlobalV::NBANDS;
#endif
}
this->psi = new psi::Psi<TK>(nsk, ncol, this->LOWF.ParaV->nrow, nullptr);
this->psi = new psi::Psi<TK>(nsk, ncol, this->orb_con.ParaV.nrow, nullptr);
}

// prepare grid in Gint
Expand Down Expand Up @@ -504,7 +504,6 @@ void ESolver_KS_LCAO<TK, TR>::nscf(void)
// Peize Lin add 2018-08-14
if (GlobalC::exx_info.info_global.cal_exx)
{
// GlobalC::exx_lcao.cal_exx_elec_nscf(this->LOWF.ParaV[0]);
const std::string file_name_exx = GlobalV::global_out_dir + "HexxR" + std::to_string(GlobalV::MY_RANK);
if (GlobalC::exx_info.info_ri.real_number)
{
Expand Down Expand Up @@ -600,7 +599,7 @@ void ESolver_KS_LCAO<TK, TR>::nscf(void)
this->sf,
this->kv,
this->psi,
this->LOWF.ParaV);
&(this->orb_con.ParaV));
}
else if (INPUT.wannier_method == 2)
{
Expand All @@ -612,7 +611,7 @@ void ESolver_KS_LCAO<TK, TR>::nscf(void)
INPUT.nnkpfile,
INPUT.wannier_spin);

myWannier.calculate(this->pelec->ekb, this->kv, *(this->psi), this->LOWF.ParaV);
myWannier.calculate(this->pelec->ekb, this->kv, *(this->psi), &(this->orb_con.ParaV));
}
#endif
}
Expand All @@ -626,7 +625,6 @@ void ESolver_KS_LCAO<TK, TR>::nscf(void)

// below is for DeePKS NSCF calculation
#ifdef __DEEPKS
const Parallel_Orbitals* pv = this->LOWF.ParaV;
if (GlobalV::deepks_out_labels || GlobalV::deepks_scf)
{
const elecstate::DensityMatrix<TK, double>* dm
Expand Down
13 changes: 7 additions & 6 deletions source/module_esolver/esolver_ks_lcao_tddft.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,8 @@ void ESolver_KS_LCAO_TDDFT::before_all_runners(Input& inp, UnitCell& ucell)
// pass Hamilt-pointer to Operator
this->gen_h.LM = &this->LM;
// pass basis-pointer to EState and Psi
this->LOC.ParaV = this->LOWF.ParaV = this->LM.ParaV;
this->LOC.ParaV = this->LM.ParaV;;
this->LOWF.ParaV = this->LM.ParaV;

// init DensityMatrix
dynamic_cast<elecstate::ElecStateLCAO<std::complex<double>>*>(this->pelec)
Expand All @@ -111,7 +112,7 @@ void ESolver_KS_LCAO_TDDFT::before_all_runners(Input& inp, UnitCell& ucell)
// init Psi, HSolver, ElecState, Hamilt
if (this->phsol == nullptr)
{
this->phsol = new hsolver::HSolverLCAO<std::complex<double>>(this->LOWF.ParaV);
this->phsol = new hsolver::HSolverLCAO<std::complex<double>>(this->LM.ParaV);
this->phsol->method = GlobalV::KS_SOLVER;
}

Expand Down Expand Up @@ -274,7 +275,7 @@ void ESolver_KS_LCAO_TDDFT::update_pot(const int istep, const int iter)
GlobalV::out_app_flag,
"H",
"data-" + std::to_string(ik),
*this->LOWF.ParaV,
*this->LM.ParaV,
GlobalV::DRANK);

ModuleIO::save_mat(istep,
Expand All @@ -286,7 +287,7 @@ void ESolver_KS_LCAO_TDDFT::update_pot(const int istep, const int iter)
GlobalV::out_app_flag,
"S",
"data-" + std::to_string(ik),
*this->LOWF.ParaV,
*this->LM.ParaV,
GlobalV::DRANK);
}
}
Expand Down Expand Up @@ -333,8 +334,8 @@ void ESolver_KS_LCAO_TDDFT::update_pot(const int istep, const int iter)
{
#ifdef __MPI
this->psi_laststep = new psi::Psi<std::complex<double>>(kv.get_nks(),
this->LOWF.ParaV->ncol_bands,
this->LOWF.ParaV->nrow,
this->LM.ParaV->ncol_bands,
this->LM.ParaV->nrow,
nullptr);
#else
this->psi_laststep = new psi::Psi<std::complex<double>>(kv.get_nks(), GlobalV::NBANDS, GlobalV::NLOCAL, nullptr);
Expand Down
2 changes: 1 addition & 1 deletion source/module_esolver/io_npz.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ void ESolver_KS_LCAO<TK, TR>::read_mat_npz(std::string& zipname, hamilt::HContai
{
ModuleBase::TITLE("LCAO_Hamilt","read_mat_npz");

const Parallel_Orbitals* paraV = this->LOWF.ParaV;
const Parallel_Orbitals* paraV = &(this->orb_con.ParaV);

#ifdef __USECNPY

Expand Down
24 changes: 19 additions & 5 deletions source/module_hamilt_lcao/hamilt_lcaodft/local_orbital_wfc.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,20 @@ class Local_Orbital_wfc
{
public:

Local_Orbital_wfc();
Local_Orbital_wfc();
~Local_Orbital_wfc();

// refactor new implementation: RAII
// a more look-looking name would be LocalOrbitalWfc, I suppose...
Local_Orbital_wfc(const int& nspin,
const int& nks,
const int& nbands,
const int& nlocal,
const int& gamma_only,
const int& nb2d,
const std::string& ks_solver,
const std::string& readin_dir);
//
void initialize();
///=========================================
/// grid wfc
/// used to generate density matrix: LOC.DM_R,
Expand All @@ -25,7 +36,10 @@ class Local_Orbital_wfc
std::complex<double>*** wfc_k_grid; // [NK, GlobalV::NBANDS, GlobalV::NLOCAL]
std::complex<double>* wfc_k_grid2; // [NK*GlobalV::NBANDS*GlobalV::NLOCAL]

// pointer to const Parallel_Orbitals object
const Parallel_Orbitals* ParaV;
// pointer to const Grid_Technique object, although have no idea about what it is...
// the name of Grid_Technique should be changed to be more informative
const Grid_Technique* gridt;

/// read wavefunction coefficients: LOWF_*.txt
Expand Down Expand Up @@ -102,7 +116,7 @@ class Local_Orbital_wfc
int nks;
};

#ifdef __MPI

// the function should not be defined here!! mohan 2024-03-28
template <typename T>
int Local_Orbital_wfc::set_wfc_grid(int naroc[2],
Expand All @@ -116,6 +130,7 @@ int Local_Orbital_wfc::set_wfc_grid(int naroc[2],
int myid,
T** ctot)
{
#ifdef __MPI
ModuleBase::TITLE(" Local_Orbital_wfc", "set_wfc_grid");
if (!wfc && !ctot)
{
Expand All @@ -142,8 +157,7 @@ int Local_Orbital_wfc::set_wfc_grid(int naroc[2],
}
}
}
#endif
return 0;
}
#endif

#endif