Skip to content

Commit

Permalink
[Dev] Fix a correctness issue when block reduce is applied with pipel…
Browse files Browse the repository at this point in the history
…ine stage (#94)

* Refactor BatchMatMulEmitter and BatchMatMulSelector for improved readability and maintainability

* Refactor import statements for improved readability and maintainability

* Refactor import statements for improved readability and maintainability

* disable failure email for ci

* remove email notifications.

* move relax pass from testing to mlc_llm

* Refactor scripts with se check_eual_ref_scripts_with_emitter function

* Lint Fix

* Refactor scripts with se check_eual_ref_scripts_with_emitter function

* bug fix in test

* lint fix.

* test cuda i4 kernel

* Refactor copyright notice in i4matmul.hpp

* Refactor BitBLASLinear test module for improved readability and maintainability

* refactor test as version below python 3.9 cannot handle int32 overflow.

* format lint for test

* Refactor test_int4b_fp16_convert.py for improved readability and maintainability

* remove unused design file

* move tile device from package to base

* dummy impl for codegen

* Refactor file structure for ladder_permutate module

* Refactor backend class and fix typos in comments

* Deep refactor Lib related code.

* remove ci pull.

* LintFix

* refactor builder for whl build

* Refactor TIRWrapper.wrap() method to include an assertion for the optimized module

* Refactor lib_generator to set library and source paths

* lint fix

* BitNet vllm integration

* chore: update codespell to version 2.3.0

* Lintfix

* Bump version to 0.0.1.dev13

* lint fix

* disable fast decoding [u]int4xint8 by default.

* optimize from dict design in Hint

* Implement SplitK

* bitnet benchmark generation.

* Add benchmark script for BitNet integration

* AtomicAdd Support

* LintFix

* ci fix when 3rdparty tvm is initialized.

* bug fix for setup

* fix a bug in block reduce

* typo fix

* BUG Fix for block reduce.

* Lint fix

* Refactor block reduce schedule template
  • Loading branch information
LeiWang1999 authored Jul 22, 2024
1 parent 853522d commit 9bdcbf8
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 13 deletions.
30 changes: 18 additions & 12 deletions bitblas/base/roller/policy/tensorcore.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,16 @@
from bitblas import tvm
from typing import Dict, List, Tuple, Optional
import numpy as np

import logging
from ...arch import TileDevice
from ..hint import Hint, Stride, TileDict, IntrinInfo
from ..node import PrimFuncNode
from .common import coalesced_factor, factorize, get_all_factors
from .default import DefaultPolicy
from ..rasterization import NoRasterization, Rasterization2DColumn

logger = logging.getLogger(__name__)


class TensorCorePolicy(DefaultPolicy):

Expand Down Expand Up @@ -47,9 +49,9 @@ def _legalize_info(self):
self.use_async_copy = False
# TODO: block reduction depth is not used for now.
# As there still exists some performance issues for block reduction.
# block_reduction_depth = self.prim_func_node.get_tag("block_reduction_depth")
# if block_reduction_depth:
# self.block_reduction_depth = block_reduction_depth
block_reduction_depth = self.prim_func_node.get_tag("block_reduction_depth")
if block_reduction_depth:
self.block_reduction_depth = block_reduction_depth

def _compute_tc_strides(
self,
Expand Down Expand Up @@ -120,7 +122,6 @@ def _check_small_tile(td: TileDict):

smem_limit = min(self.arch.max_smem_usage // td.block_per_SM, self.arch.smem_cap)
rstep_map = td.rstep_map.copy()
is_block_reduction = self.block_reduction_depth is not None

def _optimize(node, rstep):
all_steps = self.get_node_reduce_step_candidates(node)
Expand Down Expand Up @@ -185,12 +186,12 @@ def _enlarge(rstep_id):
rstep = _optimize(node, rstep_map)
rstep_map = rstep

if is_block_reduction:
# If block reduction, we should constrain the max value is 64
# Otherwise it will introduce an issue of cuda invalid args.
MAX_REDUCE_K = 64
for k in rstep_map:
rstep_map[k] = min(rstep_map[k], MAX_REDUCE_K)
# if is_block_reduction:
# # If block reduction, we should constrain the max value is 64
# # Otherwise it will introduce an issue of cuda invalid args.
# MAX_REDUCE_K = 64
# for k in rstep_map:
# rstep_map[k] = min(rstep_map[k], MAX_REDUCE_K)
td.rstep_map = rstep_map
td.smem_cost, td.cached_tensors_map = self._compute_shared_memory_usage(td)
return
Expand Down Expand Up @@ -315,7 +316,12 @@ def _score(node, thread): # small is better
if intrin_info["out_dtype"] in ["float32"]:
codegen_dict.shared_scope = "shared.dyn"
# smem capacity
if td.smem_cost > self.arch.smem_cap:
# TODO: This is a dummy mul which avoid reusing some shared memory.
# Should be removed in the future.
if td.smem_cost > (self.arch.smem_cap * 1.3):
info_message = f"Tile Dict: {td.output_tile} Shared memory exceeds the static capacity," \
" use dynamic shared memory."
logger.info(info_message)
codegen_dict.shared_scope = "shared.dyn"

codegen_dict.complete_config(node)
Expand Down
6 changes: 5 additions & 1 deletion bitblas/gpu/matmul_mma_dequantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -1986,7 +1986,7 @@ def get_param_indices(
k0, kr = sch.split(k0, [None, reduce_k])

sch.reorder(i0, j0, i1, j1, i2, j2, kr, k0, k1, i3, j3)
# sch.reorder(i0, j0, i1, j1, i2, j2, k0, k1, i3, j3)

block_idy = sch.fuse(i0, j0)
block_idx = sch.fuse(i1, j1)
thread_idy = i2
Expand All @@ -1998,6 +1998,10 @@ def get_param_indices(
thread_idz = j2 = thread_idy = sch.fuse(thread_idy, thread_idz)
sch.bind(thread_idy, "threadIdx.y")

# Put the thread binding after the shared memory prefetch
# Otherwise there's a axis missing bug behind tvm
sch.bind(kr, "threadIdx.z")

def smooth_layout_recover(block, scope, l=16, r=16, enable=True): # noqa: E741
if not enable:
return
Expand Down

0 comments on commit 9bdcbf8

Please sign in to comment.