From 71c1d6eec96fadacab7f96cf0a42455f0d2d2af1 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Mon, 5 Aug 2024 13:45:30 +0000 Subject: [PATCH] lint fix --- 3rdparty/tvm | 2 +- bitblas/base/roller/policy/tensorcore.py | 158 ++++++++++-------- bitblas/gpu/matmul_analysis.py | 6 +- .../tirscript/matmul_dequantize_impl.py | 6 +- 4 files changed, 93 insertions(+), 79 deletions(-) diff --git a/3rdparty/tvm b/3rdparty/tvm index c441882e2..6daecacc7 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit c441882e2372deeb33d0eaefd62a133d482ac669 +Subproject commit 6daecacc73c8c8fdea1b9732891e1d4a5ebbf818 diff --git a/bitblas/base/roller/policy/tensorcore.py b/bitblas/base/roller/policy/tensorcore.py index 9e6fff9ee..468498fbd 100644 --- a/bitblas/base/roller/policy/tensorcore.py +++ b/bitblas/base/roller/policy/tensorcore.py @@ -117,83 +117,92 @@ def _check_small_tile(td: TileDict): return True return False - if not _check_small_tile(td): - return None + if _check_small_tile(td): + + smem_limit = min(self.arch.max_smem_usage // td.block_per_SM, self.arch.smem_cap) + rstep_map = td.rstep_map.copy() + + def _optimize(node, rstep): + all_steps = self.get_node_reduce_step_candidates(node) + # todo(lei): optimize the all_steps enlarge policy to be a multiple of the original all_steps[k] + for k in all_steps: + all_steps[k] = list(filter(lambda x: x % rstep[k] == 0, all_steps[k])) + if any([v == [] for v in all_steps.values()]): + return rstep + + def _shared_memory_usage(td: TileDict): + return node.footprint(td.output_tile, new_rstep_map, + td.tensor_strides_map[node]) + + def _score(rstep_id): + rstep = { + k.var.name: all_steps[k.var.name][rstep_id[k.var.name]] for k in node.raxis + } + score = 0 + shape = node.propagate_inputs_on_reduction(td.get_tile(node), rstep=rstep) + input_buffers = node.block_analyzer.get_input_buffers(node.reduction_block) + for i, input_buffer in enumerate(input_buffers): + score += coalesced_factor(shape[i], input_buffer.shape) + return score + + def _enlarge(rstep_id): + candidates = [] + for ax in rstep_id: + if rstep_id[ax] + 1 == len(all_steps[ax]): + continue + r = rstep_id.copy() + r[ax] += 1 + candidates.append((r, _score(r))) + if len(candidates) == 0: + return None + return max(candidates, key=lambda x: x[1])[0] + + cur_rstep_id = { + k.var.name: all_steps[k.var.name].index(rstep[k.var.name]) for k in node.raxis + } + new_rstep_map = rstep_map.copy() + while True: + new_rstep_id = _enlarge(cur_rstep_id) + if new_rstep_id is None: + break + new_rstep_map = { + k.var.name: all_steps[k.var.name][new_rstep_id[k.var.name]] + for k in node.raxis + } + old_rstep_map = td.rstep_map + td.rstep_map = new_rstep_map + smem_usage, _ = _shared_memory_usage(td) + td.rstep_map = old_rstep_map + if smem_usage > smem_limit: + break + else: + cur_rstep_id = new_rstep_id + rstep = { + k.var.name: all_steps[k.var.name][cur_rstep_id[k.var.name]] for k in node.raxis + } + return rstep - smem_limit = min(self.arch.max_smem_usage // td.block_per_SM, self.arch.smem_cap) - rstep_map = td.rstep_map.copy() + for node in self.ordered_nodes: + if len(node.raxis) > 0: + rstep = _optimize(node, rstep_map) + rstep_map = rstep - def _optimize(node, rstep): - all_steps = self.get_node_reduce_step_candidates(node) - # todo(lei): optimize the all_steps enlarge policy to be a multiple of the original all_steps[k] - for k in all_steps: - all_steps[k] = list(filter(lambda x: x % rstep[k] == 0, all_steps[k])) - if any([v == [] for v in all_steps.values()]): - return rstep + td.rstep_map = rstep_map + td.smem_cost, td.cached_tensors_map = self._compute_shared_memory_usage(td) - def _shared_memory_usage(td: TileDict): - return node.footprint(td.output_tile, new_rstep_map, td.tensor_strides_map[node]) + if self.block_reduction_depth is not None: - def _score(rstep_id): - rstep = { - k.var.name: all_steps[k.var.name][rstep_id[k.var.name]] for k in node.raxis - } - score = 0 - shape = node.propagate_inputs_on_reduction(td.get_tile(node), rstep=rstep) - input_buffers = node.block_analyzer.get_input_buffers(node.reduction_block) - for i, input_buffer in enumerate(input_buffers): - score += coalesced_factor(shape[i], input_buffer.shape) - return score - - def _enlarge(rstep_id): - candidates = [] - for ax in rstep_id: - if rstep_id[ax] + 1 == len(all_steps[ax]): - continue - r = rstep_id.copy() - r[ax] += 1 - candidates.append((r, _score(r))) - if len(candidates) == 0: - return None - return max(candidates, key=lambda x: x[1])[0] - - cur_rstep_id = { - k.var.name: all_steps[k.var.name].index(rstep[k.var.name]) for k in node.raxis - } - new_rstep_map = rstep_map.copy() - while True: - new_rstep_id = _enlarge(cur_rstep_id) - if new_rstep_id is None: - break - new_rstep_map = { - k.var.name: all_steps[k.var.name][new_rstep_id[k.var.name]] for k in node.raxis - } - old_rstep_map = td.rstep_map - td.rstep_map = new_rstep_map - smem_usage, _ = _shared_memory_usage(td) - td.rstep_map = old_rstep_map - if smem_usage > smem_limit: - break - else: - cur_rstep_id = new_rstep_id - rstep = { - k.var.name: all_steps[k.var.name][cur_rstep_id[k.var.name]] for k in node.raxis - } - return rstep + def _expand_with_tags(rstep): + new_rstep = {k: v * self.block_reduction_depth for k, v in rstep.items()} + return new_rstep + + rstep_map = td.rstep_map.copy() + for node in self.ordered_nodes: + if len(node.raxis) > 0: + rstep = _expand_with_tags(rstep_map) + rstep_map = rstep + td.rstep_map = rstep_map - for node in self.ordered_nodes: - if len(node.raxis) > 0: - rstep = _optimize(node, rstep_map) - rstep_map = rstep - - # if is_block_reduction: - # # If block reduction, we should constrain the max value is 64 - # # Otherwise it will introduce an issue of cuda invalid args. - # MAX_REDUCE_K = 64 - # for k in rstep_map: - # rstep_map[k] = min(rstep_map[k], MAX_REDUCE_K) - td.rstep_map = rstep_map - td.smem_cost, td.cached_tensors_map = self._compute_shared_memory_usage(td) return def get_node_reduce_step_candidates(self, node): @@ -318,12 +327,15 @@ def _score(node, thread): # small is better # smem capacity # TODO: This is a dummy mul which avoid reusing some shared memory. # Should be removed in the future. - if td.smem_cost > (self.arch.smem_cap * 1.3): + if td.smem_cost > (self.arch.smem_cap): info_message = f"Tile Dict: {td.output_tile} Shared memory exceeds the static capacity," \ " use dynamic shared memory." logger.info(info_message) codegen_dict.shared_scope = "shared.dyn" + # Or assume we always use shared memory + # codegen_dict.shared_scope = "shared.dyn" + codegen_dict.complete_config(node) codegen_dict.vectorize = self._plan_vectorize(self.prim_func_node, td, block_size) codegen_dict.arch = self.arch diff --git a/bitblas/gpu/matmul_analysis.py b/bitblas/gpu/matmul_analysis.py index 210c560a1..1d0889fa3 100644 --- a/bitblas/gpu/matmul_analysis.py +++ b/bitblas/gpu/matmul_analysis.py @@ -622,14 +622,16 @@ def check_last_trait(region: List[Range]): # Analysis Block Reduction Optimization # Currently, we only support block reduction depth 2 for small M # When the func is a dequantize like ops, we should consider the M + require_block_reduce = False if hasattr(func.attrs, "dequantize_info"): for arg in func.params: inp_shape = func.buffer_map[arg].shape M = inp_shape[0] if isinstance(M, tir.IntImm) and M <= 128: - tags["block_reduction_depth"] = 2 + require_block_reduce = True break - + if require_block_reduce and check_sm_version(target.arch) == 80: + tags["block_reduction_depth"] = 2 return tags (main_block,) = reduction_blocks diff --git a/bitblas/ops/general_matmul/tirscript/matmul_dequantize_impl.py b/bitblas/ops/general_matmul/tirscript/matmul_dequantize_impl.py index 17d22dcfe..a86f6469a 100644 --- a/bitblas/ops/general_matmul/tirscript/matmul_dequantize_impl.py +++ b/bitblas/ops/general_matmul/tirscript/matmul_dequantize_impl.py @@ -515,7 +515,7 @@ def matmul_nt_dequantize_b_propagate_b( fast_decoding=False, with_bias=False, zeros_mode="original", - transform_kind: Union[int, TransformKind] = TransformKind.NonTransform, + transform_kind: Union[int, TransformKind] = TransformKind.IntraWarpTransform, ): if isinstance(transform_kind, int): transform_kind = TransformKind(transform_kind) @@ -699,8 +699,8 @@ def matmul_nt_dequantize_b_propagate_a_propagate_b( fast_decoding=False, with_bias=False, zeros_mode="original", - transform_kind_input: Union[int, TransformKind] = TransformKind.NonTransform, - transform_kind_weight: Union[int, TransformKind] = TransformKind.NonTransform, + transform_kind_input: Union[int, TransformKind] = TransformKind.IntraWarpTransform, + transform_kind_weight: Union[int, TransformKind] = TransformKind.IntraWarpTransform, ): if isinstance(transform_kind_input, int): transform_kind_input = TransformKind(transform_kind_input)