Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor: simplified Parallel_2d interface #4237

Merged
merged 12 commits into from
May 30, 2024
6 changes: 6 additions & 0 deletions source/module_base/blacs_connector.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
//====================================8<----------------------------------------
// blacs
// Initialization
#ifndef BLACS_CONNECTOR_H
#define BLACS_CONNECTOR_H

extern "C"
{
void Cblacs_pinfo(int *myid, int *nprocs);
Expand All @@ -34,6 +37,7 @@ extern "C"
// Informational and Miscellaneous
void Cblacs_gridinfo(int icontxt, int* nprow, int *npcol, int *myprow, int *mypcol);
void Cblacs_gridinit(int* icontxt, char* layout, int nprow, int npcol);
void Cblacs_gridexit(int* icontxt);
int Cblacs_pnum(int icontxt, int prow, int pcol);
void Cblacs_pcoord(int icontxt, int pnum, int *prow, int *pcol);
void Cblacs_exit(int icontxt);
Expand All @@ -45,4 +49,6 @@ extern "C"
{
int Csys2blacs_handle(MPI_Comm SysCtxt);
}
#endif // __MPI

#endif
78 changes: 4 additions & 74 deletions source/module_base/scalapack_connector.h
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
#ifndef SCALAPACK_CONNECTOR_H
#define SCALAPACK_CONNECTOR_H

#ifdef __MPI

#include <complex>

extern "C"
{
void blacs_gridinit_( int *ictxt, const char *order, const int *nprow, const int *npcol );
void blacs_gridinfo_( const int *ictxt, int *nprow, int *npcol, int *myprow, int *mypcol );
int numroc_( const int *n, const int *nb, const int *iproc, const int *srcproc, const int *nprocs );
void descinit_(
int *desc,
Expand Down Expand Up @@ -174,76 +174,6 @@ class ScalapackConnector
}
};

/*
class ScalapackConnector
{
public:
static void transpose_desc( int desc_T[9], const int desc[9] )
{
desc_T[0] = desc[0];
desc_T[1] = desc[1];
desc_T[2] = desc[3]; desc_T[3] = desc[2];
desc_T[4] = desc[5]; desc_T[5] = desc[4];
desc_T[6] = desc[6]; desc_T[7] = desc[7];
desc_T[8] = desc[8];
}

static void blacs_gridinit( int &ictxt, const char order, const int nprow, const int npcol )
{
blacs_gridinit_(&ictxt, &order, &nprow, &npcol);
}

static void blacs_gridinfo( const int &ictxt, int &nprow, int &npcol, int &myprow, int &mypcol )
{
blacs_gridinfo_( &ictxt, &nprow, &npcol, &myprow, &mypcol );
}

static int numroc( const int n, const int nb, const int iproc, const int srcproc, const int nprocs )
{
return numroc_(&n, &nb, &iproc, &srcproc, &nprocs);
}

static void descinit(
int *desc,
const int m, const int n, const int mb, const int nb, const int irsrc, const int icsrc,
const int ictxt, const int lld, int &info )
{
descinit_(desc, &m, &n, &mb, &nb, &irsrc, &icsrc, &ictxt, &lld, &info);
// descinit_(desc, &n, &m, &nb, &mb, &irsrc, &icsrc, &ictxt, &lld, &info);
}

// C = a * A.? * B.? + b * C
static void pgemm(
const char transa, const char transb,
const int M, const int N, const int K,
const double alpha,
const double *A, const int IA, const int JA, const int *DESCA,
const double *B, const int IB, const int JB, const int *DESCB,
const double beta,
double *C, const int IC, const int JC, const int *DESCC)
{
// int DESCA_T[9], DESCB_T[9], DESCC_T[9];
// transpose_desc( DESCA_T, DESCA );
// transpose_desc( DESCB_T, DESCB );
// transpose_desc( DESCC_T, DESCC );
// pdgemm_(
// &transb, &transa,
// &N, &M, &K,
// &alpha,
// B, &JB, &IB, DESCB_T,
// A, &JA, &IA, DESCA_T,
// &beta,
// C, &JC, &IC, DESCC_T);
pdgemm_(
&transa, &transb,
&M, &N, &K,
&alpha,
A, &JA, &IA, DESCA,
B, &JB, &IB, DESCB,
&beta,
C, &JC, &IC, DESCC);
}
};
*/
#endif // __MPI

#endif
#endif
5 changes: 0 additions & 5 deletions source/module_base/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,6 @@ AddTest(
TARGET base_atom_in
SOURCES atom_in_test.cpp
)
AddTest(
TARGET base_blacs_connector
LIBS ${math_libs}
SOURCES blacs_connector_test.cpp
)
AddTest(
TARGET base_timer
SOURCES timer_test.cpp ../timer.cpp ../global_variable.cpp
Expand Down
53 changes: 0 additions & 53 deletions source/module_base/test/blacs_connector_test.cpp

This file was deleted.

14 changes: 14 additions & 0 deletions source/module_base/test_parallel/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,17 @@ add_test(NAME base_parallel_reduce_test
COMMAND ${BASH} parallel_reduce_test.sh
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}
)

if(ENABLE_LCAO)
AddTest(
TARGET blacs_connector
LIBS MPI::MPI_CXX ScaLAPACK::ScaLAPACK
SOURCES blacs_connector_test.cpp
)
install(FILES blacs_connector_test.sh DESTINATION ${CMAKE_CURRENT_BINARY_DIR})
add_test(NAME blacs_connector_test
COMMAND ${BASH} blacs_connector_test.sh
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}
)
endif()

109 changes: 109 additions & 0 deletions source/module_base/test_parallel/blacs_connector_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
#ifdef __MPI

#include "../blacs_connector.h"
#include <mpi.h>
#include "gtest/gtest.h"

/************************************************
* unit test of functions in blacs_connector.h
***********************************************/

/**
* - Tested Function
* - Cblacs_gridinit
* - Initializes a grid of processors with a given number of rows and columns.
* The function creates a cartesian topology of all the processors initialized
* by the BLS library. In this topology, each processor is identified by its
* coordinates (row, col) in the grid.
*/

class BLACSTest: public testing::Test
{
protected:
void SetUp();

int rank = 0;
int nprocs = 0;
char layout = 'R';

// number of rows and columns in the process grid
int nprow = 0;
int npcol = 0;

// process coordinate
int iprow = -1;
int ipcol = -1;
};

void BLACSTest::SetUp()
{
Cblacs_pinfo(&rank, &nprocs);
}


TEST_F(BLACSTest, WorldGrid)
{
// generate a grid of size 1 x nproc
nprow = 1;
npcol = nprocs;

int ictxt_row = Csys2blacs_handle(MPI_COMM_WORLD);
Cblacs_gridinit(&ictxt_row, &layout, nprow, npcol);
Cblacs_gridinfo(ictxt_row, &nprow, &npcol, &iprow, &ipcol);

EXPECT_EQ(iprow, 0);
EXPECT_EQ(ipcol, rank);

// generate a grid of size nproc x 1
nprow = nprocs;
npcol = 1;

int ictxt_col = Csys2blacs_handle(MPI_COMM_WORLD);
Cblacs_gridinit(&ictxt_col, &layout, nprow, npcol);
Cblacs_gridinfo(ictxt_col, &nprow, &npcol, &iprow, &ipcol);

EXPECT_EQ(iprow, rank);
EXPECT_EQ(ipcol, 0);


// two BLACS grids should have difference context index
EXPECT_NE(ictxt_row, ictxt_col);
}

TEST_F(BLACSTest, SplitGrid)
{
// this test create BLACS grids based on a disjoint communicator

const int n_blacs = 2;
int rank_sub = -1;
int nprocs_sub = 0;

// sub communicators are divided based on odd / even ranks
MPI_Comm comm_sub;
MPI_Comm_split(MPI_COMM_WORLD, rank % n_blacs, rank, &comm_sub);
MPI_Comm_rank(comm_sub, &rank_sub);
MPI_Comm_size(comm_sub, &nprocs_sub);

int ctxt_sub = Csys2blacs_handle(comm_sub);

nprow = 1, npcol = nprocs_sub; // row-like grids
Cblacs_gridinit(&ctxt_sub, &layout, nprow, npcol);
Cblacs_gridinfo(ctxt_sub, &nprow, &npcol, &iprow, &ipcol);

// verifies that the BLACS grid is created based on comm_sub instead of MPI_COMM_WORLD
EXPECT_EQ(iprow, 0);
EXPECT_EQ(ipcol, rank_sub);
}

int main(int argc, char** argv)
{
MPI_Init(&argc, &argv);
testing::InitGoogleTest(&argc, argv);

int result = RUN_ALL_TESTS();

MPI_Finalize();

return result;
}
#endif
15 changes: 15 additions & 0 deletions source/module_base/test_parallel/blacs_connector_test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#!/bin/bash -e

np=`cat /proc/cpuinfo | grep "cpu cores" | uniq| awk '{print $NF}'`
echo "nprocs in this machine is $np"

for i in {8..2};
do
if [[ $i -gt $np ]];then
continue
fi
echo "TEST in parallel, nprocs=$i"
mpirun -np $i ./blacs_connector
break
done

27 changes: 6 additions & 21 deletions source/module_basis/module_ao/ORB_control.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,12 +210,8 @@ void ORB_control::setup_2d_division(std::ofstream& ofs_running,
else
{
std::cout << " Parallel Orbial, DIAGO_TYPE = " << ks_solver << std::endl;
ModuleBase::WARNING_QUIT("Parallel_Orbitals::set_global2local", "Check ks_solver.");
ModuleBase::WARNING_QUIT("Parallel_Orbitals::setup_2d_division", "Check ks_solver.");
}
// (2) set the trace, then we can calculate the nnr.
// for 2d: calculate po.nloc first, then global2local_row and global2local_col
// for O(N): calculate the three together.
this->ParaV.set_global2local(nlocal, nlocal, div_2d, ofs_running);
}


Expand Down Expand Up @@ -336,12 +332,6 @@ void ORB_control::divide_HS_2d(
if (dcolor != 0)
return; // mohan add 2012-01-13

// get the 2D index of computer.
pv->dim0 = (int)sqrt((double)dsize); // mohan update 2012/01/13
// while (GlobalV::NPROC_IN_POOL%dim0!=0)

pv->set_proc_dim(dsize);

if (pv->testpb)
ModuleBase::GlobalFunc::OUT(ofs_running, "dim0", pv->dim0);
if (pv->testpb)
Expand All @@ -352,32 +342,27 @@ void ORB_control::divide_HS_2d(
#ifdef __DEBUG
assert(nb2d > 0);
#endif
pv->set_block_size(nb2d); // mohan add 2010-06-28

ModuleBase::GlobalFunc::OUT(ofs_running, "nb2d", pv->get_block_size());

this->set_parameters(ofs_running, ofs_warning);

// call mpi_creat_cart
pv->mpi_create_cart(DIAG_WORLD);

int try_nb = pv->set_local2global(nlocal, nlocal, ofs_running, ofs_warning);
try_nb = pv->set_nloc_wfc_Eij(nbands, ofs_running, ofs_warning);
if (try_nb == 1)
int try_nb = pv->init(nlocal, nlocal, nb2d, DIAG_WORLD);
try_nb += pv->set_nloc_wfc_Eij(nbands, ofs_running, ofs_warning);
if (try_nb != 0)
{
ofs_running << " parameter nb2d is too large: nb2d = " << pv->get_block_size() << std::endl;
ofs_running << " reset nb2d to value 1, this set would make the program keep working but maybe get slower "
"during diagonalization."
<< std::endl;
pv->set_block_size(1);
try_nb = pv->set_local2global(nlocal, nlocal, ofs_running, ofs_warning);

pv->set(nlocal, nlocal, 1, pv->comm_2D, pv->blacs_ctxt);
try_nb = pv->set_nloc_wfc_Eij(nbands, ofs_running, ofs_warning);
}

// init blacs context for genelpa
if (ks_solver == "genelpa" || ks_solver == "scalapack_gvx" || ks_solver == "cusolver" || ks_solver == "cg_in_lcao" || ks_solver == "pexsi")
{
pv->set_desc(nlocal, nlocal, pv->nrow);
pv->set_desc_wfc_Eij(nlocal, nbands, pv->nrow);
}
#else // single processor used.
Expand Down
Loading