Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TL][BugFix] Add implementation of TL Gemm and Fix a bug for TL Jit #195

Merged
merged 72 commits into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
72 commits
Select commit Hold shift + click to select a range
d8884e6
Refactor BatchMatMulEmitter and BatchMatMulSelector for improved read…
LeiWang1999 Jul 5, 2024
fc84173
Refactor import statements for improved readability and maintainability
LeiWang1999 Jul 5, 2024
02f64de
Refactor import statements for improved readability and maintainability
LeiWang1999 Jul 5, 2024
397eee6
disable failure email for ci
LeiWang1999 Jul 5, 2024
20f6ad1
remove email notifications.
LeiWang1999 Jul 6, 2024
b93c394
move relax pass from testing to mlc_llm
LeiWang1999 Jul 6, 2024
ba6a6df
Merge branch 'main' of https://github.com/Microsoft/BitBLAS into main
LeiWang1999 Jul 6, 2024
257693a
Refactor scripts with se check_eual_ref_scripts_with_emitter function
LeiWang1999 Jul 6, 2024
9bb7f49
Lint Fix
LeiWang1999 Jul 6, 2024
39e7614
Merge branch 'main' of https://github.com/Microsoft/BitBLAS into main
LeiWang1999 Jul 6, 2024
93eb5a5
Refactor scripts with se check_eual_ref_scripts_with_emitter function
LeiWang1999 Jul 6, 2024
72b9740
Merge branch 'main' of https://github.com/Microsoft/BitBLAS into main
LeiWang1999 Aug 23, 2024
5b65979
Merge branch 'main' of https://github.com/Microsoft/BitBLAS into main
LeiWang1999 Aug 27, 2024
d9bd479
Merge branch 'main' of https://github.com/Microsoft/BitBLAS into main
LeiWang1999 Aug 29, 2024
99515cb
buf fix for matrix support
LeiWang1999 Aug 29, 2024
14406ef
lint fix
LeiWang1999 Aug 29, 2024
d30ec4f
dispatch tensor core based on shapes
LeiWang1999 Aug 29, 2024
fde4029
update install commands
LeiWang1999 Aug 30, 2024
6a04749
import scripts
LeiWang1999 Aug 31, 2024
9d90c40
Merge branch 'main' of https://github.com/Microsoft/BitBLAS into docs
LeiWang1999 Aug 31, 2024
9ef14e9
remove shared mem hack
LeiWang1999 Sep 1, 2024
63f363e
revert change for swizzling
LeiWang1999 Sep 1, 2024
b29c66c
bug fix
LeiWang1999 Sep 1, 2024
4643dd9
Merge branch 'main' of https://github.com/Microsoft/BitBLAS into docs
LeiWang1999 Sep 1, 2024
28beb13
tl examples
LeiWang1999 Sep 2, 2024
c0b476f
Enhance Swizzle
LeiWang1999 Sep 2, 2024
2bf14a8
lint fix
LeiWang1999 Sep 2, 2024
52accbf
Merge branch 'main' of https://github.com/Microsoft/BitBLAS into tl-l…
LeiWang1999 Sep 2, 2024
19aa985
test fix
LeiWang1999 Sep 3, 2024
ef8f93c
lint fix
LeiWang1999 Sep 3, 2024
4015cc4
optimize layout
LeiWang1999 Sep 3, 2024
5c5880c
update tl utils.
LeiWang1999 Sep 3, 2024
1042ffd
macro optimization
LeiWang1999 Sep 3, 2024
1ecd76e
Merge branch 'main' of https://github.com/Microsoft/BitBLAS into tl-l…
LeiWang1999 Sep 3, 2024
7bb21e7
test fix
LeiWang1999 Sep 4, 2024
6a22442
gemm_ss
LeiWang1999 Sep 4, 2024
b9ea093
Merge branch 'main' of https://github.com/Microsoft/BitBLAS into tl-l…
LeiWang1999 Sep 4, 2024
e9b56b4
doc fix
LeiWang1999 Sep 4, 2024
3eb6888
lint fix
LeiWang1999 Sep 6, 2024
5322785
Merge branch 'main' of https://github.com/Microsoft/BitBLAS into tl-l…
LeiWang1999 Sep 6, 2024
6f18d15
lint fix
LeiWang1999 Sep 6, 2024
187f448
remove debug print
LeiWang1999 Sep 6, 2024
e1fac68
remove debug print
LeiWang1999 Sep 6, 2024
4f25626
vectorization init
LeiWang1999 Sep 6, 2024
2686030
Merge branch 'main' of https://github.com/Microsoft/BitBLAS into tl-l…
LeiWang1999 Sep 6, 2024
23a8e8b
lint fix
LeiWang1999 Sep 6, 2024
069ad5e
prelude update
LeiWang1999 Sep 6, 2024
23fe3f8
Merge branch 'main' of https://github.com/Microsoft/BitBLAS into tl-l…
LeiWang1999 Sep 6, 2024
9119dd3
update tvm
LeiWang1999 Sep 16, 2024
15f4c1f
bug fix for reduce_k with shared memory
LeiWang1999 Sep 16, 2024
f8518ae
bug fix
LeiWang1999 Sep 16, 2024
ea50147
bug fix
LeiWang1999 Sep 16, 2024
f888af1
Enhance Macro Generation
LeiWang1999 Sep 16, 2024
a0bfabf
Lift Layout to reduce load time
LeiWang1999 Sep 16, 2024
b1fdbcf
lint fix
LeiWang1999 Sep 16, 2024
137b6fd
Merge branch 'main' of https://github.com/Microsoft/BitBLAS into tl-l…
LeiWang1999 Sep 16, 2024
0acc369
test fix
LeiWang1999 Sep 16, 2024
62de446
red fix
LeiWang1999 Sep 17, 2024
958f6f2
Merge branch 'main' of https://github.com/Microsoft/BitBLAS into tl-l…
LeiWang1999 Sep 26, 2024
f21b25c
tile lang macro example
LeiWang1999 Sep 26, 2024
0fb9535
tile lang macro example
LeiWang1999 Sep 26, 2024
2c93dad
optimize the marcro generator related items
LeiWang1999 Sep 26, 2024
e5bbf81
lint fix
LeiWang1999 Sep 26, 2024
5cfce84
Tile Lang Test with Dynamic Symbolic
LeiWang1999 Sep 26, 2024
9bafdef
more test case with block level programming
LeiWang1999 Sep 26, 2024
15f64c1
all dynamic test case
LeiWang1999 Sep 26, 2024
08bc9d4
Merge branch 'main' of https://github.com/microsoft/BitBLAS into tl-l…
LeiWang1999 Sep 26, 2024
3639e05
simplify the test case for dequantize gemm.
LeiWang1999 Sep 27, 2024
c148a22
dequant gemm updare.
LeiWang1999 Sep 27, 2024
f4486f7
Tile Lang GEMM Implementation
LeiWang1999 Sep 27, 2024
6a07890
Tile Lang Gemm Fix
LeiWang1999 Sep 27, 2024
8d45157
Merge branch 'main' of https://github.com/microsoft/BitBLAS into tl-test
LeiWang1999 Sep 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/tvm
4 changes: 4 additions & 0 deletions bitblas/ops/general_matmul/tilelang/dense/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

from .matmul import matmul_blocked # noqa: F401
189 changes: 189 additions & 0 deletions bitblas/ops/general_matmul/tilelang/dense/matmul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from bitblas import tvm as tvm
import tvm.tl.language as T

from bitblas.tl.utils import (
get_mma_micro_size,
make_swizzle_layout,
)

from bitblas.tl.macro_generator import (TensorCoreIntrinEmitter)


def maybe_pipeline(
iterable,
num_stages,
):
enable_pipeline = num_stages > 1
if enable_pipeline:
return T.Pipelined(iterable, num_stages=num_stages)
else:
return T.serial(iterable)


def matmul_blocked(
M,
N,
K,
block_M=64,
block_N=64,
block_K=32,
trans_A=False,
trans_B=False,
dtypeAB="float16",
dtypeC="float16",
accum_dtype="float16",
num_stages=2,
threads=128,
enable_rasterization=False, # Enhance L2 Locality
):
A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)

@T.prim_func
def main(
A: T.Buffer(A_shape, dtypeAB),
B: T.Buffer(B_shape, dtypeAB),
C: T.Buffer((M, N), dtypeC),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, dtypeAB)
B_shared = T.alloc_shared(B_shared_shape, dtypeAB)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)

if enable_rasterization:
# rasterization factor
T.use_swizzle(10)

T.clear(C_local)
for k in maybe_pipeline(T.ceildiv(K, block_K), num_stages):
if trans_A:
T.copy(A[k * block_K, by * block_M], A_shared)
else:
T.copy(A[by * block_M, k * block_K], A_shared)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local, trans_A, trans_B)
T.copy(C_local, C[by * block_M, bx * block_N])

return main


def matmul_macro_tensorcore(
M,
N,
K,
dtypeAB,
dtypeC,
accum_dtype,
block_row_warps,
block_col_warps,
warp_row_tiles,
warp_col_tiles,
chunk,
num_stages=2,
enable_rasterization=False,
):

block_M = block_row_warps * warp_row_tiles
block_N = block_col_warps * warp_col_tiles
block_K = chunk

micro_size_x, micro_size_y, micro_size_k = get_mma_micro_size(dtypeAB)

A_shape = (M, K)
B_shape = (N, K)
A_shared_shape = (block_M, block_K)
B_shared_shape = (block_N, block_K)
C_shared_shape = (block_M // micro_size_x, block_N // micro_size_y, micro_size_x, micro_size_y)

warp_size = 32 # nvidia gpu warp size is 32
threads = warp_size * (block_row_warps * block_col_warps)
local_size = (micro_size_x * micro_size_y) // warp_size
warp_rows = warp_row_tiles // micro_size_x
warp_cols = warp_col_tiles // micro_size_y

shared_scope = "shared.dyn" # Literal["shared", "shared.dyn"] while shared for static shared memory
mma_emitter = TensorCoreIntrinEmitter(
a_dtype=dtypeAB,
b_dtype=dtypeAB,
accum_dtype=accum_dtype,
a_transposed=False,
b_transposed=True,
block_row_warps=block_row_warps,
block_col_warps=block_col_warps,
warp_row_tiles=warp_row_tiles,
warp_col_tiles=warp_col_tiles,
chunk=chunk)

@T.prim_func
def main(
A: T.Buffer(A_shape, dtypeAB),
B: T.Buffer(B_shape, dtypeAB),
C: T.Buffer((M, N), dtypeC),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):

A_shared = T.alloc_shared(A_shared_shape, dtypeAB, shared_scope=shared_scope)
B_shared = T.alloc_shared(B_shared_shape, dtypeAB, shared_scope=shared_scope)
C_shared = T.alloc_shared(C_shared_shape, dtypeC, shared_scope=shared_scope)
A_local = T.alloc_local((warp_rows * local_size), dtypeAB)
B_local = T.alloc_local((warp_cols * local_size), dtypeAB)
C_local = T.alloc_local((warp_rows * warp_cols * local_size), accum_dtype)
thread_bindings = T.thread_binding(0, threads, "threadIdx.x")

T.annotate_layout({
A_shared: make_swizzle_layout(A_shared),
B_shared: make_swizzle_layout(B_shared),
})

if enable_rasterization:
T.use_swizzle(panel_size=10)

T.clear(C_local)

for ko in maybe_pipeline(T.ceildiv(K, block_K), num_stages):

for i, k in T.Parallel(block_M, block_K):
A_shared[i, k] = A[by * block_M + i, ko * block_K + k]

for j, k in T.Parallel(block_N, block_K):
B_shared[j, k] = B[bx * block_N + j, ko * block_K + k]

for ki in T.serial(0, (block_K // micro_size_k)):

# Load A into fragment
mma_emitter.ldmatrix_a(
A_local,
A_shared,
ki,
thread_bindings=thread_bindings,
)

# Load B into fragment
mma_emitter.ldmatrix_b(
B_local,
B_shared,
ki,
thread_bindings=thread_bindings,
)

mma_emitter.mma(A_local, B_local, C_local)

mma_emitter.stmatrix(
C_local,
C_shared,
thread_bindings=thread_bindings,
)

for i, j in T.Parallel(block_M, block_N):
C[by * block_M + i,
bx * block_N + j] = C_shared[i // micro_size_x, j // micro_size_y,
i % micro_size_x, j % micro_size_y]

return main
47 changes: 47 additions & 0 deletions bitblas/tl/mma_layout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from tvm import arith
from tvm import DataType
from typing import Union, Literal


def ldmatrix_32x8_to_shared_16x16_layout(thread_id, local_id):
row = thread_id % 16
col = 8 * (thread_id // 16) + local_id % 8
return row, col


def ldmatrix_trans_32x8_to_shared_16x16_layout(thread_id, local_id):
row = 8 * (thread_id // 16) + (thread_id % 8)
col = 8 * ((thread_id % 16) // 8) + local_id % 8
return row, col


def ldmatrix_32x16_to_shared_16x32_layout_a(thread_id, local_id):
row = thread_id % 16
col = local_id + (thread_id // 16) * 16
return row, col


def ldmatrix_32x16_to_shared_16x32_layout_b(thread_id, local_id):
row = (thread_id // 16) * 8 + (thread_id % 8)
col = local_id + 16 * ((thread_id % 16) // 8)
return row, col


def mma_store_32x8_to_shared_16x16_layout(thread_id, local_id):
row = 8 * (local_id % 4 // 2) + (thread_id // 4)
col = 8 * (local_id // 4) + (thread_id % 4) * 2 + (local_id % 2)
return row, col


def shared_16x16_to_mma_32x8_smoothlayout(i, j):
return (i * 2 + j // 8, j % 8)


def shared_16x32_to_mma_32x16_smoothlayout(i, j):
return (i * 2 + j // 16, j % 16)


def shared_32x16_to_mma_32x16_smoothlayout(i, j):
return (i * 2 + j // 16, j % 16)
76 changes: 34 additions & 42 deletions bitblas/tl/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

from tvm import arith
from tvm import DataType
import tvm.tl.language as T
from typing import Union, Literal
from .mma_layout import (
ldmatrix_32x8_to_shared_16x16_layout,
ldmatrix_trans_32x8_to_shared_16x16_layout,
ldmatrix_32x16_to_shared_16x32_layout_a,
ldmatrix_32x16_to_shared_16x32_layout_b,
mma_store_32x8_to_shared_16x16_layout,
)


def get_swizzle_layout(row_idx, col_idx, row_size, dtype: Union[DataType, str]):
Expand Down Expand Up @@ -61,48 +70,6 @@ def get_swizzle_layout(row_idx, col_idx, row_size, dtype: Union[DataType, str]):
return row_idx, ana.simplify(new_col_idx_outer * bank_elems + col_idx_inner)


def ldmatrix_32x8_to_shared_16x16_layout(thread_id, local_id):
row = thread_id % 16
col = 8 * (thread_id // 16) + local_id % 8
return row, col


def ldmatrix_trans_32x8_to_shared_16x16_layout(thread_id, local_id):
row = 8 * (thread_id // 16) + (thread_id % 8)
col = 8 * ((thread_id % 16) // 8) + local_id % 8
return row, col


def ldmatrix_32x16_to_shared_16x32_layout_a(thread_id, local_id):
row = thread_id % 16
col = local_id + (thread_id // 16) * 16
return row, col


def ldmatrix_32x16_to_shared_16x32_layout_b(thread_id, local_id):
row = (thread_id // 16) * 8 + (thread_id % 8)
col = local_id + 16 * ((thread_id % 16) // 8)
return row, col


def mma_store_32x8_to_shared_16x16_layout(thread_id, local_id):
row = 8 * (local_id % 4 // 2) + (thread_id // 4)
col = 8 * (local_id // 4) + (thread_id % 4) * 2 + (local_id % 2)
return row, col


def shared_16x16_to_mma_32x8_smoothlayout(i, j):
return (i * 2 + j // 8, j % 8)


def shared_16x32_to_mma_32x16_smoothlayout(i, j):
return (i * 2 + j // 16, j % 16)


def shared_32x16_to_mma_32x16_smoothlayout(i, j):
return (i * 2 + j // 16, j % 16)


def get_ldmatrix_offset(
matrix: Literal["A", "B"],
row_idx,
Expand All @@ -129,3 +96,28 @@ def get_ldmatrix_offset(

def mma_store_index_map(*args, **kwargs):
return mma_store_32x8_to_shared_16x16_layout(*args, **kwargs)


def get_mma_micro_size(dtype: Literal["float16", "int8"]):
# TODO(lei): FP8 related precision support.
# Basic Tensor Core Matrix Multiply operation Unit
micro_size_x = micro_size_y = 16
micro_size_k = 16
if dtype == "int8":
micro_size_k = 32
return micro_size_x, micro_size_y, micro_size_k


def make_swizzle_layout(shared_buf):
dtype = shared_buf.dtype
shape = shared_buf.shape

can_swizzle = shape[-1] * DataType(dtype).bits == 512
if not can_swizzle:
return T.Layout(shape, lambda *args: args)

def transform_func(i, j):
new_warp_i, new_warp_j = get_swizzle_layout(i, j, shape[-1], dtype)
return [new_warp_i, new_warp_j]

return T.Layout(shape, transform_func)
Loading
Loading