Skip to content

Commit

Permalink
Feature: make force and stress of sDFT support GPU (#5487)
Browse files Browse the repository at this point in the history
* refactor force in sdft

* refactor stress in sDFT

* make stress_ekin GPU

* finish sdft GPU

* fix compile

* add annotations

* fix bug of stress and force

* modify
  • Loading branch information
Qianruipku authored Nov 15, 2024
1 parent 5b1777c commit f2e91bd
Show file tree
Hide file tree
Showing 62 changed files with 1,645 additions and 1,073 deletions.
2 changes: 2 additions & 0 deletions source/Makefile.Objects
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,7 @@ OBJS_PARALLEL=parallel_common.o\
parallel_grid.o\
parallel_kpoints.o\
parallel_reduce.o\
parallel_device.o

OBJS_SRCPW=H_Ewald_pw.o\
dnrm2.o\
Expand All @@ -640,6 +641,7 @@ OBJS_SRCPW=H_Ewald_pw.o\
forces_cc.o\
forces_scc.o\
fs_nonlocal_tools.o\
fs_kin_tools.o\
force_op.o\
stress_op.o\
wf_op.o\
Expand Down
1 change: 1 addition & 0 deletions source/module_base/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ add_library(
parallel_global.cpp
parallel_comm.cpp
parallel_reduce.cpp
parallel_device.cpp
spherical_bessel_transformer.cpp
cubic_spline.cpp
module_mixing/mixing_data.cpp
Expand Down
31 changes: 17 additions & 14 deletions source/module_base/module_device/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,27 +191,30 @@ else { return "cpu";
}
}

int get_device_kpar(const int &kpar) {
int get_device_kpar(const int& kpar, const int& bndpar)
{
#if __MPI && (__CUDA || __ROCM)
int temp_nproc;
MPI_Comm_size(MPI_COMM_WORLD, &temp_nproc);
if (temp_nproc != kpar) {
ModuleBase::WARNING("Input_conv",
"None kpar set in INPUT file, auto set kpar value.");
}
// GlobalV::KPAR = temp_nproc;
// band the CPU processor to the devices
int node_rank = base_device::information::get_node_rank();
int temp_nproc = 0;
int new_kpar = kpar;
MPI_Comm_size(MPI_COMM_WORLD, &temp_nproc);
if (temp_nproc != kpar * bndpar)
{
new_kpar = temp_nproc / bndpar;
ModuleBase::WARNING("Input_conv", "kpar is not compatible with the number of processors, auto set kpar value.");
}

// get the CPU rank of current node
int node_rank = base_device::information::get_node_rank();

int device_num = -1;
int device_num = -1;
#if defined(__CUDA)
cudaGetDeviceCount(&device_num);
cudaSetDevice(node_rank % device_num);
cudaGetDeviceCount(&device_num); // get the number of GPU devices of current node
cudaSetDevice(node_rank % device_num); // band the CPU processor to the devices
#elif defined(__ROCM)
hipGetDeviceCount(&device_num);
hipSetDevice(node_rank % device_num);
#endif
return temp_nproc;
return new_kpar;
#endif
return kpar;
}
Expand Down
8 changes: 7 additions & 1 deletion source/module_base/module_device/device.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ std::string get_device_info(std::string device_flag);
* @brief Get the device kpar object
* for module_io GlobalV::KPAR
*/
int get_device_kpar(const int& kpar);
int get_device_kpar(const int& kpar, const int& bndpar);

/**
* @brief Get the device flag object
Expand All @@ -50,6 +50,12 @@ std::string get_device_flag(const std::string& device,
const std::string& basis_type);

#if __MPI
/**
* @brief Get the rank of current node
* Note that GPU can only be binded with CPU in the same node
*
* @return int
*/
int get_node_rank();
int get_node_rank_with_mpi_shared(const MPI_Comm mpi_comm = MPI_COMM_WORLD);
int stringCmp(const void* a, const void* b);
Expand Down
13 changes: 7 additions & 6 deletions source/module_base/parallel_common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,16 @@ void Parallel_Common::bcast_string(std::string& object) // Peize Lin fix bug 201
{
int size = object.size();
MPI_Bcast(&size, 1, MPI_INT, 0, MPI_COMM_WORLD);
char* swap = new char[size + 1];

int my_rank;
MPI_Comm_rank(MPI_COMM_WORLD, &my_rank);
if (0 == my_rank)
strcpy(swap, object.c_str());
MPI_Bcast(swap, size + 1, MPI_CHAR, 0, MPI_COMM_WORLD);

if (0 != my_rank)
object = static_cast<std::string>(swap);
delete[] swap;
{
object.resize(size);
}

MPI_Bcast(&object[0], size, MPI_CHAR, 0, MPI_COMM_WORLD);
return;
}

Expand Down
38 changes: 38 additions & 0 deletions source/module_base/parallel_device.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#include "parallel_device.h"
#ifdef __MPI
namespace Parallel_Common
{
void bcast_data(std::complex<double>* object, const int& n, const MPI_Comm& comm)
{
MPI_Bcast(object, n * 2, MPI_DOUBLE, 0, comm);
}
void bcast_data(std::complex<float>* object, const int& n, const MPI_Comm& comm)
{
MPI_Bcast(object, n * 2, MPI_FLOAT, 0, comm);
}
void bcast_data(double* object, const int& n, const MPI_Comm& comm)
{
MPI_Bcast(object, n, MPI_DOUBLE, 0, comm);
}
void bcast_data(float* object, const int& n, const MPI_Comm& comm)
{
MPI_Bcast(object, n, MPI_FLOAT, 0, comm);
}
void reduce_data(std::complex<double>* object, const int& n, const MPI_Comm& comm)
{
MPI_Allreduce(MPI_IN_PLACE, object, n * 2, MPI_DOUBLE, MPI_SUM, comm);
}
void reduce_data(std::complex<float>* object, const int& n, const MPI_Comm& comm)
{
MPI_Allreduce(MPI_IN_PLACE, object, n * 2, MPI_FLOAT, MPI_SUM, comm);
}
void reduce_data(double* object, const int& n, const MPI_Comm& comm)
{
MPI_Allreduce(MPI_IN_PLACE, object, n, MPI_DOUBLE, MPI_SUM, comm);
}
void reduce_data(float* object, const int& n, const MPI_Comm& comm)
{
MPI_Allreduce(MPI_IN_PLACE, object, n, MPI_FLOAT, MPI_SUM, comm);
}
}
#endif
45 changes: 21 additions & 24 deletions source/module_base/parallel_device.h
Original file line number Diff line number Diff line change
@@ -1,39 +1,34 @@
#ifndef __PARALLEL_DEVICE_H__
#define __PARALLEL_DEVICE_H__
#ifdef __MPI
#include "mpi.h"
#include "module_base/module_device/device.h"
#include "module_base/module_device/memory_op.h"
#include <complex>
#include <string>
#include <vector>
namespace Parallel_Common
{
void bcast_complex(std::complex<double>* object, const int& n, const MPI_Comm& comm)
{
MPI_Bcast(object, n * 2, MPI_DOUBLE, 0, comm);
}
void bcast_complex(std::complex<float>* object, const int& n, const MPI_Comm& comm)
{
MPI_Bcast(object, n * 2, MPI_FLOAT, 0, comm);
}
void bcast_real(double* object, const int& n, const MPI_Comm& comm)
{
MPI_Bcast(object, n, MPI_DOUBLE, 0, comm);
}
void bcast_real(float* object, const int& n, const MPI_Comm& comm)
{
MPI_Bcast(object, n, MPI_FLOAT, 0, comm);
}
void bcast_data(std::complex<double>* object, const int& n, const MPI_Comm& comm);
void bcast_data(std::complex<float>* object, const int& n, const MPI_Comm& comm);
void bcast_data(double* object, const int& n, const MPI_Comm& comm);
void bcast_data(float* object, const int& n, const MPI_Comm& comm);
void reduce_data(std::complex<double>* object, const int& n, const MPI_Comm& comm);
void reduce_data(std::complex<float>* object, const int& n, const MPI_Comm& comm);
void reduce_data(double* object, const int& n, const MPI_Comm& comm);
void reduce_data(float* object, const int& n, const MPI_Comm& comm);

template <typename T, typename Device>
/**
* @brief bcast complex in Device
* @brief bcast data in Device
*
* @tparam T: float, double, std::complex<float>, std::complex<double>
* @tparam Device
* @param ctx Device ctx
* @param object complex arrays in Device
* @param n the size of complex arrays
* @param comm MPI_Comm
* @param tmp_space tmp space in CPU
*/
void bcast_complex(const Device* ctx, T* object, const int& n, const MPI_Comm& comm, T* tmp_space = nullptr)
template <typename T, typename Device>
void bcast_dev(const Device* ctx, T* object, const int& n, const MPI_Comm& comm, T* tmp_space = nullptr)
{
const base_device::DEVICE_CPU* cpu_ctx = {};
T* object_cpu = nullptr;
Expand All @@ -56,7 +51,7 @@ void bcast_complex(const Device* ctx, T* object, const int& n, const MPI_Comm& c
object_cpu = object;
}

bcast_complex(object_cpu, n, comm);
bcast_data(object_cpu, n, comm);

if (base_device::get_device_type<Device>(ctx) == base_device::GpuDevice)
{
Expand All @@ -70,7 +65,7 @@ void bcast_complex(const Device* ctx, T* object, const int& n, const MPI_Comm& c
}

template <typename T, typename Device>
void bcast_real(const Device* ctx, T* object, const int& n, const MPI_Comm& comm, T* tmp_space = nullptr)
void reduce_dev(const Device* ctx, T* object, const int& n, const MPI_Comm& comm, T* tmp_space = nullptr)
{
const base_device::DEVICE_CPU* cpu_ctx = {};
T* object_cpu = nullptr;
Expand All @@ -93,7 +88,7 @@ void bcast_real(const Device* ctx, T* object, const int& n, const MPI_Comm& comm
object_cpu = object;
}

bcast_real(object_cpu, n, comm);
reduce_data(object_cpu, n, comm);

if (base_device::get_device_type<Device>(ctx) == base_device::GpuDevice)
{
Expand All @@ -105,7 +100,9 @@ void bcast_real(const Device* ctx, T* object, const int& n, const MPI_Comm& comm
}
return;
}

}


#endif
#endif
9 changes: 4 additions & 5 deletions source/module_elecstate/elecstate_pw_sdft.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,13 @@ void ElecStatePW_SDFT<T, Device>::psiToRho(const psi::Psi<T, Device>& psi)
ModuleBase::TITLE(this->classname, "psiToRho");
ModuleBase::timer::tick(this->classname, "psiToRho");
const int nspin = PARAM.inp.nspin;
for (int is = 0; is < nspin; is++)
{
setmem_var_op()(this->ctx, this->rho[is], 0, this->charge->nrxx);
}

if (GlobalV::MY_STOGROUP == 0)
{
for (int is = 0; is < nspin; is++)
{
setmem_var_op()(this->ctx, this->rho[is], 0, this->charge->nrxx);
}

for (int ik = 0; ik < psi.get_nk(); ++ik)
{
psi.fix_k(ik);
Expand Down
17 changes: 14 additions & 3 deletions source/module_elecstate/module_charge/charge_extra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,20 @@ void Charge_Extra::Init_CE(const int& nspin, const int& natom, const int& nrxx,

if (pot_order > 0)
{
delta_rho1.resize(this->nspin, std::vector<double>(nrxx, 0.0));
delta_rho2.resize(this->nspin, std::vector<double>(nrxx, 0.0));
delta_rho3.resize(this->nspin, std::vector<double>(nrxx, 0.0));
// delta_rho1.resize(this->nspin, std::vector<double>(nrxx, 0.0));
// delta_rho2.resize(this->nspin, std::vector<double>(nrxx, 0.0));
// delta_rho3.resize(this->nspin, std::vector<double>(nrxx, 0.0));
// qianrui replace the above code with the following code.
// The above code cannot passed valgrind tests, which has an invalid read of size 32.
delta_rho1.resize(this->nspin);
delta_rho2.resize(this->nspin);
delta_rho3.resize(this->nspin);
for (int is = 0; is < this->nspin; is++)
{
delta_rho1[is].resize(nrxx, 0.0);
delta_rho2[is].resize(nrxx, 0.0);
delta_rho3[is].resize(nrxx, 0.0);
}
}

if(pot_order == 3)
Expand Down
1 change: 0 additions & 1 deletion source/module_esolver/esolver_ks_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -623,7 +623,6 @@ void ESolver_KS_PW<T, Device>::cal_stress(ModuleBase::matrix& stress)
&this->sf,
&this->kv,
this->pw_wfc,
this->psi,
this->__kspw_psi);

// external stress
Expand Down
34 changes: 14 additions & 20 deletions source/module_esolver/esolver_sdft_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,9 @@ void ESolver_SDFT_PW<T, Device>::after_scf(const int istep)
template <typename T, typename Device>
void ESolver_SDFT_PW<T, Device>::hamilt2density_single(int istep, int iter, double ethr)
{
ModuleBase::TITLE("ESolver_SDFT_PW", "hamilt2density");
ModuleBase::timer::tick("ESolver_SDFT_PW", "hamilt2density");

// reset energy
this->pelec->f_en.eband = 0.0;
this->pelec->f_en.demet = 0.0;
Expand Down Expand Up @@ -241,6 +244,7 @@ void ESolver_SDFT_PW<T, Device>::hamilt2density_single(int istep, int iter, doub
#ifdef __MPI
MPI_Bcast(&(this->pelec->f_en.deband), 1, MPI_DOUBLE, 0, PARAPW_WORLD);
#endif
ModuleBase::timer::tick("ESolver_SDFT_PW", "hamilt2density");
}

template <typename T, typename Device>
Expand All @@ -249,10 +253,10 @@ double ESolver_SDFT_PW<T, Device>::cal_energy()
return this->pelec->f_en.etot;
}

template <>
void ESolver_SDFT_PW<std::complex<double>, base_device::DEVICE_CPU>::cal_force(ModuleBase::matrix& force)
template <typename T, typename Device>
void ESolver_SDFT_PW<T, Device>::cal_force(ModuleBase::matrix& force)
{
Sto_Forces ff(GlobalC::ucell.nat);
Sto_Forces<double, Device> ff(GlobalC::ucell.nat);

ff.cal_stoforce(force,
*this->pelec,
Expand All @@ -261,40 +265,30 @@ void ESolver_SDFT_PW<std::complex<double>, base_device::DEVICE_CPU>::cal_force(M
&this->sf,
&this->kv,
this->pw_wfc,
this->psi,
GlobalC::ppcell,
GlobalC::ucell,
*this->kspw_psi,
this->stowf);
}

template <>
void ESolver_SDFT_PW<std::complex<double>, base_device::DEVICE_GPU>::cal_force(ModuleBase::matrix& force)
{
ModuleBase::WARNING_QUIT("ESolver_SDFT_PW<T, Device>::cal_force", "DEVICE_GPU is not supported");
}

template <>
void ESolver_SDFT_PW<std::complex<double>, base_device::DEVICE_CPU>::cal_stress(ModuleBase::matrix& stress)
template <typename T, typename Device>
void ESolver_SDFT_PW<T, Device>::cal_stress(ModuleBase::matrix& stress)
{
Sto_Stress_PW ss;
Sto_Stress_PW<double, Device> ss;
ss.cal_stress(stress,
*this->pelec,
this->pw_rho,
&GlobalC::ucell.symm,
&this->sf,
&this->kv,
this->pw_wfc,
this->psi,
*this->kspw_psi,
this->stowf,
this->pelec->charge,
&GlobalC::ppcell,
GlobalC::ucell);
}

template <>
void ESolver_SDFT_PW<std::complex<double>, base_device::DEVICE_GPU>::cal_stress(ModuleBase::matrix& stress)
{
ModuleBase::WARNING_QUIT("ESolver_SDFT_PW<T, Device>::cal_stress", "DEVICE_GPU is not supported");
}

template <typename T, typename Device>
void ESolver_SDFT_PW<T, Device>::after_all_runners()
{
Expand Down
1 change: 1 addition & 0 deletions source/module_hamilt_pw/hamilt_pwdft/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ list(APPEND objects
parallel_grid.cpp
elecond.cpp
fs_nonlocal_tools.cpp
fs_kin_tools.cpp
radial_proj.cpp
)

Expand Down
Loading

0 comments on commit f2e91bd

Please sign in to comment.