Skip to content

Commit

Permalink
Keep updating esolver and related source files (#3853)
Browse files Browse the repository at this point in the history
* update timer, add std

* update the format of esolver_fp.cpp

* add description in esolver_fp.h

* update a few esolver files

* update esolver description for LCAO

* update some formats of esolver codes

* update esolver_ks_pw.cpp formats

* update formats of esolver_ks_pw.cpp and esolver_ks_pw.h

* update esolver_ks_lcao_tddft.cpp formats

* update format of esolver_sdft_pw.cpp

* update the format of esolver_sdft_pw_tool.cpp

* keep formating esolver_ks_pw.cpp

* formating esolver_of_interface.cpp

* change GlobalC::ucell to ucell in esolver_ks.cpp

* remove some GlobalC::ucell in esolver_sdft_pw.cpp

* refactor the code before getting rid of RA in esolver_lcao

* refactor before getting rid of LOWF

* refactor before getting rid of LCAO_hamilt.h and LCAO_matrix.h

* refactor wavefunc_in_pw

* refactor density matrix

* refactor the format cal_dm_psi.cpp

* format forces.cpp

* refactor esolver_of_tool.cpp

* change member function beforescf in Esolver to before_scf

* change afterscf to after_scf

* change updatepot to update_pot

* change eachiterinit to iter_init, change eachiterfinish to iter_finish

* refactor esolvers, change member function names of most esolvers

* reformat esolver.h

* update tests for esolvers
  • Loading branch information
mohanchen authored Mar 31, 2024
1 parent 7f05684 commit 96b72b2
Show file tree
Hide file tree
Showing 48 changed files with 907 additions and 848 deletions.
28 changes: 13 additions & 15 deletions source/driver_run.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,16 @@
* Esolver::Run takes in a configuration and provides force and stress,
* the configuration-changing subroutine takes force and stress and updates the configuration
*/
void Driver::driver_run()
void Driver::driver_run(void)
{
ModuleBase::TITLE("Driver", "driver_line");
ModuleBase::timer::tick("Driver", "driver_line");

// 1. Determine type of Esolver
//! 1: initialize the ESolver
ModuleESolver::ESolver *p_esolver = nullptr;
ModuleESolver::init_esolver(p_esolver);

// 2. Setup cell and atom information
//! 2: setup cell and atom information
#ifndef __LCAO
if(GlobalV::BASIS_TYPE == "lcao_in_pw" || GlobalV::BASIS_TYPE == "lcao")
{
Expand All @@ -37,30 +37,29 @@ void Driver::driver_run()
#endif
GlobalC::ucell.setup_cell(GlobalV::stru_file, GlobalV::ofs_running);

// 3. For these two types of calculations
//! 3: for these two types of calculations
// nothing else need to be initialized
if(GlobalV::CALCULATION == "test_neighbour" || GlobalV::CALCULATION == "test_memory")
if(GlobalV::CALCULATION == "test_neighbour"
|| GlobalV::CALCULATION == "test_memory")
{
p_esolver->Run(0, GlobalC::ucell);
p_esolver->run(0, GlobalC::ucell);
ModuleBase::QUIT();
}

// 4. Initialize Esolver,and fill json-structure
p_esolver->Init(INPUT, GlobalC::ucell);
//! 4: initialize Esolver and fill json-structure
p_esolver->init(INPUT, GlobalC::ucell);


#ifdef __RAPIDJSON
Json::gen_stru_wrapper(&GlobalC::ucell);
#endif

//------------------------------------------------------------
// This part onward needs to be refactored.
//---------------------------MD/Relax-------------------------
//! 5: md or relax calculations
if(GlobalV::CALCULATION == "md")
{
Run_MD::md_line(GlobalC::ucell, p_esolver, INPUT.mdp);
}
else // scf; cell relaxation; nscf; etc
else //! scf; cell relaxation; nscf; etc
{
if (GlobalV::precision_flag == "single")
{
Expand All @@ -73,10 +72,9 @@ void Driver::driver_run()
rl_driver.relax_driver(p_esolver);
}
}
//---------------------------MD/Relax------------------

// 6. clean up esolver
p_esolver->postprocess();
//! 6: clean up esolver
p_esolver->post_process();
ModuleESolver::clean_esolver(p_esolver);

ModuleBase::timer::tick("Driver", "driver_line");
Expand Down
16 changes: 11 additions & 5 deletions source/module_elecstate/elecstate_lcao.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@ namespace elecstate
template <>
void ElecStateLCAO<double>::print_psi(const psi::Psi<double>& psi_in, const int istep)
{
if (!ElecStateLCAO<double>::out_wfc_lcao)
return;
if (!ElecStateLCAO<double>::out_wfc_lcao)
{
return;
}

// output but not do "2d-to-grid" conversion
double** wfc_grid = nullptr;
Expand All @@ -28,8 +30,11 @@ void ElecStateLCAO<double>::print_psi(const psi::Psi<double>& psi_in, const int
template <>
void ElecStateLCAO<std::complex<double>>::print_psi(const psi::Psi<std::complex<double>>& psi_in, const int istep)
{
if (!ElecStateLCAO<std::complex<double>>::out_wfc_lcao && !ElecStateLCAO<std::complex<double>>::need_psi_grid)
return;
if (!ElecStateLCAO<std::complex<double>>::out_wfc_lcao
&& !ElecStateLCAO<std::complex<double>>::need_psi_grid)
{
return;
}

// output but not do "2d-to-grid" conversion
std::complex<double>** wfc_grid = nullptr;
Expand All @@ -38,6 +43,7 @@ void ElecStateLCAO<std::complex<double>>::print_psi(const psi::Psi<std::complex<
{
wfc_grid = this->lowf->wfc_k_grid[ik];
}

#ifdef __MPI
this->lowf->wfc_2d_to_grid(istep,
ElecStateLCAO<std::complex<double>>::out_wfc_flag,
Expand Down Expand Up @@ -252,4 +258,4 @@ double ElecStateLCAO<std::complex<double>>::get_spin_constrain_energy()
template class ElecStateLCAO<double>; // Gamma_only case
template class ElecStateLCAO<std::complex<double>>; // multi-k case

} // namespace elecstate
} // namespace elecstate
17 changes: 11 additions & 6 deletions source/module_elecstate/module_dm/cal_dm_psi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,10 @@ void cal_dm_psi(const Parallel_Orbitals* ParaV,
ModuleBase::WARNING_QUIT("ElecStateLCAO::cal_dm", "please check global2local_col!");
}
}
if (ib_global >= wg.nc)
continue;
if (ib_global >= wg.nc)
{
continue;
}
const double wg_local = wg(ik, ib_global);
double* wg_wfc_pointer = &(wg_wfc(0, ib_local, 0));
BlasConnector::scal(nbasis_local, wg_local, wg_wfc_pointer, 1);
Expand Down Expand Up @@ -107,9 +109,11 @@ void cal_dm_psi(const Parallel_Orbitals* ParaV,
ModuleBase::WARNING_QUIT("ElecStateLCAO::cal_dm", "please check global2local_col!");
}
}
if (ib_global >= wg.nc)
continue;
const double wg_local = wg(ik, ib_global);
if (ib_global >= wg.nc)
{
continue;
}
const double wg_local = wg(ik, ib_global);
std::complex<double>* wg_wfc_pointer = &(wg_wfc(0, ib_local, 0));
BlasConnector::scal(nbasis_local, wg_local, wg_wfc_pointer, 1);
}
Expand Down Expand Up @@ -233,7 +237,8 @@ void psiMulPsi(const psi::Psi<std::complex<double>>& psi1,
const char N_char = 'N', T_char = 'T';
const int nlocal = psi1.get_nbasis();
const int nbands = psi1.get_nbands();
const std::complex<double> one_complex = {1.0, 0.0}, zero_complex = {0.0, 0.0};
const std::complex<double> one_complex = {1.0, 0.0};
const std::complex<double> zero_complex = {0.0, 0.0};
zgemm_(&N_char,
&T_char,
&nlocal,
Expand Down
89 changes: 46 additions & 43 deletions source/module_elecstate/module_dm/density_matrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -440,11 +440,11 @@ void DensityMatrix<TK,TR>::cal_DMR_test()
const int* r_index = tmp_ap.get_R_index(ir);
hamilt::BaseMatrix<TR>* tmp_matrix = tmp_ap.find_matrix(r_index[0], r_index[1], r_index[2]);
#ifdef __DEBUG
if (tmp_matrix == nullptr)
{
std::cout << "tmp_matrix is nullptr" << std::endl;
continue;
}
if (tmp_matrix == nullptr)
{
std::cout << "tmp_matrix is nullptr" << std::endl;
continue;
}
#endif
std::complex<TR> tmp_res;
// loop over k-points
Expand Down Expand Up @@ -515,44 +515,47 @@ void DensityMatrix<std::complex<double>, double>::cal_DMR()
#endif
// loop over k-points
if(GlobalV::NSPIN !=4 )
for (int ik = 0; ik < this->_nks; ++ik)
{
// cal k_phase
// if TK==std::complex<double>, kphase is e^{ikR}
const ModuleBase::Vector3<double> dR(r_index[0], r_index[1], r_index[2]);
const double arg = (this->_kv->kvec_d[ik] * dR) * ModuleBase::TWO_PI;
double sinp, cosp;
ModuleBase::libm::sincos(arg, &sinp, &cosp);
std::complex<double> kphase = std::complex<double>(cosp, sinp);
// set DMR element
double* tmp_DMR_pointer = tmp_matrix->get_pointer();
std::complex<double>* tmp_DMK_pointer = this->_DMK[ik + ik_begin].data();
double* DMK_real_pointer = nullptr;
double* DMK_imag_pointer = nullptr;
// jump DMK to fill DMR
// DMR is row-major, DMK is column-major
tmp_DMK_pointer += col_ap * this->_paraV->nrow + row_ap;
for (int mu = 0; mu < this->_paraV->get_row_size(iat1); ++mu)
{
DMK_real_pointer = (double*)tmp_DMK_pointer;
DMK_imag_pointer = DMK_real_pointer + 1;
BlasConnector::axpy(this->_paraV->get_col_size(iat2),
kphase.real(),
DMK_real_pointer,
ld_hk2,
tmp_DMR_pointer,
1);
// "-" since i^2 = -1
BlasConnector::axpy(this->_paraV->get_col_size(iat2),
-kphase.imag(),
DMK_imag_pointer,
ld_hk2,
tmp_DMR_pointer,
1);
tmp_DMK_pointer += 1;
tmp_DMR_pointer += this->_paraV->get_col_size(iat2);
}
}
{
for (int ik = 0; ik < this->_nks; ++ik)
{
// cal k_phase
// if TK==std::complex<double>, kphase is e^{ikR}
const ModuleBase::Vector3<double> dR(r_index[0], r_index[1], r_index[2]);
const double arg = (this->_kv->kvec_d[ik] * dR) * ModuleBase::TWO_PI;
double sinp, cosp;
ModuleBase::libm::sincos(arg, &sinp, &cosp);
std::complex<double> kphase = std::complex<double>(cosp, sinp);
// set DMR element
double* tmp_DMR_pointer = tmp_matrix->get_pointer();
std::complex<double>* tmp_DMK_pointer = this->_DMK[ik + ik_begin].data();
double* DMK_real_pointer = nullptr;
double* DMK_imag_pointer = nullptr;
// jump DMK to fill DMR
// DMR is row-major, DMK is column-major
tmp_DMK_pointer += col_ap * this->_paraV->nrow + row_ap;
for (int mu = 0; mu < this->_paraV->get_row_size(iat1); ++mu)
{
DMK_real_pointer = (double*)tmp_DMK_pointer;
DMK_imag_pointer = DMK_real_pointer + 1;
BlasConnector::axpy(this->_paraV->get_col_size(iat2),
kphase.real(),
DMK_real_pointer,
ld_hk2,
tmp_DMR_pointer,
1);
// "-" since i^2 = -1
BlasConnector::axpy(this->_paraV->get_col_size(iat2),
-kphase.imag(),
DMK_imag_pointer,
ld_hk2,
tmp_DMR_pointer,
1);
tmp_DMK_pointer += 1;
tmp_DMR_pointer += this->_paraV->get_col_size(iat2);
}
}
}

// treat DMR as pauli matrix when NSPIN=4
if(GlobalV::NSPIN==4)
{
Expand Down
50 changes: 28 additions & 22 deletions source/module_elecstate/module_dm/density_matrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,28 +15,33 @@ namespace elecstate
* <TK,TR> = <double,double> for Gamma-only calculation
* <TK,TR> = <std::complex<double>,double> for multi-k calculation
*/
template<typename T> struct ShiftRealComplex
{
using type = void;
};
template<>
struct ShiftRealComplex<double> {
using type = std::complex<double>;
};
template<>
struct ShiftRealComplex<std::complex<double>> {
using type = double;
};

template <typename TK, typename TR>
class DensityMatrix
{
using TRShift = typename ShiftRealComplex<TR>::type;
public:
/**
* @brief Destructor of class DensityMatrix
*/
~DensityMatrix();
template<typename T> struct ShiftRealComplex
{
using type = void;
};

template<>
struct ShiftRealComplex<double>
{
using type = std::complex<double>;
};

template<>
struct ShiftRealComplex<std::complex<double>>
{
using type = double;
};

template <typename TK, typename TR>
class DensityMatrix
{
using TRShift = typename ShiftRealComplex<TR>::type;

public:
/**
* @brief Destructor of class DensityMatrix
*/
~DensityMatrix();

/**
* @brief Constructor of class DensityMatrix for multi-k calculation
Expand Down Expand Up @@ -160,6 +165,7 @@ namespace elecstate
* @brief get pointer of paraV
*/
const Parallel_Orbitals* get_paraV_pointer() const;

const K_Vectors* get_kv_pointer() const;

/**
Expand Down
Loading

0 comments on commit 96b72b2

Please sign in to comment.