Skip to content

Commit

Permalink
Refactor namespace Exx_Opt_Orb
Browse files Browse the repository at this point in the history
  • Loading branch information
PeizeLin committed Oct 13, 2024
1 parent 54659c6 commit ff65831
Show file tree
Hide file tree
Showing 24 changed files with 635 additions and 576 deletions.
37 changes: 24 additions & 13 deletions source/module_esolver/esolver_ks_lcao.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include <memory>
#ifdef __EXX
#include "module_io/restart_exx_csr.h"
#include "module_ri/exx_opt_orb.h"
#include "module_ri/RPA_LRI.h"
#endif

Expand Down Expand Up @@ -170,9 +171,19 @@ void ESolver_KS_LCAO<TK, TR>::before_all_runners(const Input_para& inp, UnitCell
inp.lcao_rmax,
ucell,
two_center_bundle_,
orb_);
this->orb_);
//------------------init Basis_lcao----------------------

if (PARAM.inp.calculation == "gen_opt_abfs")
{
#ifdef __EXX
Exx_Opt_Orb::generate_matrix(GlobalC::exx_info.info_opt_abfs, this->kv, this->orb_);
#else
ModuleBase::WARNING_QUIT("ESolver_KS_LCAO::before_all_runners", "calculation=gen_opt_abfs must compile __EXX");
#endif
return;
}

// 5) initialize density matrix
// DensityMatrix is allocated here, DMK is also initialized here
// DMR is not initialized here, it will be constructed in each before_scf
Expand All @@ -188,7 +199,7 @@ void ESolver_KS_LCAO<TK, TR>::before_all_runners(const Input_para& inp, UnitCell
// 6) initialize Hamilt in LCAO
// * allocate H and S matrices according to computational resources
// * set the 'trace' between local H/S and global H/S
LCAO_domain::divide_HS_in_frag(PARAM.globalv.gamma_only_local, pv, this->kv.get_nks(), orb_);
LCAO_domain::divide_HS_in_frag(PARAM.globalv.gamma_only_local, pv, this->kv.get_nks(), this->orb_);

#ifdef __EXX
// 7) initialize exx
Expand All @@ -202,11 +213,11 @@ void ESolver_KS_LCAO<TK, TR>::before_all_runners(const Input_para& inp, UnitCell
// initialize 2-center radial tables for EXX-LRI
if (GlobalC::exx_info.info_ri.real_number)
{
this->exx_lri_double->init(MPI_COMM_WORLD, this->kv, orb_);
this->exx_lri_double->init(MPI_COMM_WORLD, this->kv, this->orb_);
}
else
{
this->exx_lri_complex->init(MPI_COMM_WORLD, this->kv, orb_);
this->exx_lri_complex->init(MPI_COMM_WORLD, this->kv, this->orb_);
}
}
}
Expand All @@ -215,7 +226,7 @@ void ESolver_KS_LCAO<TK, TR>::before_all_runners(const Input_para& inp, UnitCell
// 8) initialize DFT+U
if (PARAM.inp.dft_plus_u)
{
GlobalC::dftu.init(ucell, &this->pv, this->kv.get_nks(), orb_);
GlobalC::dftu.init(ucell, &this->pv, this->kv.get_nks(), this->orb_);
}

// 9) initialize ppcell
Expand Down Expand Up @@ -244,7 +255,7 @@ void ESolver_KS_LCAO<TK, TR>::before_all_runners(const Input_para& inp, UnitCell
// load the DeePKS model from deep neural network
GlobalC::ld.load_model(PARAM.inp.deepks_model);
// read pdm from file for NSCF or SCF-restart, do it only once in whole calculation
GlobalC::ld.read_projected_DM((PARAM.inp.init_chg == "file"), PARAM.inp.deepks_equiv, *orb_.Alpha);
GlobalC::ld.read_projected_DM((PARAM.inp.init_chg == "file"), PARAM.inp.deepks_equiv, *this->orb_.Alpha);
}
#endif

Expand Down Expand Up @@ -313,7 +324,7 @@ void ESolver_KS_LCAO<TK, TR>::cal_force(ModuleBase::matrix& force)
this->GG, // mohan add 2024-04-01
this->GK, // mohan add 2024-04-01
two_center_bundle_,
orb_,
this->orb_,
force,
this->scs,
this->sf,
Expand Down Expand Up @@ -466,7 +477,7 @@ void ESolver_KS_LCAO<TK, TR>::after_all_runners()
this->GG,
this->GK,
this->kv,
orb_.cutoffs(),
this->orb_.cutoffs(),
this->pelec->wg,
GlobalC::GridD
#ifdef __EXX
Expand Down Expand Up @@ -495,7 +506,7 @@ void ESolver_KS_LCAO<TK, TR>::after_all_runners()
this->kv,
this->pelec->wg,
GlobalC::GridD,
orb_.cutoffs(),
this->orb_.cutoffs(),
this->two_center_bundle_
#ifdef __EXX
,
Expand Down Expand Up @@ -1168,7 +1179,7 @@ void ESolver_KS_LCAO<TK, TR>::after_scf(const int istep)
this->pelec->ekb,
this->pelec->klist->kvec_d,
GlobalC::ucell,
orb_,
this->orb_,
GlobalC::GridD,
&(this->pv),
*(this->psi),
Expand All @@ -1188,8 +1199,8 @@ void ESolver_KS_LCAO<TK, TR>::after_scf(const int istep)
rpa_lri_double.cal_postSCF_exx(*dynamic_cast<const elecstate::ElecStateLCAO<TK>*>(this->pelec)->get_DM(),
MPI_COMM_WORLD,
this->kv,
orb_);
rpa_lri_double.init(MPI_COMM_WORLD, this->kv, orb_.cutoffs());
this->orb_);
rpa_lri_double.init(MPI_COMM_WORLD, this->kv, this->orb_.cutoffs());
rpa_lri_double.out_for_RPA(this->pv, *(this->psi), this->pelec);
}
#endif
Expand Down Expand Up @@ -1278,7 +1289,7 @@ void ESolver_KS_LCAO<TK, TR>::after_scf(const int istep)
this->kv.kvec_d,
&hR,
&GlobalC::ucell,
orb_.cutoffs(),
this->orb_.cutoffs(),
&GlobalC::GridD,
two_center_bundle_.kinetic_orb.get());

Expand Down
4 changes: 4 additions & 0 deletions source/module_esolver/lcao_others.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ void ESolver_KS_LCAO<TK, TR>::others(const int istep)
// return; // use 'return' will cause segmentation fault. by mohan
// 2024-06-09
}
else if (cal_type == "gen_opt_abfs")
{
return;
}
else if (cal_type == "test_memory")
{
std::cout << FmtCore::format("\n * * * * * *\n << Start %s.\n", "testing memory");
Expand Down
33 changes: 23 additions & 10 deletions source/module_hamilt_general/module_xc/exx_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,20 @@ struct Exx_Info

struct Exx_Info_Lip
{
const Conv_Coulomb_Pot_K::Ccp_Type& ccp_type;
const double& hse_omega;
const Conv_Coulomb_Pot_K::Ccp_Type &ccp_type;
const double &hse_omega;
double lambda = 0.3;

Exx_Info_Lip(const Exx_Info::Exx_Info_Global& info_global)
:ccp_type(info_global.ccp_type),
hse_omega(info_global.hse_omega) {}
hse_omega(info_global.hse_omega) {}
};
Exx_Info_Lip info_lip;

struct Exx_Info_RI
{
const Conv_Coulomb_Pot_K::Ccp_Type& ccp_type;
const double& hse_omega;
const Conv_Coulomb_Pot_K::Ccp_Type &ccp_type;
const double &hse_omega;

bool real_number = false;

Expand All @@ -58,15 +58,28 @@ struct Exx_Info
int abfs_Lmax = 0; // tmp

Exx_Info_RI(const Exx_Info::Exx_Info_Global& info_global)
: ccp_type(info_global.ccp_type), hse_omega(info_global.hse_omega)
{
}
:ccp_type(info_global.ccp_type),
hse_omega(info_global.hse_omega) {}
};
Exx_Info_RI info_ri;

Exx_Info() : info_lip(this->info_global), info_ri(this->info_global)
struct Exx_Info_Opt_ABFs
{
}
//const Conv_Coulomb_Pot_K::Ccp_Type &ccp_type;
//const double &hse_omega;

int abfs_Lmax = 0; // tmp
double ecut_exx = 60;
double tolerence = 1E-2;
double kmesh_times = 4;

//Exx_Info_Opt_ABFs(const Exx_Info::Exx_Info_Global& info_global)
// :ccp_type(info_global.ccp_type),
// hse_omega(info_global.hse_omega) {}
};
Exx_Info_Opt_ABFs info_opt_abfs;

Exx_Info() : info_lip(this->info_global), info_ri(this->info_global) {}
};

#endif
2 changes: 1 addition & 1 deletion source/module_hamilt_general/module_xc/xc_functional.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ void XC_Functional::set_xc_type(const std::string xc_func_in)
func_type = 4;
use_libxc = false;
}
else if( xc_func == "OPT_ORB" || xc_func == "NONE" || xc_func == "NOX+NOC")
else if( xc_func == "NONE" || xc_func == "NOX+NOC")
{
// not doing anything
}
Expand Down
13 changes: 4 additions & 9 deletions source/module_io/input_conv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,6 @@ void Input_Conv::Convert()
ModuleBase::GlobalFunc::MAKE_DIR(GlobalC::restart.folder);
if (dft_functional_lower == "hf" || dft_functional_lower == "pbe0"
|| dft_functional_lower == "hse"
|| dft_functional_lower == "opt_orb"
|| dft_functional_lower == "scan0") {
GlobalC::restart.info_save.save_charge = true;
GlobalC::restart.info_save.save_H = true;
Expand All @@ -344,7 +343,6 @@ void Input_Conv::Convert()
GlobalC::restart.folder = PARAM.globalv.global_readin_dir + "restart/";
if (dft_functional_lower == "hf" || dft_functional_lower == "pbe0"
|| dft_functional_lower == "hse"
|| dft_functional_lower == "opt_orb"
|| dft_functional_lower == "scan0") {
GlobalC::restart.info_load.load_charge = true;
GlobalC::restart.info_load.load_H = true;
Expand Down Expand Up @@ -373,14 +371,11 @@ void Input_Conv::Convert()
GlobalC::exx_info.info_global.cal_exx = true;
GlobalC::exx_info.info_global.ccp_type
= Conv_Coulomb_Pot_K::Ccp_Type::Hse;
} else if (dft_functional_lower == "opt_orb") {
GlobalC::exx_info.info_global.cal_exx = false;
Exx_Abfs::Jle::generate_matrix = true;
} else {
GlobalC::exx_info.info_global.cal_exx = false;
}

if (GlobalC::exx_info.info_global.cal_exx || Exx_Abfs::Jle::generate_matrix || PARAM.inp.rpa)
if (GlobalC::exx_info.info_global.cal_exx || PARAM.inp.rpa)
{
// EXX case, convert all EXX related variables
// GlobalC::exx_info.info_global.cal_exx = true;
Expand All @@ -407,9 +402,9 @@ void Input_Conv::Convert()
GlobalC::exx_info.info_ri.cauchy_stress_threshold = PARAM.inp.exx_cauchy_stress_threshold;
GlobalC::exx_info.info_ri.ccp_rmesh_times = std::stod(PARAM.inp.exx_ccp_rmesh_times);

Exx_Abfs::Jle::Lmax = PARAM.inp.exx_opt_orb_lmax;
Exx_Abfs::Jle::Ecut_exx = PARAM.inp.exx_opt_orb_ecut;
Exx_Abfs::Jle::tolerence = PARAM.inp.exx_opt_orb_tolerence;
GlobalC::exx_info.info_opt_abfs.abfs_Lmax = PARAM.inp.exx_opt_orb_lmax;
GlobalC::exx_info.info_opt_abfs.ecut_exx = PARAM.inp.exx_opt_orb_ecut;
GlobalC::exx_info.info_opt_abfs.tolerence = PARAM.inp.exx_opt_orb_tolerence;

// EXX does not support symmetry for nspin==4
if (PARAM.inp.calculation != "nscf" && PARAM.inp.symmetry == "1" && PARAM.inp.nspin == 4)
Expand Down
5 changes: 3 additions & 2 deletions source/module_io/read_input_item_system.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ void ReadInput::item_system()
}
{
Input_Item item("calculation");
item.annotation = "test; scf; relax; nscf; get_wf; get_pchg";
item.annotation = "scf; relax; md; cell-relax; nscf; get_S; get_wf; get_pchg; gen_bessel; gen_opt_abfs; test_memory; test_neighbour";
item.read_value = [](const Input_Item& item, Parameter& para) {
para.input.calculation = strvalue;
std::string& calculation = para.input.calculation;
Expand All @@ -78,7 +78,8 @@ void ReadInput::item_system()
"get_S",
"get_wf",
"get_pchg",
"gen_bessel"};
"gen_bessel",
"gen_opt_abfs"};
if (!find_str(callist, calculation))
{
const std::string warningstr = nofound_str(callist, "calculation");
Expand Down
10 changes: 5 additions & 5 deletions source/module_ri/Exx_LRI.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
#include "module_exx_symmetry/symmetry_rotation.h"

class Parallel_Orbitals;

template<typename T, typename Tdata>
class RPA_LRI;

Expand Down Expand Up @@ -59,19 +59,19 @@ class Exx_LRI
void init(const MPI_Comm &mpi_comm_in, const K_Vectors &kv_in, const LCAO_Orbitals& orb);
void cal_exx_force();
void cal_exx_stress();
std::vector<std::vector<int>> get_abfs_nchis() const;
// std::vector<std::vector<int>> get_abfs_nchis() const;

std::vector< std::map<TA, std::map<TAC, RI::Tensor<Tdata>>>> Hexxs;
double Eexx;
ModuleBase::matrix force_exx;
ModuleBase::matrix stress_exx;


private:
const Exx_Info::Exx_Info_RI &info;
MPI_Comm mpi_comm;
const K_Vectors *p_kv = nullptr;
std::vector<double> orb_cutoff_;
std::vector<double> orb_cutoff_;

std::vector<std::vector<std::vector<Numerical_Orbital_Lm>>> lcaos;
std::vector<std::vector<std::vector<Numerical_Orbital_Lm>>> abfs;
Expand Down Expand Up @@ -99,4 +99,4 @@ class Exx_LRI

#include "Exx_LRI.hpp"

#endif
#endif
Loading

0 comments on commit ff65831

Please sign in to comment.