diff --git a/VERSION b/VERSION index c27a65ceb..2e60e7919 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -0.0.1.dev9 \ No newline at end of file +0.0.1.dev10 \ No newline at end of file diff --git a/python/bitblas/__init__.py b/python/bitblas/__init__.py index df2c63199..8100da00f 100644 --- a/python/bitblas/__init__.py +++ b/python/bitblas/__init__.py @@ -81,4 +81,4 @@ def _init_logger(): _init_logger() -__version__ = "0.0.1.dev9" +__version__ = "0.0.1.dev10" diff --git a/python/bitblas/gpu/gemv.py b/python/bitblas/gpu/gemv.py index 7a2880ed1..7b08179d3 100644 --- a/python/bitblas/gpu/gemv.py +++ b/python/bitblas/gpu/gemv.py @@ -775,7 +775,8 @@ def apply_config( # pylint: disable=too-many-locals,missing-docstring return None block_info = block_infos[0] - if len(block_info.iters) not in [2, 3]: + if len(block_info.iters) not in [2, 3, 4]: + # either [SK, B, S, R] = [SK, B, S, R] * [SK, B, R] # either [B, S, R] = [B, S, R] * [B, R] # or [S, R] = [S, R] * [R] return None diff --git a/python/bitblas/gpu/gemv_dequantize.py b/python/bitblas/gpu/gemv_dequantize.py index 47e4bf42c..5a6405f52 100644 --- a/python/bitblas/gpu/gemv_dequantize.py +++ b/python/bitblas/gpu/gemv_dequantize.py @@ -110,6 +110,11 @@ def get_vectorize_factor(target_format): if len(sch.get_loops(block_b)) == 3: i = sch.get_loops(block_b)[0] sch.bind(i, "blockIdx.z") + elif len(sch.get_loops(block_b)) == 4: + # splitk case + sk, i = sch.get_loops(block_b)[:2] + sch.bind(sk, "blockIdx.y") + sch.bind(i, "blockIdx.z") # get target dequantize buffer's idx def get_idx(weight_decode_info: Dict): @@ -274,6 +279,14 @@ def get_vectorize_factor(target_format): if len(sch.get_loops(block_b)) == 3: i = sch.get_loops(block_b)[0] sch.bind(i, "blockIdx.z") + elif len(sch.get_loops(block_b)) == 4: + # splitk case + sk, i = sch.get_loops(block_b)[:2] + sch.bind(sk, "blockIdx.y") + sch.bind(i, "blockIdx.z") + assert len(config.thread) == 2, "SplitK only support 2D thread config" + num_warps = int(num_warps // config.thread[0]) + # get target dequantize buffer's idx def get_idx(weight_decode_info: Dict):