Skip to content

Commit

Permalink
Refactor: Use lambda expressions to reduce if conditions. (#4828)
Browse files Browse the repository at this point in the history
* Use lambda expressions to reduce if conditions.

* fix dcu issue.

* delete useless redundant param.
  • Loading branch information
grysgreat authored Aug 1, 2024
1 parent abaa9ad commit 56faf8f
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 129 deletions.
110 changes: 45 additions & 65 deletions source/module_hamilt_pw/hamilt_pwdft/kernels/cuda/stress_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,6 @@ __global__ void cal_vq_deri(
tab, it, ib, tab_2, tab_3, table_interval, gnorm[idx]);
}


template <typename FPTYPE>
__global__ void cal_stress_drhoc_aux0(
const FPTYPE* r, const FPTYPE* rhoc,
Expand All @@ -364,29 +363,23 @@ __global__ void cal_stress_drhoc_aux0(

int idx = threadIdx.x + blockIdx.x * blockDim.x;

FPTYPE aux_d[2];
FPTYPE rhocg1=0.0, f_0=0.0, f_2=0.0, f_1=0.0;

if (idx >= ngg) {return;}

for( int ir = 0;ir< mesh; ir++)
{
const int ir_2 = ir%2;
const FPTYPE gx_r = gx_arr[idx] * r [ir];
aux_d [ir_2] = r [ir] * rhoc [ir] * (r [ir] * cos (gx_r) / gx_arr[idx] - sin (gx_r) / pow(gx_arr[idx],2));

if(ir==0){
f_0 = aux_d[ir_2]*rab[ir];
} else if(ir==mesh-2){
f_2 = aux_d[ir_2]*rab[ir];
} else if(ir==mesh-1) {
f_1 = aux_d[ir_2]*rab[ir];
} else if(ir_2==0){
const double f1 = aux_d[1]*rab[ir-1];
rhocg1 += f1 + f1 + aux_d[0]*rab[ir];
}

FPTYPE rhocg1=0.0;
FPTYPE gx = gx_arr[idx];

auto aux = [](FPTYPE r, FPTYPE rhoc, FPTYPE gx, FPTYPE rab) -> FPTYPE{
return r * rhoc * (r * cos (gx * r) / gx - sin (gx * r) / (gx * gx)) * rab;
};

FPTYPE f_0 = aux(r[0],rhoc[0], gx, rab[0]);
for( int ir = 1 ; ir< mesh - 2; ir+=2)
{
rhocg1 += 2 * aux(r[ir],rhoc[ir], gx, rab[ir]) + aux(r[ir+1],rhoc[ir+1], gx, rab[ir+1]);
}//ir
FPTYPE f_2 = aux(r[mesh - 2],rhoc[mesh - 2], gx, rab[mesh - 2]);
FPTYPE f_1 = aux(r[mesh - 1],rhoc[mesh - 1], gx, rab[mesh - 1]);

rhocg1 += f_2+f_2;
rhocg1 += rhocg1;
rhocg1 += f_0 + f_1;
Expand All @@ -405,31 +398,24 @@ __global__ void cal_stress_drhoc_aux1(

int idx = threadIdx.x + blockIdx.x * blockDim.x;

FPTYPE aux_d[2];
FPTYPE rhocg1=0.0, f_0=0.0, f_2=0.0, f_1=0.0;

if (idx >= ngg) {return;}

for( int ir = 0;ir< mesh; ir++)
{
const int ir_2 = ir%2;
const FPTYPE gx_r = gx_arr[idx] * r [ir];

aux_d [ir_2] = ir!=0 ? sin(gx_r) / (gx_r) : 1.0;
aux_d [ir_2] = r[ir] * r[ir] * rhoc [ir] * aux_d [ir_2];

if(ir==0){
f_0 = aux_d[ir_2]*rab[ir];
} else if(ir==mesh-2){
f_2 = aux_d[ir_2]*rab[ir];
} else if(ir==mesh-1) {
f_1 = aux_d[ir_2]*rab[ir];
} else if(ir_2==0){
const double f1 = aux_d[1]*rab[ir-1];
rhocg1 += f1 + f1 + aux_d[0]*rab[ir];
}

FPTYPE rhocg1=0.0;
FPTYPE gx = gx_arr[idx];

auto aux = [](FPTYPE r, FPTYPE rhoc, FPTYPE gx, FPTYPE rab) -> FPTYPE{
return sin (gx * r) / (gx * r) * r * r * rhoc * rab;
};

FPTYPE f_0 = r[0] * r[0] * rhoc[0] * rab[0];
for( int ir = 1 ; ir< mesh - 2; ir+=2)
{
rhocg1 += 2 * aux(r[ir],rhoc[ir], gx, rab[ir]) + aux(r[ir+1],rhoc[ir+1], gx, rab[ir+1]);
}//ir

FPTYPE f_2 = aux(r[mesh - 2],rhoc[mesh - 2], gx, rab[mesh - 2]);
FPTYPE f_1 = aux(r[mesh - 1],rhoc[mesh - 1], gx, rab[mesh - 1]);

rhocg1 += f_2+f_2;
rhocg1 += rhocg1;
rhocg1 += f_0 + f_1;
Expand All @@ -445,33 +431,28 @@ __global__ void cal_stress_drhoc_aux2(
const FPTYPE *gx_arr, const FPTYPE *rab, FPTYPE *drhocg,
const int mesh, const int igl0, const int ngg, const double omega
){
const double FOUR_PI = 4.0 * 3.14159265358979323846;

int idx = threadIdx.x + blockIdx.x * blockDim.x;

FPTYPE aux_d[2];
FPTYPE rhocg1=0.0, f_0=0.0, f_2=0.0, f_1=0.0;
int idx = threadIdx.x + blockIdx.x * blockDim.x;

if (idx >= ngg) {return;}

for( int ir = 0;ir< mesh; ir++)
{
const int ir_2 = ir%2;
const FPTYPE gx_r = gx_arr[idx] * r [ir];

aux_d [ir_2] = r[ir] < 1.0e-8 ? rhoc [ir] : rhoc [ir] * sin(gx_r) / (gx_r);
if(ir==0){
f_0 = aux_d[ir_2]*rab[ir];
} else if(ir==mesh-2){
f_2 = aux_d[ir_2]*rab[ir];
} else if(ir==mesh-1) {
f_1 = aux_d[ir_2]*rab[ir];
} else if(ir_2==0){
const double f1 = aux_d[1]*rab[ir-1];
rhocg1 += f1 + f1 + aux_d[0]*rab[ir];
}

FPTYPE rhocg1=0.0;
FPTYPE gx = gx_arr[idx];

auto aux = [](FPTYPE r, FPTYPE rhoc, FPTYPE gx, FPTYPE rab) -> FPTYPE{
return r < 1.0e-8 ? rab * rhoc : rab * rhoc * sin(gx * r) / (gx * r);
};


FPTYPE f_0 = r[0] * r[0] * rhoc[0] * rab[0];
for( int ir = 1 ; ir< mesh - 2; ir+=2)
{
rhocg1 += 2 * aux(r[ir],rhoc[ir], gx, rab[ir]) + aux(r[ir+1],rhoc[ir+1], gx, rab[ir+1]);
}//ir
FPTYPE f_2 = aux(r[mesh - 2],rhoc[mesh - 2], gx, rab[mesh - 2]);
FPTYPE f_1 = aux(r[mesh - 1],rhoc[mesh - 1], gx, rab[mesh - 1]);

rhocg1 += f_2+f_2;
rhocg1 += rhocg1;
rhocg1 += f_0 + f_1;
Expand All @@ -480,7 +461,6 @@ __global__ void cal_stress_drhoc_aux2(
drhocg [idx] = rhocg1;
}


template <typename FPTYPE>
void cal_vkb_op<FPTYPE, base_device::DEVICE_GPU>::operator()(
const base_device::DEVICE_GPU* ctx,
Expand Down
114 changes: 50 additions & 64 deletions source/module_hamilt_pw/hamilt_pwdft/kernels/rocm/stress_op.hip.cu
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,6 @@ __global__ void cal_vq_deri(
}



template <typename FPTYPE>
__global__ void cal_stress_drhoc_aux0(
const FPTYPE* r, const FPTYPE* rhoc,
Expand All @@ -363,29 +362,23 @@ __global__ void cal_stress_drhoc_aux0(

int idx = threadIdx.x + blockIdx.x * blockDim.x;

FPTYPE aux_d[2];
FPTYPE rhocg1=0.0, f_0=0.0, f_2=0.0, f_1=0.0;

if (idx >= ngg) {return;}

for( int ir = 0;ir< mesh; ir++)
{
const int ir_2 = ir%2;
const FPTYPE gx_r = gx_arr[idx] * r [ir];
aux_d [ir_2] = r [ir] * rhoc [ir] * (r [ir] * cos (gx_r) / gx_arr[idx] - sin (gx_r) / (gx_arr[idx] * gx_arr[idx]));

if(ir==0){
f_0 = aux_d[ir_2]*rab[ir];
} else if(ir==mesh-2){
f_2 = aux_d[ir_2]*rab[ir];
} else if(ir==mesh-1) {
f_1 = aux_d[ir_2]*rab[ir];
} else if(ir_2==0){
const double f1 = aux_d[1]*rab[ir-1];
rhocg1 += f1 + f1 + aux_d[0]*rab[ir];
}

FPTYPE rhocg1=0.0;
FPTYPE gx = gx_arr[idx];

auto aux = [](FPTYPE r, FPTYPE rhoc, FPTYPE gx, FPTYPE rab) -> FPTYPE{
return r * rhoc * (r * cos (gx * r) / gx - sin (gx * r) / (gx * gx)) * rab;
};

FPTYPE f_0 = aux(r[0],rhoc[0], gx, rab[0]);
for( int ir = 1 ; ir< mesh - 2; ir+=2)
{
rhocg1 += 2 * aux(r[ir],rhoc[ir], gx, rab[ir]) + aux(r[ir+1],rhoc[ir+1], gx, rab[ir+1]);
}//ir
FPTYPE f_2 = aux(r[mesh - 2],rhoc[mesh - 2], gx, rab[mesh - 2]);
FPTYPE f_1 = aux(r[mesh - 1],rhoc[mesh - 1], gx, rab[mesh - 1]);

rhocg1 += f_2+f_2;
rhocg1 += rhocg1;
rhocg1 += f_0 + f_1;
Expand All @@ -404,31 +397,24 @@ __global__ void cal_stress_drhoc_aux1(

int idx = threadIdx.x + blockIdx.x * blockDim.x;

FPTYPE aux_d[2];
FPTYPE rhocg1=0.0, f_0=0.0, f_2=0.0, f_1=0.0;

if (idx >= ngg) {return;}

for( int ir = 0;ir< mesh; ir++)
{
const int ir_2 = ir%2;
const FPTYPE gx_r = gx_arr[idx] * r [ir];

aux_d [ir_2] = ir!=0 ? sin(gx_r) / (gx_r) : 1.0;
aux_d [ir_2] = r[ir] * r[ir] * rhoc [ir] * aux_d [ir_2];

if(ir==0){
f_0 = aux_d[ir_2]*rab[ir];
} else if(ir==mesh-2){
f_2 = aux_d[ir_2]*rab[ir];
} else if(ir==mesh-1) {
f_1 = aux_d[ir_2]*rab[ir];
} else if(ir_2==0){
const double f1 = aux_d[1]*rab[ir-1];
rhocg1 += f1 + f1 + aux_d[0]*rab[ir];
}

FPTYPE rhocg1=0.0;
FPTYPE gx = gx_arr[idx];

auto aux = [](FPTYPE r, FPTYPE rhoc, FPTYPE gx, FPTYPE rab) -> FPTYPE{
return sin (gx * r) / (gx * r) * r * r * rhoc * rab;
};

FPTYPE f_0 = r[0] * r[0] * rhoc[0] * rab[0];
for( int ir = 1 ; ir< mesh - 2; ir+=2)
{
rhocg1 += 2 * aux(r[ir],rhoc[ir], gx, rab[ir]) + aux(r[ir+1],rhoc[ir+1], gx, rab[ir+1]);
}//ir

FPTYPE f_2 = aux(r[mesh - 2],rhoc[mesh - 2], gx, rab[mesh - 2]);
FPTYPE f_1 = aux(r[mesh - 1],rhoc[mesh - 1], gx, rab[mesh - 1]);

rhocg1 += f_2+f_2;
rhocg1 += rhocg1;
rhocg1 += f_0 + f_1;
Expand All @@ -444,33 +430,28 @@ __global__ void cal_stress_drhoc_aux2(
const FPTYPE *gx_arr, const FPTYPE *rab, FPTYPE *drhocg,
const int mesh, const int igl0, const int ngg, const double omega
){
const double FOUR_PI = 4.0 * 3.14159265358979323846;

int idx = threadIdx.x + blockIdx.x * blockDim.x;

FPTYPE aux_d[2];
FPTYPE rhocg1=0.0, f_0=0.0, f_2=0.0, f_1=0.0;
int idx = threadIdx.x + blockIdx.x * blockDim.x;

if (idx >= ngg) {return;}

for( int ir = 0;ir< mesh; ir++)
{
const int ir_2 = ir%2;
const FPTYPE gx_r = gx_arr[idx] * r [ir];

aux_d [ir_2] = r[ir] < 1.0e-8 ? rhoc [ir] : rhoc [ir] * sin(gx_r) / (gx_r);
if(ir==0){
f_0 = aux_d[ir_2]*rab[ir];
} else if(ir==mesh-2){
f_2 = aux_d[ir_2]*rab[ir];
} else if(ir==mesh-1) {
f_1 = aux_d[ir_2]*rab[ir];
} else if(ir_2==0){
const double f1 = aux_d[1]*rab[ir-1];
rhocg1 += f1 + f1 + aux_d[0]*rab[ir];
}

FPTYPE rhocg1=0.0;
FPTYPE gx = gx_arr[idx];

auto aux = [](FPTYPE r, FPTYPE rhoc, FPTYPE gx, FPTYPE rab) -> FPTYPE{
return r < 1.0e-8 ? rab * rhoc : rab * rhoc * sin(gx * r) / (gx * r);
};


FPTYPE f_0 = r[0] * r[0] * rhoc[0] * rab[0];
for( int ir = 1 ; ir< mesh - 2; ir+=2)
{
rhocg1 += 2 * aux(r[ir],rhoc[ir], gx, rab[ir]) + aux(r[ir+1],rhoc[ir+1], gx, rab[ir+1]);
}//ir
FPTYPE f_2 = aux(r[mesh - 2],rhoc[mesh - 2], gx, rab[mesh - 2]);
FPTYPE f_1 = aux(r[mesh - 1],rhoc[mesh - 1], gx, rab[mesh - 1]);

rhocg1 += f_2+f_2;
rhocg1 += rhocg1;
rhocg1 += f_0 + f_1;
Expand All @@ -480,6 +461,7 @@ __global__ void cal_stress_drhoc_aux2(
}



template <typename FPTYPE>
void cal_vkb_op<FPTYPE, base_device::DEVICE_GPU>::operator()(
const base_device::DEVICE_GPU* ctx,
Expand Down Expand Up @@ -595,6 +577,10 @@ void cal_stress_drhoc_aux_op<FPTYPE, base_device::DEVICE_GPU>::operator()(
hipLaunchKernelGGL(HIP_KERNEL_NAME(cal_stress_drhoc_aux1<FPTYPE>),block,THREADS_PER_BLOCK,0,0,
r,rhoc,gx_arr,rab,drhocg,mesh,igl0,ngg,omega
);
} else if(type == 2 ){
hipLaunchKernelGGL(HIP_KERNEL_NAME(cal_stress_drhoc_aux2<FPTYPE>),block,THREADS_PER_BLOCK,0,0,
r,rhoc,gx_arr,rab,drhocg,mesh,igl0,ngg,omega
);
}

return ;
Expand Down

0 comments on commit 56faf8f

Please sign in to comment.