Skip to content

Commit

Permalink
rocblas trmm now also takes 3 matrices
Browse files Browse the repository at this point in the history
  • Loading branch information
simonpintarelli committed Sep 20, 2024
1 parent fdf35bb commit 4e026ac
Showing 1 changed file with 0 additions and 28 deletions.
28 changes: 0 additions & 28 deletions src/core/acc/acc_blas.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -262,14 +262,8 @@ strmm(char side__, char uplo__, char transa__, char diag__, int m__, int n__, fl
acc::blas_api::operation_t transa = get_gpublasOperation_t(transa__);
acc::blas_api::diagonal_t diag = get_gpublasDiagonal_t(diag__);
// acc::set_device();
#ifdef SIRIUS_CUDA
CALL_GPU_BLAS(acc::blas_api::strmm, (stream_handle(stream_id), side, uplo, transa, diag, m__, n__, alpha__, A__,
lda__, B__, ldb__, B__, ldb__));
#else
// rocblas trmm function does not take three matrices
CALL_GPU_BLAS(acc::blas_api::strmm,
(stream_handle(stream_id), side, uplo, transa, diag, m__, n__, alpha__, A__, lda__, B__, ldb__));
#endif
}

inline void
Expand All @@ -281,14 +275,8 @@ dtrmm(char side__, char uplo__, char transa__, char diag__, int m__, int n__, do
acc::blas_api::operation_t transa = get_gpublasOperation_t(transa__);
acc::blas_api::diagonal_t diag = get_gpublasDiagonal_t(diag__);
// acc::set_device();
#ifdef SIRIUS_CUDA
CALL_GPU_BLAS(acc::blas_api::dtrmm, (stream_handle(stream_id), side, uplo, transa, diag, m__, n__, alpha__, A__,
lda__, B__, ldb__, B__, ldb__));
#else
// rocblas trmm function does not take three matrices
CALL_GPU_BLAS(acc::blas_api::dtrmm,
(stream_handle(stream_id), side, uplo, transa, diag, m__, n__, alpha__, A__, lda__, B__, ldb__));
#endif
}

inline void
Expand All @@ -300,19 +288,11 @@ ctrmm(char side__, char uplo__, char transa__, char diag__, int m__, int n__, ac
acc::blas_api::operation_t transa = get_gpublasOperation_t(transa__);
acc::blas_api::diagonal_t diag = get_gpublasDiagonal_t(diag__);
// acc::set_device();
#ifdef SIRIUS_CUDA
CALL_GPU_BLAS(acc::blas_api::ctrmm, (stream_handle(stream_id), side, uplo, transa, diag, m__, n__,
reinterpret_cast<const acc::blas_api::complex_float_t*>(alpha__),
reinterpret_cast<const acc::blas_api::complex_float_t*>(A__), lda__,
reinterpret_cast<acc::blas_api::complex_float_t*>(B__), ldb__,
reinterpret_cast<acc::blas_api::complex_float_t*>(B__), ldb__));
#else
// rocblas trmm function does not take three matrices
CALL_GPU_BLAS(acc::blas_api::ctrmm, (stream_handle(stream_id), side, uplo, transa, diag, m__, n__,
reinterpret_cast<const acc::blas_api::complex_float_t*>(alpha__),
reinterpret_cast<const acc::blas_api::complex_float_t*>(A__), lda__,
reinterpret_cast<acc::blas_api::complex_float_t*>(B__), ldb__));
#endif
}

inline void
Expand All @@ -324,19 +304,11 @@ ztrmm(char side__, char uplo__, char transa__, char diag__, int m__, int n__, ac
acc::blas_api::operation_t transa = get_gpublasOperation_t(transa__);
acc::blas_api::diagonal_t diag = get_gpublasDiagonal_t(diag__);
// acc::set_device();
#ifdef SIRIUS_CUDA
CALL_GPU_BLAS(acc::blas_api::ztrmm, (stream_handle(stream_id), side, uplo, transa, diag, m__, n__,
reinterpret_cast<const acc::blas_api::complex_double_t*>(alpha__),
reinterpret_cast<const acc::blas_api::complex_double_t*>(A__), lda__,
reinterpret_cast<acc::blas_api::complex_double_t*>(B__), ldb__,
reinterpret_cast<acc::blas_api::complex_double_t*>(B__), ldb__));
#else
// rocblas trmm function does not take three matrices
CALL_GPU_BLAS(acc::blas_api::ztrmm, (stream_handle(stream_id), side, uplo, transa, diag, m__, n__,
reinterpret_cast<const acc::blas_api::complex_double_t*>(alpha__),
reinterpret_cast<const acc::blas_api::complex_double_t*>(A__), lda__,
reinterpret_cast<acc::blas_api::complex_double_t*>(B__), ldb__));
#endif
}

inline void
Expand Down

0 comments on commit 4e026ac

Please sign in to comment.