Skip to content

Commit

Permalink
fix UTs
Browse files Browse the repository at this point in the history
  • Loading branch information
maki49 committed Nov 14, 2024
1 parent 4f4f356 commit 4cc41f5
Show file tree
Hide file tree
Showing 8 changed files with 28 additions and 23 deletions.
2 changes: 1 addition & 1 deletion source/module_hsolver/diago_dav_subspace.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ Diago_DavSubspace<T, Device>::Diago_DavSubspace(PreFunc&& precondition_in,
const int& diag_nmax_in,
const bool& need_subspace_in,
const diag_comm_info& diag_comm_in)
: precondition(std::forward<PreFunc>(precondition_in)), n_band(nband_in), dim(nbasis_in), nbase_x(nband_in* david_ndim_in),
: precondition(precondition_in), n_band(nband_in), dim(nbasis_in), nbase_x(nband_in* david_ndim_in),
diag_thr(diag_thr_in), iter_nmax(diag_nmax_in), is_subspace(need_subspace_in), diag_comm(diag_comm_in)
{
this->device = base_device::get_device_type<Device>(this->ctx);
Expand Down
2 changes: 1 addition & 1 deletion source/module_hsolver/diago_david.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ DiagoDavid<T, Device>::DiagoDavid(PreFunc&& precondition_in,
const bool use_paw_in,
const diag_comm_info& diag_comm_in)
: nband(nband_in), dim(dim_in), nbase_x(david_ndim_in* nband_in), david_ndim(david_ndim_in), use_paw(use_paw_in), diag_comm(diag_comm_in),
precondition(std::forward<PreFunc>(precondition_in))
precondition(precondition_in)
{
this->device = base_device::get_device_type<Device>(this->ctx);

Expand Down
16 changes: 8 additions & 8 deletions source/module_hsolver/precondition_funcs.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
#include "module_base/module_device/types.h"
#include "module_base/module_device/memory_op.h"
#include "module_hsolver/kernels/math_kernel_op.h"

/// @brief Preconditioner Function Library
/// Users can add other types of operation than the following ones at one's need.
namespace hsolver
{
template <typename T>
Expand All @@ -19,10 +22,6 @@ namespace hsolver
/// @brief Transform vectors
namespace fvec
{
/// @brief To be called in the iterative eigensolver.
/// Users can add other types of operation than the following ones at one's need.
/// fixed parameters: object vector, eigenvalue, leading dimension, number of vectors

///---------------------------------------------------------------------------------------------
/// type 1: directly divide each vector by the precondition vector
///---------------------------------------------------------------------------------------------
Expand All @@ -36,7 +35,7 @@ namespace hsolver
vector_div_vector_op<T, Device>()({}, dim, ptr_m, ptr_m, pre);
}
}
/// calling intereface in the eigensolver
/// Intereface to be called in the eigensolver
template <typename T>
using Div = std::function<void(T*, const size_t&, const size_t&)>;
// Kernel function full of dependence
Expand Down Expand Up @@ -73,17 +72,18 @@ namespace hsolver
}
}
}
/// calling intereface in the eigensolver
/// Intereface to be called in the eigensolver
template <typename T>
using DivTransMinusEig = std::function<void(T*, const Real<T>*, const size_t&, const size_t&)>;
// Kernel function full of dependence
/// Kernel function full of dependence
template <typename T, typename Device = base_device::DEVICE_CPU>
using DivTransMinusEigKernel = std::function<decltype(div_trans_prevec_minus_eigen<T, Device>)>;
}

/// @brief A operator-like class of precondition function
/// to encapsulate the pre-allocation of memory on different devices before starting the iterative eigensolver.
/// One can pass the operatr() function of this class, or other custom lambdas/functions to eigensolvers.
/// One can use `.get()` interface to get the function to be called by the eigensovler,
/// or pass a custom lambdas/function to replace the one returned by `.get()`.
template <typename T, typename Device = base_device::DEVICE_CPU, typename Kernel_t = fvec::DivKernel<T, Device>>
struct PreOP
{
Expand Down
10 changes: 5 additions & 5 deletions source/module_hsolver/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,11 @@ if (ENABLE_MPI)
SOURCES test_hsolver_pw.cpp ../hsolver_pw.cpp ../hsolver_lcaopw.cpp ../diago_bpcg.cpp ../diago_dav_subspace.cpp ../diag_const_nums.cpp ../diago_iter_assist.cpp
)

# AddTest(
# TARGET HSolver_sdft
# LIBS parameter ${math_libs} psi device base container
# SOURCES test_hsolver_sdft.cpp ../hsolver_pw_sdft.cpp ../hsolver_pw.cpp ../diago_bpcg.cpp ../diago_dav_subspace.cpp ../diag_const_nums.cpp ../diago_iter_assist.cpp
# )
AddTest(
TARGET HSolver_sdft
LIBS parameter ${math_libs} psi device base container
SOURCES test_hsolver_sdft.cpp ../hsolver_pw_sdft.cpp ../hsolver_pw.cpp ../diago_bpcg.cpp ../diago_dav_subspace.cpp ../diag_const_nums.cpp ../diago_iter_assist.cpp
)

if(ENABLE_LCAO)
if(USE_ELPA)
Expand Down
5 changes: 3 additions & 2 deletions source/module_hsolver/test/diago_david_float_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,9 @@ class DiagoDavPrepare

const int dim = phi.get_current_nbas() ;
const int nband = phi.get_nbands();
const int ld_psi =phi.get_nbasis();
hsolver::DiagoDavid<std::complex<float>> dav(precondition, nband, dim, order, false, comm_info);
const int ld_psi = phi.get_nbasis();
const hsolver::PreOP<std::complex<float>> pre_op(precondition, dim);
hsolver::DiagoDavid<std::complex<float>> dav(pre_op.get(), nband, dim, order, false, comm_info);

hsolver::DiagoIterAssist<std::complex<float>>::PW_DIAG_NMAX = maxiter;
hsolver::DiagoIterAssist<std::complex<float>>::PW_DIAG_THR = eps;
Expand Down
3 changes: 2 additions & 1 deletion source/module_hsolver/test/diago_david_real_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ class DiagoDavPrepare
const int dim = phi.get_current_nbas();
const int nband = phi.get_nbands();
const int ld_psi = phi.get_nbasis();
hsolver::DiagoDavid<double> dav(precondition, nband, dim, order, false, comm_info);
const hsolver::PreOP<double> pre_op(precondition, dim);
hsolver::DiagoDavid<double> dav(pre_op.get(), nband, dim, order, false, comm_info);

hsolver::DiagoIterAssist<double>::PW_DIAG_NMAX = maxiter;
hsolver::DiagoIterAssist<double>::PW_DIAG_THR = eps;
Expand Down
6 changes: 4 additions & 2 deletions source/module_hsolver/test/diago_david_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,10 @@ class DiagoDavPrepare

const int dim = phi.get_current_nbas();
const int nband = phi.get_nbands();
const int ld_psi = phi.get_nbasis();
hsolver::DiagoDavid<std::complex<double>> dav(precondition, nband, dim, order, false, comm_info);
const int ld_psi = phi.get_nbasis();
const auto pre_func = [&precondition](std::complex<double>* ptr, const int& ld, const int& nvec)->void
{ hsolver::fvec::div_prevec(ptr, ld, nvec, precondition); };
hsolver::DiagoDavid<std::complex<double>> dav(pre_func, nband, dim, order, false, comm_info);

hsolver::DiagoIterAssist<std::complex<double>>::PW_DIAG_NMAX = maxiter;
hsolver::DiagoIterAssist<std::complex<double>>::PW_DIAG_THR = eps;
Expand Down
7 changes: 4 additions & 3 deletions source/module_hsolver/test/hsolver_pw_sup.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "module_basis/module_pw/pw_basis_k.h"
#include "module_hsolver/precondition_funcs.h"

namespace ModulePW {

Expand Down Expand Up @@ -121,15 +122,15 @@ template class DiagoCG<std::complex<float>, base_device::DEVICE_CPU>;
template class DiagoCG<std::complex<double>, base_device::DEVICE_CPU>;

template <typename T, typename Device>
DiagoDavid<T, Device>::DiagoDavid(const Real* precondition_in,
DiagoDavid<T, Device>::DiagoDavid(PreFunc&& precondition_in,
const int nband_in,
const int dim_in,
const int david_ndim_in,
const bool use_paw_in,
const diag_comm_info& diag_comm_in)
: nband(nband_in), dim(dim_in), nbase_x(david_ndim_in * nband_in), david_ndim(david_ndim_in), use_paw(use_paw_in), diag_comm(diag_comm_in) {
: nband(nband_in), dim(dim_in), nbase_x(david_ndim_in* nband_in), david_ndim(david_ndim_in), use_paw(use_paw_in), diag_comm(diag_comm_in),
precondition(std::forward<PreFunc>(precondition_in)) {
this->device = base_device::get_device_type<Device>(this->ctx);
this->precondition = precondition_in;

test_david = 2;
// 1: check which function is called and which step is executed
Expand Down

0 comments on commit 4cc41f5

Please sign in to comment.