Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor: remove all globalV in diago_david #4211

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions source/module_hsolver/diagh.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,14 @@
#include <string>

#include "module_base/macros.h"

#include "module_hamilt_general/hamilt.h"
#include "module_psi/psi.h"

#ifdef __MPI
#include "mpi.h"
#endif

template<typename T> struct consts
{
consts();
Expand All @@ -18,6 +23,22 @@ template<typename T> struct consts
namespace hsolver
{


struct diag_comm_info
{

const int rank;
const int nproc;

haozhihan marked this conversation as resolved.
Show resolved Hide resolved
#ifndef __MPI
diag_comm_info(const int rank_in, const int nproc_in) : rank(rank_in), nproc(nproc_in) {}
#else
const MPI_Comm comm;
diag_comm_info(const MPI_Comm &comm_in, const int rank_in, const int nproc_in) : comm(comm_in), rank(rank_in), nproc(nproc_in) {}
#endif
};


template <typename T, typename Device = base_device::DEVICE_CPU>
class DiagH
{
Expand Down
12 changes: 4 additions & 8 deletions source/module_hsolver/diago_dav_subspace.cpp
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
#include "diago_dav_subspace.h"

#include <algorithm>
#include <type_traits>

#include "diago_iter_assist.h"
#include "module_base/blas_connector.h"
#include "module_base/constants.h"
#include "module_base/lapack_connector.h"

#include "module_base/memory.h"
#include "module_base/parallel_common.h"
#include "module_base/parallel_reduce.h"
#include "module_base/timer.h"
#include "module_base/parallel_global.h"
#include "module_base/module_device/device.h"

#include "module_hsolver/kernels/dngvd_op.h"
#include "module_hsolver/kernels/math_kernel_op.h"

Expand Down
4 changes: 0 additions & 4 deletions source/module_hsolver/diago_dav_subspace.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,6 @@
#define DIAGO_NEW_DAV_H

#include "diagh.h"
#include "module_base/complexmatrix.h"
#include "module_base/macros.h"
#include "module_hamilt_pw/hamilt_pwdft/structure_factor.h"
#include "module_base/module_device/device.h"

namespace hsolver
{
Expand Down
92 changes: 50 additions & 42 deletions source/module_hsolver/diago_david.cpp
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
#include "diago_david.h"

#include "diago_iter_assist.h"
#include "module_base/blas_connector.h"
#include "module_base/constants.h"
#include "module_base/lapack_connector.h"

#include "module_base/memory.h"
#include "module_base/parallel_common.h"
#include "module_base/parallel_reduce.h"
#include "module_base/timer.h"
#include "module_base/module_device/device.h"

#include "module_hsolver/kernels/dngvd_op.h"
#include "module_hsolver/kernels/math_kernel_op.h"

Expand All @@ -17,22 +15,30 @@

using namespace hsolver;

template <typename T, typename Device> DiagoDavid<T, Device>::DiagoDavid(const Real* precondition_in)

template <typename T, typename Device>
DiagoDavid<T, Device>::DiagoDavid(const Real* precondition_in,
const int david_ndim_in,
const bool use_paw_in,
const diag_comm_info& diag_comm_in)
: david_ndim(david_ndim_in), use_paw(use_paw_in), diag_comm(diag_comm_in)
{
this->device = base_device::get_device_type<Device>(this->ctx);
this->precondition = precondition_in;

test_david = 2;
this->one = &this->cs.one;
this->zero = &this->cs.zero;
this->neg_one = &this->cs.neg_one;

test_david = 2;
// 1: check which function is called and which step is executed
// 2: check the eigenvalues of the result of each iteration
// 3: check the eigenvalues and errors of the last result
// default: no check
}

template <typename T, typename Device> DiagoDavid<T, Device>::~DiagoDavid()
template <typename T, typename Device>
DiagoDavid<T, Device>::~DiagoDavid()
{
delmem_complex_op()(this->ctx, this->hphi);
delmem_complex_op()(this->ctx, this->sphi);
Expand All @@ -41,10 +47,6 @@ template <typename T, typename Device> DiagoDavid<T, Device>::~DiagoDavid()
delmem_complex_op()(this->ctx, this->vcc);
delmem_complex_op()(this->ctx, this->lagrange_matrix);
base_device::memory::delete_memory_op<Real, base_device::DEVICE_CPU>()(this->cpu_ctx, this->eigenvalue);
if (this->device == base_device::GpuDevice)
{
delmem_var_op()(this->ctx, this->d_precondition);
}
}

template <typename T, typename Device>
Expand All @@ -53,11 +55,14 @@ void DiagoDavid<T, Device>::diag_mock(hamilt::Hamilt<T, Device>* phm_in,
Real* eigenvalue_in)
{
if (test_david == 1)
{
ModuleBase::TITLE("DiagoDavid", "diag_mock");
}
ModuleBase::timer::tick("DiagoDavid", "diag_mock");

assert(DiagoDavid::PW_DIAG_NDIM > 1);
assert(DiagoDavid::PW_DIAG_NDIM * psi.get_nbands() < psi.get_current_nbas() * GlobalV::NPROC_IN_POOL);
assert(this->david_ndim > 1);
assert(this->david_ndim * psi.get_nbands() < psi.get_current_nbas() * diag_comm.nproc);
haozhihan marked this conversation as resolved.
Show resolved Hide resolved

// qianrui change it 2021-7-25.
// In strictly speaking, it shoule be PW_DIAG_NDIM*nband < npw sum of all pools. We roughly estimate it here.
// However, in most cases, total number of plane waves should be much larger than nband*PW_DIAG_NDIM
Expand All @@ -71,25 +76,21 @@ void DiagoDavid<T, Device>::diag_mock(hamilt::Hamilt<T, Device>* phm_in,
/// - "band" means the superscript I : the number of excited states to be solved
/// - k : k-points, the same meaning as the ground state
/// - "basis" : number of occupied ks-orbitals(subscripts i,j) * number of unoccupied ks-orbitals(subscripts a,b), corresponding to "bands" of the ground state

this->dim = psi.get_k_first() ? psi.get_current_nbas() : psi.get_nk() * psi.get_nbasis();
this->dmx = psi.get_k_first() ? psi.get_nbasis() : psi.get_nk() * psi.get_nbasis();
this->n_band = psi.get_nbands();
this->nbase_x = DiagoDavid::PW_DIAG_NDIM * this->n_band; // maximum dimension of the reduced basis set
this->nbase_x = this->david_ndim * this->n_band; // maximum dimension of the reduced basis set

// the lowest N eigenvalues
base_device::memory::resize_memory_op<Real, base_device::DEVICE_CPU>()(this->cpu_ctx,
this->eigenvalue,
this->nbase_x,
"DAV::eig");
base_device::memory::set_memory_op<Real, base_device::DEVICE_CPU>()(this->cpu_ctx,
this->eigenvalue,
0,
this->nbase_x);
base_device::memory::resize_memory_op<Real, base_device::DEVICE_CPU>()(
this->cpu_ctx, this->eigenvalue, this->nbase_x, "DAV::eig");
base_device::memory::set_memory_op<Real, base_device::DEVICE_CPU>()(
this->cpu_ctx, this->eigenvalue, 0, this->nbase_x);

psi::Psi<T, Device> basis(1,
this->nbase_x,
this->dim,
&(psi.get_ngk(0))); // the reduced basis set
this->nbase_x,
this->dim,
&(psi.get_ngk(0))); // the reduced basis set
ModuleBase::Memory::record("DAV::basis", this->nbase_x * this->dim * sizeof(T));

//<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
Expand Down Expand Up @@ -141,7 +142,7 @@ void DiagoDavid<T, Device>::diag_mock(hamilt::Hamilt<T, Device>* phm_in,

for (int m = 0; m < this->n_band; m++)
{
if(GlobalV::use_paw)
if(this->use_paw)
{
#ifdef USE_PAW
#ifdef __DEBUG
Expand Down Expand Up @@ -174,7 +175,7 @@ void DiagoDavid<T, Device>::diag_mock(hamilt::Hamilt<T, Device>* phm_in,
&this->lagrange_matrix[m * this->n_band],
pre_matrix_mm_m[m],
pre_matrix_mv_m[m]);
if(GlobalV::use_paw)
if(this->use_paw)
{
#ifdef USE_PAW
GlobalC::paw_cell.paw_nl_psi(1,reinterpret_cast<const std::complex<double>*> (&basis(m, 0)),
Expand Down Expand Up @@ -248,7 +249,7 @@ void DiagoDavid<T, Device>::diag_mock(hamilt::Hamilt<T, Device>* phm_in,
// updata eigenvectors of Hamiltonian

// ModuleBase::GlobalFunc::ZEROS(psi.get_pointer(), n_band * this->dmx);
setmem_complex_op()(this->ctx, psi.get_pointer(), 0, n_band * this->dmx);
setmem_complex_op()(this->ctx, psi.get_pointer(), 0, n_band * psi.get_nbasis());
//<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
// haozhihan repalce 2022-10-18
gemm_op<T, Device>()(this->ctx,
Expand All @@ -264,7 +265,7 @@ void DiagoDavid<T, Device>::diag_mock(hamilt::Hamilt<T, Device>* phm_in,
this->nbase_x,
this->zero,
psi.get_pointer(), // C dim * n_band
this->dmx
psi.get_nbasis()
);

if (!this->notconv || (dav_iter == DiagoIterAssist<T, Device>::PW_DIAG_NMAX))
Expand Down Expand Up @@ -477,7 +478,7 @@ void DiagoDavid<T, Device>::cal_grad(hamilt::Hamilt<T, Device>* phm_in,
this->planSchmitOrth(notconv, pre_matrix_mm_m.data(), pre_matrix_mv_m.data());
for (int m = 0; m < notconv; m++)
{
if(GlobalV::use_paw)
if(this->use_paw)
{
#ifdef USE_PAW
GlobalC::paw_cell.paw_nl_psi(1,reinterpret_cast<const std::complex<double>*> (&basis(nbase + m, 0)),
Expand Down Expand Up @@ -520,7 +521,7 @@ void DiagoDavid<T, Device>::cal_grad(hamilt::Hamilt<T, Device>* phm_in,
&lagrange[m * (nbase + notconv)],
pre_matrix_mm_m[m],
pre_matrix_mv_m[m]);
if(GlobalV::use_paw)
if(this->use_paw)
{
#ifdef USE_PAW
GlobalC::paw_cell.paw_nl_psi(1,reinterpret_cast<const std::complex<double>*> (&basis(nbase + m, 0)),
Expand Down Expand Up @@ -594,7 +595,7 @@ void DiagoDavid<T, Device>::cal_elem(const int& dim,


#ifdef __MPI
if (GlobalV::NPROC_IN_POOL > 1)
if (diag_comm.nproc > 1)
{
matrixTranspose_op<T, Device>()(this->ctx, this->nbase_x, this->nbase_x, hcc, hcc);
matrixTranspose_op<T, Device>()(this->ctx, this->nbase_x, this->nbase_x, scc, scc);
Expand All @@ -608,17 +609,17 @@ void DiagoDavid<T, Device>::cal_elem(const int& dim,
else
{
if (base_device::get_current_precision(swap) == "single") {
MPI_Reduce(swap, hcc + nbase * this->nbase_x, notconv * this->nbase_x, MPI_COMPLEX, MPI_SUM, 0, POOL_WORLD);
MPI_Reduce(swap, hcc + nbase * this->nbase_x, notconv * this->nbase_x, MPI_COMPLEX, MPI_SUM, 0, diag_comm.comm);
}
else {
MPI_Reduce(swap, hcc + nbase * this->nbase_x, notconv * this->nbase_x, MPI_DOUBLE_COMPLEX, MPI_SUM, 0, POOL_WORLD);
MPI_Reduce(swap, hcc + nbase * this->nbase_x, notconv * this->nbase_x, MPI_DOUBLE_COMPLEX, MPI_SUM, 0, diag_comm.comm);
}
syncmem_complex_op()(this->ctx, this->ctx, swap, scc + nbase * this->nbase_x, notconv * this->nbase_x);
if (base_device::get_current_precision(swap) == "single") {
MPI_Reduce(swap, scc + nbase * this->nbase_x, notconv * this->nbase_x, MPI_COMPLEX, MPI_SUM, 0, POOL_WORLD);
MPI_Reduce(swap, scc + nbase * this->nbase_x, notconv * this->nbase_x, MPI_COMPLEX, MPI_SUM, 0, diag_comm.comm);
}
else {
MPI_Reduce(swap, scc + nbase * this->nbase_x, notconv * this->nbase_x, MPI_DOUBLE_COMPLEX, MPI_SUM, 0, POOL_WORLD);
MPI_Reduce(swap, scc + nbase * this->nbase_x, notconv * this->nbase_x, MPI_DOUBLE_COMPLEX, MPI_SUM, 0, diag_comm.comm);
}
}
delete[] swap;
Expand Down Expand Up @@ -658,7 +659,7 @@ void DiagoDavid<T, Device>::diag_zhegvx(const int& nbase,
{
// ModuleBase::TITLE("DiagoDavid","diag_zhegvx");
ModuleBase::timer::tick("DiagoDavid", "diag_zhegvx");
if (GlobalV::RANK_IN_POOL == 0)
if (diag_comm.rank == 0)
{
assert(nbase_x >= std::max(1, nbase));

Expand All @@ -682,14 +683,14 @@ void DiagoDavid<T, Device>::diag_zhegvx(const int& nbase,
}

#ifdef __MPI
if (GlobalV::NPROC_IN_POOL > 1)
if (diag_comm.nproc > 1)
{
// vcc: nbase * nband
for (int i = 0; i < nband; i++)
{
MPI_Bcast(&vcc[i * this->nbase_x], nbase, MPI_DOUBLE_COMPLEX, 0, POOL_WORLD);
MPI_Bcast(&vcc[i * this->nbase_x], nbase, MPI_DOUBLE_COMPLEX, 0, diag_comm.comm);
}
MPI_Bcast(this->eigenvalue, nband, MPI_DOUBLE, 0, POOL_WORLD);
MPI_Bcast(this->eigenvalue, nband, MPI_DOUBLE, 0, diag_comm.comm);
}
#endif

Expand Down Expand Up @@ -1055,6 +1056,13 @@ void DiagoDavid<T, Device>::diag(hamilt::Hamilt<T, Device>* phm_in,
std::cout << "\n notconv = " << this->notconv;
std::cout << "\n DiagoDavid::diag', too many bands are not converged! \n";
}

#if defined(__CUDA) || defined(__ROCM)
if (this->device == base_device::GpuDevice)
{
delmem_var_op()(this->ctx, this->d_precondition);
}
#endif
return;
}

Expand Down
43 changes: 18 additions & 25 deletions source/module_hsolver/diago_david.h
Original file line number Diff line number Diff line change
@@ -1,20 +1,8 @@
//==========================================================
// AUTHOR : wangjp
// Data :2009-04
// Last Update:
//
// 09-05-10 modify SchmitOrth() diag_zhegvx() as static
// member function
//==========================================================

#ifndef DIAGODAVID_H
#define DIAGODAVID_H

#include "diagh.h"
#include "module_base/complexmatrix.h"
#include "module_base/macros.h"
#include "module_hamilt_pw/hamilt_pwdft/structure_factor.h"
#include "module_base/module_device/device.h"


namespace hsolver
{
Expand All @@ -27,31 +15,37 @@ class DiagoDavid : public DiagH<T, Device>
// return T if T is real type(float, double),
// otherwise return the real type of T(complex<float>, complex<double>)
using Real = typename GetTypeReal<T>::type;

public:
DiagoDavid(const Real* precondition_in);
~DiagoDavid();

// this is the override function diag() for CG method
void diag(hamilt::Hamilt<T, Device>* phm_in,
psi::Psi<T, Device>& phi,
Real* eigenvalue_in);
DiagoDavid(const Real* precondition_in,
const int david_ndim_in,
const bool use_paw_in,
const diag_comm_info& diag_comm_in);

virtual ~DiagoDavid() override;

static int PW_DIAG_NDIM;
virtual void diag(hamilt::Hamilt<T, Device>* phm_in,
psi::Psi<T, Device>& phi,
Real* eigenvalue_in) override ;

private:
int david_ndim = 4;
bool use_paw = false;
int test_david = 0;

/// record for how many bands not have convergence eigenvalues
int notconv = 0;
diag_comm_info diag_comm;

/// row size for input psi matrix
int n_band = 0;
/// col size for input psi matrix
int dmx = 0;
/// non-zero col size for inputted psi matrix
int dim = 0;
// maximum dimension of the reduced basis set
int nbase_x = 0;

/// record for how many bands not have convergence eigenvalues
int notconv = 0;

/// precondition for cg diag
const Real* precondition = nullptr;
Real* d_precondition = nullptr;
Expand Down Expand Up @@ -150,7 +144,6 @@ class DiagoDavid : public DiagH<T, Device>
consts<T> cs;
const T* one = nullptr, * zero = nullptr, * neg_one = nullptr;
};
template <typename Real, typename Device> int DiagoDavid<Real, Device>::PW_DIAG_NDIM = 4;
} // namespace hsolver

#endif
Loading