diff --git a/source/module_cell/cal_atoms_info.h b/source/module_cell/cal_atoms_info.h new file mode 100644 index 0000000000..92789eb1ee --- /dev/null +++ b/source/module_cell/cal_atoms_info.h @@ -0,0 +1,75 @@ +#ifndef CAL_ATOMS_INFO_H +#define CAL_ATOMS_INFO_H +#include "module_parameter/parameter.h" +#include "unitcell.h" +class CalAtomsInfo +{ + public: + CalAtomsInfo(){}; + ~CalAtomsInfo(){}; + + /** + * @brief Calculate the atom information from pseudopotential to set Parameter + * + * @param atoms [in] Atom pointer + * @param ntype [in] number of atom types + * @param para [out] Parameter object + */ + void cal_atoms_info(const Atom* atoms, const int& ntype, Parameter& para) + { + // calculate initial total magnetization when NSPIN=2 + if (para.inp.nspin == 2 && !para.globalv.two_fermi) + { + for (int it = 0; it < ntype; ++it) + { + for (int ia = 0; ia < atoms[it].na; ++ia) + { + GlobalV::nupdown += atoms[it].mag[ia]; + } + } + GlobalV::ofs_running << " The readin total magnetization is " << GlobalV::nupdown << std::endl; + } + + if (!para.inp.use_paw) + { + // decide whether to be USPP + for (int it = 0; it < ntype; ++it) + { + if (atoms[it].ncpp.tvanp) + { + GlobalV::use_uspp = true; + } + } + + // calculate the total number of local basis + GlobalV::NLOCAL = 0; + for (int it = 0; it < ntype; ++it) + { + const int nlocal_it = atoms[it].nw * atoms[it].na; + if (para.inp.nspin != 4) + { + GlobalV::NLOCAL += nlocal_it; + } + else + { + GlobalV::NLOCAL += nlocal_it * 2; // zhengdy-soc + } + } + } + + // calculate the total number of electrons + cal_nelec(atoms, ntype, GlobalV::nelec); + + // autoset and check GlobalV::NBANDS + std::vector nelec_spin(2, 0.0); + if (para.inp.nspin == 2) + { + nelec_spin[0] = (GlobalV::nelec + GlobalV::nupdown) / 2.0; + nelec_spin[1] = (GlobalV::nelec - GlobalV::nupdown) / 2.0; + } + cal_nbands(GlobalV::nelec, GlobalV::NLOCAL, nelec_spin, GlobalV::NBANDS); + + return; + } +}; +#endif \ No newline at end of file diff --git a/source/module_cell/module_neighbor/test/prepare_unitcell.h b/source/module_cell/module_neighbor/test/prepare_unitcell.h index 5f217aadf2..a92f15842f 100644 --- a/source/module_cell/module_neighbor/test/prepare_unitcell.h +++ b/source/module_cell/module_neighbor/test/prepare_unitcell.h @@ -252,8 +252,8 @@ UcellTestPrepare::UcellTestPrepare(std::string latname_in, coor_type(coor_type_in), coordinates(coordinates_in) { - mbl = {0}; - velocity = {0}; + mbl = std::valarray(0.0, coordinates_in.size()); + velocity = std::valarray(0.0, coordinates_in.size()); } UcellTestPrepare::UcellTestPrepare(std::string latname_in, diff --git a/source/module_cell/test/prepare_unitcell.h b/source/module_cell/test/prepare_unitcell.h index a02eafdbb5..f6692ae8f7 100644 --- a/source/module_cell/test/prepare_unitcell.h +++ b/source/module_cell/test/prepare_unitcell.h @@ -251,8 +251,8 @@ UcellTestPrepare::UcellTestPrepare(std::string latname_in, coor_type(coor_type_in), coordinates(coordinates_in) { - mbl = {0}; - velocity = {0}; + mbl = std::valarray(0.0, coordinates_in.size()); + velocity = std::valarray(0.0, coordinates_in.size()); } UcellTestPrepare::UcellTestPrepare(std::string latname_in, diff --git a/source/module_cell/test/support/mock_unitcell.cpp b/source/module_cell/test/support/mock_unitcell.cpp index 23b4df3909..b91430987a 100644 --- a/source/module_cell/test/support/mock_unitcell.cpp +++ b/source/module_cell/test/support/mock_unitcell.cpp @@ -124,5 +124,5 @@ void UnitCell::setup(const std::string& latname_in, const int& lmaxmax_in, const bool& init_vel_in, const std::string& fixed_axes_in) {} -void UnitCell::cal_nelec(double& nelec) {} +void cal_nelec(const Atom* atoms, const int& ntype, double& nelec) {} void UnitCell::compare_atom_labels(std::string label1, std::string label2) {} \ No newline at end of file diff --git a/source/module_cell/test/unitcell_test_readpp.cpp b/source/module_cell/test/unitcell_test_readpp.cpp index b2b01c0559..f643db848b 100644 --- a/source/module_cell/test/unitcell_test_readpp.cpp +++ b/source/module_cell/test/unitcell_test_readpp.cpp @@ -89,6 +89,8 @@ Magnetism::~Magnetism() { delete[] this->start_magnetization; } * possible of an element * - CalNelec: UnitCell::cal_nelec * - calculate the total number of valence electrons from psp files + * - CalNbands: elecstate::ElecState::cal_nbands() + * - calculate the number of bands */ // mock function @@ -114,9 +116,16 @@ class UcellTest : public ::testing::Test { pp_dir = "./support/"; PARAM.input.pseudo_rcut = 15.0; PARAM.input.dft_functional = "default"; + PARAM.input.esolver_type = "ksdft"; PARAM.input.test_pseudo_cell = true; PARAM.input.nspin = 1; PARAM.input.basis_type = "pw"; + GlobalV::nelec = 10.0; + GlobalV::nupdown = 0.0; + PARAM.sys.two_fermi = false; + GlobalV::NBANDS = 6; + GlobalV::NLOCAL = 6; + PARAM.input.lspinorb = false; } void TearDown() { ofs.close(); } }; @@ -256,6 +265,7 @@ TEST_F(UcellTest, CalNwfc1) { ucell->read_cell_pseudopots(pp_dir, ofs); EXPECT_FALSE(ucell->atoms[0].ncpp.has_so); EXPECT_FALSE(ucell->atoms[1].ncpp.has_so); + GlobalV::NLOCAL = 3 * 9; ucell->cal_nwfc(ofs); EXPECT_EQ(ucell->atoms[0].iw2l[8], 2); EXPECT_EQ(ucell->atoms[0].iw2n[8], 0); @@ -282,7 +292,6 @@ TEST_F(UcellTest, CalNwfc1) { EXPECT_EQ(ucell->atoms[0].nw, 9); EXPECT_EQ(ucell->atoms[1].nw, 9); EXPECT_EQ(ucell->nwmax, 9); - EXPECT_EQ(GlobalV::NLOCAL, 3 * 9); // check itia2iat EXPECT_EQ(ucell->itia2iat.getSize(), 4); EXPECT_EQ(ucell->itia2iat(0, 0), 0); @@ -322,8 +331,8 @@ TEST_F(UcellTest, CalNwfc2) { ucell->read_cell_pseudopots(pp_dir, ofs); EXPECT_FALSE(ucell->atoms[0].ncpp.has_so); EXPECT_FALSE(ucell->atoms[1].ncpp.has_so); - ucell->cal_nwfc(ofs); - EXPECT_EQ(GlobalV::NLOCAL, 3 * 9 * 2); + GlobalV::NLOCAL = 3 * 9 * 2; + EXPECT_NO_THROW(ucell->cal_nwfc(ofs)); } TEST_F(UcellDeathTest, CheckStructure) { @@ -396,10 +405,163 @@ TEST_F(UcellTest, CalNelec) { EXPECT_EQ(1, ucell->atoms[0].na); EXPECT_EQ(2, ucell->atoms[1].na); double nelec = 0; - ucell->cal_nelec(nelec); + cal_nelec(ucell->atoms, ucell->ntype, nelec); EXPECT_DOUBLE_EQ(6, nelec); } +TEST_F(UcellTest, CalNbands) +{ + std::vector nelec_spin(2, 5.0); + cal_nbands(GlobalV::nelec, GlobalV::NLOCAL, nelec_spin, GlobalV::NBANDS); + EXPECT_EQ(GlobalV::NBANDS, 6); +} + +TEST_F(UcellTest, CalNbandsFractionElec) +{ + GlobalV::nelec = 9.5; + std::vector nelec_spin(2, 5.0); + cal_nbands(GlobalV::nelec, GlobalV::NLOCAL, nelec_spin, GlobalV::NBANDS); + EXPECT_EQ(GlobalV::NBANDS, 6); +} + +TEST_F(UcellTest, CalNbandsSOC) +{ + PARAM.input.lspinorb = true; + GlobalV::NBANDS = 0; + std::vector nelec_spin(2, 5.0); + cal_nbands(GlobalV::nelec, GlobalV::NLOCAL, nelec_spin, GlobalV::NBANDS); + EXPECT_EQ(GlobalV::NBANDS, 20); +} + +TEST_F(UcellTest, CalNbandsSDFT) +{ + PARAM.input.esolver_type = "sdft"; + std::vector nelec_spin(2, 5.0); + EXPECT_NO_THROW(cal_nbands(GlobalV::nelec, GlobalV::NLOCAL, nelec_spin, GlobalV::NBANDS)); +} + +TEST_F(UcellTest, CalNbandsLCAO) +{ + PARAM.input.basis_type = "lcao"; + std::vector nelec_spin(2, 5.0); + EXPECT_NO_THROW(cal_nbands(GlobalV::nelec, GlobalV::NLOCAL, nelec_spin, GlobalV::NBANDS)); +} + +TEST_F(UcellTest, CalNbandsLCAOINPW) +{ + PARAM.input.basis_type = "lcao_in_pw"; + GlobalV::NLOCAL = GlobalV::NBANDS - 1; + std::vector nelec_spin(2, 5.0); + testing::internal::CaptureStdout(); + EXPECT_EXIT(cal_nbands(GlobalV::nelec, GlobalV::NLOCAL, nelec_spin, GlobalV::NBANDS), ::testing::ExitedWithCode(0), ""); + output = testing::internal::GetCapturedStdout(); + EXPECT_THAT(output, testing::HasSubstr("NLOCAL < NBANDS")); +} + +TEST_F(UcellTest, CalNbandsWarning1) +{ + GlobalV::NBANDS = GlobalV::nelec / 2 - 1; + std::vector nelec_spin(2, 5.0); + testing::internal::CaptureStdout(); + EXPECT_EXIT(cal_nbands(GlobalV::nelec, GlobalV::NLOCAL, nelec_spin, GlobalV::NBANDS), ::testing::ExitedWithCode(0), ""); + output = testing::internal::GetCapturedStdout(); + EXPECT_THAT(output, testing::HasSubstr("Too few bands!")); +} + +TEST_F(UcellTest, CalNbandsWarning2) +{ + PARAM.input.nspin = 2; + GlobalV::nupdown = 4.0; + std::vector nelec_spin(2); + nelec_spin[0] = (GlobalV::nelec + GlobalV::nupdown) / 2.0; + nelec_spin[1] = (GlobalV::nelec - GlobalV::nupdown) / 2.0; + testing::internal::CaptureStdout(); + EXPECT_EXIT(cal_nbands(GlobalV::nelec, GlobalV::NLOCAL, nelec_spin, GlobalV::NBANDS), ::testing::ExitedWithCode(0), ""); + output = testing::internal::GetCapturedStdout(); + EXPECT_THAT(output, testing::HasSubstr("Too few spin up bands!")); +} + +TEST_F(UcellTest, CalNbandsWarning3) +{ + PARAM.input.nspin = 2; + GlobalV::nupdown = -4.0; + std::vector nelec_spin(2); + nelec_spin[0] = (GlobalV::nelec + GlobalV::nupdown) / 2.0; + nelec_spin[1] = (GlobalV::nelec - GlobalV::nupdown) / 2.0; + testing::internal::CaptureStdout(); + EXPECT_EXIT(cal_nbands(GlobalV::nelec, GlobalV::NLOCAL, nelec_spin, GlobalV::NBANDS), ::testing::ExitedWithCode(0), ""); + output = testing::internal::GetCapturedStdout(); + EXPECT_THAT(output, testing::HasSubstr("Too few spin down bands!")); +} + +TEST_F(UcellTest, CalNbandsSpin1) +{ + PARAM.input.nspin = 1; + GlobalV::NBANDS = 0; + std::vector nelec_spin(2, 5.0); + cal_nbands(GlobalV::nelec, GlobalV::NLOCAL, nelec_spin, GlobalV::NBANDS); + EXPECT_EQ(GlobalV::NBANDS, 15); +} + +TEST_F(UcellTest, CalNbandsSpin1LCAO) +{ + PARAM.input.nspin = 1; + GlobalV::NBANDS = 0; + PARAM.input.basis_type = "lcao"; + std::vector nelec_spin(2, 5.0); + cal_nbands(GlobalV::nelec, GlobalV::NLOCAL, nelec_spin, GlobalV::NBANDS); + EXPECT_EQ(GlobalV::NBANDS, 6); +} + +TEST_F(UcellTest, CalNbandsSpin4) +{ + PARAM.input.nspin = 4; + GlobalV::NBANDS = 0; + std::vector nelec_spin(2, 5.0); + cal_nbands(GlobalV::nelec, GlobalV::NLOCAL, nelec_spin, GlobalV::NBANDS); + EXPECT_EQ(GlobalV::NBANDS, 30); +} + +TEST_F(UcellTest, CalNbandsSpin4LCAO) +{ + PARAM.input.nspin = 4; + GlobalV::NBANDS = 0; + PARAM.input.basis_type = "lcao"; + std::vector nelec_spin(2, 5.0); + cal_nbands(GlobalV::nelec, GlobalV::NLOCAL, nelec_spin, GlobalV::NBANDS); + EXPECT_EQ(GlobalV::NBANDS, 6); +} + +TEST_F(UcellTest, CalNbandsSpin2) +{ + PARAM.input.nspin = 2; + GlobalV::NBANDS = 0; + std::vector nelec_spin(2, 5.0); + cal_nbands(GlobalV::nelec, GlobalV::NLOCAL, nelec_spin, GlobalV::NBANDS); + EXPECT_EQ(GlobalV::NBANDS, 16); +} + +TEST_F(UcellTest, CalNbandsSpin2LCAO) +{ + PARAM.input.nspin = 2; + GlobalV::NBANDS = 0; + PARAM.input.basis_type = "lcao"; + std::vector nelec_spin(2, 5.0); + cal_nbands(GlobalV::nelec, GlobalV::NLOCAL, nelec_spin, GlobalV::NBANDS); + EXPECT_EQ(GlobalV::NBANDS, 6); +} + +TEST_F(UcellTest, CalNbandsGaussWarning) +{ + GlobalV::NBANDS = 5; + std::vector nelec_spin(2, 5.0); + PARAM.input.smearing_method = "gaussian"; + testing::internal::CaptureStdout(); + EXPECT_EXIT(cal_nbands(GlobalV::nelec, GlobalV::NLOCAL, nelec_spin, GlobalV::NBANDS), ::testing::ExitedWithCode(0), ""); + output = testing::internal::GetCapturedStdout(); + EXPECT_THAT(output, testing::HasSubstr("for smearing, num. of bands > num. of occupied bands")); +} + #ifdef __MPI #include "mpi.h" int main(int argc, char** argv) { diff --git a/source/module_cell/unitcell.cpp b/source/module_cell/unitcell.cpp old mode 100644 new mode 100755 index 08c68aa259..4b90c06d9c --- a/source/module_cell/unitcell.cpp +++ b/source/module_cell/unitcell.cpp @@ -8,6 +8,7 @@ #include "module_base/global_variable.h" #include "unitcell.h" #include "module_parameter/parameter.h" +#include "cal_atoms_info.h" #ifdef __LCAO #include "../module_basis/module_ao/ORB_read.h" // to use 'ORB' -- mohan 2021-01-30 @@ -594,17 +595,6 @@ void UnitCell::setup_cell(const std::string& fn, std::ofstream& log) { this->bcast_unitcell(); #endif - // after read STRU, calculate initial total magnetization when NSPIN=2 - if (PARAM.inp.nspin == 2 && !PARAM.globalv.two_fermi) { - for (int it = 0; it < this->ntype; it++) { - for (int ia = 0; ia < this->atoms[it].na; ia++) { - GlobalV::nupdown += this->atoms[it].mag[ia]; - } - } - GlobalV::ofs_running << " The readin total magnetization is " - << GlobalV::nupdown << std::endl; - } - //======================================================== // Calculate unit cell volume // the reason to calculate volume here is @@ -831,14 +821,15 @@ void UnitCell::read_pseudo(std::ofstream& ofs) { ModuleBase::WARNING_QUIT("setup_cell", "All DFT functional must consistent."); } - if (atoms[it].ncpp.tvanp) { - GlobalV::use_uspp = true; - } } // setup the total number of PAOs cal_natomwfc(ofs); + // Calculate the information of atoms from the pseudopotential to set PARAM + CalAtomsInfo ca; + ca.cal_atoms_info(this->atoms, this->ntype, PARAM); + // setup GlobalV::NLOCAL cal_nwfc(ofs); @@ -950,14 +941,14 @@ void UnitCell::cal_nwfc(std::ofstream& log) { //=========================== // (3) set nwfc and stapos_wf //=========================== - GlobalV::NLOCAL = 0; + int nlocal_tmp = 0; for (int it = 0; it < ntype; it++) { - atoms[it].stapos_wf = GlobalV::NLOCAL; + atoms[it].stapos_wf = nlocal_tmp; const int nlocal_it = atoms[it].nw * atoms[it].na; if (PARAM.inp.nspin != 4) { - GlobalV::NLOCAL += nlocal_it; + nlocal_tmp += nlocal_it; } else { - GlobalV::NLOCAL += nlocal_it * 2; // zhengdy-soc + nlocal_tmp += nlocal_it * 2; // zhengdy-soc } // for tests @@ -968,17 +959,18 @@ void UnitCell::cal_nwfc(std::ofstream& log) { // OUT(GlobalV::ofs_running,"NLOCAL",GlobalV::NLOCAL); log << " " << std::setw(40) << "NLOCAL" - << " = " << GlobalV::NLOCAL << std::endl; + << " = " << nlocal_tmp << std::endl; //======================================================== // (4) set index for itia2iat, itiaiw2iwt //======================================================== // mohan add 2010-09-26 - assert(GlobalV::NLOCAL > 0); + assert(nlocal_tmp > 0); + assert(nlocal_tmp == GlobalV::NLOCAL); delete[] iwt2iat; delete[] iwt2iw; - this->iwt2iat = new int[GlobalV::NLOCAL]; - this->iwt2iw = new int[GlobalV::NLOCAL]; + this->iwt2iat = new int[nlocal_tmp]; + this->iwt2iw = new int[nlocal_tmp]; this->itia2iat.create(ntype, namax); // this->itiaiw2iwt.create(ntype, namax, nwmax*PARAM.globalv.npol); @@ -1582,65 +1574,160 @@ void UnitCell::remake_cell() { } } -void UnitCell::cal_nelec(double& nelec) { +void cal_nelec(const Atom* atoms, const int& ntype, double& nelec) +{ ModuleBase::TITLE("UnitCell", "cal_nelec"); GlobalV::ofs_running << "\n SETUP THE ELECTRONS NUMBER" << std::endl; - if (nelec == 0) { - if (PARAM.inp.use_paw) { + if (nelec == 0) + { + if (PARAM.inp.use_paw) + { #ifdef USE_PAW - for (int it = 0; it < this->ntype; it++) { + for (int it = 0; it < ntype; it++) + { std::stringstream ss1, ss2; - ss1 << " electron number of element " - << GlobalC::paw_cell.get_zat(it) << std::endl; - const int nelec_it - = GlobalC::paw_cell.get_val(it) * this->atoms[it].na; + ss1 << " electron number of element " << GlobalC::paw_cell.get_zat(it) << std::endl; + const int nelec_it = GlobalC::paw_cell.get_val(it) * atoms[it].na; nelec += nelec_it; - ss2 << "total electron number of element " - << GlobalC::paw_cell.get_zat(it); - - ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, - ss1.str(), - GlobalC::paw_cell.get_val(it)); - ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, - ss2.str(), - nelec_it); + ss2 << "total electron number of element " << GlobalC::paw_cell.get_zat(it); + + ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, ss1.str(), GlobalC::paw_cell.get_val(it)); + ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, ss2.str(), nelec_it); } #endif - } else { - for (int it = 0; it < this->ntype; it++) { + } + else + { + for (int it = 0; it < ntype; it++) + { std::stringstream ss1, ss2; - ss1 << "electron number of element " << this->atoms[it].label; - const double nelec_it - = this->atoms[it].ncpp.zv * this->atoms[it].na; + ss1 << "electron number of element " << atoms[it].label; + const double nelec_it = atoms[it].ncpp.zv * atoms[it].na; nelec += nelec_it; - ss2 << "total electron number of element " - << this->atoms[it].label; - - ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, - ss1.str(), - this->atoms[it].ncpp.zv); - ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, - ss2.str(), - nelec_it); + ss2 << "total electron number of element " << atoms[it].label; + + ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, ss1.str(), atoms[it].ncpp.zv); + ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, ss2.str(), nelec_it); } - ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, - "AUTOSET number of electrons: ", - nelec); + ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, "AUTOSET number of electrons: ", nelec); } } - if (PARAM.inp.nelec_delta != 0) { - ModuleBase::GlobalFunc::OUT( - GlobalV::ofs_running, - "nelec_delta is NOT zero, please make sure you know what you are " - "doing! nelec_delta: ", - PARAM.inp.nelec_delta); + if (PARAM.inp.nelec_delta != 0) + { + ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, + "nelec_delta is NOT zero, please make sure you know what you are " + "doing! nelec_delta: ", + PARAM.inp.nelec_delta); nelec += PARAM.inp.nelec_delta; ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, "nelec now: ", nelec); } return; } +void cal_nbands(const int& nelec, const int& nlocal, const std::vector& nelec_spin, int& nbands) +{ + if (PARAM.inp.esolver_type == "sdft") // qianrui 2021-2-20 + { + return; + } + //======================================= + // calculate number of bands (setup.f90) + //======================================= + double occupied_bands = static_cast(nelec / ModuleBase::DEGSPIN); + if (PARAM.inp.lspinorb == 1) { + occupied_bands = static_cast(nelec); + } + + if ((occupied_bands - std::floor(occupied_bands)) > 0.0) + { + occupied_bands = std::floor(occupied_bands) + 1.0; // mohan fix 2012-04-16 + } + + ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, "occupied bands", occupied_bands); + + if (nbands == 0) + { + if (PARAM.inp.nspin == 1) + { + const int nbands1 = static_cast(occupied_bands) + 10; + const int nbands2 = static_cast(1.2 * occupied_bands) + 1; + nbands = std::max(nbands1, nbands2); + if (PARAM.inp.basis_type != "pw") { + nbands = std::min(nbands, nlocal); + } + } + else if (PARAM.inp.nspin == 4) + { + const int nbands3 = nelec + 20; + const int nbands4 = static_cast(1.2 * nelec) + 1; + nbands = std::max(nbands3, nbands4); + if (PARAM.inp.basis_type != "pw") { + nbands = std::min(nbands, nlocal); + } + } + else if (PARAM.inp.nspin == 2) + { + const double max_occ = std::max(nelec_spin[0], nelec_spin[1]); + const int nbands3 = static_cast(max_occ) + 11; + const int nbands4 = static_cast(1.2 * max_occ) + 1; + nbands = std::max(nbands3, nbands4); + if (PARAM.inp.basis_type != "pw") { + nbands = std::min(nbands, nlocal); + } + } + ModuleBase::GlobalFunc::AUTO_SET("NBANDS", nbands); + } + // else if ( PARAM.inp.calculation=="scf" || PARAM.inp.calculation=="md" || PARAM.inp.calculation=="relax") //pengfei + // 2014-10-13 + else + { + if (nbands < occupied_bands) { + ModuleBase::WARNING_QUIT("unitcell", "Too few bands!"); + } + if (PARAM.inp.nspin == 2) + { + if (nbands < nelec_spin[0]) + { + ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, "nelec_up", nelec_spin[0]); + ModuleBase::WARNING_QUIT("ElecState::cal_nbands", "Too few spin up bands!"); + } + if (nbands < nelec_spin[1]) + { + ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, "nelec_down", nelec_spin[1]); + ModuleBase::WARNING_QUIT("ElecState::cal_nbands", "Too few spin down bands!"); + } + } + } + + // mohan add 2010-09-04 + // std::cout << "nbands(this-> = " < num. of occupied bands"); + } + } + + // mohan update 2021-02-19 + // mohan add 2011-01-5 + if (PARAM.inp.basis_type == "lcao" || PARAM.inp.basis_type == "lcao_in_pw") + { + if (nbands > nlocal) + { + ModuleBase::WARNING_QUIT("ElecState::cal_nbandsc", "NLOCAL < NBANDS"); + } + else + { + ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, "NLOCAL", nlocal); + ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, "NBANDS", nbands); + } + } + + ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, "NBANDS", nbands); +} + void UnitCell::compare_atom_labels(std::string label1, std::string label2) { if (label1 != label2) //'!( "Ag" == "Ag" || "47" == "47" || "Silver" == Silver" )' diff --git a/source/module_cell/unitcell.h b/source/module_cell/unitcell.h index 9a8e54ce3d..4657d97b62 100644 --- a/source/module_cell/unitcell.h +++ b/source/module_cell/unitcell.h @@ -307,10 +307,6 @@ class UnitCell { const bool& init_vel_in, const std::string& fixed_axes_in); - /// @brief calculate the total number of electrons in system - /// (GlobalV::nelec) - void cal_nelec(double& nelec); - /// @brief check consistency between two atom labels from STRU and pseudo or /// orb file void compare_atom_labels(std::string label1, std::string label2); @@ -334,4 +330,23 @@ class UnitCell { std::vector> get_lnchiCounts() const; }; +/** + * @brief calculate the total number of electrons in system + * + * @param atoms [in] atom pointer + * @param ntype [in] number of atom types + * @param nelec [out] total number of electrons + */ +void cal_nelec(const Atom* atoms, const int& ntype, double& nelec); + +/** + * @brief Calculate the number of bands. + * + * @param nelec [in] total number of electrons + * @param nlocal [in] total number of local basis + * @param nelec_spin [in] number of electrons for each spin + * @param nbands [out] number of bands + */ +void cal_nbands(const int& nelec, const int& nlocal, const std::vector& nelec_spin, int& nbands); + #endif // unitcell class diff --git a/source/module_elecstate/elecstate.cpp b/source/module_elecstate/elecstate.cpp index 720d920ed5..6d77e815a6 100644 --- a/source/module_elecstate/elecstate.cpp +++ b/source/module_elecstate/elecstate.cpp @@ -248,116 +248,11 @@ void ElecState::init_ks(Charge* chg_in, // pointer for class Charge this->bigpw = bigpw_in; // init nelec_spin with nelec and nupdown this->init_nelec_spin(); - // autoset and check GlobalV::NBANDS, nelec_spin is used when NSPIN==2 - this->cal_nbands(); // initialize ekb and wg this->ekb.create(nk_in, GlobalV::NBANDS); this->wg.create(nk_in, GlobalV::NBANDS); } -void ElecState::cal_nbands() -{ - if (PARAM.inp.esolver_type == "sdft") // qianrui 2021-2-20 - { - return; - } - //======================================= - // calculate number of bands (setup.f90) - //======================================= - double occupied_bands = static_cast(GlobalV::nelec / ModuleBase::DEGSPIN); - if (PARAM.inp.lspinorb == 1) { - occupied_bands = static_cast(GlobalV::nelec); - } - - if ((occupied_bands - std::floor(occupied_bands)) > 0.0) - { - occupied_bands = std::floor(occupied_bands) + 1.0; // mohan fix 2012-04-16 - } - - ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, "occupied bands", occupied_bands); - - if (GlobalV::NBANDS == 0) - { - if (PARAM.inp.nspin == 1) - { - const int nbands1 = static_cast(occupied_bands) + 10; - const int nbands2 = static_cast(1.2 * occupied_bands) + 1; - GlobalV::NBANDS = std::max(nbands1, nbands2); - if (PARAM.inp.basis_type != "pw") { - GlobalV::NBANDS = std::min(GlobalV::NBANDS, GlobalV::NLOCAL); - } - } - else if (PARAM.inp.nspin == 4) - { - const int nbands3 = GlobalV::nelec + 20; - const int nbands4 = static_cast(1.2 * GlobalV::nelec) + 1; - GlobalV::NBANDS = std::max(nbands3, nbands4); - if (PARAM.inp.basis_type != "pw") { - GlobalV::NBANDS = std::min(GlobalV::NBANDS, GlobalV::NLOCAL); - } - } - else if (PARAM.inp.nspin == 2) - { - const double max_occ = std::max(this->nelec_spin[0], this->nelec_spin[1]); - const int nbands3 = static_cast(max_occ) + 11; - const int nbands4 = static_cast(1.2 * max_occ) + 1; - GlobalV::NBANDS = std::max(nbands3, nbands4); - if (PARAM.inp.basis_type != "pw") { - GlobalV::NBANDS = std::min(GlobalV::NBANDS, GlobalV::NLOCAL); - } - } - ModuleBase::GlobalFunc::AUTO_SET("NBANDS", GlobalV::NBANDS); - } - // else if ( PARAM.inp.calculation=="scf" || PARAM.inp.calculation=="md" || PARAM.inp.calculation=="relax") //pengfei - // 2014-10-13 - else - { - if (GlobalV::NBANDS < occupied_bands) { - ModuleBase::WARNING_QUIT("unitcell", "Too few bands!"); - } - if (PARAM.inp.nspin == 2) - { - if (GlobalV::NBANDS < this->nelec_spin[0]) - { - ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, "nelec_up", this->nelec_spin[0]); - ModuleBase::WARNING_QUIT("ElecState::cal_nbands", "Too few spin up bands!"); - } - if (GlobalV::NBANDS < this->nelec_spin[1]) - { - ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, "nelec_down", this->nelec_spin[1]); - ModuleBase::WARNING_QUIT("ElecState::cal_nbands", "Too few spin down bands!"); - } - } - } - - // mohan add 2010-09-04 - // std::cout << "nbands(this-> = " < num. of occupied bands"); - } - } - - // mohan update 2021-02-19 - // mohan add 2011-01-5 - if (PARAM.inp.basis_type == "lcao" || PARAM.inp.basis_type == "lcao_in_pw") - { - if (GlobalV::NBANDS > GlobalV::NLOCAL) - { - ModuleBase::WARNING_QUIT("ElecState::cal_nbandsc", "NLOCAL < NBANDS"); - } - else - { - ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, "NLOCAL", GlobalV::NLOCAL); - ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, "NBANDS", GlobalV::NBANDS); - } - } - - ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, "NBANDS", GlobalV::NBANDS); -} - void set_is_occupied(std::vector& is_occupied, elecstate::ElecState* pes, const int i_scf, diff --git a/source/module_elecstate/elecstate.h b/source/module_elecstate/elecstate.h index 023b72b03a..a9fc79ff7c 100644 --- a/source/module_elecstate/elecstate.h +++ b/source/module_elecstate/elecstate.h @@ -83,9 +83,6 @@ class ElecState //for NSPIN=4, it will record total number, magnetization for x, y, z direction std::vector nelec_spin; - //calculate nbands and - void cal_nbands(); - virtual void print_psi(const psi::Psi& psi_in, const int istep = -1) { return; diff --git a/source/module_elecstate/module_dm/test/prepare_unitcell.h b/source/module_elecstate/module_dm/test/prepare_unitcell.h index 6b295fc0f2..01f37ddd9c 100644 --- a/source/module_elecstate/module_dm/test/prepare_unitcell.h +++ b/source/module_elecstate/module_dm/test/prepare_unitcell.h @@ -308,8 +308,8 @@ UcellTestPrepare::UcellTestPrepare(std::string latname_in, coor_type(coor_type_in), coordinates(coordinates_in) { - mbl = {0}; - velocity = {0}; + mbl = std::valarray(0.0, coordinates_in.size()); + velocity = std::valarray(0.0, coordinates_in.size()); } UcellTestPrepare::UcellTestPrepare(std::string latname_in, diff --git a/source/module_elecstate/test/elecstate_base_test.cpp b/source/module_elecstate/test/elecstate_base_test.cpp index de015315ba..54a5839f4c 100644 --- a/source/module_elecstate/test/elecstate_base_test.cpp +++ b/source/module_elecstate/test/elecstate_base_test.cpp @@ -98,8 +98,6 @@ void Charge::check_rho() * - determine the number of electrons for spin up and down * - Constructor: elecstate::ElecState(charge, rhopw, bigpw) * - constructor ElecState using existing charge, rhopw, bigpw - * - CalNbands: elecstate::ElecState::cal_nbands() - * - calculate the number of bands * - InitKS: elecstate::ElecState::init_ks() * - initialize the elecstate for KS-DFT * - GetRho: elecstate::ElecState::getRho() @@ -183,145 +181,6 @@ TEST_F(ElecStateTest, Constructor) delete charge; } -TEST_F(ElecStateTest, CalNbands) -{ - elecstate->cal_nbands(); - EXPECT_EQ(GlobalV::NBANDS, 6); -} - -TEST_F(ElecStateTest, CalNbandsFractionElec) -{ - GlobalV::nelec = 9.5; - elecstate->cal_nbands(); - EXPECT_EQ(GlobalV::NBANDS, 6); -} - -TEST_F(ElecStateTest, CalNbandsSOC) -{ - PARAM.input.lspinorb = true; - GlobalV::NBANDS = 0; - elecstate->cal_nbands(); - EXPECT_EQ(GlobalV::NBANDS, 20); -} - -TEST_F(ElecStateTest, CalNbandsSDFT) -{ - PARAM.input.esolver_type = "sdft"; - EXPECT_NO_THROW(elecstate->cal_nbands()); -} - -TEST_F(ElecStateTest, CalNbandsLCAO) -{ - PARAM.input.basis_type = "lcao"; - EXPECT_NO_THROW(elecstate->cal_nbands()); -} - -TEST_F(ElecStateDeathTest, CalNbandsLCAOINPW) -{ - PARAM.input.basis_type = "lcao_in_pw"; - GlobalV::NLOCAL = GlobalV::NBANDS - 1; - testing::internal::CaptureStdout(); - EXPECT_EXIT(elecstate->cal_nbands(), ::testing::ExitedWithCode(0), ""); - output = testing::internal::GetCapturedStdout(); - EXPECT_THAT(output, testing::HasSubstr("NLOCAL < NBANDS")); -} - -TEST_F(ElecStateDeathTest, CalNbandsWarning1) -{ - GlobalV::NBANDS = GlobalV::nelec / 2 - 1; - testing::internal::CaptureStdout(); - EXPECT_EXIT(elecstate->cal_nbands(), ::testing::ExitedWithCode(0), ""); - output = testing::internal::GetCapturedStdout(); - EXPECT_THAT(output, testing::HasSubstr("Too few bands!")); -} - -TEST_F(ElecStateDeathTest, CalNbandsWarning2) -{ - PARAM.input.nspin = 2; - GlobalV::nupdown = 4.0; - elecstate->init_nelec_spin(); - testing::internal::CaptureStdout(); - EXPECT_EXIT(elecstate->cal_nbands(), ::testing::ExitedWithCode(0), ""); - output = testing::internal::GetCapturedStdout(); - EXPECT_THAT(output, testing::HasSubstr("Too few spin up bands!")); -} - -TEST_F(ElecStateDeathTest, CalNbandsWarning3) -{ - PARAM.input.nspin = 2; - GlobalV::nupdown = -4.0; - elecstate->init_nelec_spin(); - testing::internal::CaptureStdout(); - EXPECT_EXIT(elecstate->cal_nbands(), ::testing::ExitedWithCode(0), ""); - output = testing::internal::GetCapturedStdout(); - EXPECT_THAT(output, testing::HasSubstr("Too few spin down bands!")); -} - -TEST_F(ElecStateTest, CalNbandsSpin1) -{ - PARAM.input.nspin = 1; - GlobalV::NBANDS = 0; - elecstate->cal_nbands(); - EXPECT_EQ(GlobalV::NBANDS, 15); -} - -TEST_F(ElecStateTest, CalNbandsSpin1LCAO) -{ - PARAM.input.nspin = 1; - GlobalV::NBANDS = 0; - PARAM.input.basis_type = "lcao"; - elecstate->cal_nbands(); - EXPECT_EQ(GlobalV::NBANDS, 6); -} - -TEST_F(ElecStateTest, CalNbandsSpin4) -{ - PARAM.input.nspin = 4; - GlobalV::NBANDS = 0; - elecstate->cal_nbands(); - EXPECT_EQ(GlobalV::NBANDS, 30); -} - -TEST_F(ElecStateTest, CalNbandsSpin4LCAO) -{ - PARAM.input.nspin = 4; - GlobalV::NBANDS = 0; - PARAM.input.basis_type = "lcao"; - elecstate->cal_nbands(); - EXPECT_EQ(GlobalV::NBANDS, 6); -} - -TEST_F(ElecStateTest, CalNbandsSpin2) -{ - PARAM.input.nspin = 2; - GlobalV::NBANDS = 0; - elecstate->init_nelec_spin(); - elecstate->cal_nbands(); - EXPECT_EQ(GlobalV::NBANDS, 16); -} - -TEST_F(ElecStateTest, CalNbandsSpin2LCAO) -{ - PARAM.input.nspin = 2; - GlobalV::NBANDS = 0; - PARAM.input.basis_type = "lcao"; - elecstate->init_nelec_spin(); - elecstate->cal_nbands(); - EXPECT_EQ(GlobalV::NBANDS, 6); -} - -TEST_F(ElecStateDeathTest, CalNbandsGaussWarning) -{ - Occupy::use_gaussian_broadening = true; - EXPECT_TRUE(Occupy::gauss()); - GlobalV::NBANDS = 5; - testing::internal::CaptureStdout(); - EXPECT_EXIT(elecstate->cal_nbands(), ::testing::ExitedWithCode(0), ""); - output = testing::internal::GetCapturedStdout(); - EXPECT_THAT(output, testing::HasSubstr("for smearing, num. of bands > num. of occupied bands")); - Occupy::use_gaussian_broadening = false; -} - TEST_F(ElecStateTest, InitKS) { Charge* charge = new Charge; diff --git a/source/module_esolver/esolver_ks.cpp b/source/module_esolver/esolver_ks.cpp index 0bc5c6ad0d..240d34dc4e 100644 --- a/source/module_esolver/esolver_ks.cpp +++ b/source/module_esolver/esolver_ks.cpp @@ -11,6 +11,7 @@ #include "module_io/print_info.h" #include "module_io/write_istate_info.h" #include "module_parameter/parameter.h" +#include "module_cell/cal_atoms_info.h" #include //--------------Temporary---------------- @@ -187,13 +188,12 @@ void ESolver_KS::before_all_runners(const Input_para& inp, UnitCell& } delete[] atom_coord; delete[] atom_type; + CalAtomsInfo ca; + ca.cal_atoms_info(ucell.atoms, ucell.ntype, PARAM); } #endif /// End PAW - //! 3) calculate the electron number - ucell.cal_nelec(GlobalV::nelec); - //! 4) it has been established that // xc_func is same for all elements, therefore // only the first one if used diff --git a/source/module_esolver/esolver_of.cpp b/source/module_esolver/esolver_of.cpp index edd50a6962..49edcb90b2 100644 --- a/source/module_esolver/esolver_of.cpp +++ b/source/module_esolver/esolver_of.cpp @@ -70,7 +70,6 @@ void ESolver_OF::before_all_runners(const Input_para& inp, UnitCell& ucell) this->max_iter_ = inp.scf_nmax; this->dV_ = ucell.omega / this->pw_rho->nxyz; - ucell.cal_nelec(GlobalV::nelec); ModuleBase::GlobalFunc::DONE(GlobalV::ofs_running, "SETUP UNITCELL"); XC_Functional::set_xc_type(ucell.atoms[0].ncpp.xc_func); diff --git a/source/module_hamilt_lcao/module_deltaspin/test/prepare_unitcell.h b/source/module_hamilt_lcao/module_deltaspin/test/prepare_unitcell.h index aa8221a8fe..ec95f74087 100644 --- a/source/module_hamilt_lcao/module_deltaspin/test/prepare_unitcell.h +++ b/source/module_hamilt_lcao/module_deltaspin/test/prepare_unitcell.h @@ -232,8 +232,8 @@ UcellTestPrepare::UcellTestPrepare(std::string latname_in, coor_type(coor_type_in), coordinates(coordinates_in) { - mbl = {0}; - velocity = {0}; + mbl = std::valarray(0.0, coordinates_in.size()); + velocity = std::valarray(0.0, coordinates_in.size()); } UcellTestPrepare::UcellTestPrepare(const UcellTestPrepare &utp): diff --git a/source/module_hamilt_lcao/module_hcontainer/test/prepare_unitcell.h b/source/module_hamilt_lcao/module_hcontainer/test/prepare_unitcell.h index 0fdce8f66b..70d9eebaea 100644 --- a/source/module_hamilt_lcao/module_hcontainer/test/prepare_unitcell.h +++ b/source/module_hamilt_lcao/module_hcontainer/test/prepare_unitcell.h @@ -308,8 +308,8 @@ UcellTestPrepare::UcellTestPrepare(std::string latname_in, coor_type(coor_type_in), coordinates(coordinates_in) { - mbl = {0}; - velocity = {0}; + mbl = std::valarray(0.0, coordinates_in.size()); + velocity = std::valarray(0.0, coordinates_in.size()); } UcellTestPrepare::UcellTestPrepare(std::string latname_in, diff --git a/source/module_hsolver/test/hsolver_supplementary_mock.h b/source/module_hsolver/test/hsolver_supplementary_mock.h index 8cee0698af..95c69e555e 100644 --- a/source/module_hsolver/test/hsolver_supplementary_mock.h +++ b/source/module_hsolver/test/hsolver_supplementary_mock.h @@ -57,11 +57,6 @@ void ElecState::init_ks(Charge* chg_in, // pointer for class Charge return; } -void ElecState::cal_nbands() -{ - return; -} - } // namespace elecstate diff --git a/source/module_io/test/prepare_unitcell.h b/source/module_io/test/prepare_unitcell.h index 5a831eea06..d9deaf0ff3 100644 --- a/source/module_io/test/prepare_unitcell.h +++ b/source/module_io/test/prepare_unitcell.h @@ -252,8 +252,8 @@ UcellTestPrepare::UcellTestPrepare(std::string latname_in, coor_type(coor_type_in), coordinates(coordinates_in) { - mbl = {0}; - velocity = {0}; + mbl = std::valarray(0.0, coordinates_in.size()); + velocity = std::valarray(0.0, coordinates_in.size()); } UcellTestPrepare::UcellTestPrepare(std::string latname_in, diff --git a/source/module_io/test/write_orb_info_test.cpp b/source/module_io/test/write_orb_info_test.cpp index 7da84e439f..b9c6f66483 100644 --- a/source/module_io/test/write_orb_info_test.cpp +++ b/source/module_io/test/write_orb_info_test.cpp @@ -49,6 +49,7 @@ TEST(OrbInfo,WriteOrbInfo) PARAM.input.nspin = 1; PARAM.input.basis_type = "pw"; PARAM.input.dft_functional = "default"; + GlobalV::NLOCAL = 18; ucell->read_cell_pseudopots(pp_dir,ofs); ucell->cal_nwfc(ofs); ModuleIO::write_orb_info(ucell); diff --git a/source/module_io/test_serial/prepare_unitcell.h b/source/module_io/test_serial/prepare_unitcell.h index 5f217aadf2..a92f15842f 100644 --- a/source/module_io/test_serial/prepare_unitcell.h +++ b/source/module_io/test_serial/prepare_unitcell.h @@ -252,8 +252,8 @@ UcellTestPrepare::UcellTestPrepare(std::string latname_in, coor_type(coor_type_in), coordinates(coordinates_in) { - mbl = {0}; - velocity = {0}; + mbl = std::valarray(0.0, coordinates_in.size()); + velocity = std::valarray(0.0, coordinates_in.size()); } UcellTestPrepare::UcellTestPrepare(std::string latname_in, diff --git a/source/module_lr/esolver_lrtd_lcao.cpp b/source/module_lr/esolver_lrtd_lcao.cpp index f5cb2d9b01..f8ae2037e0 100644 --- a/source/module_lr/esolver_lrtd_lcao.cpp +++ b/source/module_lr/esolver_lrtd_lcao.cpp @@ -236,7 +236,6 @@ LR::ESolver_LR::ESolver_LR(const Input_para& inp, UnitCell& ucell) : inpu this->pelec = new elecstate::ElecStateLCAO(); // necessary steps in ESolver_KS::before_all_runners : symmetry and k-points - ucell.cal_nelec(GlobalV::nelec); if (ModuleSymmetry::Symmetry::symm_flag == 1) { GlobalC::ucell.symm.analy_sys(ucell.lat, ucell.st, ucell.atoms, GlobalV::ofs_running); diff --git a/source/module_parameter/parameter.h b/source/module_parameter/parameter.h index 38b544f2ed..87f1efe267 100644 --- a/source/module_parameter/parameter.h +++ b/source/module_parameter/parameter.h @@ -6,6 +6,7 @@ namespace ModuleIO { class ReadInput; } +class CalAtomInfo; class Parameter { public: @@ -32,8 +33,12 @@ class Parameter void set_start_time(const std::time_t& start_time); private: - // Only ReadInput can modify the value of Parameter. - friend class ModuleIO::ReadInput; + // Only ReadInput and CalAtomInfo can modify the value of Parameter. + // Do not add extra friend class here!!! + friend class ModuleIO::ReadInput; // ReadInput read INPUT file and give the value to Parameter + friend class CalAtomsInfo; // CalAtomInfo calculate the atom information from pseudopotential and give the value to + // Parameter + // INPUT parameters Input_para input; // System parameters