Skip to content

Commit

Permalink
rename diago_blas to diago_scalapack (#4233)
Browse files Browse the repository at this point in the history
  • Loading branch information
haozhihan committed May 25, 2024
1 parent e98ce38 commit 95bdff3
Show file tree
Hide file tree
Showing 8 changed files with 26 additions and 26 deletions.
2 changes: 1 addition & 1 deletion source/Makefile.Objects
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ OBJS_HSOLVER=diago_cg.o\
dngvd_op.o\

OBJS_HSOLVER_LCAO=hsolver_lcao.o\
diago_blas.o\
diago_scalapack.o\
diago_elpa.o\
elpa_new.o\
elpa_new_real.o\
Expand Down
2 changes: 1 addition & 1 deletion source/module_hsolver/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ list(APPEND objects
if(ENABLE_LCAO)
list(APPEND objects
hsolver_lcao.cpp
diago_blas.cpp
diago_scalapack.cpp
)
if (USE_ELPA)
list(APPEND objects
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
// DATE : 2022-04-14
//=====================

#include "diago_blas.h"
#include "diago_scalapack.h"

#include <cassert>
#include <cstring>
Expand All @@ -21,9 +21,9 @@ typedef hamilt::MatrixBlock<std::complex<double>> matcd;
namespace hsolver
{
template<>
void DiagoBlas<double>::diag(hamilt::Hamilt<double>* phm_in, psi::Psi<double>& psi, Real* eigenvalue_in)
void DiagoScalapack<double>::diag(hamilt::Hamilt<double>* phm_in, psi::Psi<double>& psi, Real* eigenvalue_in)
{
ModuleBase::TITLE("DiagoElpa", "diag");
ModuleBase::TITLE("DiagoScalapack", "diag");
matd h_mat, s_mat;
phm_in->matrix(h_mat, s_mat);
assert(h_mat.col == s_mat.col && h_mat.row == s_mat.row && h_mat.desc == s_mat.desc);
Expand All @@ -33,9 +33,9 @@ namespace hsolver
BlasConnector::copy(GlobalV::NBANDS, eigen.data(), inc, eigenvalue_in, inc);
}
template<>
void DiagoBlas<std::complex<double>>::diag(hamilt::Hamilt<std::complex<double>>* phm_in, psi::Psi<std::complex<double>>& psi, Real* eigenvalue_in)
void DiagoScalapack<std::complex<double>>::diag(hamilt::Hamilt<std::complex<double>>* phm_in, psi::Psi<std::complex<double>>& psi, Real* eigenvalue_in)
{
ModuleBase::TITLE("DiagoElpa", "diag");
ModuleBase::TITLE("DiagoScalapack", "diag");
matcd h_mat, s_mat;
phm_in->matrix(h_mat, s_mat);
assert(h_mat.col == s_mat.col && h_mat.row == s_mat.row && h_mat.desc == s_mat.desc);
Expand All @@ -46,7 +46,7 @@ namespace hsolver
}

template<typename T>
std::pair<int, std::vector<int>> DiagoBlas<T>::pdsygvx_once(const int* const desc,
std::pair<int, std::vector<int>> DiagoScalapack<T>::pdsygvx_once(const int* const desc,
const int ncol,
const int nrow,
const double *const h_mat,
Expand Down Expand Up @@ -169,7 +169,7 @@ namespace hsolver
+ ModuleBase::GlobalFunc::TO_STRING(__LINE__));
}
template<typename T>
std::pair<int, std::vector<int>> DiagoBlas<T>::pzhegvx_once(const int* const desc,
std::pair<int, std::vector<int>> DiagoScalapack<T>::pzhegvx_once(const int* const desc,
const int ncol,
const int nrow,
const std::complex<double> *const h_mat,
Expand Down Expand Up @@ -303,7 +303,7 @@ namespace hsolver
+ ModuleBase::GlobalFunc::TO_STRING(__LINE__));
}
template<typename T>
void DiagoBlas<T>::pdsygvx_diag(const int* const desc,
void DiagoScalapack<T>::pdsygvx_diag(const int* const desc,
const int ncol,
const int nrow,
const double *const h_mat,
Expand All @@ -321,7 +321,7 @@ namespace hsolver
}

template<typename T>
void DiagoBlas<T> ::pzhegvx_diag(const int* const desc,
void DiagoScalapack<T> ::pzhegvx_diag(const int* const desc,
const int ncol,
const int nrow,
const std::complex<double> *const h_mat,
Expand All @@ -339,7 +339,7 @@ namespace hsolver
}

template<typename T>
void DiagoBlas<T>::post_processing(const int info, const std::vector<int>& vec)
void DiagoScalapack<T>::post_processing(const int info, const std::vector<int>& vec)
{
const std::string str_info = "info = " + ModuleBase::GlobalFunc::TO_STRING(info) + ".\n";
const std::string str_FILE
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
// DATE : 2022-04-14
//=====================

#ifndef DIAGOBLAS_H
#define DIAGOBLAS_H
#ifndef DIAGO_SCALAPACK_H
#define DIAGO_SCALAPACK_H

#include <complex>
#include <utility>
Expand All @@ -20,7 +20,7 @@
namespace hsolver
{
template<typename T>
class DiagoBlas : public DiagH<T>
class DiagoScalapack : public DiagH<T>
{
private:
using Real = typename GetTypeReal<T>::type;
Expand Down
4 changes: 2 additions & 2 deletions source/module_hsolver/hsolver_lcao.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#include "hsolver_lcao.h"

#include "diago_blas.h"
#include "diago_scalapack.h"
#include "diago_cg.h"
#include <ATen/core/tensor.h>
#include <ATen/core/tensor_types.h>
Expand Down Expand Up @@ -49,7 +49,7 @@ void HSolverLCAO<T, Device>::solveTemplate(hamilt::Hamilt<T>* pHamilt,
}
if (this->pdiagh == nullptr)
{
this->pdiagh = new DiagoBlas<T>();
this->pdiagh = new DiagoScalapack<T>();
this->pdiagh->method = this->method;
}
}
Expand Down
6 changes: 3 additions & 3 deletions source/module_hsolver/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,13 @@ if(ENABLE_LCAO)
AddTest(
TARGET HSolver_LCAO
LIBS ${math_libs} ELPA::ELPA base genelpa psi device
SOURCES diago_lcao_test.cpp ../diago_elpa.cpp ../diago_blas.cpp
SOURCES diago_lcao_test.cpp ../diago_elpa.cpp ../diago_scalapack.cpp
)
else()
AddTest(
TARGET HSolver_LCAO
LIBS ${math_libs} base psi device
SOURCES diago_lcao_test.cpp ../diago_blas.cpp
SOURCES diago_lcao_test.cpp ../diago_scalapack.cpp
)
endif()

Expand All @@ -106,7 +106,7 @@ if (USE_CUDA)
AddTest(
TARGET HSolver_LCAO_cusolver
LIBS ${math_libs} base psi device
SOURCES diago_lcao_cusolver_test.cpp ../diago_cusolver.cpp ../diago_blas.cpp
SOURCES diago_lcao_cusolver_test.cpp ../diago_cusolver.cpp ../diago_scalapack.cpp
../kernels/math_kernel_op.cpp
../kernels/dngvd_op.cpp
../kernels/cuda/diag_cusolver.cu
Expand Down
6 changes: 3 additions & 3 deletions source/module_hsolver/test/diago_lcao_cusolver_test.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#include <vector>

#include "gtest/gtest.h"
#include "module_hsolver/diago_blas.h"
#include "module_hsolver/diago_scalapack.h"
#include "module_hsolver/test/diago_elpa_utils.h"
#include "mpi.h"
#include "string.h"
Expand All @@ -24,7 +24,7 @@
/**
* Tested function:
* - hsolver::DiagoElpa::diag (for ELPA)
* - hsolver::DiagoBlas::diag (for Scalapack)
* - hsolver::DiagoScalapack::diag (for Scalapack)
*
* The 2d block cyclic distribution of H/S matrix is done by
* self-realized functions in module_hsolver/test/diago_elpa_utils.h
Expand Down Expand Up @@ -76,7 +76,7 @@ class DiagoPrepare
MPI_Comm_rank(MPI_COMM_WORLD, &myrank);

if (ks_solver == "scalapack_gvx")
dh = new hsolver::DiagoBlas<T>;
dh = new hsolver::DiagoScalapack<T>;
#ifdef __CUDA
else if (ks_solver == "cusolver")
dh = new hsolver::DiagoCusolver<T>;
Expand Down
6 changes: 3 additions & 3 deletions source/module_hsolver/test/diago_lcao_test.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#include "module_hsolver/test/diago_elpa_utils.h"
#include "module_hsolver/diago_blas.h"
#include "module_hsolver/diago_scalapack.h"
#include "mpi.h"
#include "string.h"
#include "gtest/gtest.h"
Expand All @@ -20,7 +20,7 @@
/**
* Tested function:
* - hsolver::DiagoElpa::diag (for ELPA)
* - hsolver::DiagoBlas::diag (for Scalapack)
* - hsolver::DiagoScalapack::diag (for Scalapack)
*
* The 2d block cyclic distribution of H/S matrix is done by
* self-realized functions in module_hsolver/test/diago_elpa_utils.h
Expand Down Expand Up @@ -60,7 +60,7 @@ template<class T> class DiagoPrepare
MPI_Comm_rank(MPI_COMM_WORLD, &myrank);

if (ks_solver == "scalapack_gvx")
dh = new hsolver::DiagoBlas<T>;
dh = new hsolver::DiagoScalapack<T>;
#ifdef __ELPA
else if(ks_solver == "genelpa")
dh = new hsolver::DiagoElpa<T>;
Expand Down

0 comments on commit 95bdff3

Please sign in to comment.