Skip to content

Commit

Permalink
Parallelize z-direction for sycl-gen kernels
Browse files Browse the repository at this point in the history
  • Loading branch information
uumesh committed Aug 12, 2023
1 parent b6c8abe commit 4598eb6
Showing 1 changed file with 55 additions and 42 deletions.
97 changes: 55 additions & 42 deletions include/ceed/jit-source/sycl/sycl-gen-templates.h
Original file line number Diff line number Diff line change
Expand Up @@ -180,15 +180,16 @@ inline void writeDofsStrided2d(const CeedInt num_comp, const CeedInt P_1D, const
inline void readDofsOffset3d(const CeedInt num_comp, const CeedInt strides_comp, const CeedInt P_1D, const CeedInt num_elem,
const global CeedInt* restrict indices, const global CeedScalar* restrict d_u, private CeedScalar* restrict r_u) {
const CeedInt item_id_x = get_local_id(0);
const CeedInt item_id_y = get_local_id(1);
const CeedInt item_id_y = get_local_id(1) % T_1D;
const CeedInt item_id_z = get_local_id(1) / T_1D;
const CeedInt elem = get_global_id(2);

if (item_id_x < P_1D && item_id_y < P_1D && elem < num_elem) {
for (CeedInt z = 0; z < P_1D; ++z) {
const CeedInt node = item_id_x + P_1D * (item_id_y + P_1D * z);
if (item_id_x < P_1D && item_id_y < P_1D && item_id_z < P_1D && elem < num_elem) {
// for (CeedInt z = 0; z < P_1D; ++z) {
const CeedInt node = item_id_x + P_1D * (item_id_y + P_1D * item_id_z);
const CeedInt ind = indices[node + elem * P_1D * P_1D * P_1D];
for (CeedInt comp = 0; comp < num_comp; ++comp) r_u[z + comp * P_1D] = d_u[ind + strides_comp * comp];
}
for (CeedInt comp = 0; comp < num_comp; ++comp) r_u[comp] = d_u[ind + strides_comp * comp];
// }
}
}

Expand All @@ -199,15 +200,16 @@ inline void readDofsStrided3d(const CeedInt num_comp, const CeedInt P_1D, const
const CeedInt strides_elem, const CeedInt num_elem, const global CeedScalar* restrict d_u,
private CeedScalar* restrict r_u) {
const CeedInt item_id_x = get_local_id(0);
const CeedInt item_id_y = get_local_id(1);
const CeedInt item_id_y = get_local_id(1) % T_1D;
const CeedInt item_id_z = get_local_id(1) / T_1D;
const CeedInt elem = get_global_id(2);

if (item_id_x < P_1D && item_id_y < P_1D && elem < num_elem) {
for (CeedInt z = 0; z < P_1D; ++z) {
const CeedInt node = item_id_x + P_1D * (item_id_y + P_1D * z);
if (item_id_x < P_1D && item_id_y < P_1D && item_id_z < P_1D && elem < num_elem) {
// for (CeedInt z = 0; z < P_1D; ++z) {
const CeedInt node = item_id_x + P_1D * (item_id_y + P_1D * item_id_z);
const CeedInt ind = node * strides_node + elem * strides_elem;
for (CeedInt comp = 0; comp < num_comp; ++comp) r_u[z + comp * P_1D] = d_u[ind + comp * strides_comp];
}
for (CeedInt comp = 0; comp < num_comp; ++comp) r_u[comp] = d_u[ind + comp * strides_comp];
// }
}
}

Expand Down Expand Up @@ -250,16 +252,17 @@ inline void readSliceQuadsStrided3d(const CeedInt num_comp, const CeedInt Q_1D,
inline void writeDofsOffset3d(const CeedInt num_comp, const CeedInt strides_comp, const CeedInt P_1D, const CeedInt num_elem,
const global CeedInt* restrict indices, const private CeedScalar* restrict r_v, global CeedAtomicScalar* restrict d_v) {
const CeedInt item_id_x = get_local_id(0);
const CeedInt item_id_y = get_local_id(1);
const CeedInt item_id_y = get_local_id(1) % T_1D;
const CeedInt item_id_z = get_local_id(1) / T_1D;
const CeedInt elem = get_global_id(2);

if (item_id_x < P_1D && item_id_y < P_1D && elem < num_elem) {
for (CeedInt z = 0; z < P_1D; ++z) {
const CeedInt node = item_id_x + item_id_y * P_1D + z * P_1D * P_1D;
if (item_id_x < P_1D && item_id_y < P_1D && item_id_z < P_1D && elem < num_elem) {
// for (CeedInt z = 0; z < P_1D; ++z) {
const CeedInt node = item_id_x + P_1D * (item_id_y + item_id_z * P_1D);
const CeedInt ind = indices[node + elem * P_1D * P_1D * P_1D];
for (CeedInt comp = 0; comp < num_comp; ++comp)
atomic_fetch_add_explicit(&d_v[ind + strides_comp * comp], r_v[z + comp * P_1D], memory_order_relaxed, memory_scope_device);
}
atomic_fetch_add_explicit(&d_v[ind + strides_comp * comp], r_v[comp], memory_order_relaxed, memory_scope_device);
// }
}
}

Expand All @@ -270,15 +273,16 @@ inline void writeDofsStrided3d(const CeedInt num_comp, const CeedInt P_1D, const
const CeedInt strides_elem, const CeedInt num_elem, const private CeedScalar* restrict r_v,
global CeedScalar* restrict d_v) {
const CeedInt item_id_x = get_local_id(0);
const CeedInt item_id_y = get_local_id(1);
const CeedInt item_id_y = get_local_id(1) % T_1D;
const CeedInt item_id_z = get_local_id(1) / T_1D;
const CeedInt elem = get_global_id(2);

if (item_id_x < P_1D && item_id_y < P_1D && elem < num_elem) {
for (CeedInt z = 0; z < P_1D; ++z) {
const CeedInt node = item_id_x + P_1D * (item_id_y + P_1D * z);
if (item_id_x < P_1D && item_id_y < P_1D && item_id_z < P_1D && elem < num_elem) {
// for (CeedInt z = 0; z < P_1D; ++z) {
const CeedInt node = item_id_x + P_1D * (item_id_y + item_id_z * P_1D);
const CeedInt ind = node * strides_node + elem * strides_elem;
for (CeedInt comp = 0; comp < num_comp; ++comp) d_v[ind + comp * strides_comp] += r_v[z + comp * P_1D];
}
for (CeedInt comp = 0; comp < num_comp; ++comp) d_v[ind + comp * strides_comp] += r_v[comp];
// }
}
}

Expand All @@ -288,28 +292,30 @@ inline void writeDofsStrided3d(const CeedInt num_comp, const CeedInt P_1D, const
inline void gradCollo3d(const CeedInt num_comp, const CeedInt Q_1D, const CeedInt q, const private CeedScalar* restrict r_U,
const local CeedScalar* s_G, private CeedScalar* restrict r_V, local CeedScalar* restrict scratch) {
const CeedInt item_id_x = get_local_id(0);
const CeedInt item_id_y = get_local_id(1);
const CeedInt item_id_y = get_local_id(1) % T_1D;
const CeedInt item_id_z = get_local_id(1) / T_1D;

for (CeedInt comp = 0; comp < num_comp; ++comp) {
if (item_id_x < Q_1D && item_id_y < Q_1D) {
scratch[item_id_x + item_id_y * T_1D] = r_U[q + comp * Q_1D];
if (item_id_x < Q_1D && item_id_y < Q_1D && item_id_z < Q_1D) {
scratch[item_id_x + (item_id_y + item_id_z * T_1D) * T_1D] = r_U[comp];
}
work_group_barrier(CLK_LOCAL_MEM_FENCE);

if (item_id_x < Q_1D && item_id_y < Q_1D) {
if (item_id_x < Q_1D && item_id_y < Q_1D && item_id_z < Q_1D) {
// X derivative
r_V[comp + 0 * num_comp] = 0.0;
for (CeedInt i = 0; i < Q_1D; ++i)
r_V[comp + 0 * num_comp] += s_G[i + item_id_x * Q_1D] * scratch[i + item_id_y * T_1D]; // Contract x direction (X derivative)
r_V[comp + 0 * num_comp] += s_G[i + item_id_x * Q_1D] * scratch[i + (item_id_y + item_id_z * T_1D) * T_1D]; // Contract x direction (X derivative)

// Y derivative
r_V[comp + 1 * num_comp] = 0.0;
for (CeedInt i = 0; i < Q_1D; ++i)
r_V[comp + 1 * num_comp] += s_G[i + item_id_y * Q_1D] * scratch[item_id_x + i * T_1D]; // Contract y direction (Y derivative)
r_V[comp + 1 * num_comp] += s_G[i + item_id_y * Q_1D] * scratch[item_id_x + (i + item_id_z * T_1D) * T_1D]; // Contract y direction (Y derivative)

// Z derivative
r_V[comp + 2 * num_comp] = 0.0;
for (CeedInt i = 0; i < Q_1D; ++i) r_V[comp + 2 * num_comp] += s_G[i + q * Q_1D] * r_U[i + comp * Q_1D]; // Contract z direction (Z derivative)
for (CeedInt i = 0; i < Q_1D; ++i)
r_V[comp + 2 * num_comp] += s_G[i + item_id_z * Q_1D] * scratch[item_id_x + (item_id_y + i * T_1D) * T_1D]; // Contract z direction (Z derivative)
}

work_group_barrier(CLK_LOCAL_MEM_FENCE);
Expand All @@ -322,38 +328,45 @@ inline void gradCollo3d(const CeedInt num_comp, const CeedInt Q_1D, const CeedIn
inline void gradColloTranspose3d(const CeedInt num_comp, const CeedInt Q_1D, const CeedInt q, const private CeedScalar* restrict r_U,
const local CeedScalar* restrict s_G, private CeedScalar* restrict r_V, local CeedScalar* restrict scratch) {
const CeedInt item_id_x = get_local_id(0);
const CeedInt item_id_y = get_local_id(1);
const CeedInt item_id_y = get_local_id(1) % T_1D;
const CeedInt item_id_z = get_local_id(1) / T_1D;

for (CeedInt comp = 0; comp < num_comp; ++comp) {
// X derivative
if (item_id_x < Q_1D && item_id_y < Q_1D) {
scratch[item_id_x + item_id_y * T_1D] = r_U[comp + 0 * num_comp];
if (item_id_x < Q_1D && item_id_y < Q_1D && item_id_z < Q_1D) {
scratch[item_id_x + (item_id_y + item_id_z * T_1D) * T_1D] = r_U[comp + 0 * num_comp];
}
work_group_barrier(CLK_LOCAL_MEM_FENCE);

if (item_id_x < Q_1D && item_id_y < Q_1D) {
if (item_id_x < Q_1D && item_id_y < Q_1D && item_id_z < Q_1D) {
for (CeedInt i = 0; i < Q_1D; ++i)
r_V[q + comp * Q_1D] += s_G[item_id_x + i * Q_1D] * scratch[i + item_id_y * T_1D]; // Contract x direction (X derivative)
r_V[comp] += s_G[item_id_x + i * Q_1D] * scratch[i + (item_id_y + item_id_z * T_1D) * T_1D]; // Contract x direction (X derivative)
}
work_group_barrier(CLK_LOCAL_MEM_FENCE);

// Y derivative
if (item_id_x < Q_1D && item_id_y < Q_1D) {
scratch[item_id_x + item_id_y * T_1D] = r_U[comp + 1 * num_comp];
if (item_id_x < Q_1D && item_id_y < Q_1D && item_id_z < Q_1D) {
scratch[item_id_x + (item_id_y + item_id_z * T_1D) * T_1D] = r_U[comp + 1 * num_comp];
}
work_group_barrier(CLK_LOCAL_MEM_FENCE);

if (item_id_x < Q_1D && item_id_y < Q_1D) {
if (item_id_x < Q_1D && item_id_y < Q_1D && item_id_z < Q_1D) {
for (CeedInt i = 0; i < Q_1D; ++i)
r_V[q + comp * Q_1D] += s_G[item_id_y + i * Q_1D] * scratch[item_id_x + i * T_1D]; // Contract y direction (Y derivative)
r_V[comp] += s_G[item_id_y + i * Q_1D] * scratch[item_id_x + (i + item_id_z * T_1D) * T_1D]; // Contract y direction (Y derivative)
}
work_group_barrier(CLK_LOCAL_MEM_FENCE);

// Z derivative
if (item_id_x < Q_1D && item_id_y < Q_1D) {
if (item_id_x < Q_1D && item_id_y < Q_1D && item_id_z < Q_1D) {
scratch[item_id_x + (item_id_y + item_id_z * T_1D) * T_1D] = r_U[comp + 2 * num_comp];
}
work_group_barrier(CLK_LOCAL_MEM_FENCE);

if (item_id_x < Q_1D && item_id_y < Q_1D && item_id_z < Q_1D) {
for (CeedInt i = 0; i < Q_1D; ++i)
r_V[i + comp * Q_1D] += s_G[i + q * Q_1D] * r_U[comp + 2 * num_comp]; // PARTIAL contract z direction (Z derivative)
r_V[comp] += s_G[item_id_z + i * Q_1D] * scratch[item_id_x + (item_id_y + i * T_1D) * T_1D]; // Contract z direction (Z derivative)
}
work_group_barrier(CLK_LOCAL_MEM_FENCE);
}
}

Expand Down

0 comments on commit 4598eb6

Please sign in to comment.