Skip to content

Commit

Permalink
dsl benchmark scirpts
Browse files Browse the repository at this point in the history
  • Loading branch information
LeiWang1999 committed Feb 28, 2024
1 parent 1e93b94 commit ebbd294
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 82 deletions.
77 changes: 28 additions & 49 deletions benchmark/dsl/matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from bitblas.gpu.matmul_analysis import get_tensorized_func_and_tags
from bitblas.gpu import Matmul
from bitblas.base.utils import apply_and_build
from bitblas.ops.impl.matmul_impl import matmul_nt_propagate_a_propagate_b
import time


Expand Down Expand Up @@ -93,54 +94,8 @@ def main(a: T.handle, b: T.handle, c: T.handle):

return MyModule


def matmul_nt_propagate_a_b(M, N, K, in_dtype="float16", out_dtype="float16"):
wm, wn, wk = 16, 16, 16
if in_dtype == "int8":
wm, wn, wk = 16, 16, 32

@tvm.script.ir_module
class MyModule:
@T.prim_func
def main(a: T.handle, b: T.handle, c: T.handle):
T.func_attr(
{
"global_symbol": "main",
"tir.noalias": True,
"smooth_a": True,
"smooth_b": True,
}
)
A = T.match_buffer(a, [M // wm, K // wk, wm, wk], dtype=in_dtype)
B = T.match_buffer(b, [N // wn, K // wk, wn, wk], dtype=in_dtype)
C = T.match_buffer(c, [M, N], dtype=out_dtype)
A_reindex = T.alloc_buffer([M, K], dtype=in_dtype)
B_reindex = T.alloc_buffer([N, K], dtype=in_dtype)

for i, k in T.grid(M, K):
with T.block("A_reindex"):
vj, vk = T.axis.remap("SS", [i, k])
A_reindex[vj, vk] = A[vj // wm, vk // wk, vj % wm, vk % wk]

for j, k in T.grid(N, K):
with T.block("B_reindex"):
vj, vk = T.axis.remap("SS", [j, k])
B_reindex[vj, vk] = B[vj // wn, vk // wk, vj % wn, vk % wk]

for i, j, k in T.grid(M, N, K):
with T.block("C"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
with T.init():
C[vi, vj] = tvm.tir.const(0, out_dtype)
C[vi, vj] = C[vi, vj] + A_reindex[vi, vk].astype(
out_dtype
) * B_reindex[vj, vk].astype(out_dtype)

return MyModule


# fmt:off
benchmark_sets = [
typical_test_shapes = [
# (prim_func, input_args, default_dlight_schedule),
(matmul_nt, (1024, 1024, 1024, "float16", "float16"), Matmul),
(matmul_nt, (16, 8192, 8192, "float16", "float16"), Matmul),
Expand All @@ -152,11 +107,35 @@ def main(a: T.handle, b: T.handle, c: T.handle):
(matmul_nn, (16384, 16384, 16384, "float16", "float16"), Matmul),
(matmul_nt, (1024, 1024, 1024, "float32", "float32"), Matmul),
(matmul_nt_propagate_b_f16_f16_mma, (16384, 16384, 16384), Matmul),
(matmul_nt_propagate_a_b, (16384, 16384, 16384, "int8", "int32"), Matmul),
(matmul_nt_propagate_a_b, (16384, 16384, 16384, "float16", "float16"), Matmul),
(matmul_nt_propagate_a_propagate_b, (16384, 16384, 16384, "int8", "int32", "int32"), Matmul),
(matmul_nt_propagate_a_propagate_b, (16384, 16384, 16384, "float16", "float16"), Matmul),
]
# fmt:on

llm_shapes = [
# square test
(matmul_nt_propagate_a_propagate_b, (16384, 16384, 16384, "float16", "float16"), Matmul),
# BLOOM-176B
(matmul_nt_propagate_a_propagate_b, (8192, 43008, 14336, "float16", "float16"), Matmul),
(matmul_nt_propagate_a_propagate_b, (8192, 14336, 14336, "float16", "float16"), Matmul),
(matmul_nt_propagate_a_propagate_b, (8192, 57344, 14336, "float16", "float16"), Matmul),
(matmul_nt_propagate_a_propagate_b, (8192, 14336, 57344, "float16", "float16"), Matmul),
# # OPT-65B
(matmul_nt_propagate_a_propagate_b, (8192, 9216, 9216, "float16", "float16"), Matmul),
(matmul_nt_propagate_a_propagate_b, (8192, 36864, 9216, "float16", "float16"), Matmul),
(matmul_nt_propagate_a_propagate_b, (8192, 9216, 36864, "float16", "float16"), Matmul),
(matmul_nt_propagate_a_propagate_b, (8192, 22016, 8192, "float16", "float16"), Matmul),
# # LLAMA-70B/65B
(matmul_nt_propagate_a_propagate_b, (8192, 8192, 22016, "float16", "float16"), Matmul),
(matmul_nt_propagate_a_propagate_b, (8192, 8192, 8192, "float16", "float16"), Matmul),
(matmul_nt_propagate_a_propagate_b, (8192, 28672, 8192, "float16", "float16"), Matmul),
(matmul_nt_propagate_a_propagate_b, (8192, 8192, 28672, "float16", "float16"), Matmul),
]

benchmark_sets = []
benchmark_sets.extend(llm_shapes)


benchmark_results = {}
for get_prim_func, input_args, d_schedule in benchmark_sets:
ir_module = get_prim_func(*input_args)
Expand Down
111 changes: 78 additions & 33 deletions python/bitblas/gpu/matmul_mma.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
get_reduction_blocks,
get_dequantize_block,
normalize_to_matmul,
get_propagate_map
get_propagate_map,
)


Expand Down Expand Up @@ -193,7 +193,9 @@ def apply( # pylint: disable=too-many-locals,missing-docstring
sch.bind(j2, "threadIdx.y")

# Step 4. Read/write to shared mem and register
def fetch_input(block_outer, read_buffer_idx, tensor_name: Literal["A", "B"], is_transpose):
def fetch_input(
block_outer, read_buffer_idx, tensor_name: Literal["A", "B"], is_transpose
):
# 1) Read to shared memory
block_read_smem = sch.cache_read(block_outer, read_buffer_idx, "shared.dyn")
sch.compute_at(block_read_smem, k0)
Expand All @@ -203,15 +205,21 @@ def fetch_input(block_outer, read_buffer_idx, tensor_name: Literal["A", "B"], is

# For transposed read, we directly load transposed tensor from global
# Then use ldmatrix.trans to handle transpose later
if (tensor_name == "A" and is_transpose) or (tensor_name == "B" and not is_transpose):
if (tensor_name == "A" and is_transpose) or (
tensor_name == "B" and not is_transpose
):
# specifical handle transpose read (for NN matmul or TT matmul)
v0, v1 = sch.get_loops(block_read_smem)[-2:]
sch.reorder(v1, v0)
sch.transform_layout(block_read_smem, ("write", 0), lambda b, i, j: (b, j, i))
sch.transform_layout(
block_read_smem, ("write", 0), lambda b, i, j: (b, j, i)
)

# bind loops
fused = sch.fuse(*sch.get_loops(block_read_smem)[-2:])
f0, f1, f2, f3, f4 = sch.split(fused, [None, thread_z, thread_y, thread_x, vector_size])
f0, f1, f2, f3, f4 = sch.split(
fused, [None, thread_z, thread_y, thread_x, vector_size]
)
sch.bind(f1, "threadIdx.z")
sch.bind(f2, "threadIdx.y")
sch.bind(f3, "threadIdx.x")
Expand All @@ -231,8 +239,12 @@ def fetch_input(block_outer, read_buffer_idx, tensor_name: Literal["A", "B"], is
if not is_transpose
else (micro_size_k, micro_size_spatial)
)
v00, v01 = sch.split(sch.get_loops(block_read_reg)[-2], [None, micro_size_1])
v10, v11 = sch.split(sch.get_loops(block_read_reg)[-1], [None, micro_size_2])
v00, v01 = sch.split(
sch.get_loops(block_read_reg)[-2], [None, micro_size_1]
)
v10, v11 = sch.split(
sch.get_loops(block_read_reg)[-1], [None, micro_size_2]
)
sch.reorder(v00, v10, v01, v11)

# reorder read axis to match the layout of ldmatrix
Expand All @@ -243,7 +255,9 @@ def fetch_input(block_outer, read_buffer_idx, tensor_name: Literal["A", "B"], is
v0,
v1 // micro_size_1,
v2 // micro_size_2,
*shared_16x16_to_mma_32x8_layout(v1 % micro_size_1, v2 % micro_size_2),
*shared_16x16_to_mma_32x8_layout(
v1 % micro_size_1, v2 % micro_size_2
),
),
)

Expand All @@ -253,13 +267,19 @@ def fetch_input(block_outer, read_buffer_idx, tensor_name: Literal["A", "B"], is

return block_read_smem, block_read_reg

block_read_a, block_read_reg_a = fetch_input(block_outer, 0, "A", is_transpose_a)
block_read_b, block_read_reg_b = fetch_input(block_outer, 1, "B", is_transpose_b)
block_read_a, block_read_reg_a = fetch_input(
block_outer, 0, "A", is_transpose_a
)
block_read_b, block_read_reg_b = fetch_input(
block_outer, 1, "B", is_transpose_b
)

# Write to register, and then smem
def store_output(block_outer, write_buffer_idx):
# 1) Write to shared memory
block_write_smem = sch.cache_write(block_outer, write_buffer_idx, "shared.dyn")
block_write_smem = sch.cache_write(
block_outer, write_buffer_idx, "shared.dyn"
)
sch.reverse_compute_at(block_write_smem, block_axis)
auto_inline_consumer_chain(sch, block_write_smem)

Expand Down Expand Up @@ -288,13 +308,15 @@ def store_output(block_outer, write_buffer_idx):
v0,
v1 // micro_size_m,
v2 // micro_size_n,
*shared_16x16_to_mma_32x8_layout(v1 % micro_size_m, v2 % micro_size_n),
*shared_16x16_to_mma_32x8_layout(
v1 % micro_size_m, v2 % micro_size_n
),
),
)

return block_write_smem, block_write_reg

block_write_smem, block_write_reg = store_output(block_outer, 0)
_, block_write_reg = store_output(block_outer, 0)

# Step 5. Schedule tensor core computation
block_init = sch.decompose_reduction(block_outer, k0)
Expand Down Expand Up @@ -359,10 +381,6 @@ def apply_config( # pylint: disable=too-many-locals,missing-docstring
return None

main_block = reduction_blocks[0]

# Step 1. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, J, K]/B[S, K, J]
if not (func.attrs is not None and "dlight.tensorcore_prenormlized" in func.attrs.keys()):
sch = normalize_to_matmul(sch, main_block, ["a", "a", "a"])

output_blocks = [sch.get(block) for block in sch.get_output_blocks(root_block)]

Expand All @@ -387,6 +405,13 @@ def check_has_dynamic(func: tir.PrimFunc):

cache_write_required = check_require_cache(func)

# Step 1. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, J, K]/B[S, K, J]
if not (
func.attrs is not None
and "dlight.tensorcore_prenormlized" in func.attrs.keys()
):
sch = normalize_to_matmul(sch, main_block, ["a", "a", "a"])

shared_scope = config.shared_scope

intrin_info = config.intrin_info
Expand Down Expand Up @@ -525,11 +550,17 @@ def smooth_smem_layout_rewrite(block, scope, l=16, r=16, enable=True):
),
)

smooth_smem_layout_rewrite(block_outer, ("read", 0), *a_lr, enable=intrin_info.smooth_a)
smooth_smem_layout_rewrite(block_outer, ("read", 1), *b_lr, enable=intrin_info.smooth_b)
smooth_smem_layout_rewrite(
block_outer, ("read", 0), *a_lr, enable=intrin_info.smooth_a
)
smooth_smem_layout_rewrite(
block_outer, ("read", 1), *b_lr, enable=intrin_info.smooth_b
)
smooth_smem_layout_rewrite(block_outer, ("write", 0), enable=True)

def fetch_to_shared(block, idx, vec_len, can_swizzle=False, is_smooth=False, trans=False):
def fetch_to_shared(
block, idx, vec_len, can_swizzle=False, is_smooth=False, trans=False
):
block_read = sch.cache_read(block, idx, shared_scope)
sch.compute_at(block_read, k0, preserve_unit_loops=True)
ndim = len(sch.get(block_read).iter_vars)
Expand Down Expand Up @@ -574,27 +605,35 @@ def fetch_to_shared(block, idx, vec_len, can_swizzle=False, is_smooth=False, tra
)

# rewrite global smooth layout
def smooth_gmem_layout_rewrite(sch, block, enable=True, trans=False, matrix_name="A"):
def smooth_gmem_layout_rewrite(
sch, block, enable=True, trans=False, matrix_name="A"
):
if not enable:
return
# step1: find the first producer block
# Notes: we assume the layout propagate happens in the first producer block
# otherwise, the layout transform will have no effect as it will transform both
# read and write buffer
producers = _collect_producers(sch, block)

propagate_block: tir.Block = producers[-1]
g2s_block = a_g2s if matrix_name == "A" else b_g2s
propagate_block: tir.Block = producers[-1] if len(producers) > 0 else g2s_block

# step2: transform the layout with inverse permutation
_, inverse_indexmap = get_propagate_map(trans=trans, dtype=intrin_info.in_dtype, matrix_name=matrix_name)
_, inverse_indexmap = get_propagate_map(
trans=trans, dtype=intrin_info.in_dtype, matrix_name=matrix_name
)

def inverse_permutation(i, j, ii, jj):
return (i, j, *inverse_indexmap.map_indices([ii, jj]))

sch.transform_layout(propagate_block, ("read", 0), inverse_permutation)

smooth_gmem_layout_rewrite(sch, a_g2s, intrin_info.smooth_a, intrin_info.trans_a, matrix_name="A")
smooth_gmem_layout_rewrite(sch, b_g2s, intrin_info.smooth_b, intrin_info.trans_b, matrix_name="B")
smooth_gmem_layout_rewrite(
sch, a_g2s, intrin_info.smooth_a, intrin_info.trans_a, matrix_name="A"
)
smooth_gmem_layout_rewrite(
sch, b_g2s, intrin_info.smooth_b, intrin_info.trans_b, matrix_name="B"
)
auto_inline_producers(sch, a_g2s)
auto_inline_producers(sch, b_g2s)

Expand All @@ -620,13 +659,13 @@ def inverse_permutation(i, j, ii, jj):
if cache_write_required:
auto_inline_consumer_chain(sch, accumulator_shared_to_global)
sch.reverse_compute_at(
accumulator_shared_to_global, sch.get_loops(store)[-5], preserve_unit_loops=True
accumulator_shared_to_global,
sch.get_loops(store)[-5],
preserve_unit_loops=True,
)
vec_len = get_coalesced_veclen(sch.get(accumulator_shared_to_global))
fused = sch.fuse(*sch.get_loops(accumulator_shared_to_global)[-5:])
f0, f1, f2 = sch.split(
fused, factors=[None, warp_size, vec_len]
)
f0, f1, f2 = sch.split(fused, factors=[None, warp_size, vec_len])
sch.bind(f1, "threadIdx.x")
sch.vectorize(f2)
sch.unroll(f0)
Expand All @@ -641,10 +680,14 @@ def inverse_permutation(i, j, ii, jj):
index_map_a, index_map_b, index_map_c = intrin_group["index_map"]

sch.transform_layout(
A_mat, ("write", 0), get_warp_index_map(index_map_a, *a_lr, intrin_info.smooth_a)
A_mat,
("write", 0),
get_warp_index_map(index_map_a, *a_lr, intrin_info.smooth_a),
)
sch.transform_layout(
B_mat, ("write", 0), get_warp_index_map(index_map_b, *b_lr, intrin_info.smooth_b)
B_mat,
("write", 0),
get_warp_index_map(index_map_b, *b_lr, intrin_info.smooth_b),
)
sch.transform_layout(
store,
Expand Down Expand Up @@ -676,7 +719,9 @@ def tensorize_init_store_compute():
tensorize_init_store_compute()

if stage > 1:
sch.annotate(k0, ann_key="software_pipeline_stage", ann_val=[0, 0, stage - 1])
sch.annotate(
k0, ann_key="software_pipeline_stage", ann_val=[0, 0, stage - 1]
)
sch.annotate(k0, ann_key="software_pipeline_order", ann_val=[0, 1, 2])
if use_async:
sch.annotate(k0, "software_pipeline_async_stages", [0])
Expand Down

0 comments on commit ebbd294

Please sign in to comment.