From 9bdcbf833a786b8bc06266bd9ee546564e0e828b Mon Sep 17 00:00:00 2001 From: Lei Wang <34334180+LeiWang1999@users.noreply.github.com> Date: Mon, 22 Jul 2024 15:40:08 +0800 Subject: [PATCH] [Dev] Fix a correctness issue when block reduce is applied with pipeline stage (#94) * Refactor BatchMatMulEmitter and BatchMatMulSelector for improved readability and maintainability * Refactor import statements for improved readability and maintainability * Refactor import statements for improved readability and maintainability * disable failure email for ci * remove email notifications. * move relax pass from testing to mlc_llm * Refactor scripts with se check_eual_ref_scripts_with_emitter function * Lint Fix * Refactor scripts with se check_eual_ref_scripts_with_emitter function * bug fix in test * lint fix. * test cuda i4 kernel * Refactor copyright notice in i4matmul.hpp * Refactor BitBLASLinear test module for improved readability and maintainability * refactor test as version below python 3.9 cannot handle int32 overflow. * format lint for test * Refactor test_int4b_fp16_convert.py for improved readability and maintainability * remove unused design file * move tile device from package to base * dummy impl for codegen * Refactor file structure for ladder_permutate module * Refactor backend class and fix typos in comments * Deep refactor Lib related code. * remove ci pull. * LintFix * refactor builder for whl build * Refactor TIRWrapper.wrap() method to include an assertion for the optimized module * Refactor lib_generator to set library and source paths * lint fix * BitNet vllm integration * chore: update codespell to version 2.3.0 * Lintfix * Bump version to 0.0.1.dev13 * lint fix * disable fast decoding [u]int4xint8 by default. * optimize from dict design in Hint * Implement SplitK * bitnet benchmark generation. * Add benchmark script for BitNet integration * AtomicAdd Support * LintFix * ci fix when 3rdparty tvm is initialized. * bug fix for setup * fix a bug in block reduce * typo fix * BUG Fix for block reduce. * Lint fix * Refactor block reduce schedule template --- bitblas/base/roller/policy/tensorcore.py | 30 ++++++++++++++---------- bitblas/gpu/matmul_mma_dequantize.py | 6 ++++- 2 files changed, 23 insertions(+), 13 deletions(-) diff --git a/bitblas/base/roller/policy/tensorcore.py b/bitblas/base/roller/policy/tensorcore.py index e69bcabc3..9e6fff9ee 100644 --- a/bitblas/base/roller/policy/tensorcore.py +++ b/bitblas/base/roller/policy/tensorcore.py @@ -4,7 +4,7 @@ from bitblas import tvm from typing import Dict, List, Tuple, Optional import numpy as np - +import logging from ...arch import TileDevice from ..hint import Hint, Stride, TileDict, IntrinInfo from ..node import PrimFuncNode @@ -12,6 +12,8 @@ from .default import DefaultPolicy from ..rasterization import NoRasterization, Rasterization2DColumn +logger = logging.getLogger(__name__) + class TensorCorePolicy(DefaultPolicy): @@ -47,9 +49,9 @@ def _legalize_info(self): self.use_async_copy = False # TODO: block reduction depth is not used for now. # As there still exists some performance issues for block reduction. - # block_reduction_depth = self.prim_func_node.get_tag("block_reduction_depth") - # if block_reduction_depth: - # self.block_reduction_depth = block_reduction_depth + block_reduction_depth = self.prim_func_node.get_tag("block_reduction_depth") + if block_reduction_depth: + self.block_reduction_depth = block_reduction_depth def _compute_tc_strides( self, @@ -120,7 +122,6 @@ def _check_small_tile(td: TileDict): smem_limit = min(self.arch.max_smem_usage // td.block_per_SM, self.arch.smem_cap) rstep_map = td.rstep_map.copy() - is_block_reduction = self.block_reduction_depth is not None def _optimize(node, rstep): all_steps = self.get_node_reduce_step_candidates(node) @@ -185,12 +186,12 @@ def _enlarge(rstep_id): rstep = _optimize(node, rstep_map) rstep_map = rstep - if is_block_reduction: - # If block reduction, we should constrain the max value is 64 - # Otherwise it will introduce an issue of cuda invalid args. - MAX_REDUCE_K = 64 - for k in rstep_map: - rstep_map[k] = min(rstep_map[k], MAX_REDUCE_K) + # if is_block_reduction: + # # If block reduction, we should constrain the max value is 64 + # # Otherwise it will introduce an issue of cuda invalid args. + # MAX_REDUCE_K = 64 + # for k in rstep_map: + # rstep_map[k] = min(rstep_map[k], MAX_REDUCE_K) td.rstep_map = rstep_map td.smem_cost, td.cached_tensors_map = self._compute_shared_memory_usage(td) return @@ -315,7 +316,12 @@ def _score(node, thread): # small is better if intrin_info["out_dtype"] in ["float32"]: codegen_dict.shared_scope = "shared.dyn" # smem capacity - if td.smem_cost > self.arch.smem_cap: + # TODO: This is a dummy mul which avoid reusing some shared memory. + # Should be removed in the future. + if td.smem_cost > (self.arch.smem_cap * 1.3): + info_message = f"Tile Dict: {td.output_tile} Shared memory exceeds the static capacity," \ + " use dynamic shared memory." + logger.info(info_message) codegen_dict.shared_scope = "shared.dyn" codegen_dict.complete_config(node) diff --git a/bitblas/gpu/matmul_mma_dequantize.py b/bitblas/gpu/matmul_mma_dequantize.py index f7dede4a1..de1b5b896 100644 --- a/bitblas/gpu/matmul_mma_dequantize.py +++ b/bitblas/gpu/matmul_mma_dequantize.py @@ -1986,7 +1986,7 @@ def get_param_indices( k0, kr = sch.split(k0, [None, reduce_k]) sch.reorder(i0, j0, i1, j1, i2, j2, kr, k0, k1, i3, j3) - # sch.reorder(i0, j0, i1, j1, i2, j2, k0, k1, i3, j3) + block_idy = sch.fuse(i0, j0) block_idx = sch.fuse(i1, j1) thread_idy = i2 @@ -1998,6 +1998,10 @@ def get_param_indices( thread_idz = j2 = thread_idy = sch.fuse(thread_idy, thread_idz) sch.bind(thread_idy, "threadIdx.y") + # Put the thread binding after the shared memory prefetch + # Otherwise there's a axis missing bug behind tvm + sch.bind(kr, "threadIdx.z") + def smooth_layout_recover(block, scope, l=16, r=16, enable=True): # noqa: E741 if not enable: return