Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor: extracted transpose parameter checking routine of gemm/v #4279

Merged
merged 5 commits into from
Jun 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 29 additions & 114 deletions source/module_hsolver/kernels/cuda/math_kernel_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -717,6 +717,26 @@ void axpy_op<std::complex<double>, base_device::DEVICE_GPU>::operator()(const ba
cublasErrcheck(cublasZaxpy(cublas_handle, N, (double2*)alpha, (double2*)X, incX, (double2*)Y, incY));
}

cublasOperation_t judge_trans_op(bool is_complex, const char& trans, const char* name)
{
if (trans == 'N')
{
return CUBLAS_OP_N;
}
else if(trans == 'T')
{
return CUBLAS_OP_T;
}
else if(is_complex && trans == 'C')
{
return CUBLAS_OP_C;
}
else
{
ModuleBase::WARNING_QUIT(name, std::string("Unknown trans type ") + trans + std::string(" !"));
}
}

template <>
void gemv_op<double, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* d,
const char& trans,
Expand All @@ -731,16 +751,7 @@ void gemv_op<double, base_device::DEVICE_GPU>::operator()(const base_device::DEV
double* Y,
const int& incy)
{
cublasOperation_t cutrans = {};
if (trans == 'N') {
cutrans = CUBLAS_OP_N;
}
else if (trans == 'T') {
cutrans = CUBLAS_OP_T;
}
else {
ModuleBase::WARNING_QUIT("gemv_op", std::string("Unknown trans type ") + trans + std::string(" !"));
}
cublasOperation_t cutrans = judge_trans_op(false, trans, "gemv_op");
cublasErrcheck(cublasDgemv(cublas_handle, cutrans, m, n, alpha, A, lda, X, incx, beta, Y, incx));
}

Expand All @@ -758,19 +769,7 @@ void gemv_op<std::complex<float>, base_device::DEVICE_GPU>::operator()(const bas
std::complex<float>* Y,
const int& incy)
{
cublasOperation_t cutrans = {};
if (trans == 'N'){
cutrans = CUBLAS_OP_N;
}
else if (trans == 'T'){
cutrans = CUBLAS_OP_T;
}
else if (trans == 'C'){
cutrans = CUBLAS_OP_C;
}
else {
ModuleBase::WARNING_QUIT("gemv_op", std::string("Unknown trans type ") + trans + std::string(" !"));
}
cublasOperation_t cutrans = judge_trans_op(true, trans, "gemv_op");
cublasErrcheck(cublasCgemv(cublas_handle, cutrans, m, n, (float2*)alpha, (float2*)A, lda, (float2*)X, incx, (float2*)beta, (float2*)Y, incx));
}

Expand All @@ -788,19 +787,7 @@ void gemv_op<std::complex<double>, base_device::DEVICE_GPU>::operator()(const ba
std::complex<double>* Y,
const int& incy)
{
cublasOperation_t cutrans = {};
if (trans == 'N'){
cutrans = CUBLAS_OP_N;
}
else if (trans == 'T'){
cutrans = CUBLAS_OP_T;
}
else if (trans == 'C'){
cutrans = CUBLAS_OP_C;
}
else {
ModuleBase::WARNING_QUIT("gemv_op", std::string("Unknown trans type ") + trans + std::string(" !"));
}
cublasOperation_t cutrans = judge_trans_op(true, trans, "gemv_op");
cublasErrcheck(cublasZgemv(cublas_handle, cutrans, m, n, (double2*)alpha, (double2*)A, lda, (double2*)X, incx, (double2*)beta, (double2*)Y, incx));
}

Expand Down Expand Up @@ -840,28 +827,8 @@ void gemm_op<double, base_device::DEVICE_GPU>::operator()(const base_device::DEV
double* c,
const int& ldc)
{
cublasOperation_t cutransA;
cublasOperation_t cutransB;
// cutransA
if (transa == 'N') {
cutransA = CUBLAS_OP_N;
}
else if (transa == 'T') {
cutransA = CUBLAS_OP_T;
}
else {
ModuleBase::WARNING_QUIT("gemm_op", std::string("Unknown transa type ") + transa + std::string(" !"));
}
// cutransB
if (transb == 'N') {
cutransB = CUBLAS_OP_N;
}
else if (transb == 'T') {
cutransB = CUBLAS_OP_T;
}
else {
ModuleBase::WARNING_QUIT("gemm_op", std::string("Unknown transb type ") + transb + std::string(" !"));
}
cublasOperation_t cutransA = judge_trans_op(false, transa, "gemm_op");
cublasOperation_t cutransB = judge_trans_op(false, transb, "gemm_op");
cublasErrcheck(cublasDgemm(cublas_handle, cutransA, cutransB, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc));
}
template <>
Expand All @@ -880,34 +847,8 @@ void gemm_op<std::complex<float>, base_device::DEVICE_GPU>::operator()(const bas
std::complex<float>* c,
const int& ldc)
{
cublasOperation_t cutransA = {};
cublasOperation_t cutransB = {};
// cutransA
if (transa == 'N'){
cutransA = CUBLAS_OP_N;
}
else if (transa == 'T'){
cutransA = CUBLAS_OP_T;
}
else if (transa == 'C'){
cutransA = CUBLAS_OP_C;
}
else {
ModuleBase::WARNING_QUIT("gemm_op", std::string("Unknown transa type ") + transa + std::string(" !"));
}
// cutransB
if (transb == 'N'){
cutransB = CUBLAS_OP_N;
}
else if (transb == 'T'){
cutransB = CUBLAS_OP_T;
}
else if (transb == 'C'){
cutransB = CUBLAS_OP_C;
}
else {
ModuleBase::WARNING_QUIT("gemm_op", std::string("Unknown transb type ") + transb + std::string(" !"));
}
cublasOperation_t cutransA = judge_trans_op(true, transa, "gemm_op");
cublasOperation_t cutransB = judge_trans_op(true, transb, "gemm_op");
cublasErrcheck(cublasCgemm(cublas_handle, cutransA, cutransB, m, n ,k, (float2*)alpha, (float2*)a , lda, (float2*)b, ldb, (float2*)beta, (float2*)c, ldc));
}

Expand All @@ -927,34 +868,8 @@ void gemm_op<std::complex<double>, base_device::DEVICE_GPU>::operator()(const ba
std::complex<double>* c,
const int& ldc)
{
cublasOperation_t cutransA;
cublasOperation_t cutransB;
// cutransA
if (transa == 'N'){
cutransA = CUBLAS_OP_N;
}
else if (transa == 'T'){
cutransA = CUBLAS_OP_T;
}
else if (transa == 'C'){
cutransA = CUBLAS_OP_C;
}
else {
ModuleBase::WARNING_QUIT("gemm_op", std::string("Unknown transa type ") + transa + std::string(" !"));
}
// cutransB
if (transb == 'N'){
cutransB = CUBLAS_OP_N;
}
else if (transb == 'T'){
cutransB = CUBLAS_OP_T;
}
else if (transb == 'C'){
cutransB = CUBLAS_OP_C;
}
else {
ModuleBase::WARNING_QUIT("gemm_op", std::string("Unknown transb type ") + transb + std::string(" !"));
}
cublasOperation_t cutransA = judge_trans_op(true, transa, "gemm_op");
cublasOperation_t cutransB = judge_trans_op(true, transb, "gemm_op");
cublasErrcheck(cublasZgemm(cublas_handle, cutransA, cutransB, m, n ,k, (double2*)alpha, (double2*)a , lda, (double2*)b, ldb, (double2*)beta, (double2*)c, ldc));
}

Expand Down
146 changes: 29 additions & 117 deletions source/module_hsolver/kernels/rocm/math_kernel_op.hip.cu
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,26 @@ void axpy_op<std::complex<double>, base_device::DEVICE_GPU>::operator()(const ba
hipblasErrcheck(hipblasZaxpy(cublas_handle, N, (hipblasDoubleComplex*)alpha, (hipblasDoubleComplex*)X, incX, (hipblasDoubleComplex*)Y, incY));
}

hipblasOperation_t judge_trans_op(bool is_complex, const char& trans, const char* name)
{
if (trans == 'N')
{
return HIPBLAS_OP_N;
}
else if(trans == 'T')
{
return HIPBLAS_OP_T;
}
else if(is_complex && trans == 'C')
{
return HIPBLAS_OP_C;
}
else
{
ModuleBase::WARNING_QUIT(name, std::string("Unknown trans type ") + trans + std::string(" !"));
}
}

template <>
void gemv_op<double, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* d,
const char& trans,
Expand All @@ -655,19 +675,7 @@ void gemv_op<double, base_device::DEVICE_GPU>::operator()(const base_device::DEV
double* Y,
const int& incy)
{
hipblasOperation_t cutrans = {};
if (trans == 'N') {
cutrans = HIPBLAS_OP_N;
}
else if (trans == 'T') {
cutrans = HIPBLAS_OP_T;
}
else if (trans == 'C') {
cutrans = HIPBLAS_OP_C;
}
else {
ModuleBase::WARNING_QUIT("gemv_op", std::string("Unknown trans type ") + trans + std::string(" !"));
}
hipblasOperation_t cutrans = judge_trans_op(false, trans, "gemv_op");
hipblasErrcheck(hipblasDgemv(cublas_handle, cutrans, m, n, alpha, A, lda, X, incx, beta, Y, incx));
}

Expand All @@ -685,19 +693,7 @@ void gemv_op<std::complex<float>, base_device::DEVICE_GPU>::operator()(const bas
std::complex<float>* Y,
const int& incy)
{
hipblasOperation_t cutrans = {};
if (trans == 'N') {
cutrans = HIPBLAS_OP_N;
}
else if (trans == 'T') {
cutrans = HIPBLAS_OP_T;
}
else if (trans == 'C') {
cutrans = HIPBLAS_OP_C;
}
else {
ModuleBase::WARNING_QUIT("gemv_op", std::string("Unknown trans type ") + trans + std::string(" !"));
}
hipblasOperation_t cutrans = judge_trans_op(true, trans, "gemv_op");
hipblasErrcheck(hipblasCgemv(cublas_handle, cutrans, m, n, (hipblasComplex*)alpha, (hipblasComplex*)A, lda, (hipblasComplex*)X, incx, (hipblasComplex*)beta, (hipblasComplex*)Y, incx));
}

Expand All @@ -715,19 +711,7 @@ void gemv_op<std::complex<double>, base_device::DEVICE_GPU>::operator()(const ba
std::complex<double>* Y,
const int& incy)
{
hipblasOperation_t cutrans = {};
if (trans == 'N'){
cutrans = HIPBLAS_OP_N;
}
else if (trans == 'T'){
cutrans = HIPBLAS_OP_T;
}
else if (trans == 'C'){
cutrans = HIPBLAS_OP_C;
}
else {
ModuleBase::WARNING_QUIT("gemv_op", std::string("Unknown trans type ") + trans + std::string(" !"));
}
hipblasOperation_t cutrans = judge_trans_op(true, trans, "gemv_op");
hipblasErrcheck(hipblasZgemv(cublas_handle, cutrans, m, n, (hipblasDoubleComplex*)alpha, (hipblasDoubleComplex*)A, lda, (hipblasDoubleComplex*)X, incx, (hipblasDoubleComplex*)beta, (hipblasDoubleComplex*)Y, incx));
}

Expand Down Expand Up @@ -767,28 +751,8 @@ void gemm_op<double, base_device::DEVICE_GPU>::operator()(const base_device::DEV
double* c,
const int& ldc)
{
hipblasOperation_t cutransA;
hipblasOperation_t cutransB;
// cutransA
if (transa == 'N') {
cutransA = HIPBLAS_OP_N;
}
else if (transa == 'T') {
cutransA = HIPBLAS_OP_T;
}
else {
ModuleBase::WARNING_QUIT("gemm_op", std::string("Unknown transa type ") + transa + std::string(" !"));
}
// cutransB
if (transb == 'N') {
cutransB = HIPBLAS_OP_N;
}
else if (transb == 'T') {
cutransB = HIPBLAS_OP_T;
}
else {
ModuleBase::WARNING_QUIT("gemm_op", std::string("Unknown transb type ") + transb + std::string(" !"));
}
hipblasOperation_t cutransA = judge_trans_op(false, transa, "gemm_op");
hipblasOperation_t cutransB = judge_trans_op(false, transb, "gemm_op");
hipblasErrcheck(hipblasDgemm(cublas_handle, cutransA, cutransB, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc));
}

Expand All @@ -808,34 +772,8 @@ void gemm_op<std::complex<float>, base_device::DEVICE_GPU>::operator()(const bas
std::complex<float>* c,
const int& ldc)
{
hipblasOperation_t cutransA = {};
hipblasOperation_t cutransB = {};
// cutransA
if (transa == 'N'){
cutransA = HIPBLAS_OP_N;
}
else if (transa == 'T'){
cutransA = HIPBLAS_OP_T;
}
else if (transa == 'C'){
cutransA = HIPBLAS_OP_C;
}
else {
ModuleBase::WARNING_QUIT("gemm_op", std::string("Unknown transa type ") + transa + std::string(" !"));
}
// cutransB
if (transb == 'N'){
cutransB = HIPBLAS_OP_N;
}
else if (transb == 'T'){
cutransB = HIPBLAS_OP_T;
}
else if (transb == 'C'){
cutransB = HIPBLAS_OP_C;
}
else {
ModuleBase::WARNING_QUIT("gemm_op", std::string("Unknown transb type ") + transb + std::string(" !"));
}
hipblasOperation_t cutransA = judge_trans_op(true, transa, "gemm_op");
hipblasOperation_t cutransB = judge_trans_op(true, transb, "gemm_op");
hipblasErrcheck(hipblasCgemm(cublas_handle, cutransA, cutransB, m, n ,k, (hipblasComplex*)alpha, (hipblasComplex*)a , lda, (hipblasComplex*)b, ldb, (hipblasComplex*)beta, (hipblasComplex*)c, ldc));
}

Expand All @@ -855,34 +793,8 @@ void gemm_op<std::complex<double>, base_device::DEVICE_GPU>::operator()(const ba
std::complex<double>* c,
const int& ldc)
{
hipblasOperation_t cutransA;
hipblasOperation_t cutransB;
// cutransA
if (transa == 'N'){
cutransA = HIPBLAS_OP_N;
}
else if (transa == 'T'){
cutransA = HIPBLAS_OP_T;
}
else if (transa == 'C'){
cutransA = HIPBLAS_OP_C;
}
else {
ModuleBase::WARNING_QUIT("gemm_op", std::string("Unknown transa type ") + transa + std::string(" !"));
}
// cutransB
if (transb == 'N'){
cutransB = HIPBLAS_OP_N;
}
else if (transb == 'T'){
cutransB = HIPBLAS_OP_T;
}
else if (transb == 'C'){
cutransB = HIPBLAS_OP_C;
}
else {
ModuleBase::WARNING_QUIT("gemm_op", std::string("Unknown transb type ") + transb + std::string(" !"));
}
hipblasOperation_t cutransA = judge_trans_op(true, transa, "gemm_op");
hipblasOperation_t cutransB = judge_trans_op(true, transb, "gemm_op");
hipblasErrcheck(hipblasZgemm(cublas_handle, cutransA, cutransB, m, n ,k, (hipblasDoubleComplex*)alpha, (hipblasDoubleComplex*)a , lda, (hipblasDoubleComplex*)b, ldb, (hipblasDoubleComplex*)beta, (hipblasDoubleComplex*)c, ldc));
}

Expand Down
Loading