diff --git a/csrc/mmdeploy/backend_ops/onnxruntime/modulated_deform_conv/modulated_deform_conv.cpp b/csrc/mmdeploy/backend_ops/onnxruntime/modulated_deform_conv/modulated_deform_conv.cpp index 818e2fa6c2..3c4404de5f 100644 --- a/csrc/mmdeploy/backend_ops/onnxruntime/modulated_deform_conv/modulated_deform_conv.cpp +++ b/csrc/mmdeploy/backend_ops/onnxruntime/modulated_deform_conv/modulated_deform_conv.cpp @@ -2,6 +2,7 @@ #include "modulated_deform_conv.h" #include +#include #include #include "modulated_deform_conv/modulated_deform_conv_cpu.h" @@ -9,63 +10,40 @@ namespace mmdeploy { -void gemm_ref_fp32(const float *A, const float *B, const float *V, const float *H, - const int32_t trans_A, const int32_t trans_B, const int32_t M, const int32_t N, - const int32_t K, const float alpha, const float beta, float *Y) { - if (!trans_A && !trans_B) { // MK, KN; NN - for (int64_t m = 0; m < M; ++m) { - for (int64_t n = 0; n < N; ++n) { - float y = 0.0f; - for (int64_t k = 0; k < K; ++k) { - y += A[m * K + k] * B[k * N + n]; - } - y *= alpha; - if (V) y += beta * V[n]; - if (H) y += beta * H[m * N + n]; - Y[m * N + n] = y; - } +void parallel_unroll_gemm(const float *A, const float *B, const float *V, const float *H, + const int32_t M, const int32_t N, const int32_t K, const float alpha, + const float beta, float *Y, const int32_t start_row, + const int32_t end_row) { + float tmp[N]; // tmp + for (int32_t m = start_row; m < end_row; ++m) { + for (int32_t n = 0; n < N; n++) { + tmp[n] = 0; } - } - if (trans_A && !trans_B) { // KM, KN; TN - for (int64_t m = 0; m < M; ++m) { - for (int64_t n = 0; n < N; ++n) { - float y = 0.0f; - for (int64_t k = 0; k < K; ++k) { - y += A[k * M + m] * B[k * N + n]; + { + int32_t remainder = K % 8; // unroll + for (int32_t k = 0; k < K; k += 8) { + for (int32_t n = 0; n < N; n++) { + tmp[n] += A[m * K + k] * B[k * N + n]; + tmp[n] += A[m * K + k + 1] * B[k * N + N + n]; + tmp[n] += A[m * K + k + 2] * B[k * N + 2 * N + n]; + tmp[n] += A[m * K + k + 3] * B[k * N + 3 * N + n]; + tmp[n] += A[m * K + k + 4] * B[k * N + 4 * N + n]; + tmp[n] += A[m * K + k + 5] * B[k * N + 5 * N + n]; + tmp[n] += A[m * K + k + 6] * B[k * N + 6 * N + n]; + tmp[n] += A[m * K + k + 7] * B[k * N + 7 * N + n]; } - y *= alpha; - if (V) y += beta * V[n]; - if (H) y += beta * H[m * N + n]; - Y[m * N + n] = y; } - } - } - if (trans_A && trans_B) { // KM, NK; TT - for (int64_t m = 0; m < M; ++m) { - for (int64_t n = 0; n < N; ++n) { - float y = 0.0f; - for (int64_t k = 0; k < K; ++k) { - y += A[k * M + m] * B[n * K + k]; + for (int32_t k = K - remainder; k < K; k++) { + for (int32_t n = 0; n < N; n++) { + tmp[n] += A[m * K + k] * B[k * N + n]; } - y *= alpha; - if (V) y += beta * V[n]; - if (H) y += beta * H[m * N + n]; - Y[m * N + n] = y; } } - } - if (!trans_A && trans_B) { // MK, NK; NT - for (int64_t m = 0; m < M; ++m) { - for (int64_t n = 0; n < N; ++n) { - float y = 0.0f; - for (int64_t k = 0; k < K; ++k) { - y += A[m * K + k] * B[n * K + k]; - } - y *= alpha; - if (V) y += beta * V[n]; - if (H) y += beta * H[m * N + n]; - Y[m * N + n] = y; - } + for (int32_t n = 0; n < N; n++) { + tmp[n] *= alpha; + if (V) tmp[n] += beta * V[n]; + if (H) tmp[n] += beta * H[m * N + n]; + Y[m * N + n] = tmp[n]; } } } @@ -82,6 +60,10 @@ void deformable_conv2d_ref_fp32(const float *src, const float *offset, const flo const int64_t dilation_w, float *columns, float *dst) { const int64_t ic_per_gp = channels / group; const int64_t oc_per_gp = num_output / group; + // Set up for launching threads + std::size_t num_threads = std::thread::hardware_concurrency(); + std::vector threads; + threads.reserve(num_threads); for (int64_t b = 0; b < batch; ++b) { for (int64_t g = 0; g < group; ++g) { @@ -102,9 +84,28 @@ void deformable_conv2d_ref_fp32(const float *src, const float *offset, const flo } else { memset(dst_ptr, 0.0f, sizeof(float) * oc_per_gp * dst_h * dst_w); } - gemm_ref_fp32(filter + g * oc_per_gp * ic_per_gp * kernel_h * kernel_w, columns, nullptr, - dst_ptr, 0, 0, oc_per_gp, dst_h * dst_w, ic_per_gp * kernel_h * kernel_w, 1.0f, - 1.0f, dst_ptr); + if (num_threads > 1) { + // Calculate values to pass to threads + int32_t n_rows = (oc_per_gp + num_threads - 1) / num_threads; + int32_t end_row = 0; + for (int32_t i = 0; i < num_threads; i++) { + auto start_row = i * n_rows; + end_row = start_row + n_rows; + if (end_row > oc_per_gp) end_row = oc_per_gp; + std::thread t(parallel_unroll_gemm, + filter + g * oc_per_gp * ic_per_gp * kernel_h * kernel_w, columns, nullptr, + dst_ptr, oc_per_gp, dst_h * dst_w, ic_per_gp * kernel_h * kernel_w, 1.0f, + 1.0f, dst_ptr, start_row, end_row); + threads.emplace_back(std::move(t)); + } + // Wait for all threads to complete + for (auto &t : threads) t.join(); + threads.clear(); + } else { // parallel gemm degrade to serial gemm with start_row=0 and end_row= oc_per_gp + parallel_unroll_gemm(filter + g * oc_per_gp * ic_per_gp * kernel_h * kernel_w, columns, + nullptr, dst_ptr, oc_per_gp, dst_h * dst_w, + ic_per_gp * kernel_h * kernel_w, 1.0f, 1.0f, dst_ptr, 0, oc_per_gp); + } } } }