diff --git a/source/module_elecstate/elecstate_lcao.h b/source/module_elecstate/elecstate_lcao.h index c3e7ae3a2d..a3637fc64d 100644 --- a/source/module_elecstate/elecstate_lcao.h +++ b/source/module_elecstate/elecstate_lcao.h @@ -16,13 +16,23 @@ class ElecStateLCAO : public ElecState { public: ElecStateLCAO(){} // will be called by ElecStateLCAO_TDDFT + /* + Note: on the removal of LOWF + The entire instance of Local_Orbital_wfc is not really necessary, because + this class only need it to do 2dbcd wavefunction gathering. Therefore, + what is critical is the 2dbcd handle which stores information about the + wavefunction, and another free function to do the 2dbcd gathering. + + A future work would be replace the Local_Orbital_wfc with a 2dbcd handle. + A free gathering function will also be needed. + */ ElecStateLCAO(Charge* chg_in , const K_Vectors* klist_in , int nks_in, Local_Orbital_Charge* loc_in , Gint_Gamma* gint_gamma_in, //mohan add 2024-04-01 Gint_k* gint_k_in, //mohan add 2024-04-01 - Local_Orbital_wfc* lowf_in , + Local_Orbital_wfc* lowf_in, ModulePW::PW_Basis* rhopw_in , ModulePW::PW_Basis_Big* bigpw_in ) { diff --git a/source/module_elecstate/elecstate_lcao_tddft.cpp b/source/module_elecstate/elecstate_lcao_tddft.cpp index b25ac19338..114ca33974 100644 --- a/source/module_elecstate/elecstate_lcao_tddft.cpp +++ b/source/module_elecstate/elecstate_lcao_tddft.cpp @@ -51,7 +51,6 @@ void ElecStateLCAO_TDDFT::psiToRho_td(const psi::Psi>& psi) } } - //this->loc->cal_dk_k(*this->lowf->gridt, this->wg, *(this->klist)); for (int is = 0; is < GlobalV::NSPIN; is++) { ModuleBase::GlobalFunc::ZEROS(this->charge->rho[is], this->charge->nrxx); // mohan 2009-11-10 diff --git a/source/module_esolver/esolver_ks_lcao.cpp b/source/module_esolver/esolver_ks_lcao.cpp index 8b1066073f..317bcf109f 100644 --- a/source/module_esolver/esolver_ks_lcao.cpp +++ b/source/module_esolver/esolver_ks_lcao.cpp @@ -163,10 +163,26 @@ void ESolver_KS_LCAO::before_all_runners(Input& inp, UnitCell& ucell) this->gen_h.LM = &this->LM; //! pass basis-pointer to EState and Psi - this->LOC.ParaV = this->LOWF.ParaV = this->LM.ParaV = &(this->orb_con.ParaV); + /* + Inform: on getting rid of ORB_control and Parallel_Orbitals + + Have to say it is all the stories start, the ORB_control instance pass its Parallel_Orbitals instance to + Local_Orbital_Charge, Local_Orbital_Wfc and LCAO_Matrix, which is actually for getting information + of 2D block-cyclic distribution. + + To remove LOC, LOWF and LM use in functions, one must make sure there is no more information imported + to those classes. Then places where to get information from them can be substituted to orb_con + + Plan: + 1. Specifically for paraV, the thing to do first is to replace the use of ParaV to the oen of ORB_control. + Then remove ORB_control and place paraV somewhere. + */ + this->LOC.ParaV = &(this->orb_con.ParaV); + this->LOWF.ParaV = &(this->orb_con.ParaV); + this->LM.ParaV = &(this->orb_con.ParaV); // 5) initialize Density Matrix - dynamic_cast*>(this->pelec)->init_DM(&this->kv, this->LM.ParaV, GlobalV::NSPIN); + dynamic_cast*>(this->pelec)->init_DM(&this->kv, &(this->orb_con.ParaV), GlobalV::NSPIN); if (GlobalV::CALCULATION == "get_S") @@ -232,7 +248,7 @@ void ESolver_KS_LCAO::before_all_runners(Input& inp, UnitCell& ucell) // 10) init HSolver if (this->phsol == nullptr) { - this->phsol = new hsolver::HSolverLCAO(this->LOWF.ParaV); + this->phsol = new hsolver::HSolverLCAO(&(this->orb_con.ParaV)); this->phsol->method = GlobalV::KS_SOLVER; } @@ -286,7 +302,14 @@ void ESolver_KS_LCAO::init_after_vc(Input& inp, UnitCell& ucell) ModuleBase::timer::tick("ESolver_KS_LCAO", "init_after_vc"); ESolver_KS::init_after_vc(inp, ucell); - + /* + Notes: on the removal of LOWF + Following constructor of ElecStateLCAO requires LOWF. However, ElecState only need + LOWF to do wavefunction 2dbcd (2D BlockCyclicDistribution) gathering. So, a free + function is needed to replace the use of LOWF. The function indeed needs the information + about 2dbcd, therefore another instance storing the information is needed instead. + Then that instance will be the input of "the free function to gather". + */ if (GlobalV::md_prec_level == 2) { delete this->pelec; @@ -296,7 +319,7 @@ void ESolver_KS_LCAO::init_after_vc(Input& inp, UnitCell& ucell) &(this->LOC), &(this->GG), // mohan add 2024-04-01 &(this->GK), // mohan add 2024-04-01 - &(this->LOWF), + &(this->LOWF), // should be replaced by a 2dbcd handle, if insist the "print_psi" must be in ElecState class this->pw_rho, this->pw_big); @@ -657,10 +680,10 @@ void ESolver_KS_LCAO::iter_init(const int istep, const int iter) // first need to calculate the weight according to // electrons number. if (istep == 0 - && this->wf.init_wfc == "file" - && this->LOWF.error == 0) - { - if (iter == 1) + && this->wf.init_wfc == "file" // Note: on the removal of LOWF + && this->LOWF.error == 0) // this means the wavefunction is read without any error. + { // However the I/O of wavefunction are nonsence to be implmented in different places. + if (iter == 1) // once the reading of wavefunction has any error, should exit immediately. { std::cout << " WAVEFUN -> CHARGE " << std::endl; @@ -831,11 +854,11 @@ void ESolver_KS_LCAO::hamilt2density(int istep, int iter, double ethr) #ifdef __EXX if (GlobalC::exx_info.info_ri.real_number) { - this->exd->exx_hamilt2density(*this->pelec, *this->LOWF.ParaV, iter); + this->exd->exx_hamilt2density(*this->pelec, this->orb_con.ParaV, iter); } else { - this->exc->exx_hamilt2density(*this->pelec, *this->LOWF.ParaV, iter); + this->exc->exx_hamilt2density(*this->pelec, this->orb_con.ParaV, iter); } #endif @@ -861,8 +884,6 @@ void ESolver_KS_LCAO::hamilt2density(int istep, int iter, double ethr) #ifdef __DEEPKS if (GlobalV::deepks_scf) { - const Parallel_Orbitals* pv = this->LOWF.ParaV; - const std::vector>& dm = dynamic_cast*>(this->pelec)->get_DM()->get_DMK_vector(); @@ -940,7 +961,7 @@ void ESolver_KS_LCAO::update_pot(const int istep, const int iter) GlobalV::out_app_flag, "H", "data-" + std::to_string(ik), - *this->LOWF.ParaV, + this->orb_con.ParaV, GlobalV::DRANK); ModuleIO::save_mat(istep, s_mat.p, @@ -951,7 +972,7 @@ void ESolver_KS_LCAO::update_pot(const int istep, const int iter) GlobalV::out_app_flag, "S", "data-" + std::to_string(ik), - *this->LOWF.ParaV, + this->orb_con.ParaV, GlobalV::DRANK); } } @@ -1048,8 +1069,8 @@ void ESolver_KS_LCAO::iter_finish(int iter) if (GlobalC::restart.info_save.save_H && two_level_step > 0 && (!GlobalC::exx_info.info_global.separate_loop || iter == 1)) // to avoid saving the same value repeatedly { - std::vector Hexxk_save(this->LOWF.ParaV->get_local_size()); - for (int ik = 0;ik < this->kv.get_nks();++ik) + std::vector Hexxk_save(this->orb_con.ParaV.get_local_size()); + for (int ik = 0; ik < this->kv.get_nks(); ++ik) { ModuleBase::GlobalFunc::ZEROS(Hexxk_save.data(), Hexxk_save.size()); @@ -1057,7 +1078,7 @@ void ESolver_KS_LCAO::iter_finish(int iter) opexx_save.contributeHk(ik); - GlobalC::restart.save_disk("Hexx", ik, this->LOWF.ParaV->get_local_size(), Hexxk_save.data()); + GlobalC::restart.save_disk("Hexx", ik, this->orb_con.ParaV.get_local_size(), Hexxk_save.data()); } if (GlobalV::MY_RANK == 0) { @@ -1234,7 +1255,7 @@ void ESolver_KS_LCAO::after_scf(const int istep) GlobalC::ucell, GlobalC::ORB, GlobalC::GridD, - this->LOWF.ParaV, + &(this->orb_con.ParaV), *(this->psi), dynamic_cast*>(this->pelec)->get_DM()); @@ -1251,7 +1272,7 @@ void ESolver_KS_LCAO::after_scf(const int istep) RPA_LRI rpa_lri_double(GlobalC::exx_info.info_ri); rpa_lri_double.cal_postSCF_exx(*dynamic_cast*>(this->pelec)->get_DM(), MPI_COMM_WORLD, this->kv); rpa_lri_double.init(MPI_COMM_WORLD, this->kv); - rpa_lri_double.out_for_RPA(*(this->LOWF.ParaV), *(this->psi), this->pelec); + rpa_lri_double.out_for_RPA(this->orb_con.ParaV, *(this->psi), this->pelec); } #endif @@ -1418,7 +1439,7 @@ ModuleIO::Output_Mat_Sparse ESolver_KS_LCAO::create_Output_Mat_Spars INPUT.out_mat_r, istep, this->pelec->pot->get_effective_v(), - *this->LOWF.ParaV, + this->orb_con.ParaV, this->gen_h, // mohan add 2024-04-06 this->GK, // mohan add 2024-04-01 this->LM, diff --git a/source/module_esolver/esolver_ks_lcao_elec.cpp b/source/module_esolver/esolver_ks_lcao_elec.cpp index 882b427e76..830cad8287 100644 --- a/source/module_esolver/esolver_ks_lcao_elec.cpp +++ b/source/module_esolver/esolver_ks_lcao_elec.cpp @@ -109,23 +109,23 @@ void ESolver_KS_LCAO::beforesolver(const int istep) if (GlobalV::GAMMA_ONLY_LOCAL) { nsk = GlobalV::NSPIN; - ncol = this->LOWF.ParaV->ncol_bands; + ncol = this->orb_con.ParaV.ncol_bands; if (GlobalV::KS_SOLVER == "genelpa" || GlobalV::KS_SOLVER == "lapack_gvx" || GlobalV::KS_SOLVER == "pexsi" || GlobalV::KS_SOLVER == "cusolver") { - ncol = this->LOWF.ParaV->ncol; + ncol = this->orb_con.ParaV.ncol; } } else { nsk = this->kv.get_nks(); #ifdef __MPI - ncol = this->LOWF.ParaV->ncol_bands; + ncol = this->orb_con.ParaV.ncol_bands; #else ncol = GlobalV::NBANDS; #endif } - this->psi = new psi::Psi(nsk, ncol, this->LOWF.ParaV->nrow, nullptr); + this->psi = new psi::Psi(nsk, ncol, this->orb_con.ParaV.nrow, nullptr); } // prepare grid in Gint @@ -504,7 +504,6 @@ void ESolver_KS_LCAO::nscf(void) // Peize Lin add 2018-08-14 if (GlobalC::exx_info.info_global.cal_exx) { - // GlobalC::exx_lcao.cal_exx_elec_nscf(this->LOWF.ParaV[0]); const std::string file_name_exx = GlobalV::global_out_dir + "HexxR" + std::to_string(GlobalV::MY_RANK); if (GlobalC::exx_info.info_ri.real_number) { @@ -600,7 +599,7 @@ void ESolver_KS_LCAO::nscf(void) this->sf, this->kv, this->psi, - this->LOWF.ParaV); + &(this->orb_con.ParaV)); } else if (INPUT.wannier_method == 2) { @@ -612,7 +611,7 @@ void ESolver_KS_LCAO::nscf(void) INPUT.nnkpfile, INPUT.wannier_spin); - myWannier.calculate(this->pelec->ekb, this->kv, *(this->psi), this->LOWF.ParaV); + myWannier.calculate(this->pelec->ekb, this->kv, *(this->psi), &(this->orb_con.ParaV)); } #endif } @@ -626,7 +625,6 @@ void ESolver_KS_LCAO::nscf(void) // below is for DeePKS NSCF calculation #ifdef __DEEPKS - const Parallel_Orbitals* pv = this->LOWF.ParaV; if (GlobalV::deepks_out_labels || GlobalV::deepks_scf) { const elecstate::DensityMatrix* dm diff --git a/source/module_esolver/esolver_ks_lcao_tddft.cpp b/source/module_esolver/esolver_ks_lcao_tddft.cpp index 9bd2915ee8..99d2c052f6 100644 --- a/source/module_esolver/esolver_ks_lcao_tddft.cpp +++ b/source/module_esolver/esolver_ks_lcao_tddft.cpp @@ -102,7 +102,8 @@ void ESolver_KS_LCAO_TDDFT::before_all_runners(Input& inp, UnitCell& ucell) // pass Hamilt-pointer to Operator this->gen_h.LM = &this->LM; // pass basis-pointer to EState and Psi - this->LOC.ParaV = this->LOWF.ParaV = this->LM.ParaV; + this->LOC.ParaV = this->LM.ParaV;; + this->LOWF.ParaV = this->LM.ParaV; // init DensityMatrix dynamic_cast>*>(this->pelec) @@ -111,7 +112,7 @@ void ESolver_KS_LCAO_TDDFT::before_all_runners(Input& inp, UnitCell& ucell) // init Psi, HSolver, ElecState, Hamilt if (this->phsol == nullptr) { - this->phsol = new hsolver::HSolverLCAO>(this->LOWF.ParaV); + this->phsol = new hsolver::HSolverLCAO>(this->LM.ParaV); this->phsol->method = GlobalV::KS_SOLVER; } @@ -274,7 +275,7 @@ void ESolver_KS_LCAO_TDDFT::update_pot(const int istep, const int iter) GlobalV::out_app_flag, "H", "data-" + std::to_string(ik), - *this->LOWF.ParaV, + *this->LM.ParaV, GlobalV::DRANK); ModuleIO::save_mat(istep, @@ -286,7 +287,7 @@ void ESolver_KS_LCAO_TDDFT::update_pot(const int istep, const int iter) GlobalV::out_app_flag, "S", "data-" + std::to_string(ik), - *this->LOWF.ParaV, + *this->LM.ParaV, GlobalV::DRANK); } } @@ -333,8 +334,8 @@ void ESolver_KS_LCAO_TDDFT::update_pot(const int istep, const int iter) { #ifdef __MPI this->psi_laststep = new psi::Psi>(kv.get_nks(), - this->LOWF.ParaV->ncol_bands, - this->LOWF.ParaV->nrow, + this->LM.ParaV->ncol_bands, + this->LM.ParaV->nrow, nullptr); #else this->psi_laststep = new psi::Psi>(kv.get_nks(), GlobalV::NBANDS, GlobalV::NLOCAL, nullptr); diff --git a/source/module_esolver/io_npz.cpp b/source/module_esolver/io_npz.cpp index 8a557884fe..9c58bd7b0b 100644 --- a/source/module_esolver/io_npz.cpp +++ b/source/module_esolver/io_npz.cpp @@ -33,7 +33,7 @@ void ESolver_KS_LCAO::read_mat_npz(std::string& zipname, hamilt::HContai { ModuleBase::TITLE("LCAO_Hamilt","read_mat_npz"); - const Parallel_Orbitals* paraV = this->LOWF.ParaV; + const Parallel_Orbitals* paraV = &(this->orb_con.ParaV); #ifdef __USECNPY diff --git a/source/module_hamilt_lcao/hamilt_lcaodft/local_orbital_wfc.h b/source/module_hamilt_lcao/hamilt_lcaodft/local_orbital_wfc.h index d994289029..9cf66ea0ea 100644 --- a/source/module_hamilt_lcao/hamilt_lcaodft/local_orbital_wfc.h +++ b/source/module_hamilt_lcao/hamilt_lcaodft/local_orbital_wfc.h @@ -12,9 +12,20 @@ class Local_Orbital_wfc { public: - Local_Orbital_wfc(); + Local_Orbital_wfc(); ~Local_Orbital_wfc(); - + // refactor new implementation: RAII + // a more look-looking name would be LocalOrbitalWfc, I suppose... + Local_Orbital_wfc(const int& nspin, + const int& nks, + const int& nbands, + const int& nlocal, + const int& gamma_only, + const int& nb2d, + const std::string& ks_solver, + const std::string& readin_dir); + // + void initialize(); ///========================================= /// grid wfc /// used to generate density matrix: LOC.DM_R, @@ -25,7 +36,10 @@ class Local_Orbital_wfc std::complex*** wfc_k_grid; // [NK, GlobalV::NBANDS, GlobalV::NLOCAL] std::complex* wfc_k_grid2; // [NK*GlobalV::NBANDS*GlobalV::NLOCAL] + // pointer to const Parallel_Orbitals object const Parallel_Orbitals* ParaV; + // pointer to const Grid_Technique object, although have no idea about what it is... + // the name of Grid_Technique should be changed to be more informative const Grid_Technique* gridt; /// read wavefunction coefficients: LOWF_*.txt @@ -102,7 +116,7 @@ class Local_Orbital_wfc int nks; }; -#ifdef __MPI + // the function should not be defined here!! mohan 2024-03-28 template int Local_Orbital_wfc::set_wfc_grid(int naroc[2], @@ -116,6 +130,7 @@ int Local_Orbital_wfc::set_wfc_grid(int naroc[2], int myid, T** ctot) { +#ifdef __MPI ModuleBase::TITLE(" Local_Orbital_wfc", "set_wfc_grid"); if (!wfc && !ctot) { @@ -142,8 +157,7 @@ int Local_Orbital_wfc::set_wfc_grid(int naroc[2], } } } +#endif return 0; } #endif - -#endif