From 91b3d875c380b5caa412fb8118347b09bc534d71 Mon Sep 17 00:00:00 2001 From: Shichao Sun <1147972163@qq.com> Date: Mon, 17 Jun 2024 16:22:26 +0800 Subject: [PATCH] Refactor: extracted transpose parameter checking routine of gemm/v (#4279) * Refactor: extracted transpose parameter checking routine of gemm and gemv to function * Fix: Corrected error log when provided with wrong param * Fix: Correct ROCm code returning type and variable name --------- Co-authored-by: Mohan Chen --- .../kernels/cuda/math_kernel_op.cu | 143 ++++------------- .../kernels/rocm/math_kernel_op.hip.cu | 146 ++++-------------- 2 files changed, 58 insertions(+), 231 deletions(-) diff --git a/source/module_hsolver/kernels/cuda/math_kernel_op.cu b/source/module_hsolver/kernels/cuda/math_kernel_op.cu index ed1b9379f9..c5a49b85e3 100644 --- a/source/module_hsolver/kernels/cuda/math_kernel_op.cu +++ b/source/module_hsolver/kernels/cuda/math_kernel_op.cu @@ -717,6 +717,26 @@ void axpy_op, base_device::DEVICE_GPU>::operator()(const ba cublasErrcheck(cublasZaxpy(cublas_handle, N, (double2*)alpha, (double2*)X, incX, (double2*)Y, incY)); } +cublasOperation_t judge_trans_op(bool is_complex, const char& trans, const char* name) +{ + if (trans == 'N') + { + return CUBLAS_OP_N; + } + else if(trans == 'T') + { + return CUBLAS_OP_T; + } + else if(is_complex && trans == 'C') + { + return CUBLAS_OP_C; + } + else + { + ModuleBase::WARNING_QUIT(name, std::string("Unknown trans type ") + trans + std::string(" !")); + } +} + template <> void gemv_op::operator()(const base_device::DEVICE_GPU* d, const char& trans, @@ -731,16 +751,7 @@ void gemv_op::operator()(const base_device::DEV double* Y, const int& incy) { - cublasOperation_t cutrans = {}; - if (trans == 'N') { - cutrans = CUBLAS_OP_N; - } - else if (trans == 'T') { - cutrans = CUBLAS_OP_T; - } - else { - ModuleBase::WARNING_QUIT("gemv_op", std::string("Unknown trans type ") + trans + std::string(" !")); - } + cublasOperation_t cutrans = judge_trans_op(false, trans, "gemv_op"); cublasErrcheck(cublasDgemv(cublas_handle, cutrans, m, n, alpha, A, lda, X, incx, beta, Y, incx)); } @@ -758,19 +769,7 @@ void gemv_op, base_device::DEVICE_GPU>::operator()(const bas std::complex* Y, const int& incy) { - cublasOperation_t cutrans = {}; - if (trans == 'N'){ - cutrans = CUBLAS_OP_N; - } - else if (trans == 'T'){ - cutrans = CUBLAS_OP_T; - } - else if (trans == 'C'){ - cutrans = CUBLAS_OP_C; - } - else { - ModuleBase::WARNING_QUIT("gemv_op", std::string("Unknown trans type ") + trans + std::string(" !")); - } + cublasOperation_t cutrans = judge_trans_op(true, trans, "gemv_op"); cublasErrcheck(cublasCgemv(cublas_handle, cutrans, m, n, (float2*)alpha, (float2*)A, lda, (float2*)X, incx, (float2*)beta, (float2*)Y, incx)); } @@ -788,19 +787,7 @@ void gemv_op, base_device::DEVICE_GPU>::operator()(const ba std::complex* Y, const int& incy) { - cublasOperation_t cutrans = {}; - if (trans == 'N'){ - cutrans = CUBLAS_OP_N; - } - else if (trans == 'T'){ - cutrans = CUBLAS_OP_T; - } - else if (trans == 'C'){ - cutrans = CUBLAS_OP_C; - } - else { - ModuleBase::WARNING_QUIT("gemv_op", std::string("Unknown trans type ") + trans + std::string(" !")); - } + cublasOperation_t cutrans = judge_trans_op(true, trans, "gemv_op"); cublasErrcheck(cublasZgemv(cublas_handle, cutrans, m, n, (double2*)alpha, (double2*)A, lda, (double2*)X, incx, (double2*)beta, (double2*)Y, incx)); } @@ -840,28 +827,8 @@ void gemm_op::operator()(const base_device::DEV double* c, const int& ldc) { - cublasOperation_t cutransA; - cublasOperation_t cutransB; - // cutransA - if (transa == 'N') { - cutransA = CUBLAS_OP_N; - } - else if (transa == 'T') { - cutransA = CUBLAS_OP_T; - } - else { - ModuleBase::WARNING_QUIT("gemm_op", std::string("Unknown transa type ") + transa + std::string(" !")); - } - // cutransB - if (transb == 'N') { - cutransB = CUBLAS_OP_N; - } - else if (transb == 'T') { - cutransB = CUBLAS_OP_T; - } - else { - ModuleBase::WARNING_QUIT("gemm_op", std::string("Unknown transb type ") + transb + std::string(" !")); - } + cublasOperation_t cutransA = judge_trans_op(false, transa, "gemm_op"); + cublasOperation_t cutransB = judge_trans_op(false, transb, "gemm_op"); cublasErrcheck(cublasDgemm(cublas_handle, cutransA, cutransB, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc)); } template <> @@ -880,34 +847,8 @@ void gemm_op, base_device::DEVICE_GPU>::operator()(const bas std::complex* c, const int& ldc) { - cublasOperation_t cutransA = {}; - cublasOperation_t cutransB = {}; - // cutransA - if (transa == 'N'){ - cutransA = CUBLAS_OP_N; - } - else if (transa == 'T'){ - cutransA = CUBLAS_OP_T; - } - else if (transa == 'C'){ - cutransA = CUBLAS_OP_C; - } - else { - ModuleBase::WARNING_QUIT("gemm_op", std::string("Unknown transa type ") + transa + std::string(" !")); - } - // cutransB - if (transb == 'N'){ - cutransB = CUBLAS_OP_N; - } - else if (transb == 'T'){ - cutransB = CUBLAS_OP_T; - } - else if (transb == 'C'){ - cutransB = CUBLAS_OP_C; - } - else { - ModuleBase::WARNING_QUIT("gemm_op", std::string("Unknown transb type ") + transb + std::string(" !")); - } + cublasOperation_t cutransA = judge_trans_op(true, transa, "gemm_op"); + cublasOperation_t cutransB = judge_trans_op(true, transb, "gemm_op"); cublasErrcheck(cublasCgemm(cublas_handle, cutransA, cutransB, m, n ,k, (float2*)alpha, (float2*)a , lda, (float2*)b, ldb, (float2*)beta, (float2*)c, ldc)); } @@ -927,34 +868,8 @@ void gemm_op, base_device::DEVICE_GPU>::operator()(const ba std::complex* c, const int& ldc) { - cublasOperation_t cutransA; - cublasOperation_t cutransB; - // cutransA - if (transa == 'N'){ - cutransA = CUBLAS_OP_N; - } - else if (transa == 'T'){ - cutransA = CUBLAS_OP_T; - } - else if (transa == 'C'){ - cutransA = CUBLAS_OP_C; - } - else { - ModuleBase::WARNING_QUIT("gemm_op", std::string("Unknown transa type ") + transa + std::string(" !")); - } - // cutransB - if (transb == 'N'){ - cutransB = CUBLAS_OP_N; - } - else if (transb == 'T'){ - cutransB = CUBLAS_OP_T; - } - else if (transb == 'C'){ - cutransB = CUBLAS_OP_C; - } - else { - ModuleBase::WARNING_QUIT("gemm_op", std::string("Unknown transb type ") + transb + std::string(" !")); - } + cublasOperation_t cutransA = judge_trans_op(true, transa, "gemm_op"); + cublasOperation_t cutransB = judge_trans_op(true, transb, "gemm_op"); cublasErrcheck(cublasZgemm(cublas_handle, cutransA, cutransB, m, n ,k, (double2*)alpha, (double2*)a , lda, (double2*)b, ldb, (double2*)beta, (double2*)c, ldc)); } diff --git a/source/module_hsolver/kernels/rocm/math_kernel_op.hip.cu b/source/module_hsolver/kernels/rocm/math_kernel_op.hip.cu index 89cf59f6fa..ef5a1c1ece 100644 --- a/source/module_hsolver/kernels/rocm/math_kernel_op.hip.cu +++ b/source/module_hsolver/kernels/rocm/math_kernel_op.hip.cu @@ -641,6 +641,26 @@ void axpy_op, base_device::DEVICE_GPU>::operator()(const ba hipblasErrcheck(hipblasZaxpy(cublas_handle, N, (hipblasDoubleComplex*)alpha, (hipblasDoubleComplex*)X, incX, (hipblasDoubleComplex*)Y, incY)); } +hipblasOperation_t judge_trans_op(bool is_complex, const char& trans, const char* name) +{ + if (trans == 'N') + { + return HIPBLAS_OP_N; + } + else if(trans == 'T') + { + return HIPBLAS_OP_T; + } + else if(is_complex && trans == 'C') + { + return HIPBLAS_OP_C; + } + else + { + ModuleBase::WARNING_QUIT(name, std::string("Unknown trans type ") + trans + std::string(" !")); + } +} + template <> void gemv_op::operator()(const base_device::DEVICE_GPU* d, const char& trans, @@ -655,19 +675,7 @@ void gemv_op::operator()(const base_device::DEV double* Y, const int& incy) { - hipblasOperation_t cutrans = {}; - if (trans == 'N') { - cutrans = HIPBLAS_OP_N; - } - else if (trans == 'T') { - cutrans = HIPBLAS_OP_T; - } - else if (trans == 'C') { - cutrans = HIPBLAS_OP_C; - } - else { - ModuleBase::WARNING_QUIT("gemv_op", std::string("Unknown trans type ") + trans + std::string(" !")); - } + hipblasOperation_t cutrans = judge_trans_op(false, trans, "gemv_op"); hipblasErrcheck(hipblasDgemv(cublas_handle, cutrans, m, n, alpha, A, lda, X, incx, beta, Y, incx)); } @@ -685,19 +693,7 @@ void gemv_op, base_device::DEVICE_GPU>::operator()(const bas std::complex* Y, const int& incy) { - hipblasOperation_t cutrans = {}; - if (trans == 'N') { - cutrans = HIPBLAS_OP_N; - } - else if (trans == 'T') { - cutrans = HIPBLAS_OP_T; - } - else if (trans == 'C') { - cutrans = HIPBLAS_OP_C; - } - else { - ModuleBase::WARNING_QUIT("gemv_op", std::string("Unknown trans type ") + trans + std::string(" !")); - } + hipblasOperation_t cutrans = judge_trans_op(true, trans, "gemv_op"); hipblasErrcheck(hipblasCgemv(cublas_handle, cutrans, m, n, (hipblasComplex*)alpha, (hipblasComplex*)A, lda, (hipblasComplex*)X, incx, (hipblasComplex*)beta, (hipblasComplex*)Y, incx)); } @@ -715,19 +711,7 @@ void gemv_op, base_device::DEVICE_GPU>::operator()(const ba std::complex* Y, const int& incy) { - hipblasOperation_t cutrans = {}; - if (trans == 'N'){ - cutrans = HIPBLAS_OP_N; - } - else if (trans == 'T'){ - cutrans = HIPBLAS_OP_T; - } - else if (trans == 'C'){ - cutrans = HIPBLAS_OP_C; - } - else { - ModuleBase::WARNING_QUIT("gemv_op", std::string("Unknown trans type ") + trans + std::string(" !")); - } + hipblasOperation_t cutrans = judge_trans_op(true, trans, "gemv_op"); hipblasErrcheck(hipblasZgemv(cublas_handle, cutrans, m, n, (hipblasDoubleComplex*)alpha, (hipblasDoubleComplex*)A, lda, (hipblasDoubleComplex*)X, incx, (hipblasDoubleComplex*)beta, (hipblasDoubleComplex*)Y, incx)); } @@ -767,28 +751,8 @@ void gemm_op::operator()(const base_device::DEV double* c, const int& ldc) { - hipblasOperation_t cutransA; - hipblasOperation_t cutransB; - // cutransA - if (transa == 'N') { - cutransA = HIPBLAS_OP_N; - } - else if (transa == 'T') { - cutransA = HIPBLAS_OP_T; - } - else { - ModuleBase::WARNING_QUIT("gemm_op", std::string("Unknown transa type ") + transa + std::string(" !")); - } - // cutransB - if (transb == 'N') { - cutransB = HIPBLAS_OP_N; - } - else if (transb == 'T') { - cutransB = HIPBLAS_OP_T; - } - else { - ModuleBase::WARNING_QUIT("gemm_op", std::string("Unknown transb type ") + transb + std::string(" !")); - } + hipblasOperation_t cutransA = judge_trans_op(false, transa, "gemm_op"); + hipblasOperation_t cutransB = judge_trans_op(false, transb, "gemm_op"); hipblasErrcheck(hipblasDgemm(cublas_handle, cutransA, cutransB, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc)); } @@ -808,34 +772,8 @@ void gemm_op, base_device::DEVICE_GPU>::operator()(const bas std::complex* c, const int& ldc) { - hipblasOperation_t cutransA = {}; - hipblasOperation_t cutransB = {}; - // cutransA - if (transa == 'N'){ - cutransA = HIPBLAS_OP_N; - } - else if (transa == 'T'){ - cutransA = HIPBLAS_OP_T; - } - else if (transa == 'C'){ - cutransA = HIPBLAS_OP_C; - } - else { - ModuleBase::WARNING_QUIT("gemm_op", std::string("Unknown transa type ") + transa + std::string(" !")); - } - // cutransB - if (transb == 'N'){ - cutransB = HIPBLAS_OP_N; - } - else if (transb == 'T'){ - cutransB = HIPBLAS_OP_T; - } - else if (transb == 'C'){ - cutransB = HIPBLAS_OP_C; - } - else { - ModuleBase::WARNING_QUIT("gemm_op", std::string("Unknown transb type ") + transb + std::string(" !")); - } + hipblasOperation_t cutransA = judge_trans_op(true, transa, "gemm_op"); + hipblasOperation_t cutransB = judge_trans_op(true, transb, "gemm_op"); hipblasErrcheck(hipblasCgemm(cublas_handle, cutransA, cutransB, m, n ,k, (hipblasComplex*)alpha, (hipblasComplex*)a , lda, (hipblasComplex*)b, ldb, (hipblasComplex*)beta, (hipblasComplex*)c, ldc)); } @@ -855,34 +793,8 @@ void gemm_op, base_device::DEVICE_GPU>::operator()(const ba std::complex* c, const int& ldc) { - hipblasOperation_t cutransA; - hipblasOperation_t cutransB; - // cutransA - if (transa == 'N'){ - cutransA = HIPBLAS_OP_N; - } - else if (transa == 'T'){ - cutransA = HIPBLAS_OP_T; - } - else if (transa == 'C'){ - cutransA = HIPBLAS_OP_C; - } - else { - ModuleBase::WARNING_QUIT("gemm_op", std::string("Unknown transa type ") + transa + std::string(" !")); - } - // cutransB - if (transb == 'N'){ - cutransB = HIPBLAS_OP_N; - } - else if (transb == 'T'){ - cutransB = HIPBLAS_OP_T; - } - else if (transb == 'C'){ - cutransB = HIPBLAS_OP_C; - } - else { - ModuleBase::WARNING_QUIT("gemm_op", std::string("Unknown transb type ") + transb + std::string(" !")); - } + hipblasOperation_t cutransA = judge_trans_op(true, transa, "gemm_op"); + hipblasOperation_t cutransB = judge_trans_op(true, transb, "gemm_op"); hipblasErrcheck(hipblasZgemm(cublas_handle, cutransA, cutransB, m, n ,k, (hipblasDoubleComplex*)alpha, (hipblasDoubleComplex*)a , lda, (hipblasDoubleComplex*)b, ldb, (hipblasDoubleComplex*)beta, (hipblasDoubleComplex*)c, ldc)); }