Skip to content

Commit

Permalink
refine gemm in ort mdcn (#2292)
Browse files Browse the repository at this point in the history
* refine gemm in ort mdcn

* int64_t -> int32_t
  • Loading branch information
AllentDan committed Sep 5, 2023
1 parent 468c423 commit 6fdf459
Showing 1 changed file with 55 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,70 +2,48 @@
#include "modulated_deform_conv.h"

#include <cmath>
#include <thread>
#include <vector>

#include "modulated_deform_conv/modulated_deform_conv_cpu.h"
#include "ort_utils.h"

namespace mmdeploy {

void gemm_ref_fp32(const float *A, const float *B, const float *V, const float *H,
const int32_t trans_A, const int32_t trans_B, const int32_t M, const int32_t N,
const int32_t K, const float alpha, const float beta, float *Y) {
if (!trans_A && !trans_B) { // MK, KN; NN
for (int64_t m = 0; m < M; ++m) {
for (int64_t n = 0; n < N; ++n) {
float y = 0.0f;
for (int64_t k = 0; k < K; ++k) {
y += A[m * K + k] * B[k * N + n];
}
y *= alpha;
if (V) y += beta * V[n];
if (H) y += beta * H[m * N + n];
Y[m * N + n] = y;
}
void parallel_unroll_gemm(const float *A, const float *B, const float *V, const float *H,
const int32_t M, const int32_t N, const int32_t K, const float alpha,
const float beta, float *Y, const int32_t start_row,
const int32_t end_row) {
float tmp[N]; // tmp
for (int32_t m = start_row; m < end_row; ++m) {
for (int32_t n = 0; n < N; n++) {
tmp[n] = 0;
}
}
if (trans_A && !trans_B) { // KM, KN; TN
for (int64_t m = 0; m < M; ++m) {
for (int64_t n = 0; n < N; ++n) {
float y = 0.0f;
for (int64_t k = 0; k < K; ++k) {
y += A[k * M + m] * B[k * N + n];
{
int32_t remainder = K % 8; // unroll
for (int32_t k = 0; k < K; k += 8) {
for (int32_t n = 0; n < N; n++) {
tmp[n] += A[m * K + k] * B[k * N + n];
tmp[n] += A[m * K + k + 1] * B[k * N + N + n];
tmp[n] += A[m * K + k + 2] * B[k * N + 2 * N + n];
tmp[n] += A[m * K + k + 3] * B[k * N + 3 * N + n];
tmp[n] += A[m * K + k + 4] * B[k * N + 4 * N + n];
tmp[n] += A[m * K + k + 5] * B[k * N + 5 * N + n];
tmp[n] += A[m * K + k + 6] * B[k * N + 6 * N + n];
tmp[n] += A[m * K + k + 7] * B[k * N + 7 * N + n];
}
y *= alpha;
if (V) y += beta * V[n];
if (H) y += beta * H[m * N + n];
Y[m * N + n] = y;
}
}
}
if (trans_A && trans_B) { // KM, NK; TT
for (int64_t m = 0; m < M; ++m) {
for (int64_t n = 0; n < N; ++n) {
float y = 0.0f;
for (int64_t k = 0; k < K; ++k) {
y += A[k * M + m] * B[n * K + k];
for (int32_t k = K - remainder; k < K; k++) {
for (int32_t n = 0; n < N; n++) {
tmp[n] += A[m * K + k] * B[k * N + n];
}
y *= alpha;
if (V) y += beta * V[n];
if (H) y += beta * H[m * N + n];
Y[m * N + n] = y;
}
}
}
if (!trans_A && trans_B) { // MK, NK; NT
for (int64_t m = 0; m < M; ++m) {
for (int64_t n = 0; n < N; ++n) {
float y = 0.0f;
for (int64_t k = 0; k < K; ++k) {
y += A[m * K + k] * B[n * K + k];
}
y *= alpha;
if (V) y += beta * V[n];
if (H) y += beta * H[m * N + n];
Y[m * N + n] = y;
}
for (int32_t n = 0; n < N; n++) {
tmp[n] *= alpha;
if (V) tmp[n] += beta * V[n];
if (H) tmp[n] += beta * H[m * N + n];
Y[m * N + n] = tmp[n];
}
}
}
Expand All @@ -82,6 +60,10 @@ void deformable_conv2d_ref_fp32(const float *src, const float *offset, const flo
const int64_t dilation_w, float *columns, float *dst) {
const int64_t ic_per_gp = channels / group;
const int64_t oc_per_gp = num_output / group;
// Set up for launching threads
std::size_t num_threads = std::thread::hardware_concurrency();
std::vector<std::thread> threads;
threads.reserve(num_threads);

for (int64_t b = 0; b < batch; ++b) {
for (int64_t g = 0; g < group; ++g) {
Expand All @@ -102,9 +84,28 @@ void deformable_conv2d_ref_fp32(const float *src, const float *offset, const flo
} else {
memset(dst_ptr, 0.0f, sizeof(float) * oc_per_gp * dst_h * dst_w);
}
gemm_ref_fp32(filter + g * oc_per_gp * ic_per_gp * kernel_h * kernel_w, columns, nullptr,
dst_ptr, 0, 0, oc_per_gp, dst_h * dst_w, ic_per_gp * kernel_h * kernel_w, 1.0f,
1.0f, dst_ptr);
if (num_threads > 1) {
// Calculate values to pass to threads
int32_t n_rows = (oc_per_gp + num_threads - 1) / num_threads;
int32_t end_row = 0;
for (int32_t i = 0; i < num_threads; i++) {
auto start_row = i * n_rows;
end_row = start_row + n_rows;
if (end_row > oc_per_gp) end_row = oc_per_gp;
std::thread t(parallel_unroll_gemm,
filter + g * oc_per_gp * ic_per_gp * kernel_h * kernel_w, columns, nullptr,
dst_ptr, oc_per_gp, dst_h * dst_w, ic_per_gp * kernel_h * kernel_w, 1.0f,
1.0f, dst_ptr, start_row, end_row);
threads.emplace_back(std::move(t));
}
// Wait for all threads to complete
for (auto &t : threads) t.join();
threads.clear();
} else { // parallel gemm degrade to serial gemm with start_row=0 and end_row= oc_per_gp
parallel_unroll_gemm(filter + g * oc_per_gp * ic_per_gp * kernel_h * kernel_w, columns,
nullptr, dst_ptr, oc_per_gp, dst_h * dst_w,
ic_per_gp * kernel_h * kernel_w, 1.0f, 1.0f, dst_ptr, 0, oc_per_gp);
}
}
}
}
Expand Down

0 comments on commit 6fdf459

Please sign in to comment.