Skip to content

Commit

Permalink
Refactoring shared backend to 3d threads
Browse files Browse the repository at this point in the history
  • Loading branch information
uumesh committed Jul 25, 2023
1 parent 887d617 commit ca6301b
Show file tree
Hide file tree
Showing 4 changed files with 170 additions and 147 deletions.
8 changes: 6 additions & 2 deletions backends/sycl-shared/ceed-sycl-shared-basis.sycl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,15 @@
//------------------------------------------------------------------------------
// Compute the local range of for basis kernels
//------------------------------------------------------------------------------
static int ComputeLocalRange(Ceed ceed, CeedInt dim, CeedInt thread_1d, CeedInt *local_range, CeedInt max_group_size = 256) {
static int ComputeLocalRange(Ceed ceed, CeedInt dim, CeedInt thread_1d, CeedInt *local_range, CeedInt max_group_size = 128) {
local_range[0] = thread_1d;
local_range[1] = (dim > 1) ? thread_1d : 1;
local_range[1] = (dim > 2 ? thread_1d : 1) * (dim > 1 ? thread_1d : 1);

const CeedInt min_group_size = local_range[0] * local_range[1];

if (min_group_size > max_group_size) max_group_size = 256;
if (min_group_size > max_group_size) max_group_size = 512;
if (min_group_size > max_group_size) max_group_size = 1024;
CeedCheck(min_group_size <= max_group_size, ceed, CEED_ERROR_BACKEND, "Requested group size is smaller than the required minimum.");

local_range[2] = max_group_size / min_group_size; // elements per group
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,16 +115,15 @@ inline void ReadElementStrided3d(const CeedInt NUM_COMP, const CeedInt P_1D, con
const CeedInt strides_comp, const CeedInt strides_elem, global const CeedScalar* restrict d_u,
private CeedScalar* restrict r_u) {
const CeedInt item_id_x = get_local_id(0);
const CeedInt item_id_y = get_local_id(1);
const CeedInt item_id_y = get_local_id(1) % T_1D;
const CeedInt item_id_z = get_local_id(1) / T_1D;
const CeedInt elem = get_global_id(2);

if (item_id_x < P_1D && item_id_y < P_1D && elem < num_elem) {
for (CeedInt z = 0; z < P_1D; z++) {
const CeedInt node = item_id_x + item_id_y * P_1D + z * P_1D * P_1D;
const CeedInt ind = node * strides_node + elem * strides_elem;
for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
r_u[z + comp * P_1D] = d_u[ind + comp * strides_comp];
}
if (item_id_x < P_1D && item_id_y < P_1D && item_id_z < P_1D && elem < num_elem) {
const CeedInt node = item_id_x + P_1D * (item_id_y + P_1D * item_id_z);
const CeedInt ind = node * strides_node + elem * strides_elem;
for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
r_u[comp] = d_u[ind + comp * strides_comp];
}
}
}
Expand All @@ -136,16 +135,15 @@ inline void WriteElementStrided3d(const CeedInt NUM_COMP, const CeedInt P_1D, co
const CeedInt strides_comp, const CeedInt strides_elem, private const CeedScalar* restrict r_v,
global CeedScalar* restrict d_v) {
const CeedInt item_id_x = get_local_id(0);
const CeedInt item_id_y = get_local_id(1);
const CeedInt item_id_y = get_local_id(1) % T_1D;
const CeedInt item_id_z = get_local_id(1) / T_1D;
const CeedInt elem = get_global_id(2);

if (item_id_x < P_1D && item_id_y < P_1D && elem < num_elem) {
for (CeedInt z = 0; z < P_1D; z++) {
const CeedInt node = item_id_x + item_id_y * P_1D + z * P_1D * P_1D;
const CeedInt ind = node * strides_node + elem * strides_elem;
for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
d_v[ind + comp * strides_comp] = r_v[z + comp * P_1D];
}
if (item_id_x < P_1D && item_id_y < P_1D && item_id_z < P_1D && elem < num_elem) {
const CeedInt node = item_id_x + P_1D * (item_id_y + P_1D * item_id_z);
const CeedInt ind = node * strides_node + elem * strides_elem;
for (CeedInt comp = 0; comp < NUM_COMP; comp++) {
d_v[ind + comp * strides_comp] = r_v[comp];
}
}
}
Expand Down
Loading

0 comments on commit ca6301b

Please sign in to comment.