Skip to content

Commit

Permalink
Improve KE for commandline and programmatically tuning dispatch (#18778)
Browse files Browse the repository at this point in the history
  • Loading branch information
cloudhan authored Apr 8, 2024
1 parent cc3faba commit e19c778
Show file tree
Hide file tree
Showing 17 changed files with 438 additions and 170 deletions.
31 changes: 28 additions & 3 deletions onnxruntime/python/tools/kernel_explorer/kernel_explorer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
#include <pybind11/pybind11.h>
#include <pybind11/embed.h>
#include <pybind11/numpy.h>
#include <pybind11/stl.h>

#include "python/tools/kernel_explorer/device_array.h"
#include "python/tools/kernel_explorer/kernel_explorer_interface.h"

Expand All @@ -13,6 +15,10 @@ namespace onnxruntime {

static py::module::module_def _kernel_explorer_module_def;

bool TuningInfo::collect_enabled_{false};
std::vector<TuningResults> TuningInfo::collected_tuning_results_ = {};
std::optional<int> TuningInfo::max_tuning_duration_ms_ = {};

py::module GetKernelExplorerModule() {
static pybind11::module_ m = []() {
auto tmp = pybind11::module_::create_extension_module(
Expand All @@ -36,29 +42,48 @@ KE_REGISTER(m) {
.def("UpdateHostNumpyArray", &DeviceArray::UpdateHostNumpyArray)
.def("UpdateDeviceArray", &DeviceArray::UpdateDeviceArray);

m.def("enable_collect_tuning_results", TuningInfo::EnableCollect, pybind11::arg("enable") = true);

m.def("max_tuning_duration_ms", TuningInfo::SetMaxTuningDurationMs);

m.def("get_collected_tuning_results", []() {
py::list ret;
for (const auto& trs : TuningInfo::GetCollectedTuningResults()) {
py::dict py_trs;
py_trs["ep"] = trs.ep;
py_trs["results"] = trs.results;
py_trs["validators"] = trs.validators;
ret.append(std::move(py_trs));
}
return ret;
});

// clang-format ill-format the following code below version 18
// clang-format off
m.def("is_composable_kernel_available", []() {
#ifdef USE_COMPOSABLE_KERNEL
return true;
#else
return false;
return false;
#endif
});

m.def("is_hipblaslt_available", []() {
#ifdef USE_HIPBLASLT
return true;
#else
return false;
return false;
#endif
});

m.def("is_float8_available", []() {
#ifndef DISABLE_FLOAT8_TYPES
return true;
#else
return false;
return false;
#endif
});
// clang-format on
}

} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,24 @@ using TuningContextT = onnxruntime::rocm::tunable::RocmTuningContext;

namespace onnxruntime {

struct TuningInfo {
static void EnableCollect(bool b) {
collect_enabled_ = b;
}

static std::vector<TuningResults> GetCollectedTuningResults() {
return collected_tuning_results_;
}

static void SetMaxTuningDurationMs(int milliseconds) {
max_tuning_duration_ms_ = milliseconds;
}

static bool collect_enabled_;
static std::vector<TuningResults> collected_tuning_results_;
static std::optional<int> max_tuning_duration_ms_;
};

/// Wrapping around Op and TunableOp
class IKernelExplorer {
public:
Expand All @@ -59,7 +77,11 @@ class IKernelExplorer {
return timer.Duration() / repeats_;
}

virtual ~IKernelExplorer() = default;
virtual ~IKernelExplorer() {
if (TuningInfo::collect_enabled_) {
TuningInfo::collected_tuning_results_.emplace_back(this->ep_->GetTuningContext()->GetTuningResults());
}
}

protected:
ExecutionProvider* GetEp() {
Expand All @@ -73,6 +95,15 @@ class IKernelExplorer {
auto tuning_ctx = this->ep_->GetTuningContext();
if (nullptr != tuning_ctx) {
tuning_ctx->RegisterAllocatorsView(&this->allocators_);
for (const auto& tr : TuningInfo::collected_tuning_results_) {
auto status = tuning_ctx->LoadTuningResults(tr);
if (!status.IsOK()) {
LOGS_DEFAULT(ERROR) << status;
}
}
if (TuningInfo::max_tuning_duration_ms_.has_value()) {
tuning_ctx->SetMaxTuningDurationMs(*TuningInfo::max_tuning_duration_ms_);
}
}
stream_ = std::make_unique<onnxruntime::Stream>(nullptr, this->ep_->GetOrtDeviceByMemType(OrtMemTypeDefault));
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,7 @@ class qkv_format: # noqa: N801
Q_KV_BSNH_BSN2H: int

def is_composable_kernel_available(*args, **kwargs): ...
def is_hipblaslt_available(*args, **kwargs): ...

def enable_collect_tuning_results(*args, **kwargs): ...
def get_collected_tuning_results(*args, **kwargs): ...
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# --------------------------------------------------------------------------

import os
import sys
from dataclasses import dataclass
from itertools import product

Expand All @@ -23,6 +22,7 @@ def dtype_to_suffix(dtype):
}[dtype]


@ke.dispatchable
def _test_batched_gemm(
func, dtype: str, transa: bool, transb: bool, m: int, n: int, k: int, batch: int, alpha=1.0, beta=0.0
):
Expand Down Expand Up @@ -148,6 +148,7 @@ def report(self):
return f"{self.duration:>6.2f} us {self.tflops:>5.2f} tflops " + common


@ke.dispatchable(pattern_arg=0)
def profile_gemm_func(f, dtype: str, transa: bool, transb: bool, m: int, n: int, k: int, batch: int):
a_shape = (k, m) if transa else (m, k)
b_shape = (n, k) if transb else (k, n)
Expand Down Expand Up @@ -177,12 +178,13 @@ def profile_gemm_func(f, dtype: str, transa: bool, transb: bool, m: int, n: int,
ke.report(BatchedGemmMetric(impl, dtype, duration_ms, flops, transa, transb, m, n, k, batch))


def profile_with_args(dtype, transa, transb, m, n, k, batch, sort):
@ke.dispatchable
def profile_with_args(dtype, transa, transb, m, n, k, batch):
dtype_suffix = "_" + dtype_to_suffix(dtype)
transab_suffix = "_" + transab_to_suffix((transa, transb))
fn_rocblas = getattr(ke, "RocBlasBatchedGemm" + dtype_suffix)
fn_tunable = getattr(ke, "BatchedGemmTunable" + dtype_suffix + transab_suffix)
with ke.benchmark(sort):
with ke.benchmark():
profile_gemm_func(fn_rocblas, dtype, transa, transb, m, n, k, batch)
profile_gemm_func(fn_tunable, dtype, transa, transb, m, n, k, batch)
print()
Expand All @@ -192,27 +194,22 @@ def profile():
for dtype in dtypes:
for m, n, k in get_gemm_bert_sizes(full=False):
for batch in [1, 32, 64]:
profile_with_args(dtype, False, False, m, n, k, batch, True)
profile_with_args(dtype, False, False, m, n, k, batch)


if __name__ == "__main__":
import argparse

parser = argparse.ArgumentParser()
group = parser.add_argument_group("profile with args")
parser = ke.get_argument_parser()
group = parser.add_argument_group()
group.add_argument("dtype", choices=dtypes)
group.add_argument("transa", choices="NT")
group.add_argument("transb", choices="NT")
group.add_argument("m", type=int)
group.add_argument("n", type=int)
group.add_argument("k", type=int)
group.add_argument("batch", type=int)
group.add_argument("--sort", action="store_true")

if len(sys.argv) == 1:
if not ke.has_args():
profile()
else:
args = parser.parse_args()
profile_with_args(
args.dtype, args.transa == "T", args.transb == "T", args.m, args.n, args.k, args.batch, args.sort
)
args.dispatch(args.dtype, args.transa == "T", args.transb == "T", args.m, args.n, args.k, args.batch)
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
# Licensed under the MIT License.
# --------------------------------------------------------------------------

import sys
from dataclasses import dataclass

import kernel_explorer as ke
Expand Down Expand Up @@ -31,6 +30,7 @@ def report(self):
return f"{self.duration:6.2f} us {self.gbps:5.2f} GB/s {self.dtype} n={self.n} k={self.k} {self.name}"


@ke.dispatchable(pattern_arg=3)
def profile_dequantize_int4_func(n, k, dtype, func):
np.random.seed(0)
output = np.random.rand(n, k).astype(dtype)
Expand All @@ -48,31 +48,29 @@ def profile_dequantize_int4_func(n, k, dtype, func):
ke.report(DequantizeInt4Metric(func, dtype, duration_ms, total_bytes, n, k))


def profile_with_args(n, k, dtype, sort):
with ke.benchmark(sort):
@ke.dispatchable
def profile_with_args(n, k, dtype):
with ke.benchmark():
for func in dtype_to_funcs(dtype):
profile_dequantize_int4_func(n, k, dtype, func)


def profile():
for dt in dtypes:
for n, k in ((4096, 4096), (4096, 12288), (12288, 4096)):
profile_with_args(n, k, dt, True)
profile_with_args(n, k, dt)
print()


if __name__ == "__main__":
import argparse

parser = argparse.ArgumentParser()
group = parser.add_argument_group("profile with args")
parser = ke.get_argument_parser()
group = parser.add_argument_group()
group.add_argument("n", type=int)
group.add_argument("k", type=int)
group.add_argument("dtype", choices=dtypes)
group.add_argument("--sort", action="store_true")

if len(sys.argv) == 1:
if not ke.has_args():
profile()
else:
args = parser.parse_args()
profile_with_args(args.n, args.k, args.dtype, args.sort)
args.dispatch(args.n, args.k, args.dtype)
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# --------------------------------------------------------------------------

import re
import sys
from dataclasses import dataclass
from itertools import product

Expand Down Expand Up @@ -90,6 +89,7 @@ def report(self):
return "not supported " + common


@ke.dispatchable(pattern_arg=4)
def profile_elementwise_func(batch_size, seq_len, hidden_size, dtype, func):
x_size = [batch_size, seq_len, hidden_size]
bias_size = hidden_size
Expand All @@ -112,33 +112,31 @@ def profile_elementwise_func(batch_size, seq_len, hidden_size, dtype, func):
ke.report(ElementwiseMetric(func, dtype, duration_ms, total_bytes, batch_size, seq_len, hidden_size))


def profile_with_args(batch_size, seq_len, hidden_size, fn_name, dtype, sort):
with ke.benchmark(sort):
@ke.dispatchable
def profile_with_args(batch_size, seq_len, hidden_size, fn_name, dtype):
with ke.benchmark():
for func in dtype_to_funcs(fn_name, dtype):
profile_elementwise_func(batch_size, seq_len, hidden_size, dtype, func)


def profile():
for dtype in dtypes:
for bert_size in get_bert_sizes():
profile_with_args(*bert_size, "FastGeLU", dtype, True)
profile_with_args(*bert_size, "FastGeLU", dtype)
print()


if __name__ == "__main__":
import argparse

parser = argparse.ArgumentParser()
group = parser.add_argument_group("profile with args")
parser = ke.get_argument_parser()
group = parser.add_argument_group()
group.add_argument("batch_size", type=int)
group.add_argument("seq_len", type=int)
group.add_argument("hidden_size", type=int)
group.add_argument("fn_name", choices=fn_names)
group.add_argument("dtype", choices=dtypes)
group.add_argument("--sort", action="store_true")

if len(sys.argv) == 1:
if not ke.has_args():
profile()
else:
args = parser.parse_args()
profile_with_args(args.batch_size, args.seq_len, args.hidden_size, args.fn_name, args.dtype, args.sort)
args.dispatch(args.batch_size, args.seq_len, args.hidden_size, args.fn_name, args.dtype)
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
# Licensed under the MIT License.
# --------------------------------------------------------------------------

import sys
from dataclasses import dataclass
from itertools import product

Expand Down Expand Up @@ -120,6 +119,7 @@ def report(self):
return f"{self.duration:>6.2f} us {self.tflops:>5.2f} tflops " + common


@ke.dispatchable(pattern_arg=0)
def profile_gemmfastgelu_func(my_func, dtype: str, m: int, n: int, k: int, transa: bool, transb: bool):
a_shape = (k, m) if transa else (m, k)
b_shape = (n, k) if transb else (k, n)
Expand Down Expand Up @@ -153,10 +153,11 @@ def profile_gemmfastgelu_func(my_func, dtype: str, m: int, n: int, k: int, trans
ke.report(GemmFastGeluMetric(impl, dtype, duration_ms, floating_point_operations, transa, transb, m, n, k))


def profile_with_args(transa, transb, dtype, m, n, k, sort):
@ke.dispatchable
def profile_with_args(transa, transb, dtype, m, n, k):
dtype_suffix = "_" + dtype_to_suffix(dtype)
transab_suffix = "_" + transab_to_suffix((transa, transb))
with ke.benchmark(sort):
with ke.benchmark():
profile_gemmfastgelu_func(getattr(ke, "GemmFastGeluUnfused" + dtype_suffix), dtype, m, n, k, transa, transb)
profile_gemmfastgelu_func(
getattr(ke, "CKGemmFastGelu" + dtype_suffix + transab_suffix), dtype, m, n, k, transa, transb
Expand All @@ -173,24 +174,22 @@ def profile_with_args(transa, transb, dtype, m, n, k, sort):
def profile():
for dtype in dtypes:
for m, n, k in get_gemm_bert_sizes(full=True):
profile_with_args(False, False, dtype, m, n, k, True)
profile_with_args(False, False, dtype, m, n, k)
print()


if __name__ == "__main__":
import argparse

parser = argparse.ArgumentParser()
group = parser.add_argument_group("profile with args")
parser = ke.get_argument_parser()
group = parser.add_argument_group()
group.add_argument("transa", choices="NT")
group.add_argument("transb", choices="NT")
group.add_argument("dtype", choices=dtypes)
group.add_argument("m", type=int)
group.add_argument("n", type=int)
group.add_argument("k", type=int)
group.add_argument("--sort", action="store_true")
if len(sys.argv) == 1:

if not ke.has_args():
profile()
else:
args = parser.parse_args()
profile_with_args(args.transa == "T", args.transb == "T", args.dtype, args.m, args.n, args.k, args.sort)
args.dispatch(args.transa == "T", args.transb == "T", args.dtype, args.m, args.n, args.k)
Loading

0 comments on commit e19c778

Please sign in to comment.