diff --git a/src/components/ec/rocm/kernel/ec_rocm_reduce.cu b/src/components/ec/rocm/kernel/ec_rocm_reduce.cu index 01f10e4017..099ee82d95 100644 --- a/src/components/ec/rocm/kernel/ec_rocm_reduce.cu +++ b/src/components/ec/rocm/kernel/ec_rocm_reduce.cu @@ -8,6 +8,7 @@ #include "ec_rocm.h" #include "utils/ucc_math_op.h" #include +#include #define ROCM_REDUCE_WITH_OP_DEFAULT(NAME, _OP) \ template \ @@ -54,6 +55,41 @@ } \ } +#define ROCM_REDUCE_WITH_COMPLEX_PRODUCT_DEFAULT(NAME, _OP) \ + template \ + __global__ void UCC_REDUCE_ROCM_DEFAULT_COMPLEX_##NAME(ucc_eee_task_reduce_t task, \ + uint16_t flags) \ + { \ + size_t start = blockIdx.x * blockDim.x + threadIdx.x; \ + size_t step = blockDim.x * gridDim.x; \ + size_t count = task.count; \ + int n_srcs = task.n_srcs; \ + const _Type **s = (const _Type **)task.srcs; \ + _Type * d = (_Type *)task.dst; \ + size_t i; \ + \ + switch (n_srcs) { \ + case 2: \ + for (i = start; i < count; i += step) { \ + d[i] = _OP(s[0][i], s[1][i]); \ + } \ + break; \ + default: \ + for (i = start; i < count; i += step) { \ + d[i] = _OP(s[0][i], s[1][i]); \ + for (size_t j = 2; j < n_srcs; j++) { \ + d[i] = _OP(d[i], s[j][i]); \ + } \ + } \ + break; \ + } \ + if (flags & UCC_EEE_TASK_FLAG_REDUCE_WITH_ALPHA) { \ + for (i = start; i < count; i += step) { \ + d[i] = d[i] * (_AlphaType)task.alpha; \ + } \ + } \ + } + #define ROCM_REDUCE_WITH_OP_STRIDED(NAME, _OP) \ template \ __global__ void UCC_REDUCE_ROCM_STRIDED_##NAME( \ @@ -99,8 +135,45 @@ } \ } +#define ROCM_REDUCE_WITH_COMPLEX_PRODUCT_STRIDED(NAME, _OP) \ + template \ + __global__ void UCC_REDUCE_ROCM_STRIDED_COMPLEX_##NAME( \ + const _Type *s1, const _Type *s2, _Type *d, size_t count, \ + size_t stride, uint16_t n_src2, const bool with_alpha, \ + const double alpha) \ + { \ + size_t start = blockIdx.x * blockDim.x + threadIdx.x; \ + size_t step = blockDim.x * gridDim.x; \ + size_t ld = stride / sizeof(_Type); \ + size_t i; \ + \ + ucc_assert_system(stride % sizeof(_Type) == 0); \ + switch (n_src2) { \ + case 1: \ + for (i = start; i < count; i += step) { \ + d[i] = _OP(s1[i], s2[i]); \ + } \ + break; \ + default: \ + for (i = start; i < count; i += step) { \ + d[i] = _OP(s1[i], s2[i]); \ + for (size_t j = 1; j < n_src2; j++) { \ + d[i] = _OP(d[i], s2[i + j * ld]); \ + } \ + } \ + break; \ + } \ + if (with_alpha) { \ + for (i = start; i < count; i += step) { \ + d[i] = d[i] * (_AlphaType)alpha; \ + } \ + } \ + } + ROCM_REDUCE_WITH_OP_DEFAULT(SUM, DO_OP_SUM); ROCM_REDUCE_WITH_OP_DEFAULT(PROD, DO_OP_PROD); +ROCM_REDUCE_WITH_COMPLEX_PRODUCT_DEFAULT(PROD_DOUBLE, hipCmul); +ROCM_REDUCE_WITH_COMPLEX_PRODUCT_DEFAULT(PROD_FLOAT, hipCmulf); ROCM_REDUCE_WITH_OP_DEFAULT(MIN, DO_OP_MIN); ROCM_REDUCE_WITH_OP_DEFAULT(MAX, DO_OP_MAX); ROCM_REDUCE_WITH_OP_DEFAULT(LAND, DO_OP_LAND); @@ -112,6 +185,8 @@ ROCM_REDUCE_WITH_OP_DEFAULT(BXOR, DO_OP_BXOR); ROCM_REDUCE_WITH_OP_STRIDED(SUM, DO_OP_SUM); ROCM_REDUCE_WITH_OP_STRIDED(PROD, DO_OP_PROD); +ROCM_REDUCE_WITH_COMPLEX_PRODUCT_STRIDED(PROD_DOUBLE, hipCmul); +ROCM_REDUCE_WITH_COMPLEX_PRODUCT_STRIDED(PROD_FLOAT, hipCmulf); ROCM_REDUCE_WITH_OP_STRIDED(MIN, DO_OP_MIN); ROCM_REDUCE_WITH_OP_STRIDED(MAX, DO_OP_MAX); ROCM_REDUCE_WITH_OP_STRIDED(LAND, DO_OP_LAND); @@ -136,6 +211,21 @@ ROCM_REDUCE_WITH_OP_STRIDED(BXOR, DO_OP_BXOR); } \ } while (0) +#define LAUNCH_KERNEL_B(NAME, type, _AlphaType, _task, s, b, t) \ + do { \ + if (_task->task_type == UCC_EE_EXECUTOR_TASK_REDUCE) { \ + UCC_REDUCE_ROCM_DEFAULT_COMPLEX_##NAME \ + <<>>(_task->reduce, _task->flags); \ + } else { \ + ucc_eee_task_reduce_strided_t *trs = &_task->reduce_strided; \ + UCC_REDUCE_ROCM_STRIDED_COMPLEX_##NAME<<>>( \ + (type *)trs->src1, (type *)trs->src2, (type *)trs->dst, \ + trs->count, trs->stride, trs->n_src2, \ + (bool)(_task->flags & UCC_EEE_TASK_FLAG_REDUCE_WITH_ALPHA), \ + trs->alpha); \ + } \ + } while (0) + #define LAUNCH_KERNEL(NAME, type, _task, s, b, t) \ LAUNCH_KERNEL_A(NAME, type, type, _task, s, b, t) @@ -207,15 +297,15 @@ ROCM_REDUCE_WITH_OP_STRIDED(BXOR, DO_OP_BXOR); } \ } while (0) -#define DT_REDUCE_FLOAT_COMPLEX(type, _alphaType, _task, _op, s, b, t) \ +#define DT_REDUCE_FLOAT_COMPLEX(NAME, type, _alphaType, _task, _op, s, b, t) \ do { \ switch (_op) { \ case UCC_OP_AVG: \ case UCC_OP_SUM: \ - LAUNCH_KERNEL_A(SUM, type , _alphaType, _task, s, b, t); \ + LAUNCH_KERNEL_A(SUM, type , _alphaType, _task, s, b, t); \ break; \ case UCC_OP_PROD: \ - LAUNCH_KERNEL_A(PROD, type, _alphaType, _task, s, b, t); \ + LAUNCH_KERNEL_B(NAME, type, _alphaType, _task, s, b, t); \ break; \ default: \ ec_error(&ucc_ec_rocm.super, \ @@ -299,10 +389,10 @@ ucc_status_t ucc_ec_rocm_reduce(ucc_ee_executor_task_args_t *task, return UCC_ERR_NOT_SUPPORTED; #endif case UCC_DT_FLOAT32_COMPLEX: - DT_REDUCE_FLOAT_COMPLEX(hipFloatComplex, float, task, op, stream, bk, th); + DT_REDUCE_FLOAT_COMPLEX(PROD_FLOAT, hipFloatComplex, float, task, op, stream, bk, th); break; case UCC_DT_FLOAT64_COMPLEX: - DT_REDUCE_FLOAT_COMPLEX(hipDoubleComplex, double, task, op, stream, bk, th); + DT_REDUCE_FLOAT_COMPLEX(PROD_DOUBLE, hipDoubleComplex, double, task, op, stream, bk, th); break; case UCC_DT_BFLOAT16: ucc_assert(2 == sizeof(hip_bfloat16));