Skip to content

Commit

Permalink
Refactor: extracted transpose parameter checking routine of gemm/v (#…
Browse files Browse the repository at this point in the history
…4279)

* Refactor: extracted transpose parameter checking routine of gemm and gemv to function

* Fix: Corrected error log when provided with wrong param

* Fix: Correct ROCm code returning type and variable name

---------

Co-authored-by: Mohan Chen <[email protected]>
  • Loading branch information
OldDriver233 and mohanchen authored Jun 17, 2024
1 parent bc62cd2 commit 91b3d87
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 231 deletions.
143 changes: 29 additions & 114 deletions source/module_hsolver/kernels/cuda/math_kernel_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -717,6 +717,26 @@ void axpy_op<std::complex<double>, base_device::DEVICE_GPU>::operator()(const ba
cublasErrcheck(cublasZaxpy(cublas_handle, N, (double2*)alpha, (double2*)X, incX, (double2*)Y, incY));
}

cublasOperation_t judge_trans_op(bool is_complex, const char& trans, const char* name)
{
if (trans == 'N')
{
return CUBLAS_OP_N;
}
else if(trans == 'T')
{
return CUBLAS_OP_T;
}
else if(is_complex && trans == 'C')
{
return CUBLAS_OP_C;
}
else
{
ModuleBase::WARNING_QUIT(name, std::string("Unknown trans type ") + trans + std::string(" !"));
}
}

template <>
void gemv_op<double, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* d,
const char& trans,
Expand All @@ -731,16 +751,7 @@ void gemv_op<double, base_device::DEVICE_GPU>::operator()(const base_device::DEV
double* Y,
const int& incy)
{
cublasOperation_t cutrans = {};
if (trans == 'N') {
cutrans = CUBLAS_OP_N;
}
else if (trans == 'T') {
cutrans = CUBLAS_OP_T;
}
else {
ModuleBase::WARNING_QUIT("gemv_op", std::string("Unknown trans type ") + trans + std::string(" !"));
}
cublasOperation_t cutrans = judge_trans_op(false, trans, "gemv_op");
cublasErrcheck(cublasDgemv(cublas_handle, cutrans, m, n, alpha, A, lda, X, incx, beta, Y, incx));
}

Expand All @@ -758,19 +769,7 @@ void gemv_op<std::complex<float>, base_device::DEVICE_GPU>::operator()(const bas
std::complex<float>* Y,
const int& incy)
{
cublasOperation_t cutrans = {};
if (trans == 'N'){
cutrans = CUBLAS_OP_N;
}
else if (trans == 'T'){
cutrans = CUBLAS_OP_T;
}
else if (trans == 'C'){
cutrans = CUBLAS_OP_C;
}
else {
ModuleBase::WARNING_QUIT("gemv_op", std::string("Unknown trans type ") + trans + std::string(" !"));
}
cublasOperation_t cutrans = judge_trans_op(true, trans, "gemv_op");
cublasErrcheck(cublasCgemv(cublas_handle, cutrans, m, n, (float2*)alpha, (float2*)A, lda, (float2*)X, incx, (float2*)beta, (float2*)Y, incx));
}

Expand All @@ -788,19 +787,7 @@ void gemv_op<std::complex<double>, base_device::DEVICE_GPU>::operator()(const ba
std::complex<double>* Y,
const int& incy)
{
cublasOperation_t cutrans = {};
if (trans == 'N'){
cutrans = CUBLAS_OP_N;
}
else if (trans == 'T'){
cutrans = CUBLAS_OP_T;
}
else if (trans == 'C'){
cutrans = CUBLAS_OP_C;
}
else {
ModuleBase::WARNING_QUIT("gemv_op", std::string("Unknown trans type ") + trans + std::string(" !"));
}
cublasOperation_t cutrans = judge_trans_op(true, trans, "gemv_op");
cublasErrcheck(cublasZgemv(cublas_handle, cutrans, m, n, (double2*)alpha, (double2*)A, lda, (double2*)X, incx, (double2*)beta, (double2*)Y, incx));
}

Expand Down Expand Up @@ -840,28 +827,8 @@ void gemm_op<double, base_device::DEVICE_GPU>::operator()(const base_device::DEV
double* c,
const int& ldc)
{
cublasOperation_t cutransA;
cublasOperation_t cutransB;
// cutransA
if (transa == 'N') {
cutransA = CUBLAS_OP_N;
}
else if (transa == 'T') {
cutransA = CUBLAS_OP_T;
}
else {
ModuleBase::WARNING_QUIT("gemm_op", std::string("Unknown transa type ") + transa + std::string(" !"));
}
// cutransB
if (transb == 'N') {
cutransB = CUBLAS_OP_N;
}
else if (transb == 'T') {
cutransB = CUBLAS_OP_T;
}
else {
ModuleBase::WARNING_QUIT("gemm_op", std::string("Unknown transb type ") + transb + std::string(" !"));
}
cublasOperation_t cutransA = judge_trans_op(false, transa, "gemm_op");
cublasOperation_t cutransB = judge_trans_op(false, transb, "gemm_op");
cublasErrcheck(cublasDgemm(cublas_handle, cutransA, cutransB, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc));
}
template <>
Expand All @@ -880,34 +847,8 @@ void gemm_op<std::complex<float>, base_device::DEVICE_GPU>::operator()(const bas
std::complex<float>* c,
const int& ldc)
{
cublasOperation_t cutransA = {};
cublasOperation_t cutransB = {};
// cutransA
if (transa == 'N'){
cutransA = CUBLAS_OP_N;
}
else if (transa == 'T'){
cutransA = CUBLAS_OP_T;
}
else if (transa == 'C'){
cutransA = CUBLAS_OP_C;
}
else {
ModuleBase::WARNING_QUIT("gemm_op", std::string("Unknown transa type ") + transa + std::string(" !"));
}
// cutransB
if (transb == 'N'){
cutransB = CUBLAS_OP_N;
}
else if (transb == 'T'){
cutransB = CUBLAS_OP_T;
}
else if (transb == 'C'){
cutransB = CUBLAS_OP_C;
}
else {
ModuleBase::WARNING_QUIT("gemm_op", std::string("Unknown transb type ") + transb + std::string(" !"));
}
cublasOperation_t cutransA = judge_trans_op(true, transa, "gemm_op");
cublasOperation_t cutransB = judge_trans_op(true, transb, "gemm_op");
cublasErrcheck(cublasCgemm(cublas_handle, cutransA, cutransB, m, n ,k, (float2*)alpha, (float2*)a , lda, (float2*)b, ldb, (float2*)beta, (float2*)c, ldc));
}

Expand All @@ -927,34 +868,8 @@ void gemm_op<std::complex<double>, base_device::DEVICE_GPU>::operator()(const ba
std::complex<double>* c,
const int& ldc)
{
cublasOperation_t cutransA;
cublasOperation_t cutransB;
// cutransA
if (transa == 'N'){
cutransA = CUBLAS_OP_N;
}
else if (transa == 'T'){
cutransA = CUBLAS_OP_T;
}
else if (transa == 'C'){
cutransA = CUBLAS_OP_C;
}
else {
ModuleBase::WARNING_QUIT("gemm_op", std::string("Unknown transa type ") + transa + std::string(" !"));
}
// cutransB
if (transb == 'N'){
cutransB = CUBLAS_OP_N;
}
else if (transb == 'T'){
cutransB = CUBLAS_OP_T;
}
else if (transb == 'C'){
cutransB = CUBLAS_OP_C;
}
else {
ModuleBase::WARNING_QUIT("gemm_op", std::string("Unknown transb type ") + transb + std::string(" !"));
}
cublasOperation_t cutransA = judge_trans_op(true, transa, "gemm_op");
cublasOperation_t cutransB = judge_trans_op(true, transb, "gemm_op");
cublasErrcheck(cublasZgemm(cublas_handle, cutransA, cutransB, m, n ,k, (double2*)alpha, (double2*)a , lda, (double2*)b, ldb, (double2*)beta, (double2*)c, ldc));
}

Expand Down
146 changes: 29 additions & 117 deletions source/module_hsolver/kernels/rocm/math_kernel_op.hip.cu
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,26 @@ void axpy_op<std::complex<double>, base_device::DEVICE_GPU>::operator()(const ba
hipblasErrcheck(hipblasZaxpy(cublas_handle, N, (hipblasDoubleComplex*)alpha, (hipblasDoubleComplex*)X, incX, (hipblasDoubleComplex*)Y, incY));
}

hipblasOperation_t judge_trans_op(bool is_complex, const char& trans, const char* name)
{
if (trans == 'N')
{
return HIPBLAS_OP_N;
}
else if(trans == 'T')
{
return HIPBLAS_OP_T;
}
else if(is_complex && trans == 'C')
{
return HIPBLAS_OP_C;
}
else
{
ModuleBase::WARNING_QUIT(name, std::string("Unknown trans type ") + trans + std::string(" !"));
}
}

template <>
void gemv_op<double, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* d,
const char& trans,
Expand All @@ -655,19 +675,7 @@ void gemv_op<double, base_device::DEVICE_GPU>::operator()(const base_device::DEV
double* Y,
const int& incy)
{
hipblasOperation_t cutrans = {};
if (trans == 'N') {
cutrans = HIPBLAS_OP_N;
}
else if (trans == 'T') {
cutrans = HIPBLAS_OP_T;
}
else if (trans == 'C') {
cutrans = HIPBLAS_OP_C;
}
else {
ModuleBase::WARNING_QUIT("gemv_op", std::string("Unknown trans type ") + trans + std::string(" !"));
}
hipblasOperation_t cutrans = judge_trans_op(false, trans, "gemv_op");
hipblasErrcheck(hipblasDgemv(cublas_handle, cutrans, m, n, alpha, A, lda, X, incx, beta, Y, incx));
}

Expand All @@ -685,19 +693,7 @@ void gemv_op<std::complex<float>, base_device::DEVICE_GPU>::operator()(const bas
std::complex<float>* Y,
const int& incy)
{
hipblasOperation_t cutrans = {};
if (trans == 'N') {
cutrans = HIPBLAS_OP_N;
}
else if (trans == 'T') {
cutrans = HIPBLAS_OP_T;
}
else if (trans == 'C') {
cutrans = HIPBLAS_OP_C;
}
else {
ModuleBase::WARNING_QUIT("gemv_op", std::string("Unknown trans type ") + trans + std::string(" !"));
}
hipblasOperation_t cutrans = judge_trans_op(true, trans, "gemv_op");
hipblasErrcheck(hipblasCgemv(cublas_handle, cutrans, m, n, (hipblasComplex*)alpha, (hipblasComplex*)A, lda, (hipblasComplex*)X, incx, (hipblasComplex*)beta, (hipblasComplex*)Y, incx));
}

Expand All @@ -715,19 +711,7 @@ void gemv_op<std::complex<double>, base_device::DEVICE_GPU>::operator()(const ba
std::complex<double>* Y,
const int& incy)
{
hipblasOperation_t cutrans = {};
if (trans == 'N'){
cutrans = HIPBLAS_OP_N;
}
else if (trans == 'T'){
cutrans = HIPBLAS_OP_T;
}
else if (trans == 'C'){
cutrans = HIPBLAS_OP_C;
}
else {
ModuleBase::WARNING_QUIT("gemv_op", std::string("Unknown trans type ") + trans + std::string(" !"));
}
hipblasOperation_t cutrans = judge_trans_op(true, trans, "gemv_op");
hipblasErrcheck(hipblasZgemv(cublas_handle, cutrans, m, n, (hipblasDoubleComplex*)alpha, (hipblasDoubleComplex*)A, lda, (hipblasDoubleComplex*)X, incx, (hipblasDoubleComplex*)beta, (hipblasDoubleComplex*)Y, incx));
}

Expand Down Expand Up @@ -767,28 +751,8 @@ void gemm_op<double, base_device::DEVICE_GPU>::operator()(const base_device::DEV
double* c,
const int& ldc)
{
hipblasOperation_t cutransA;
hipblasOperation_t cutransB;
// cutransA
if (transa == 'N') {
cutransA = HIPBLAS_OP_N;
}
else if (transa == 'T') {
cutransA = HIPBLAS_OP_T;
}
else {
ModuleBase::WARNING_QUIT("gemm_op", std::string("Unknown transa type ") + transa + std::string(" !"));
}
// cutransB
if (transb == 'N') {
cutransB = HIPBLAS_OP_N;
}
else if (transb == 'T') {
cutransB = HIPBLAS_OP_T;
}
else {
ModuleBase::WARNING_QUIT("gemm_op", std::string("Unknown transb type ") + transb + std::string(" !"));
}
hipblasOperation_t cutransA = judge_trans_op(false, transa, "gemm_op");
hipblasOperation_t cutransB = judge_trans_op(false, transb, "gemm_op");
hipblasErrcheck(hipblasDgemm(cublas_handle, cutransA, cutransB, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc));
}

Expand All @@ -808,34 +772,8 @@ void gemm_op<std::complex<float>, base_device::DEVICE_GPU>::operator()(const bas
std::complex<float>* c,
const int& ldc)
{
hipblasOperation_t cutransA = {};
hipblasOperation_t cutransB = {};
// cutransA
if (transa == 'N'){
cutransA = HIPBLAS_OP_N;
}
else if (transa == 'T'){
cutransA = HIPBLAS_OP_T;
}
else if (transa == 'C'){
cutransA = HIPBLAS_OP_C;
}
else {
ModuleBase::WARNING_QUIT("gemm_op", std::string("Unknown transa type ") + transa + std::string(" !"));
}
// cutransB
if (transb == 'N'){
cutransB = HIPBLAS_OP_N;
}
else if (transb == 'T'){
cutransB = HIPBLAS_OP_T;
}
else if (transb == 'C'){
cutransB = HIPBLAS_OP_C;
}
else {
ModuleBase::WARNING_QUIT("gemm_op", std::string("Unknown transb type ") + transb + std::string(" !"));
}
hipblasOperation_t cutransA = judge_trans_op(true, transa, "gemm_op");
hipblasOperation_t cutransB = judge_trans_op(true, transb, "gemm_op");
hipblasErrcheck(hipblasCgemm(cublas_handle, cutransA, cutransB, m, n ,k, (hipblasComplex*)alpha, (hipblasComplex*)a , lda, (hipblasComplex*)b, ldb, (hipblasComplex*)beta, (hipblasComplex*)c, ldc));
}

Expand All @@ -855,34 +793,8 @@ void gemm_op<std::complex<double>, base_device::DEVICE_GPU>::operator()(const ba
std::complex<double>* c,
const int& ldc)
{
hipblasOperation_t cutransA;
hipblasOperation_t cutransB;
// cutransA
if (transa == 'N'){
cutransA = HIPBLAS_OP_N;
}
else if (transa == 'T'){
cutransA = HIPBLAS_OP_T;
}
else if (transa == 'C'){
cutransA = HIPBLAS_OP_C;
}
else {
ModuleBase::WARNING_QUIT("gemm_op", std::string("Unknown transa type ") + transa + std::string(" !"));
}
// cutransB
if (transb == 'N'){
cutransB = HIPBLAS_OP_N;
}
else if (transb == 'T'){
cutransB = HIPBLAS_OP_T;
}
else if (transb == 'C'){
cutransB = HIPBLAS_OP_C;
}
else {
ModuleBase::WARNING_QUIT("gemm_op", std::string("Unknown transb type ") + transb + std::string(" !"));
}
hipblasOperation_t cutransA = judge_trans_op(true, transa, "gemm_op");
hipblasOperation_t cutransB = judge_trans_op(true, transb, "gemm_op");
hipblasErrcheck(hipblasZgemm(cublas_handle, cutransA, cutransB, m, n ,k, (hipblasDoubleComplex*)alpha, (hipblasDoubleComplex*)a , lda, (hipblasDoubleComplex*)b, ldb, (hipblasDoubleComplex*)beta, (hipblasDoubleComplex*)c, ldc));
}

Expand Down

0 comments on commit 91b3d87

Please sign in to comment.