Skip to content

Commit

Permalink
Feature: Support outputting partial charge densities for different k-…
Browse files Browse the repository at this point in the history
…points and spins separately when using PW basis set (#4829)

* Add INPUT parameter if_separate_k for partial charge calc under PW basis set

* Modified some annotation

* Change file names and add files for get_pchg calculation

* Did nothing, just formatting

* Resolve conflicts and edit get_chg_pw (draft 1)

* Edit get_pchg_pw (draft 2)

* Edit get_pchg_pw (draft 3)

* Edit get_pchg_pw (draft 4)

* Modify CMakeLists to compile get_pchg_pw

* Move definition of template function get_pchg_pw from cpp to header file

* Modify CMakeLists to compile get_pchg_pw stage 2

* Delete original implementation in esolver_ks_pw.cpp
  • Loading branch information
AsTonyshment authored Aug 1, 2024
1 parent f99fc2b commit ef393d8
Show file tree
Hide file tree
Showing 11 changed files with 274 additions and 142 deletions.
6 changes: 3 additions & 3 deletions source/Makefile.Objects
Original file line number Diff line number Diff line change
Expand Up @@ -520,16 +520,16 @@ OBJS_IO=input_conv.o\
read_input_item_exx_dftu.o\
read_input_item_other.o\
read_input_item_output.o\
bcast_globalv.o
bcast_globalv.o\

OBJS_IO_LCAO=cal_r_overlap_R.o\
write_orb_info.o\
write_dos_lcao.o\
write_proj_band_lcao.o\
write_istate_info.o\
nscf_fermi_surf.o\
get_pchg.o\
get_wf.o\
get_pchg_lcao.o\
get_wf_lcao.o\
io_dmk.o\
unk_overlap_lcao.o\
read_wfc_nao.o\
Expand Down
144 changes: 25 additions & 119 deletions source/module_esolver/esolver_ks_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include "module_hsolver/kernels/math_kernel_op.h"
#include "module_io/berryphase.h"
#include "module_io/cube_io.h"
#include "module_io/get_pchg_pw.h"
#include "module_io/numerical_basis.h"
#include "module_io/numerical_descriptor.h"
#include "module_io/to_wannier90_pw.h"
Expand Down Expand Up @@ -99,7 +100,8 @@ ESolver_KS_PW<T, Device>::~ESolver_KS_PW()
}

template <typename T, typename Device>
void ESolver_KS_PW<T, Device>::before_all_runners(const Input_para& inp, UnitCell& ucell) {
void ESolver_KS_PW<T, Device>::before_all_runners(const Input_para& inp, UnitCell& ucell)
{
// 1) call before_all_runners() of ESolver_KS
ESolver_KS<T, Device>::before_all_runners(inp, ucell);

Expand Down Expand Up @@ -661,127 +663,31 @@ void ESolver_KS_PW<T, Device>::after_scf(const int istep)
this->psi[0].size());
}

// Get bands_to_print through public function of INPUT (returns a const
// pointer to string)
// Calculate band-decomposed (partial) charge density
const std::vector<int> bands_to_print = PARAM.inp.bands_to_print;
if (bands_to_print.size() > 0)
{
// bands_picked is a vector of 0s and 1s, where 1 means the band is
// picked to output
std::vector<int> bands_picked;
bands_picked.resize(this->kspw_psi->get_nbands());
ModuleBase::GlobalFunc::ZEROS(bands_picked.data(), this->kspw_psi->get_nbands());

// Check if length of bands_to_print is valid
if (static_cast<int>(bands_to_print.size()) > this->kspw_psi->get_nbands())
{
ModuleBase::WARNING_QUIT("ESolver_KS_PW::after_scf",
"The number of bands specified by `bands_to_print` in the "
"INPUT file exceeds `nbands`!");
}

// Check if all elements in bands_picked are 0 or 1
for (int value: bands_to_print)
{
if (value != 0 && value != 1)
{
ModuleBase::WARNING_QUIT("ESolver_KS_PW::after_scf",
"The elements of `bands_to_print` must be either 0 or 1. "
"Invalid values found!");
}
}

// Fill bands_picked with values from bands_to_print
// Remaining bands are already set to 0
int length = std::min(static_cast<int>(bands_to_print.size()), this->kspw_psi->get_nbands());
for (int i = 0; i < length; ++i)
{
// bands_to_print rely on function parse_expression
// Initially designed for ocp_set, which can be double
bands_picked[i] = static_cast<int>(bands_to_print[i]);
}

std::complex<double>* wfcr = new std::complex<double>[this->pw_rho->nxyz];
double* rho_band = new double[this->pw_rho->nxyz];

for (int ib = 0; ib < this->kspw_psi->get_nbands(); ++ib)
{
// Skip the loop iteration if bands_picked[ib] is 0
if (!bands_picked[ib])
{
continue;
}

for (int i = 0; i < this->pw_rho->nxyz; i++)
{
// Initialize rho_band to zero for each band
rho_band[i] = 0.0;
}

for (int ik = 0; ik < this->kv.get_nks(); ik++)
{
this->psi->fix_k(ik);
this->pw_wfc->recip_to_real(this->ctx, &psi[0](ib, 0), wfcr, ik);

double w1 = static_cast<double>(this->kv.wk[ik] / GlobalC::ucell.omega);

for (int i = 0; i < this->pw_rho->nxyz; i++)
{
rho_band[i] += std::norm(wfcr[i]) * w1;
}
}

// Symmetrize the charge density, otherwise the results are incorrect if the symmetry is on
std::cout << " Symmetrizing band-decomposed charge density..." << std::endl;
Symmetry_rho srho;
for (int is = 0; is < GlobalV::NSPIN; is++)
{
// Use vector instead of raw pointers
std::vector<double*> rho_save_pointers(GlobalV::NSPIN, rho_band);
std::vector<std::vector<std::complex<double>>> rhog(
GlobalV::NSPIN,
std::vector<std::complex<double>>(this->pelec->charge->ngmc));

// Convert vector of vectors to vector of pointers
std::vector<std::complex<double>*> rhog_pointers(GlobalV::NSPIN);
for (int s = 0; s < GlobalV::NSPIN; s++)
{
rhog_pointers[s] = rhog[s].data();
}

srho.begin(is,
rho_save_pointers.data(),
rhog_pointers.data(),
this->pelec->charge->ngmc,
nullptr,
this->pw_rhod,
GlobalC::Pgrid,
GlobalC::ucell.symm);
}

std::stringstream ssc;
ssc << GlobalV::global_out_dir << "BAND" << ib + 1 << "_CHG.cube"; // band index starts from 1

ModuleIO::write_cube(
#ifdef __MPI
this->pw_big->bz,
this->pw_big->nbz,
this->pw_rhod->nplane,
this->pw_rhod->startz_current,
#endif
rho_band,
0,
GlobalV::NSPIN,
0,
ssc.str(),
this->pw_rhod->nx,
this->pw_rhod->ny,
this->pw_rhod->nz,
0.0,
&(GlobalC::ucell));
}
delete[] wfcr;
delete[] rho_band;
ModuleIO::get_pchg_pw(bands_to_print,
this->kspw_psi->get_nbands(),
GlobalV::NSPIN,
this->pw_rhod->nx,
this->pw_rhod->ny,
this->pw_rhod->nz,
this->pw_rhod->nxyz,
this->kv.get_nks(),
this->kv.isk,
this->kv.wk,
this->pw_big->bz,
this->pw_big->nbz,
this->pelec->charge->ngmc,
&GlobalC::ucell,
this->psi,
this->pw_rhod,
this->pw_wfc,
this->ctx,
GlobalC::Pgrid,
GlobalV::global_out_dir,
PARAM.inp.if_separate_k);
}
}

Expand Down
4 changes: 2 additions & 2 deletions source/module_esolver/lcao_before_scf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
#include "module_cell/module_neighbor/sltk_atom_arrange.h"
#include "module_cell/module_neighbor/sltk_grid_driver.h"
#include "module_io/berryphase.h"
#include "module_io/get_pchg.h"
#include "module_io/get_wf.h"
#include "module_io/get_pchg_lcao.h"
#include "module_io/get_wf_lcao.h"
#include "module_io/to_wannier90_lcao.h"
#include "module_io/to_wannier90_lcao_in_pw.h"
#include "module_io/write_HS_R.h"
Expand Down
4 changes: 2 additions & 2 deletions source/module_esolver/lcao_nscf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
#include "module_cell/module_neighbor/sltk_grid_driver.h"
#include "module_io/berryphase.h"
#include "module_io/cube_io.h"
#include "module_io/get_pchg.h"
#include "module_io/get_wf.h"
#include "module_io/get_pchg_lcao.h"
#include "module_io/get_wf_lcao.h"
#include "module_io/to_wannier90_lcao.h"
#include "module_io/to_wannier90_lcao_in_pw.h"
#include "module_io/write_HS_R.h"
Expand Down
20 changes: 11 additions & 9 deletions source/module_esolver/lcao_others.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
#include "module_cell/module_neighbor/sltk_atom_arrange.h"
#include "module_cell/module_neighbor/sltk_grid_driver.h"
#include "module_io/berryphase.h"
#include "module_io/get_pchg.h"
#include "module_io/get_wf.h"
#include "module_io/get_pchg_lcao.h"
#include "module_io/get_wf_lcao.h"
#include "module_io/to_wannier90_lcao.h"
#include "module_io/to_wannier90_lcao_in_pw.h"
#include "module_io/write_HS_R.h"
Expand Down Expand Up @@ -45,7 +45,8 @@ void ESolver_KS_LCAO<TK, TR>::others(const int istep)

const std::string cal_type = GlobalV::CALCULATION;

if (cal_type == "get_S") {
if (cal_type == "get_S")
{
std::cout << FmtCore::format("\n * * * * * *\n << Start %s.\n", "writing the overlap matrix");
this->get_S();
std::cout << FmtCore::format(" >> Finish %s.\n * * * * * *\n", "writing the overlap matrix");
Expand All @@ -54,7 +55,9 @@ void ESolver_KS_LCAO<TK, TR>::others(const int istep)

// return; // use 'return' will cause segmentation fault. by mohan
// 2024-06-09
} else if (cal_type == "test_memory") {
}
else if (cal_type == "test_memory")
{
std::cout << FmtCore::format("\n * * * * * *\n << Start %s.\n", "testing memory");
Cal_Test::test_memory(this->pw_rho,
this->pw_wfc,
Expand All @@ -67,9 +70,9 @@ void ESolver_KS_LCAO<TK, TR>::others(const int istep)
{
// test_search_neighbor();
std::cout << FmtCore::format("\n * * * * * *\n << Start %s.\n", "testing neighbour");
if (GlobalV::SEARCH_RADIUS < 0) {
std::cout << " SEARCH_RADIUS : " << GlobalV::SEARCH_RADIUS
<< std::endl;
if (GlobalV::SEARCH_RADIUS < 0)
{
std::cout << " SEARCH_RADIUS : " << GlobalV::SEARCH_RADIUS << std::endl;
std::cout << " please make sure search_radius > 0" << std::endl;
}

Expand Down Expand Up @@ -205,8 +208,7 @@ void ESolver_KS_LCAO<TK, TR>::others(const int istep)
}
else
{
ModuleBase::WARNING_QUIT("ESolver_KS_LCAO<TK, TR>::others",
"CALCULATION type not supported");
ModuleBase::WARNING_QUIT("ESolver_KS_LCAO<TK, TR>::others", "CALCULATION type not supported");
}

ModuleBase::timer::tick("ESolver_KS_LCAO", "others");
Expand Down
4 changes: 2 additions & 2 deletions source/module_io/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ if(ENABLE_LCAO)
write_orb_info.cpp
write_proj_band_lcao.cpp
nscf_fermi_surf.cpp
get_pchg.cpp
get_wf.cpp
get_pchg_lcao.cpp
get_wf_lcao.cpp
read_wfc_nao.cpp
read_wfc_lcao.cpp
write_wfc_nao.cpp
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#include "get_pchg.h"
#include "get_pchg_lcao.h"

#include "module_base/blas_connector.h"
#include "module_base/global_function.h"
Expand Down Expand Up @@ -120,7 +120,7 @@ void IState_Charge::begin(Gint_Gamma& gg,
ModuleBase::GlobalFunc::DCOPY(rho[is], rho_save[is].data(), rhopw_nrxx); // Copy data
}

std::cout << " Writting cube files...";
std::cout << " Writing cube files...";

for (int is = 0; is < nspin; ++is)
{
Expand Down Expand Up @@ -257,7 +257,7 @@ void IState_Charge::begin(Gint_k& gk,
ModuleBase::GlobalFunc::DCOPY(rho[is], rho_save[is].data(), rhopw_nrxx); // Copy data
}

std::cout << " Writting cube files...";
std::cout << " Writing cube files...";

for (int is = 0; is < nspin; ++is)
{
Expand Down Expand Up @@ -332,7 +332,7 @@ void IState_Charge::begin(Gint_k& gk,
ucell_in->symm);
}

std::cout << " Writting cube files...";
std::cout << " Writing cube files...";

for (int is = 0; is < nspin; ++is)
{
Expand Down
File renamed without changes.
Loading

0 comments on commit ef393d8

Please sign in to comment.