Skip to content

Commit

Permalink
lint fix
Browse files Browse the repository at this point in the history
  • Loading branch information
LeiWang1999 committed Aug 5, 2024
1 parent ac316fd commit 71c1d6e
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 79 deletions.
2 changes: 1 addition & 1 deletion 3rdparty/tvm
158 changes: 85 additions & 73 deletions bitblas/base/roller/policy/tensorcore.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,83 +117,92 @@ def _check_small_tile(td: TileDict):
return True
return False

if not _check_small_tile(td):
return None
if _check_small_tile(td):

smem_limit = min(self.arch.max_smem_usage // td.block_per_SM, self.arch.smem_cap)
rstep_map = td.rstep_map.copy()

def _optimize(node, rstep):
all_steps = self.get_node_reduce_step_candidates(node)
# todo(lei): optimize the all_steps enlarge policy to be a multiple of the original all_steps[k]
for k in all_steps:
all_steps[k] = list(filter(lambda x: x % rstep[k] == 0, all_steps[k]))
if any([v == [] for v in all_steps.values()]):
return rstep

def _shared_memory_usage(td: TileDict):
return node.footprint(td.output_tile, new_rstep_map,
td.tensor_strides_map[node])

def _score(rstep_id):
rstep = {
k.var.name: all_steps[k.var.name][rstep_id[k.var.name]] for k in node.raxis
}
score = 0
shape = node.propagate_inputs_on_reduction(td.get_tile(node), rstep=rstep)
input_buffers = node.block_analyzer.get_input_buffers(node.reduction_block)
for i, input_buffer in enumerate(input_buffers):
score += coalesced_factor(shape[i], input_buffer.shape)
return score

def _enlarge(rstep_id):
candidates = []
for ax in rstep_id:
if rstep_id[ax] + 1 == len(all_steps[ax]):
continue
r = rstep_id.copy()
r[ax] += 1
candidates.append((r, _score(r)))
if len(candidates) == 0:
return None
return max(candidates, key=lambda x: x[1])[0]

cur_rstep_id = {
k.var.name: all_steps[k.var.name].index(rstep[k.var.name]) for k in node.raxis
}
new_rstep_map = rstep_map.copy()
while True:
new_rstep_id = _enlarge(cur_rstep_id)
if new_rstep_id is None:
break
new_rstep_map = {
k.var.name: all_steps[k.var.name][new_rstep_id[k.var.name]]
for k in node.raxis
}
old_rstep_map = td.rstep_map
td.rstep_map = new_rstep_map
smem_usage, _ = _shared_memory_usage(td)
td.rstep_map = old_rstep_map
if smem_usage > smem_limit:
break
else:
cur_rstep_id = new_rstep_id
rstep = {
k.var.name: all_steps[k.var.name][cur_rstep_id[k.var.name]] for k in node.raxis
}
return rstep

smem_limit = min(self.arch.max_smem_usage // td.block_per_SM, self.arch.smem_cap)
rstep_map = td.rstep_map.copy()
for node in self.ordered_nodes:
if len(node.raxis) > 0:
rstep = _optimize(node, rstep_map)
rstep_map = rstep

def _optimize(node, rstep):
all_steps = self.get_node_reduce_step_candidates(node)
# todo(lei): optimize the all_steps enlarge policy to be a multiple of the original all_steps[k]
for k in all_steps:
all_steps[k] = list(filter(lambda x: x % rstep[k] == 0, all_steps[k]))
if any([v == [] for v in all_steps.values()]):
return rstep
td.rstep_map = rstep_map
td.smem_cost, td.cached_tensors_map = self._compute_shared_memory_usage(td)

def _shared_memory_usage(td: TileDict):
return node.footprint(td.output_tile, new_rstep_map, td.tensor_strides_map[node])
if self.block_reduction_depth is not None:

def _score(rstep_id):
rstep = {
k.var.name: all_steps[k.var.name][rstep_id[k.var.name]] for k in node.raxis
}
score = 0
shape = node.propagate_inputs_on_reduction(td.get_tile(node), rstep=rstep)
input_buffers = node.block_analyzer.get_input_buffers(node.reduction_block)
for i, input_buffer in enumerate(input_buffers):
score += coalesced_factor(shape[i], input_buffer.shape)
return score

def _enlarge(rstep_id):
candidates = []
for ax in rstep_id:
if rstep_id[ax] + 1 == len(all_steps[ax]):
continue
r = rstep_id.copy()
r[ax] += 1
candidates.append((r, _score(r)))
if len(candidates) == 0:
return None
return max(candidates, key=lambda x: x[1])[0]

cur_rstep_id = {
k.var.name: all_steps[k.var.name].index(rstep[k.var.name]) for k in node.raxis
}
new_rstep_map = rstep_map.copy()
while True:
new_rstep_id = _enlarge(cur_rstep_id)
if new_rstep_id is None:
break
new_rstep_map = {
k.var.name: all_steps[k.var.name][new_rstep_id[k.var.name]] for k in node.raxis
}
old_rstep_map = td.rstep_map
td.rstep_map = new_rstep_map
smem_usage, _ = _shared_memory_usage(td)
td.rstep_map = old_rstep_map
if smem_usage > smem_limit:
break
else:
cur_rstep_id = new_rstep_id
rstep = {
k.var.name: all_steps[k.var.name][cur_rstep_id[k.var.name]] for k in node.raxis
}
return rstep
def _expand_with_tags(rstep):
new_rstep = {k: v * self.block_reduction_depth for k, v in rstep.items()}
return new_rstep

rstep_map = td.rstep_map.copy()
for node in self.ordered_nodes:
if len(node.raxis) > 0:
rstep = _expand_with_tags(rstep_map)
rstep_map = rstep
td.rstep_map = rstep_map

for node in self.ordered_nodes:
if len(node.raxis) > 0:
rstep = _optimize(node, rstep_map)
rstep_map = rstep

# if is_block_reduction:
# # If block reduction, we should constrain the max value is 64
# # Otherwise it will introduce an issue of cuda invalid args.
# MAX_REDUCE_K = 64
# for k in rstep_map:
# rstep_map[k] = min(rstep_map[k], MAX_REDUCE_K)
td.rstep_map = rstep_map
td.smem_cost, td.cached_tensors_map = self._compute_shared_memory_usage(td)
return

def get_node_reduce_step_candidates(self, node):
Expand Down Expand Up @@ -318,12 +327,15 @@ def _score(node, thread): # small is better
# smem capacity
# TODO: This is a dummy mul which avoid reusing some shared memory.
# Should be removed in the future.
if td.smem_cost > (self.arch.smem_cap * 1.3):
if td.smem_cost > (self.arch.smem_cap):
info_message = f"Tile Dict: {td.output_tile} Shared memory exceeds the static capacity," \
" use dynamic shared memory."
logger.info(info_message)
codegen_dict.shared_scope = "shared.dyn"

# Or assume we always use shared memory
# codegen_dict.shared_scope = "shared.dyn"

codegen_dict.complete_config(node)
codegen_dict.vectorize = self._plan_vectorize(self.prim_func_node, td, block_size)
codegen_dict.arch = self.arch
Expand Down
6 changes: 4 additions & 2 deletions bitblas/gpu/matmul_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,14 +622,16 @@ def check_last_trait(region: List[Range]):
# Analysis Block Reduction Optimization
# Currently, we only support block reduction depth 2 for small M
# When the func is a dequantize like ops, we should consider the M
require_block_reduce = False
if hasattr(func.attrs, "dequantize_info"):
for arg in func.params:
inp_shape = func.buffer_map[arg].shape
M = inp_shape[0]
if isinstance(M, tir.IntImm) and M <= 128:
tags["block_reduction_depth"] = 2
require_block_reduce = True
break

if require_block_reduce and check_sm_version(target.arch) == 80:
tags["block_reduction_depth"] = 2
return tags

(main_block,) = reduction_blocks
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -515,7 +515,7 @@ def matmul_nt_dequantize_b_propagate_b(
fast_decoding=False,
with_bias=False,
zeros_mode="original",
transform_kind: Union[int, TransformKind] = TransformKind.NonTransform,
transform_kind: Union[int, TransformKind] = TransformKind.IntraWarpTransform,
):
if isinstance(transform_kind, int):
transform_kind = TransformKind(transform_kind)
Expand Down Expand Up @@ -699,8 +699,8 @@ def matmul_nt_dequantize_b_propagate_a_propagate_b(
fast_decoding=False,
with_bias=False,
zeros_mode="original",
transform_kind_input: Union[int, TransformKind] = TransformKind.NonTransform,
transform_kind_weight: Union[int, TransformKind] = TransformKind.NonTransform,
transform_kind_input: Union[int, TransformKind] = TransformKind.IntraWarpTransform,
transform_kind_weight: Union[int, TransformKind] = TransformKind.IntraWarpTransform,
):
if isinstance(transform_kind_input, int):
transform_kind_input = TransformKind(transform_kind_input)
Expand Down

0 comments on commit 71c1d6e

Please sign in to comment.