Skip to content

Commit

Permalink
Refactor: add CalAtomsInfo to modify parameter (#5132)
Browse files Browse the repository at this point in the history
* Refactor: add CalAtomsInfo to modify parameter
   prepare to remove globalv

* fix compile

* fix UTs
  • Loading branch information
Qianruipku authored Sep 21, 2024
1 parent 80b2c75 commit 90e4921
Show file tree
Hide file tree
Showing 21 changed files with 435 additions and 346 deletions.
75 changes: 75 additions & 0 deletions source/module_cell/cal_atoms_info.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
#ifndef CAL_ATOMS_INFO_H
#define CAL_ATOMS_INFO_H
#include "module_parameter/parameter.h"
#include "unitcell.h"
class CalAtomsInfo
{
public:
CalAtomsInfo(){};
~CalAtomsInfo(){};

/**
* @brief Calculate the atom information from pseudopotential to set Parameter
*
* @param atoms [in] Atom pointer
* @param ntype [in] number of atom types
* @param para [out] Parameter object
*/
void cal_atoms_info(const Atom* atoms, const int& ntype, Parameter& para)
{
// calculate initial total magnetization when NSPIN=2
if (para.inp.nspin == 2 && !para.globalv.two_fermi)
{
for (int it = 0; it < ntype; ++it)
{
for (int ia = 0; ia < atoms[it].na; ++ia)
{
GlobalV::nupdown += atoms[it].mag[ia];
}
}
GlobalV::ofs_running << " The readin total magnetization is " << GlobalV::nupdown << std::endl;
}

if (!para.inp.use_paw)
{
// decide whether to be USPP
for (int it = 0; it < ntype; ++it)
{
if (atoms[it].ncpp.tvanp)
{
GlobalV::use_uspp = true;
}
}

// calculate the total number of local basis
GlobalV::NLOCAL = 0;
for (int it = 0; it < ntype; ++it)
{
const int nlocal_it = atoms[it].nw * atoms[it].na;
if (para.inp.nspin != 4)
{
GlobalV::NLOCAL += nlocal_it;
}
else
{
GlobalV::NLOCAL += nlocal_it * 2; // zhengdy-soc
}
}
}

// calculate the total number of electrons
cal_nelec(atoms, ntype, GlobalV::nelec);

// autoset and check GlobalV::NBANDS
std::vector<double> nelec_spin(2, 0.0);
if (para.inp.nspin == 2)
{
nelec_spin[0] = (GlobalV::nelec + GlobalV::nupdown) / 2.0;
nelec_spin[1] = (GlobalV::nelec - GlobalV::nupdown) / 2.0;
}
cal_nbands(GlobalV::nelec, GlobalV::NLOCAL, nelec_spin, GlobalV::NBANDS);

return;
}
};
#endif
4 changes: 2 additions & 2 deletions source/module_cell/module_neighbor/test/prepare_unitcell.h
Original file line number Diff line number Diff line change
Expand Up @@ -252,8 +252,8 @@ UcellTestPrepare::UcellTestPrepare(std::string latname_in,
coor_type(coor_type_in),
coordinates(coordinates_in)
{
mbl = {0};
velocity = {0};
mbl = std::valarray<double>(0.0, coordinates_in.size());
velocity = std::valarray<double>(0.0, coordinates_in.size());
}

UcellTestPrepare::UcellTestPrepare(std::string latname_in,
Expand Down
4 changes: 2 additions & 2 deletions source/module_cell/test/prepare_unitcell.h
Original file line number Diff line number Diff line change
Expand Up @@ -251,8 +251,8 @@ UcellTestPrepare::UcellTestPrepare(std::string latname_in,
coor_type(coor_type_in),
coordinates(coordinates_in)
{
mbl = {0};
velocity = {0};
mbl = std::valarray<double>(0.0, coordinates_in.size());
velocity = std::valarray<double>(0.0, coordinates_in.size());
}

UcellTestPrepare::UcellTestPrepare(std::string latname_in,
Expand Down
2 changes: 1 addition & 1 deletion source/module_cell/test/support/mock_unitcell.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,5 +124,5 @@ void UnitCell::setup(const std::string& latname_in,
const int& lmaxmax_in,
const bool& init_vel_in,
const std::string& fixed_axes_in) {}
void UnitCell::cal_nelec(double& nelec) {}
void cal_nelec(const Atom* atoms, const int& ntype, double& nelec) {}
void UnitCell::compare_atom_labels(std::string label1, std::string label2) {}
170 changes: 166 additions & 4 deletions source/module_cell/test/unitcell_test_readpp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ Magnetism::~Magnetism() { delete[] this->start_magnetization; }
* possible of an element
* - CalNelec: UnitCell::cal_nelec
* - calculate the total number of valence electrons from psp files
* - CalNbands: elecstate::ElecState::cal_nbands()
* - calculate the number of bands
*/

// mock function
Expand All @@ -114,9 +116,16 @@ class UcellTest : public ::testing::Test {
pp_dir = "./support/";
PARAM.input.pseudo_rcut = 15.0;
PARAM.input.dft_functional = "default";
PARAM.input.esolver_type = "ksdft";
PARAM.input.test_pseudo_cell = true;
PARAM.input.nspin = 1;
PARAM.input.basis_type = "pw";
GlobalV::nelec = 10.0;
GlobalV::nupdown = 0.0;
PARAM.sys.two_fermi = false;
GlobalV::NBANDS = 6;
GlobalV::NLOCAL = 6;
PARAM.input.lspinorb = false;
}
void TearDown() { ofs.close(); }
};
Expand Down Expand Up @@ -256,6 +265,7 @@ TEST_F(UcellTest, CalNwfc1) {
ucell->read_cell_pseudopots(pp_dir, ofs);
EXPECT_FALSE(ucell->atoms[0].ncpp.has_so);
EXPECT_FALSE(ucell->atoms[1].ncpp.has_so);
GlobalV::NLOCAL = 3 * 9;
ucell->cal_nwfc(ofs);
EXPECT_EQ(ucell->atoms[0].iw2l[8], 2);
EXPECT_EQ(ucell->atoms[0].iw2n[8], 0);
Expand All @@ -282,7 +292,6 @@ TEST_F(UcellTest, CalNwfc1) {
EXPECT_EQ(ucell->atoms[0].nw, 9);
EXPECT_EQ(ucell->atoms[1].nw, 9);
EXPECT_EQ(ucell->nwmax, 9);
EXPECT_EQ(GlobalV::NLOCAL, 3 * 9);
// check itia2iat
EXPECT_EQ(ucell->itia2iat.getSize(), 4);
EXPECT_EQ(ucell->itia2iat(0, 0), 0);
Expand Down Expand Up @@ -322,8 +331,8 @@ TEST_F(UcellTest, CalNwfc2) {
ucell->read_cell_pseudopots(pp_dir, ofs);
EXPECT_FALSE(ucell->atoms[0].ncpp.has_so);
EXPECT_FALSE(ucell->atoms[1].ncpp.has_so);
ucell->cal_nwfc(ofs);
EXPECT_EQ(GlobalV::NLOCAL, 3 * 9 * 2);
GlobalV::NLOCAL = 3 * 9 * 2;
EXPECT_NO_THROW(ucell->cal_nwfc(ofs));
}

TEST_F(UcellDeathTest, CheckStructure) {
Expand Down Expand Up @@ -396,10 +405,163 @@ TEST_F(UcellTest, CalNelec) {
EXPECT_EQ(1, ucell->atoms[0].na);
EXPECT_EQ(2, ucell->atoms[1].na);
double nelec = 0;
ucell->cal_nelec(nelec);
cal_nelec(ucell->atoms, ucell->ntype, nelec);
EXPECT_DOUBLE_EQ(6, nelec);
}

TEST_F(UcellTest, CalNbands)
{
std::vector<double> nelec_spin(2, 5.0);
cal_nbands(GlobalV::nelec, GlobalV::NLOCAL, nelec_spin, GlobalV::NBANDS);
EXPECT_EQ(GlobalV::NBANDS, 6);
}

TEST_F(UcellTest, CalNbandsFractionElec)
{
GlobalV::nelec = 9.5;
std::vector<double> nelec_spin(2, 5.0);
cal_nbands(GlobalV::nelec, GlobalV::NLOCAL, nelec_spin, GlobalV::NBANDS);
EXPECT_EQ(GlobalV::NBANDS, 6);
}

TEST_F(UcellTest, CalNbandsSOC)
{
PARAM.input.lspinorb = true;
GlobalV::NBANDS = 0;
std::vector<double> nelec_spin(2, 5.0);
cal_nbands(GlobalV::nelec, GlobalV::NLOCAL, nelec_spin, GlobalV::NBANDS);
EXPECT_EQ(GlobalV::NBANDS, 20);
}

TEST_F(UcellTest, CalNbandsSDFT)
{
PARAM.input.esolver_type = "sdft";
std::vector<double> nelec_spin(2, 5.0);
EXPECT_NO_THROW(cal_nbands(GlobalV::nelec, GlobalV::NLOCAL, nelec_spin, GlobalV::NBANDS));
}

TEST_F(UcellTest, CalNbandsLCAO)
{
PARAM.input.basis_type = "lcao";
std::vector<double> nelec_spin(2, 5.0);
EXPECT_NO_THROW(cal_nbands(GlobalV::nelec, GlobalV::NLOCAL, nelec_spin, GlobalV::NBANDS));
}

TEST_F(UcellTest, CalNbandsLCAOINPW)
{
PARAM.input.basis_type = "lcao_in_pw";
GlobalV::NLOCAL = GlobalV::NBANDS - 1;
std::vector<double> nelec_spin(2, 5.0);
testing::internal::CaptureStdout();
EXPECT_EXIT(cal_nbands(GlobalV::nelec, GlobalV::NLOCAL, nelec_spin, GlobalV::NBANDS), ::testing::ExitedWithCode(0), "");
output = testing::internal::GetCapturedStdout();
EXPECT_THAT(output, testing::HasSubstr("NLOCAL < NBANDS"));
}

TEST_F(UcellTest, CalNbandsWarning1)
{
GlobalV::NBANDS = GlobalV::nelec / 2 - 1;
std::vector<double> nelec_spin(2, 5.0);
testing::internal::CaptureStdout();
EXPECT_EXIT(cal_nbands(GlobalV::nelec, GlobalV::NLOCAL, nelec_spin, GlobalV::NBANDS), ::testing::ExitedWithCode(0), "");
output = testing::internal::GetCapturedStdout();
EXPECT_THAT(output, testing::HasSubstr("Too few bands!"));
}

TEST_F(UcellTest, CalNbandsWarning2)
{
PARAM.input.nspin = 2;
GlobalV::nupdown = 4.0;
std::vector<double> nelec_spin(2);
nelec_spin[0] = (GlobalV::nelec + GlobalV::nupdown) / 2.0;
nelec_spin[1] = (GlobalV::nelec - GlobalV::nupdown) / 2.0;
testing::internal::CaptureStdout();
EXPECT_EXIT(cal_nbands(GlobalV::nelec, GlobalV::NLOCAL, nelec_spin, GlobalV::NBANDS), ::testing::ExitedWithCode(0), "");
output = testing::internal::GetCapturedStdout();
EXPECT_THAT(output, testing::HasSubstr("Too few spin up bands!"));
}

TEST_F(UcellTest, CalNbandsWarning3)
{
PARAM.input.nspin = 2;
GlobalV::nupdown = -4.0;
std::vector<double> nelec_spin(2);
nelec_spin[0] = (GlobalV::nelec + GlobalV::nupdown) / 2.0;
nelec_spin[1] = (GlobalV::nelec - GlobalV::nupdown) / 2.0;
testing::internal::CaptureStdout();
EXPECT_EXIT(cal_nbands(GlobalV::nelec, GlobalV::NLOCAL, nelec_spin, GlobalV::NBANDS), ::testing::ExitedWithCode(0), "");
output = testing::internal::GetCapturedStdout();
EXPECT_THAT(output, testing::HasSubstr("Too few spin down bands!"));
}

TEST_F(UcellTest, CalNbandsSpin1)
{
PARAM.input.nspin = 1;
GlobalV::NBANDS = 0;
std::vector<double> nelec_spin(2, 5.0);
cal_nbands(GlobalV::nelec, GlobalV::NLOCAL, nelec_spin, GlobalV::NBANDS);
EXPECT_EQ(GlobalV::NBANDS, 15);
}

TEST_F(UcellTest, CalNbandsSpin1LCAO)
{
PARAM.input.nspin = 1;
GlobalV::NBANDS = 0;
PARAM.input.basis_type = "lcao";
std::vector<double> nelec_spin(2, 5.0);
cal_nbands(GlobalV::nelec, GlobalV::NLOCAL, nelec_spin, GlobalV::NBANDS);
EXPECT_EQ(GlobalV::NBANDS, 6);
}

TEST_F(UcellTest, CalNbandsSpin4)
{
PARAM.input.nspin = 4;
GlobalV::NBANDS = 0;
std::vector<double> nelec_spin(2, 5.0);
cal_nbands(GlobalV::nelec, GlobalV::NLOCAL, nelec_spin, GlobalV::NBANDS);
EXPECT_EQ(GlobalV::NBANDS, 30);
}

TEST_F(UcellTest, CalNbandsSpin4LCAO)
{
PARAM.input.nspin = 4;
GlobalV::NBANDS = 0;
PARAM.input.basis_type = "lcao";
std::vector<double> nelec_spin(2, 5.0);
cal_nbands(GlobalV::nelec, GlobalV::NLOCAL, nelec_spin, GlobalV::NBANDS);
EXPECT_EQ(GlobalV::NBANDS, 6);
}

TEST_F(UcellTest, CalNbandsSpin2)
{
PARAM.input.nspin = 2;
GlobalV::NBANDS = 0;
std::vector<double> nelec_spin(2, 5.0);
cal_nbands(GlobalV::nelec, GlobalV::NLOCAL, nelec_spin, GlobalV::NBANDS);
EXPECT_EQ(GlobalV::NBANDS, 16);
}

TEST_F(UcellTest, CalNbandsSpin2LCAO)
{
PARAM.input.nspin = 2;
GlobalV::NBANDS = 0;
PARAM.input.basis_type = "lcao";
std::vector<double> nelec_spin(2, 5.0);
cal_nbands(GlobalV::nelec, GlobalV::NLOCAL, nelec_spin, GlobalV::NBANDS);
EXPECT_EQ(GlobalV::NBANDS, 6);
}

TEST_F(UcellTest, CalNbandsGaussWarning)
{
GlobalV::NBANDS = 5;
std::vector<double> nelec_spin(2, 5.0);
PARAM.input.smearing_method = "gaussian";
testing::internal::CaptureStdout();
EXPECT_EXIT(cal_nbands(GlobalV::nelec, GlobalV::NLOCAL, nelec_spin, GlobalV::NBANDS), ::testing::ExitedWithCode(0), "");
output = testing::internal::GetCapturedStdout();
EXPECT_THAT(output, testing::HasSubstr("for smearing, num. of bands > num. of occupied bands"));
}

#ifdef __MPI
#include "mpi.h"
int main(int argc, char** argv) {
Expand Down
Loading

0 comments on commit 90e4921

Please sign in to comment.