Skip to content

Commit

Permalink
Refactor: remove nscf() (#5455)
Browse files Browse the repository at this point in the history
* Refactor: remove nscf()

* move calculate_weight() out of psi2rho()

* update PW_DIAG_NMAX for nscf
  • Loading branch information
YuLiu98 authored Nov 13, 2024
1 parent 2a40059 commit 9cc7f89
Show file tree
Hide file tree
Showing 33 changed files with 223 additions and 705 deletions.
3 changes: 0 additions & 3 deletions source/Makefile.Objects
Original file line number Diff line number Diff line change
Expand Up @@ -245,10 +245,8 @@ OBJS_ESOLVER=esolver.o\
esolver_of.o\
esolver_of_tool.o\
esolver_of_interface.o\
pw_fun.o\
pw_init_after_vc.o\
pw_init_globalc.o\
pw_nscf.o\
pw_others.o\

OBJS_ESOLVER_LCAO=esolver_ks_lcao.o\
Expand All @@ -259,7 +257,6 @@ OBJS_ESOLVER_LCAO=esolver_ks_lcao.o\
set_matrix_grid.o\
lcao_before_scf.o\
lcao_gets.o\
lcao_nscf.o\
lcao_others.o\
lcao_init_after_vc.o\
lcao_fun.o\
Expand Down
3 changes: 1 addition & 2 deletions source/driver_run.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,14 @@ void Driver::driver_run() {
{
Run_MD::md_line(GlobalC::ucell, p_esolver, PARAM);
}
else if (cal_type == "scf" || cal_type == "relax" || cal_type == "cell-relax")
else if (cal_type == "scf" || cal_type == "relax" || cal_type == "cell-relax" || cal_type == "nscf")
{
Relax_Driver rl_driver;
rl_driver.relax_driver(p_esolver);
}
else
{
//! supported "other" functions:
//! nscf(PW,LCAO),
//! get_pchg(LCAO),
//! test_memory(PW,LCAO),
//! test_neighbour(LCAO),
Expand Down
74 changes: 27 additions & 47 deletions source/module_elecstate/elecstate_lcao.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
#include "elecstate_lcao.h"

#include <vector>

#include "cal_dm.h"
#include "module_base/timer.h"
#include "module_elecstate/module_dm/cal_dm_psi.h"
Expand All @@ -11,6 +9,8 @@
#include "module_hamilt_pw/hamilt_pwdft/global.h"
#include "module_parameter/parameter.h"

#include <vector>

namespace elecstate
{

Expand All @@ -21,34 +21,31 @@ void ElecStateLCAO<std::complex<double>>::psiToRho(const psi::Psi<std::complex<d
ModuleBase::TITLE("ElecStateLCAO", "psiToRho");
ModuleBase::timer::tick("ElecStateLCAO", "psiToRho");

this->calculate_weights();

// the calculations of dm, and dm -> rho are, technically, two separate
// functionalities, as we cannot rule out the possibility that we may have a
// dm from other sources, such as read from file. However, since we are not
// separating them now, I opt to add a flag to control how dm is obtained as
// of now
if (!PARAM.inp.dm_to_rho)
{
this->calEBand();

ModuleBase::GlobalFunc::NOTE("Calculate the density matrix.");

// this part for calculating DMK in 2d-block format, not used for charge
// now
// psi::Psi<std::complex<double>> dm_k_2d();

if (PARAM.inp.ks_solver == "genelpa" || PARAM.inp.ks_solver == "elpa" || PARAM.inp.ks_solver == "scalapack_gvx" || PARAM.inp.ks_solver == "lapack"
|| PARAM.inp.ks_solver == "cusolver" || PARAM.inp.ks_solver == "cusolvermp"
|| PARAM.inp.ks_solver == "cg_in_lcao") // Peize Lin test 2019-05-15
{
elecstate::cal_dm_psi(this->DM->get_paraV_pointer(),
this->wg,
psi,
*(this->DM));
this->DM->cal_DMR();
}
}
// // the calculations of dm, and dm -> rho are, technically, two separate
// // functionalities, as we cannot rule out the possibility that we may have a
// // dm from other sources, such as read from file. However, since we are not
// // separating them now, I opt to add a flag to control how dm is obtained as
// // of now
// if (!PARAM.inp.dm_to_rho)
// {
// ModuleBase::GlobalFunc::NOTE("Calculate the density matrix.");

// // this part for calculating DMK in 2d-block format, not used for charge
// // now
// // psi::Psi<std::complex<double>> dm_k_2d();

// if (PARAM.inp.ks_solver == "genelpa" || PARAM.inp.ks_solver == "elpa" || PARAM.inp.ks_solver ==
// "scalapack_gvx" || PARAM.inp.ks_solver == "lapack"
// || PARAM.inp.ks_solver == "cusolver" || PARAM.inp.ks_solver == "cusolvermp"
// || PARAM.inp.ks_solver == "cg_in_lcao") // Peize Lin test 2019-05-15
// {
// elecstate::cal_dm_psi(this->DM->get_paraV_pointer(),
// this->wg,
// psi,
// *(this->DM));
// this->DM->cal_DMR();
// }
// }

for (int is = 0; is < PARAM.inp.nspin; is++)
{
Expand Down Expand Up @@ -83,23 +80,6 @@ void ElecStateLCAO<double>::psiToRho(const psi::Psi<double>& psi)
ModuleBase::TITLE("ElecStateLCAO", "psiToRho");
ModuleBase::timer::tick("ElecStateLCAO", "psiToRho");

this->calculate_weights();
this->calEBand();

if (PARAM.inp.ks_solver == "genelpa" || PARAM.inp.ks_solver == "elpa" || PARAM.inp.ks_solver == "scalapack_gvx" || PARAM.inp.ks_solver == "lapack"
|| PARAM.inp.ks_solver == "cusolver" || PARAM.inp.ks_solver == "cusolvermp" || PARAM.inp.ks_solver == "cg_in_lcao")
{
ModuleBase::timer::tick("ElecStateLCAO", "cal_dm_2d");

// get DMK in 2d-block format
elecstate::cal_dm_psi(this->DM->get_paraV_pointer(),
this->wg,
psi,
*(this->DM));
this->DM->cal_DMR();
ModuleBase::timer::tick("ElecStateLCAO", "cal_dm_2d");
}

for (int is = 0; is < PARAM.inp.nspin; is++)
{
ModuleBase::GlobalFunc::ZEROS(this->charge->rho[is],
Expand Down
3 changes: 2 additions & 1 deletion source/module_elecstate/elecstate_lcao.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ class ElecStateLCAO : public ElecState
void dmToRho(std::vector<TK*> pexsi_DM, std::vector<TK*> pexsi_EDM);
#endif

DensityMatrix<TK, double>* DM = nullptr;

protected:
// calculate electronic charge density on grid points or density matrix in real space
// the consequence charge density rho saved into rho_out, preparing for charge mixing.
Expand All @@ -85,7 +87,6 @@ class ElecStateLCAO : public ElecState

Gint_Gamma* gint_gamma = nullptr; // mohan add 2024-04-01
Gint_k* gint_k = nullptr; // mohan add 2024-04-01
DensityMatrix<TK, double>* DM = nullptr;
};

template <typename TK>
Expand Down
3 changes: 0 additions & 3 deletions source/module_elecstate/elecstate_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,6 @@ void ElecStatePW<T, Device>::psiToRho(const psi::Psi<T, Device>& psi)
ModuleBase::timer::tick("ElecStatePW", "psiToRho");

this->init_rho_data();
this->calculate_weights();

this->calEBand();

for(int is=0; is<PARAM.inp.nspin; is++)
{
Expand Down
2 changes: 0 additions & 2 deletions source/module_elecstate/elecstate_pw_sdft.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@ void ElecStatePW_SDFT<T, Device>::psiToRho(const psi::Psi<T, Device>& psi)

if (GlobalV::MY_STOGROUP == 0)
{
this->calEBand();

for (int is = 0; is < nspin; is++)
{
setmem_var_op()(this->ctx, this->rho[is], 0, this->charge->nrxx);
Expand Down
3 changes: 0 additions & 3 deletions source/module_esolver/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,8 @@ list(APPEND objects
esolver_of.cpp
esolver_of_interface.cpp
esolver_of_tool.cpp
pw_fun.cpp
pw_init_after_vc.cpp
pw_init_globalc.cpp
pw_nscf.cpp
pw_others.cpp
)
if(ENABLE_LCAO)
Expand All @@ -26,7 +24,6 @@ if(ENABLE_LCAO)
dftu_cal_occup_m.cpp
lcao_before_scf.cpp
lcao_gets.cpp
lcao_nscf.cpp
lcao_others.cpp
lcao_init_after_vc.cpp
lcao_fun.cpp
Expand Down
3 changes: 0 additions & 3 deletions source/module_esolver/esolver_fp.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,6 @@ namespace ModuleESolver
//! Electorn charge density
Charge chr;

//! Non-Self-Consistant Filed (NSCF) calculations
virtual void nscf(){};

//! Structure factors that used with plane-wave basis set
Structure_Factor sf;

Expand Down
10 changes: 3 additions & 7 deletions source/module_esolver/esolver_ks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ void ESolver_KS<T, Device>::hamilt2density(const int istep, const int iter, cons

drho = p_chgmix->get_drho(pelec->charge, PARAM.inp.nelec);
hsolver_error = 0.0;
if (iter == 1)
if (iter == 1 && PARAM.inp.calculation != "nscf")
{
hsolver_error
= hsolver::cal_hsolve_error(PARAM.inp.basis_type, PARAM.inp.esolver_type, diag_ethr, PARAM.inp.nelec);
Expand Down Expand Up @@ -472,13 +472,10 @@ void ESolver_KS<T, Device>::runner(const int istep, UnitCell& ucell)

ModuleBase::GlobalFunc::DONE(GlobalV::ofs_running, "INIT SCF");

// 4) SCF iterations
this->conv_esolver = false;
this->niter = this->maxniter;

// 4) SCF iterations
this->diag_ethr = PARAM.inp.pw_diag_thr;

std::cout << " * * * * * *\n << Start SCF iteration." << std::endl;
for (int iter = 1; iter <= this->maxniter; ++iter)
{
// 6) initialization of SCF iterations
Expand All @@ -500,7 +497,6 @@ void ESolver_KS<T, Device>::runner(const int istep, UnitCell& ucell)
break;
}
} // end scf iterations
std::cout << " >> Leave SCF iteration.\n * * * * * *" << std::endl;

// 15) after scf
ModuleBase::timer::tick(this->classname, "after_scf");
Expand Down Expand Up @@ -632,7 +628,7 @@ void ESolver_KS<T, Device>::iter_finish(const int istep, int& iter)

// If drho < hsolver_error in the first iter or drho < scf_thr, we
// do not change rho.
if (drho < hsolver_error || this->conv_esolver)
if (drho < hsolver_error || this->conv_esolver || PARAM.inp.calculation == "nscf")
{
if (drho < hsolver_error)
{
Expand Down
3 changes: 0 additions & 3 deletions source/module_esolver/esolver_ks.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,6 @@ class ESolver_KS : public ESolver_FP
// calculate electron density from a specific Hamiltonian with ethr
virtual void hamilt2density_single(const int istep, const int iter, const double ethr);

// calculate electron states from a specific Hamiltonian
virtual void hamilt2estates(const double ethr) {};

// calculate electron density from a specific Hamiltonian
void hamilt2density(const int istep, const int iter, const double ethr);

Expand Down
66 changes: 63 additions & 3 deletions source/module_esolver/esolver_ks_lcao.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
#include "esolver_ks_lcao.h"

#include "module_base/formatter.h"
#include "module_base/global_variable.h"
#include "module_base/tool_title.h"
#include "module_elecstate/module_dm/cal_dm_psi.h"
#include "module_io/berryphase.h"
#include "module_io/cube_io.h"
#include "module_io/dos_nao.h"
#include "module_io/nscf_band.h"
Expand All @@ -10,6 +13,8 @@
#include "module_io/output_mulliken.h"
#include "module_io/output_sk.h"
#include "module_io/to_qo.h"
#include "module_io/to_wannier90_lcao.h"
#include "module_io/to_wannier90_lcao_in_pw.h"
#include "module_io/write_HS.h"
#include "module_io/write_eband_terms.hpp"
#include "module_io/write_elecstat_pot.h"
Expand Down Expand Up @@ -47,7 +52,7 @@
#include "module_hamilt_lcao/hamilt_lcaodft/hamilt_lcao.h"
#include "module_hsolver/hsolver_lcao.h"
// function used by deepks
#include "module_elecstate/cal_dm.h"
// #include "module_elecstate/cal_dm.h"
//---------------------------------------------------

#include "module_hamilt_lcao/module_deltaspin/spin_constrain.h"
Expand Down Expand Up @@ -586,6 +591,14 @@ void ESolver_KS_LCAO<TK, TR>::iter_init(const int istep, const int iter)
// and the ncalculate the charge density on grid.

this->pelec->skip_weights = true;
this->pelec->calculate_weights();
if (!PARAM.inp.dm_to_rho)
{
auto _pelec = dynamic_cast<elecstate::ElecStateLCAO<TK>*>(this->pelec);
_pelec->calEBand();
elecstate::cal_dm_psi(_pelec->DM->get_paraV_pointer(), _pelec->wg, *this->psi, *(_pelec->DM));
_pelec->DM->cal_DMR();
}
this->pelec->psiToRho(*this->psi);
this->pelec->skip_weights = false;

Expand Down Expand Up @@ -714,9 +727,9 @@ void ESolver_KS_LCAO<TK, TR>::hamilt2density_single(int istep, int iter, double
// reset energy
this->pelec->f_en.eband = 0.0;
this->pelec->f_en.demet = 0.0;

bool skip_charge = PARAM.inp.calculation == "nscf" ? true : false;
hsolver::HSolverLCAO<TK> hsolver_lcao_obj(&(this->pv), PARAM.inp.ks_solver);
hsolver_lcao_obj.solve(this->p_hamilt, this->psi[0], this->pelec, false);
hsolver_lcao_obj.solve(this->p_hamilt, this->psi[0], this->pelec, skip_charge);

// 5) what's the exd used for?
#ifdef __EXX
Expand Down Expand Up @@ -1192,6 +1205,53 @@ void ESolver_KS_LCAO<TK, TR>::after_scf(const int istep)

delete ekinetic;
}

// add by jingan in 2018.11.7
if (PARAM.inp.calculation == "nscf" && PARAM.inp.towannier90)
{
std::cout << FmtCore::format("\n * * * * * *\n << Start %s.\n", "Wave function to Wannier90");
if (PARAM.inp.wannier_method == 1)
{
toWannier90_LCAO_IN_PW myWannier(PARAM.inp.out_wannier_mmn,
PARAM.inp.out_wannier_amn,
PARAM.inp.out_wannier_unk,
PARAM.inp.out_wannier_eig,
PARAM.inp.out_wannier_wvfn_formatted,
PARAM.inp.nnkpfile,
PARAM.inp.wannier_spin);

myWannier
.calculate(this->pelec->ekb, this->pw_wfc, this->pw_big, this->sf, this->kv, this->psi, &(this->pv));
}
else if (PARAM.inp.wannier_method == 2)
{
toWannier90_LCAO myWannier(PARAM.inp.out_wannier_mmn,
PARAM.inp.out_wannier_amn,
PARAM.inp.out_wannier_unk,
PARAM.inp.out_wannier_eig,
PARAM.inp.out_wannier_wvfn_formatted,
PARAM.inp.nnkpfile,
PARAM.inp.wannier_spin,
orb_);

myWannier.calculate(this->pelec->ekb, this->kv, *(this->psi), &(this->pv));
}
std::cout << FmtCore::format(" >> Finish %s.\n * * * * * *\n", "Wave function to Wannier90");
}

// add by jingan
if (PARAM.inp.calculation == "nscf" && berryphase::berry_phase_flag && ModuleSymmetry::Symmetry::symm_flag != 1)
{
std::cout << FmtCore::format("\n * * * * * *\n << Start %s.\n", "Berry phase calculation");
berryphase bp(&(this->pv));
bp.lcao_init(this->kv,
this->GridT,
orb_); // additional step before calling
// macroscopic_polarization (why capitalize
// the function name?)
bp.Macroscopic_polarization(this->pw_wfc->npwk_max, this->psi, this->pw_rho, this->pw_wfc, this->kv);
std::cout << FmtCore::format(" >> Finish %s.\n * * * * * *\n", "Berry phase calculation");
}
}

//------------------------------------------------------------------------------
Expand Down
2 changes: 0 additions & 2 deletions source/module_esolver/esolver_ks_lcao.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,6 @@ class ESolver_KS_LCAO : public ESolver_KS<TK> {

void after_all_runners() override;

void nscf() override;

void get_S();

void cal_mag(const int istep, const bool print = false);
Expand Down
3 changes: 2 additions & 1 deletion source/module_esolver/esolver_ks_lcao_tddft.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,9 @@ void ESolver_KS_LCAO_TDDFT::hamilt2density_single(const int istep, const int ite
this->pelec->f_en.demet = 0.0;
if (this->psi != nullptr)
{
bool skip_charge = PARAM.inp.calculation == "nscf" ? true : false;
hsolver::HSolverLCAO<std::complex<double>> hsolver_lcao_obj(&this->pv, PARAM.inp.ks_solver);
hsolver_lcao_obj.solve(this->p_hamilt, this->psi[0], this->pelec_td, false);
hsolver_lcao_obj.solve(this->p_hamilt, this->psi[0], this->pelec_td, skip_charge);
}
}

Expand Down
3 changes: 2 additions & 1 deletion source/module_esolver/esolver_ks_lcaopw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ namespace ModuleESolver
hsolver::DiagoIterAssist<T>::SCF_ITER = iter;
hsolver::DiagoIterAssist<T>::PW_DIAG_THR = ethr;
hsolver::DiagoIterAssist<T>::PW_DIAG_NMAX = PARAM.inp.pw_diag_nmax;
bool skip_charge = PARAM.inp.calculation == "nscf" ? true : false;

// It is not a good choice to overload another solve function here, this will spoil the concept of
// multiple inheritance and polymorphism. But for now, we just do it in this way.
Expand All @@ -138,7 +139,7 @@ namespace ModuleESolver
}

hsolver::HSolverLIP<T> hsolver_lip_obj(this->pw_wfc);
hsolver_lip_obj.solve(this->p_hamilt, this->kspw_psi[0], this->pelec, psig.lock().get()[0], false);
hsolver_lip_obj.solve(this->p_hamilt, this->kspw_psi[0], this->pelec, psig.lock().get()[0], skip_charge);

// add exx
#ifdef __EXX
Expand Down
Loading

0 comments on commit 9cc7f89

Please sign in to comment.