Skip to content

Commit

Permalink
Fix: solve the convergence problem of E value in dav_subspace method (
Browse files Browse the repository at this point in the history
#4052)

* solve the coverage problem of E value in dav_subspace method

* add DIAGO_FULL_ACC for abacus input

* update dav_subspace E_coverage code

* delete test result file

* fix the conditions of the `is_occupied` assignment

* add input parameter `diago_full_acc` in input_main.md

* remove diago_full_acc from globalV

* fix build bug

* fix build bug

* fix build bug

* meet @mohanchen's requirements

---------

Co-authored-by: Mohan Chen <[email protected]>
  • Loading branch information
haozhihan and mohanchen committed May 7, 2024
1 parent d0f69c2 commit 4385149
Show file tree
Hide file tree
Showing 13 changed files with 80 additions and 12 deletions.
13 changes: 10 additions & 3 deletions docs/advanced/input_files/input-main.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
- [pw\_diag\_thr](#pw_diag_thr)
- [pw\_diag\_nmax](#pw_diag_nmax)
- [pw\_diag\_ndim](#pw_diag_ndim)
- [diago_full_acc](#diago_full_acc)
- [erf\_ecut](#erf_ecut)
- [fft\_mode](#fft_mode)
- [erf\_height](#erf_height)
Expand Down Expand Up @@ -772,21 +773,27 @@ These variables are used to control the plane wave related parameters.
### pw_diag_thr

- **Type**: Real
- **Description**: Only used when you use `diago_type = cg` or `diago_type = david`. It indicates the threshold for the first electronic iteration, from the second iteration the pw_diag_thr will be updated automatically. **For nscf calculations with planewave basis set, pw_diag_thr should be <= 1e-3.**
- **Description**: Only used when you use `ks_solver = cg/dav/dav_subspace/bpcg`. It indicates the threshold for the first electronic iteration, from the second iteration the pw_diag_thr will be updated automatically. **For nscf calculations with planewave basis set, pw_diag_thr should be <= 1e-3.**
- **Default**: 0.01

### pw_diag_nmax

- **Type**: Integer
- **Description**: Only useful when you use `ks_solver = cg` or `ks_solver = dav`. It indicates the maximal iteration number for cg/david method.
- **Description**: Only useful when you use `ks_solver = cg/dav/dav_subspace/bpcg`. It indicates the maximal iteration number for cg/david/dav_subspace/bpcg method.
- **Default**: 40

### pw_diag_ndim

- **Type**: Integer
- **Description**: Only useful when you use `ks_solver = dav`. It indicates the maximal dimension for the Davidson method.
- **Description**: Only useful when you use `ks_solver = dav` or `ks_solver = dav_subspace`. It indicates dimension of workspace(number of wavefunction packets, at least 2 needed) for the Davidson method. A larger value may yield a smaller number of iterations in the algorithm but uses more memory and more CPU time in subspace diagonalization.
- **Default**: 4

### diago_full_acc

- **Type**: bool
- **Description**: Only useful when you use `ks_solver = dav_subspace`. If `TRUE`, all the empty states are diagonalized at the same level of accuracy of the occupied ones. Otherwise the empty states are diagonalized using a larger threshold (10-5) (this should not affect total energy, forces, and other ground-state properties).
- **Default**: false

### erf_ecut

- **Type**: Real
Expand Down
2 changes: 1 addition & 1 deletion source/module_hsolver/diago_dav_subspace.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ void Diago_DavSubspace<T, Device>::diag_once(hamilt::Hamilt<T, Device>* phm_in,
}
else
{
double empty_ethr = std::max(DiagoIterAssist<T, Device>::PW_DIAG_THR * 5.0, 1e-5);
const double empty_ethr = std::max(DiagoIterAssist<T, Device>::PW_DIAG_THR * 5.0, Diago_DavSubspace::dav_large_thr);
convflag[m] = (std::abs(this->eigenvalue_in_dav[m] - eigenvalue_in_hsolver[m]) < empty_ethr);
}

Expand Down
5 changes: 5 additions & 0 deletions source/module_hsolver/diago_dav_subspace.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ class Diago_DavSubspace : public DiagH<T, Device>

static int PW_DIAG_NDIM;

static double dav_large_thr;

private:
bool is_subspace = false;

Expand Down Expand Up @@ -137,6 +139,9 @@ class Diago_DavSubspace : public DiagH<T, Device>
template <typename Real, typename Device>
int Diago_DavSubspace<Real, Device>::PW_DIAG_NDIM = 4;

template <typename Real, typename Device>
double Diago_DavSubspace<Real, Device>::dav_large_thr = 1e-5;

} // namespace hsolver

#endif
22 changes: 16 additions & 6 deletions source/module_hsolver/hsolver_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,20 +158,30 @@ void HSolverPW<T, Device>::solve(hamilt::Hamilt<T, Device>* pHamilt,

std::vector<Real> eigenvalues(pes->ekb.nr * pes->ekb.nc, 0);

if (this->is_first_scf == true)
if (this->is_first_scf)
{
is_occupied.resize(psi.get_nk() * psi.get_nbands(), true);
}
else
{
for (size_t i = 0; i < psi.get_nk(); i++)
if (this->diago_full_acc)
{
for (size_t j = 0; j < psi.get_nbands(); j++)
is_occupied.assign(is_occupied.size(), true);
}
else
{
for (int i = 0; i < psi.get_nk(); i++)
{
if (pes->wg(i, j) < 1.0)
if (pes->klist->wk[i] > 0.0)
{
is_occupied[i * psi.get_nbands() + j] = false;
}
for (int j = 0; j < psi.get_nbands(); j++)
{
if (pes->wg(i, j) / pes->klist->wk[i] < 0.01)
{
is_occupied[i * psi.get_nbands() + j] = false;
}
}
}
}
}
}
Expand Down
15 changes: 14 additions & 1 deletion source/module_hsolver/hsolver_pw.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,22 @@ class HSolverPW: public HSolver<T, Device>
{
private:
bool is_first_scf = true;

// Note GetTypeReal<T>::type will
// return T if T is real type(float, double),
// otherwise return the real type of T(complex<float>, complex<double>)
using Real = typename GetTypeReal<T>::type;

public:
/**
* @brief diago_full_acc
* If .TRUE. all the empty states are diagonalized at the same level of accuracy of the occupied ones.
* Otherwise the empty states are diagonalized using a larger threshold
* (this should not affect total energy, forces, and other ground-state properties).
*
*/
static bool diago_full_acc;

HSolverPW(ModulePW::PW_Basis_K* wfc_basis_in, wavefunc* pwf_in);

/*void init(
Expand Down Expand Up @@ -79,9 +90,11 @@ class HSolverPW: public HSolver<T, Device>
using resmem_var_op = psi::memory::resize_memory_op<Real, psi::DEVICE_CPU>;
using delmem_var_op = psi::memory::delete_memory_op<Real, psi::DEVICE_CPU>;
using castmem_2d_2h_op = psi::memory::cast_memory_op<double, Real, psi::DEVICE_CPU, psi::DEVICE_CPU>;

};

template <typename T, typename Device>
bool HSolverPW<T, Device>::diago_full_acc = false;

} // namespace hsolver

#endif
1 change: 1 addition & 0 deletions source/module_io/DEFAULT_TYPE.conf
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ diago_proc int
pw_diag_nmax int
diago_cg_prec int
pw_diag_ndim int
diago_full_acc bool
pw_diag_thr double
nb2d int
nurse int
Expand Down
1 change: 1 addition & 0 deletions source/module_io/DEFAULT_VALUE.conf
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@
pw_diag_nmax 50
diago_cg_prec 1
pw_diag_ndim 4
diago_full_acc false
pw_diag_thr 1.0e-2
nb2d 0
nurse 0
Expand Down
6 changes: 6 additions & 0 deletions source/module_io/input.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,7 @@ void Input::Default(void)
pw_diag_nmax = 50;
diago_cg_prec = 1; // mohan add 2012-03-31
pw_diag_ndim = 4;
diago_full_acc = false;
pw_diag_thr = 1.0e-2;
nb2d = 0;
nurse = 0;
Expand Down Expand Up @@ -1172,6 +1173,10 @@ bool Input::Read(const std::string& fn)
{
read_value(ifs, pw_diag_ndim);
}
else if (strcmp("diago_full_acc", word) == 0)
{
read_value(ifs, diago_full_acc);
}
else if (strcmp("pw_diag_thr", word) == 0)
{
read_value(ifs, pw_diag_thr);
Expand Down Expand Up @@ -3365,6 +3370,7 @@ void Input::Bcast()
Parallel_Common::bcast_int(pw_diag_nmax);
Parallel_Common::bcast_int(diago_cg_prec);
Parallel_Common::bcast_int(pw_diag_ndim);
Parallel_Common::bcast_bool(diago_full_acc);
Parallel_Common::bcast_double(pw_diag_thr);
Parallel_Common::bcast_int(nb2d);
Parallel_Common::bcast_int(nurse);
Expand Down
1 change: 1 addition & 0 deletions source/module_io/input.h
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ class Input
int pw_diag_nmax;
int diago_cg_prec; // mohan add 2012-03-31
int pw_diag_ndim;
bool diago_full_acc;
double pw_diag_thr; // used in cg method

int nb2d; // matrix 2d division.
Expand Down
12 changes: 12 additions & 0 deletions source/module_io/input_conv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@
#include "module_io/input.h"
#include "module_relax/relax_old/ions_move_basic.h"
#include "module_relax/relax_old/lattice_change_basic.h"

#ifdef __EXX
#include "module_ri/exx_abfs-jle.h"
#endif

#ifdef __LCAO
#include "module_basis/module_ao/ORB_read.h"
#include "module_elecstate/potentials/H_TDDFT_pw.h"
Expand All @@ -30,6 +32,7 @@
#include "module_elecstate/potentials/efield.h"
#include "module_elecstate/potentials/gatefield.h"
#include "module_hsolver/hsolver_lcao.h"
#include "module_hsolver/hsolver_pw.h"
#include "module_md/md_func.h"
#include "module_psi/kernels/device.h"

Expand Down Expand Up @@ -388,6 +391,15 @@ void Input_Conv::Convert(void)
GlobalV::PW_DIAG_NMAX = INPUT.pw_diag_nmax;
GlobalV::DIAGO_CG_PREC = INPUT.diago_cg_prec;
GlobalV::PW_DIAG_NDIM = INPUT.pw_diag_ndim;

hsolver::HSolverPW<std::complex<float>, psi::DEVICE_CPU>::diago_full_acc = INPUT.diago_full_acc;
hsolver::HSolverPW<std::complex<double>, psi::DEVICE_CPU>::diago_full_acc = INPUT.diago_full_acc;

#if ((defined __CUDA) || (defined __ROCM))
hsolver::HSolverPW<std::complex<float>, psi::DEVICE_GPU>::diago_full_acc = INPUT.diago_full_acc;
hsolver::HSolverPW<std::complex<double>, psi::DEVICE_GPU>::diago_full_acc = INPUT.diago_full_acc;
#endif

GlobalV::PW_DIAG_THR = INPUT.pw_diag_thr;
GlobalV::NB2D = INPUT.nb2d;
GlobalV::NURSE = INPUT.nurse;
Expand Down
4 changes: 4 additions & 0 deletions source/module_io/parameter_pool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -718,6 +718,10 @@ void input_parameters_set(std::map<std::string, InputParameter> input_parameters
{
INPUT.pw_diag_ndim = *static_cast<int*>(input_parameters["pw_diag_ndim"].get());
}
else if (input_parameters.count("diago_full_acc") != 0)
{
INPUT.diago_full_acc = *static_cast<int*>(input_parameters["diago_full_acc"].get());
}
else if (input_parameters.count("pw_diag_thr") != 0)
{
INPUT.pw_diag_thr = *static_cast<double*>(input_parameters["pw_diag_thr"].get());
Expand Down
5 changes: 4 additions & 1 deletion source/module_io/test/input_conv_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include "gmock/gmock.h"
#include "module_io/input_conv.h"
#include "module_base/global_variable.h"
#include "module_hsolver/hsolver_pw.h"
#include "for_testing_input_conv.h"

/************************************************
Expand Down Expand Up @@ -88,7 +89,9 @@ TEST_F(InputConvTest, Conv)
EXPECT_EQ(GlobalV::DIAGO_PROC,4);
EXPECT_EQ(GlobalV::PW_DIAG_NMAX,50);
EXPECT_EQ(GlobalV::DIAGO_CG_PREC,1);
EXPECT_EQ(GlobalV::PW_DIAG_NDIM,4);
EXPECT_EQ(GlobalV::PW_DIAG_NDIM, 4);
EXPECT_EQ(hsolver::HSolverPW<std::complex<float>>::diago_full_acc, false);
EXPECT_EQ(hsolver::HSolverPW<std::complex<double>>::diago_full_acc, false);
EXPECT_DOUBLE_EQ(GlobalV::PW_DIAG_THR,0.01);
EXPECT_EQ(GlobalV::NB2D,0);
EXPECT_EQ(GlobalV::NURSE,0);
Expand Down
5 changes: 5 additions & 0 deletions source/module_io/write_input.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,11 @@ void Input::Print(const std::string &fn) const
{
ModuleBase::GlobalFunc::OUTP(ofs, "pw_diag_ndim", pw_diag_ndim, "max dimension for davidson");
}
else if (ks_solver == "dav_subspace")
{
ModuleBase::GlobalFunc::OUTP(ofs, "pw_diag_ndim", pw_diag_ndim, "dimension of workspace (number of wavefunction packets, at least 2 needed)");
ModuleBase::GlobalFunc::OUTP(ofs, "diago_full_acc", pw_diag_ndim, "if all the empty states are diagonalized at the same level of accuracy of the occupied ones.");
}
ModuleBase::GlobalFunc::OUTP(ofs,
"pw_diag_thr",
pw_diag_thr,
Expand Down

0 comments on commit 4385149

Please sign in to comment.