Skip to content

Commit

Permalink
EC/ROCM: Prod overload issue for HIP complex (#783)
Browse files Browse the repository at this point in the history
(cherry picked from commit c2a5062)

Co-authored-by: Pedram Alizadeh <[email protected]>
  • Loading branch information
edgargabriel and PedramAlizadeh authored May 22, 2023
1 parent a036a5f commit c0b5d1f
Showing 1 changed file with 95 additions and 5 deletions.
100 changes: 95 additions & 5 deletions src/components/ec/rocm/kernel/ec_rocm_reduce.cu
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "ec_rocm.h"
#include "utils/ucc_math_op.h"
#include <inttypes.h>
#include <hip/hip_complex.h>

#define ROCM_REDUCE_WITH_OP_DEFAULT(NAME, _OP) \
template <typename _Type, typename _AlphaType> \
Expand Down Expand Up @@ -54,6 +55,41 @@
} \
}

#define ROCM_REDUCE_WITH_COMPLEX_PRODUCT_DEFAULT(NAME, _OP) \
template <typename _Type, typename _AlphaType> \
__global__ void UCC_REDUCE_ROCM_DEFAULT_COMPLEX_##NAME(ucc_eee_task_reduce_t task, \
uint16_t flags) \
{ \
size_t start = blockIdx.x * blockDim.x + threadIdx.x; \
size_t step = blockDim.x * gridDim.x; \
size_t count = task.count; \
int n_srcs = task.n_srcs; \
const _Type **s = (const _Type **)task.srcs; \
_Type * d = (_Type *)task.dst; \
size_t i; \
\
switch (n_srcs) { \
case 2: \
for (i = start; i < count; i += step) { \
d[i] = _OP(s[0][i], s[1][i]); \
} \
break; \
default: \
for (i = start; i < count; i += step) { \
d[i] = _OP(s[0][i], s[1][i]); \
for (size_t j = 2; j < n_srcs; j++) { \
d[i] = _OP(d[i], s[j][i]); \
} \
} \
break; \
} \
if (flags & UCC_EEE_TASK_FLAG_REDUCE_WITH_ALPHA) { \
for (i = start; i < count; i += step) { \
d[i] = d[i] * (_AlphaType)task.alpha; \
} \
} \
}

#define ROCM_REDUCE_WITH_OP_STRIDED(NAME, _OP) \
template <typename _Type, typename _AlphaType> \
__global__ void UCC_REDUCE_ROCM_STRIDED_##NAME( \
Expand Down Expand Up @@ -99,8 +135,45 @@
} \
}

#define ROCM_REDUCE_WITH_COMPLEX_PRODUCT_STRIDED(NAME, _OP) \
template <typename _Type, typename _AlphaType> \
__global__ void UCC_REDUCE_ROCM_STRIDED_COMPLEX_##NAME( \
const _Type *s1, const _Type *s2, _Type *d, size_t count, \
size_t stride, uint16_t n_src2, const bool with_alpha, \
const double alpha) \
{ \
size_t start = blockIdx.x * blockDim.x + threadIdx.x; \
size_t step = blockDim.x * gridDim.x; \
size_t ld = stride / sizeof(_Type); \
size_t i; \
\
ucc_assert_system(stride % sizeof(_Type) == 0); \
switch (n_src2) { \
case 1: \
for (i = start; i < count; i += step) { \
d[i] = _OP(s1[i], s2[i]); \
} \
break; \
default: \
for (i = start; i < count; i += step) { \
d[i] = _OP(s1[i], s2[i]); \
for (size_t j = 1; j < n_src2; j++) { \
d[i] = _OP(d[i], s2[i + j * ld]); \
} \
} \
break; \
} \
if (with_alpha) { \
for (i = start; i < count; i += step) { \
d[i] = d[i] * (_AlphaType)alpha; \
} \
} \
}

ROCM_REDUCE_WITH_OP_DEFAULT(SUM, DO_OP_SUM);
ROCM_REDUCE_WITH_OP_DEFAULT(PROD, DO_OP_PROD);
ROCM_REDUCE_WITH_COMPLEX_PRODUCT_DEFAULT(PROD_DOUBLE, hipCmul);
ROCM_REDUCE_WITH_COMPLEX_PRODUCT_DEFAULT(PROD_FLOAT, hipCmulf);
ROCM_REDUCE_WITH_OP_DEFAULT(MIN, DO_OP_MIN);
ROCM_REDUCE_WITH_OP_DEFAULT(MAX, DO_OP_MAX);
ROCM_REDUCE_WITH_OP_DEFAULT(LAND, DO_OP_LAND);
Expand All @@ -112,6 +185,8 @@ ROCM_REDUCE_WITH_OP_DEFAULT(BXOR, DO_OP_BXOR);

ROCM_REDUCE_WITH_OP_STRIDED(SUM, DO_OP_SUM);
ROCM_REDUCE_WITH_OP_STRIDED(PROD, DO_OP_PROD);
ROCM_REDUCE_WITH_COMPLEX_PRODUCT_STRIDED(PROD_DOUBLE, hipCmul);
ROCM_REDUCE_WITH_COMPLEX_PRODUCT_STRIDED(PROD_FLOAT, hipCmulf);
ROCM_REDUCE_WITH_OP_STRIDED(MIN, DO_OP_MIN);
ROCM_REDUCE_WITH_OP_STRIDED(MAX, DO_OP_MAX);
ROCM_REDUCE_WITH_OP_STRIDED(LAND, DO_OP_LAND);
Expand All @@ -136,6 +211,21 @@ ROCM_REDUCE_WITH_OP_STRIDED(BXOR, DO_OP_BXOR);
} \
} while (0)

#define LAUNCH_KERNEL_B(NAME, type, _AlphaType, _task, s, b, t) \
do { \
if (_task->task_type == UCC_EE_EXECUTOR_TASK_REDUCE) { \
UCC_REDUCE_ROCM_DEFAULT_COMPLEX_##NAME<type, _AlphaType> \
<<<b, t, 0, s>>>(_task->reduce, _task->flags); \
} else { \
ucc_eee_task_reduce_strided_t *trs = &_task->reduce_strided; \
UCC_REDUCE_ROCM_STRIDED_COMPLEX_##NAME<type, _AlphaType><<<b, t, 0, s>>>( \
(type *)trs->src1, (type *)trs->src2, (type *)trs->dst, \
trs->count, trs->stride, trs->n_src2, \
(bool)(_task->flags & UCC_EEE_TASK_FLAG_REDUCE_WITH_ALPHA), \
trs->alpha); \
} \
} while (0)

#define LAUNCH_KERNEL(NAME, type, _task, s, b, t) \
LAUNCH_KERNEL_A(NAME, type, type, _task, s, b, t)

Expand Down Expand Up @@ -207,15 +297,15 @@ ROCM_REDUCE_WITH_OP_STRIDED(BXOR, DO_OP_BXOR);
} \
} while (0)

#define DT_REDUCE_FLOAT_COMPLEX(type, _alphaType, _task, _op, s, b, t) \
#define DT_REDUCE_FLOAT_COMPLEX(NAME, type, _alphaType, _task, _op, s, b, t) \
do { \
switch (_op) { \
case UCC_OP_AVG: \
case UCC_OP_SUM: \
LAUNCH_KERNEL_A(SUM, type , _alphaType, _task, s, b, t); \
LAUNCH_KERNEL_A(SUM, type , _alphaType, _task, s, b, t); \
break; \
case UCC_OP_PROD: \
LAUNCH_KERNEL_A(PROD, type, _alphaType, _task, s, b, t); \
LAUNCH_KERNEL_B(NAME, type, _alphaType, _task, s, b, t); \
break; \
default: \
ec_error(&ucc_ec_rocm.super, \
Expand Down Expand Up @@ -299,10 +389,10 @@ ucc_status_t ucc_ec_rocm_reduce(ucc_ee_executor_task_args_t *task,
return UCC_ERR_NOT_SUPPORTED;
#endif
case UCC_DT_FLOAT32_COMPLEX:
DT_REDUCE_FLOAT_COMPLEX(hipFloatComplex, float, task, op, stream, bk, th);
DT_REDUCE_FLOAT_COMPLEX(PROD_FLOAT, hipFloatComplex, float, task, op, stream, bk, th);
break;
case UCC_DT_FLOAT64_COMPLEX:
DT_REDUCE_FLOAT_COMPLEX(hipDoubleComplex, double, task, op, stream, bk, th);
DT_REDUCE_FLOAT_COMPLEX(PROD_DOUBLE, hipDoubleComplex, double, task, op, stream, bk, th);
break;
case UCC_DT_BFLOAT16:
ucc_assert(2 == sizeof(hip_bfloat16));
Expand Down

0 comments on commit c0b5d1f

Please sign in to comment.