Skip to content

Commit

Permalink
Refactor: simplified Parallel_2d interface (#4237)
Browse files Browse the repository at this point in the history
* clean up

* change to new interface

* fix missing header

* add missing header

* add missing header

* follow advice from review

* update BLACS test

* compromise to pyabacus compilation
  • Loading branch information
jinzx10 authored May 30, 2024
1 parent aec6760 commit c4c99e1
Show file tree
Hide file tree
Showing 53 changed files with 450 additions and 684 deletions.
6 changes: 6 additions & 0 deletions source/module_base/blacs_connector.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
//====================================8<----------------------------------------
// blacs
// Initialization
#ifndef BLACS_CONNECTOR_H
#define BLACS_CONNECTOR_H

extern "C"
{
void Cblacs_pinfo(int *myid, int *nprocs);
Expand All @@ -34,6 +37,7 @@ extern "C"
// Informational and Miscellaneous
void Cblacs_gridinfo(int icontxt, int* nprow, int *npcol, int *myprow, int *mypcol);
void Cblacs_gridinit(int* icontxt, char* layout, int nprow, int npcol);
void Cblacs_gridexit(int* icontxt);
int Cblacs_pnum(int icontxt, int prow, int pcol);
void Cblacs_pcoord(int icontxt, int pnum, int *prow, int *pcol);
void Cblacs_exit(int icontxt);
Expand All @@ -45,4 +49,6 @@ extern "C"
{
int Csys2blacs_handle(MPI_Comm SysCtxt);
}
#endif // __MPI

#endif
78 changes: 4 additions & 74 deletions source/module_base/scalapack_connector.h
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
#ifndef SCALAPACK_CONNECTOR_H
#define SCALAPACK_CONNECTOR_H

#ifdef __MPI

#include <complex>

extern "C"
{
void blacs_gridinit_( int *ictxt, const char *order, const int *nprow, const int *npcol );
void blacs_gridinfo_( const int *ictxt, int *nprow, int *npcol, int *myprow, int *mypcol );
int numroc_( const int *n, const int *nb, const int *iproc, const int *srcproc, const int *nprocs );
void descinit_(
int *desc,
Expand Down Expand Up @@ -174,76 +174,6 @@ class ScalapackConnector
}
};

/*
class ScalapackConnector
{
public:
static void transpose_desc( int desc_T[9], const int desc[9] )
{
desc_T[0] = desc[0];
desc_T[1] = desc[1];
desc_T[2] = desc[3]; desc_T[3] = desc[2];
desc_T[4] = desc[5]; desc_T[5] = desc[4];
desc_T[6] = desc[6]; desc_T[7] = desc[7];
desc_T[8] = desc[8];
}
static void blacs_gridinit( int &ictxt, const char order, const int nprow, const int npcol )
{
blacs_gridinit_(&ictxt, &order, &nprow, &npcol);
}
static void blacs_gridinfo( const int &ictxt, int &nprow, int &npcol, int &myprow, int &mypcol )
{
blacs_gridinfo_( &ictxt, &nprow, &npcol, &myprow, &mypcol );
}
static int numroc( const int n, const int nb, const int iproc, const int srcproc, const int nprocs )
{
return numroc_(&n, &nb, &iproc, &srcproc, &nprocs);
}
static void descinit(
int *desc,
const int m, const int n, const int mb, const int nb, const int irsrc, const int icsrc,
const int ictxt, const int lld, int &info )
{
descinit_(desc, &m, &n, &mb, &nb, &irsrc, &icsrc, &ictxt, &lld, &info);
// descinit_(desc, &n, &m, &nb, &mb, &irsrc, &icsrc, &ictxt, &lld, &info);
}
// C = a * A.? * B.? + b * C
static void pgemm(
const char transa, const char transb,
const int M, const int N, const int K,
const double alpha,
const double *A, const int IA, const int JA, const int *DESCA,
const double *B, const int IB, const int JB, const int *DESCB,
const double beta,
double *C, const int IC, const int JC, const int *DESCC)
{
// int DESCA_T[9], DESCB_T[9], DESCC_T[9];
// transpose_desc( DESCA_T, DESCA );
// transpose_desc( DESCB_T, DESCB );
// transpose_desc( DESCC_T, DESCC );
// pdgemm_(
// &transb, &transa,
// &N, &M, &K,
// &alpha,
// B, &JB, &IB, DESCB_T,
// A, &JA, &IA, DESCA_T,
// &beta,
// C, &JC, &IC, DESCC_T);
pdgemm_(
&transa, &transb,
&M, &N, &K,
&alpha,
A, &JA, &IA, DESCA,
B, &JB, &IB, DESCB,
&beta,
C, &JC, &IC, DESCC);
}
};
*/
#endif // __MPI

#endif
#endif
5 changes: 0 additions & 5 deletions source/module_base/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,6 @@ AddTest(
TARGET base_atom_in
SOURCES atom_in_test.cpp
)
AddTest(
TARGET base_blacs_connector
LIBS ${math_libs}
SOURCES blacs_connector_test.cpp
)
AddTest(
TARGET base_timer
SOURCES timer_test.cpp ../timer.cpp ../global_variable.cpp
Expand Down
53 changes: 0 additions & 53 deletions source/module_base/test/blacs_connector_test.cpp

This file was deleted.

14 changes: 14 additions & 0 deletions source/module_base/test_parallel/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,17 @@ add_test(NAME base_parallel_reduce_test
COMMAND ${BASH} parallel_reduce_test.sh
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}
)

if(ENABLE_LCAO)
AddTest(
TARGET blacs_connector
LIBS MPI::MPI_CXX ScaLAPACK::ScaLAPACK
SOURCES blacs_connector_test.cpp
)
install(FILES blacs_connector_test.sh DESTINATION ${CMAKE_CURRENT_BINARY_DIR})
add_test(NAME blacs_connector_test
COMMAND ${BASH} blacs_connector_test.sh
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}
)
endif()

109 changes: 109 additions & 0 deletions source/module_base/test_parallel/blacs_connector_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
#ifdef __MPI

#include "../blacs_connector.h"
#include <mpi.h>
#include "gtest/gtest.h"

/************************************************
* unit test of functions in blacs_connector.h
***********************************************/

/**
* - Tested Function
* - Cblacs_gridinit
* - Initializes a grid of processors with a given number of rows and columns.
* The function creates a cartesian topology of all the processors initialized
* by the BLS library. In this topology, each processor is identified by its
* coordinates (row, col) in the grid.
*/

class BLACSTest: public testing::Test
{
protected:
void SetUp();

int rank = 0;
int nprocs = 0;
char layout = 'R';

// number of rows and columns in the process grid
int nprow = 0;
int npcol = 0;

// process coordinate
int iprow = -1;
int ipcol = -1;
};

void BLACSTest::SetUp()
{
Cblacs_pinfo(&rank, &nprocs);
}


TEST_F(BLACSTest, WorldGrid)
{
// generate a grid of size 1 x nproc
nprow = 1;
npcol = nprocs;

int ictxt_row = Csys2blacs_handle(MPI_COMM_WORLD);
Cblacs_gridinit(&ictxt_row, &layout, nprow, npcol);
Cblacs_gridinfo(ictxt_row, &nprow, &npcol, &iprow, &ipcol);

EXPECT_EQ(iprow, 0);
EXPECT_EQ(ipcol, rank);

// generate a grid of size nproc x 1
nprow = nprocs;
npcol = 1;

int ictxt_col = Csys2blacs_handle(MPI_COMM_WORLD);
Cblacs_gridinit(&ictxt_col, &layout, nprow, npcol);
Cblacs_gridinfo(ictxt_col, &nprow, &npcol, &iprow, &ipcol);

EXPECT_EQ(iprow, rank);
EXPECT_EQ(ipcol, 0);


// two BLACS grids should have difference context index
EXPECT_NE(ictxt_row, ictxt_col);
}

TEST_F(BLACSTest, SplitGrid)
{
// this test create BLACS grids based on a disjoint communicator

const int n_blacs = 2;
int rank_sub = -1;
int nprocs_sub = 0;

// sub communicators are divided based on odd / even ranks
MPI_Comm comm_sub;
MPI_Comm_split(MPI_COMM_WORLD, rank % n_blacs, rank, &comm_sub);
MPI_Comm_rank(comm_sub, &rank_sub);
MPI_Comm_size(comm_sub, &nprocs_sub);

int ctxt_sub = Csys2blacs_handle(comm_sub);

nprow = 1, npcol = nprocs_sub; // row-like grids
Cblacs_gridinit(&ctxt_sub, &layout, nprow, npcol);
Cblacs_gridinfo(ctxt_sub, &nprow, &npcol, &iprow, &ipcol);

// verifies that the BLACS grid is created based on comm_sub instead of MPI_COMM_WORLD
EXPECT_EQ(iprow, 0);
EXPECT_EQ(ipcol, rank_sub);
}

int main(int argc, char** argv)
{
MPI_Init(&argc, &argv);
testing::InitGoogleTest(&argc, argv);

int result = RUN_ALL_TESTS();

MPI_Finalize();

return result;
}
#endif
15 changes: 15 additions & 0 deletions source/module_base/test_parallel/blacs_connector_test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#!/bin/bash -e

np=`cat /proc/cpuinfo | grep "cpu cores" | uniq| awk '{print $NF}'`
echo "nprocs in this machine is $np"

for i in {8..2};
do
if [[ $i -gt $np ]];then
continue
fi
echo "TEST in parallel, nprocs=$i"
mpirun -np $i ./blacs_connector
break
done

27 changes: 6 additions & 21 deletions source/module_basis/module_ao/ORB_control.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,12 +210,8 @@ void ORB_control::setup_2d_division(std::ofstream& ofs_running,
else
{
std::cout << " Parallel Orbial, DIAGO_TYPE = " << ks_solver << std::endl;
ModuleBase::WARNING_QUIT("Parallel_Orbitals::set_global2local", "Check ks_solver.");
ModuleBase::WARNING_QUIT("Parallel_Orbitals::setup_2d_division", "Check ks_solver.");
}
// (2) set the trace, then we can calculate the nnr.
// for 2d: calculate po.nloc first, then global2local_row and global2local_col
// for O(N): calculate the three together.
this->ParaV.set_global2local(nlocal, nlocal, div_2d, ofs_running);
}


Expand Down Expand Up @@ -336,12 +332,6 @@ void ORB_control::divide_HS_2d(
if (dcolor != 0)
return; // mohan add 2012-01-13

// get the 2D index of computer.
pv->dim0 = (int)sqrt((double)dsize); // mohan update 2012/01/13
// while (GlobalV::NPROC_IN_POOL%dim0!=0)

pv->set_proc_dim(dsize);

if (pv->testpb)
ModuleBase::GlobalFunc::OUT(ofs_running, "dim0", pv->dim0);
if (pv->testpb)
Expand All @@ -352,32 +342,27 @@ void ORB_control::divide_HS_2d(
#ifdef __DEBUG
assert(nb2d > 0);
#endif
pv->set_block_size(nb2d); // mohan add 2010-06-28

ModuleBase::GlobalFunc::OUT(ofs_running, "nb2d", pv->get_block_size());

this->set_parameters(ofs_running, ofs_warning);

// call mpi_creat_cart
pv->mpi_create_cart(DIAG_WORLD);

int try_nb = pv->set_local2global(nlocal, nlocal, ofs_running, ofs_warning);
try_nb = pv->set_nloc_wfc_Eij(nbands, ofs_running, ofs_warning);
if (try_nb == 1)
int try_nb = pv->init(nlocal, nlocal, nb2d, DIAG_WORLD);
try_nb += pv->set_nloc_wfc_Eij(nbands, ofs_running, ofs_warning);
if (try_nb != 0)
{
ofs_running << " parameter nb2d is too large: nb2d = " << pv->get_block_size() << std::endl;
ofs_running << " reset nb2d to value 1, this set would make the program keep working but maybe get slower "
"during diagonalization."
<< std::endl;
pv->set_block_size(1);
try_nb = pv->set_local2global(nlocal, nlocal, ofs_running, ofs_warning);

pv->set(nlocal, nlocal, 1, pv->comm_2D, pv->blacs_ctxt);
try_nb = pv->set_nloc_wfc_Eij(nbands, ofs_running, ofs_warning);
}

// init blacs context for genelpa
if (ks_solver == "genelpa" || ks_solver == "scalapack_gvx" || ks_solver == "cusolver" || ks_solver == "cg_in_lcao" || ks_solver == "pexsi")
{
pv->set_desc(nlocal, nlocal, pv->nrow);
pv->set_desc_wfc_Eij(nlocal, nbands, pv->nrow);
}
#else // single processor used.
Expand Down
Loading

0 comments on commit c4c99e1

Please sign in to comment.