Skip to content

Commit

Permalink
lint fix
Browse files Browse the repository at this point in the history
  • Loading branch information
LeiWang1999 committed Nov 1, 2024
1 parent fd4973c commit 8f7767b
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 45 deletions.
22 changes: 6 additions & 16 deletions bitblas/tl/macro_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,7 @@ def _warp_mma(A_local_buf, B_local_buf, C_local_buf):

return _warp_mma(A_local_buf, B_local_buf, C_local_buf)


class INT4TensorCoreIntrinEmitter(TensorCoreIntrinEmitter):

def mma(self, A_local_buf, B_local_buf, C_local_buf):
Expand Down Expand Up @@ -533,9 +534,7 @@ def _warp_mma(A_local_buf, B_local_buf, C_local_buf):
B_local_buf.data,
j * local_size_b + lift(local_size_b) // 2,
C_local_buf.data,
i * warp_cols * local_size_out
+ j * local_size_out
+ lift(local_size_out) // 2,
i * warp_cols * local_size_out + j * local_size_out + lift(local_size_out) // 2,
T.bool(False),
)

Expand Down Expand Up @@ -571,18 +570,14 @@ def _warp_mma(A_local_buf, B_local_buf, C_local_buf):
B_local_buf.data,
j * local_size_b + lift(local_size_b) // 2 + lift(local_size_b) // 4,
C_local_buf.data,
i * warp_cols * local_size_out
+ j * local_size_out
+ lift(local_size_out) // 2,
i * warp_cols * local_size_out + j * local_size_out + lift(local_size_out) // 2,
T.bool(False),
)

return _warp_mma(A_local_buf, B_local_buf, C_local_buf)


class INT4TensorCoreIntrinEmitterWithLadderTransform(
TensorCoreIntrinEmitterWithLadderTransform
):
class INT4TensorCoreIntrinEmitterWithLadderTransform(TensorCoreIntrinEmitterWithLadderTransform):

def mma(self, A_local_buf, B_local_buf, C_local_buf):

Expand Down Expand Up @@ -643,9 +638,7 @@ def _warp_mma(A_local_buf, B_local_buf, C_local_buf):
B_local_buf.data,
j * local_size_b + lift(local_size_b) // 2,
C_local_buf.data,
i * warp_cols * local_size_out
+ j * local_size_out
+ lift(local_size_out) // 2,
i * warp_cols * local_size_out + j * local_size_out + lift(local_size_out) // 2,
T.bool(False),
)

Expand Down Expand Up @@ -681,11 +674,8 @@ def _warp_mma(A_local_buf, B_local_buf, C_local_buf):
B_local_buf.data,
j * local_size_b + lift(local_size_b) // 2 + lift(local_size_b) // 4,
C_local_buf.data,
i * warp_cols * local_size_out
+ j * local_size_out
+ lift(local_size_out) // 2,
i * warp_cols * local_size_out + j * local_size_out + lift(local_size_out) // 2,
T.bool(False),
)


return _warp_mma(A_local_buf, B_local_buf, C_local_buf)
31 changes: 11 additions & 20 deletions testing/python/tilelang/test_tilelang_gemm_s4_mma.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@
from tvm import tl as TL
import tvm.tl.language as T
from bitblas.tl.utils import (
make_swizzle_layout,
)
make_swizzle_layout,)

from bitblas.tl.macro_generator import (
INT4TensorCoreIntrinEmitter,
Expand All @@ -20,6 +19,7 @@

torch.manual_seed(0)


@simplify_prim_func
def tl_matmul(
M,
Expand Down Expand Up @@ -61,8 +61,8 @@ def tl_matmul(
block_N = block_col_warps * warp_col_tiles
block_K = chunk

A_shape = (M, K) # int8 storage represents int4*2
B_shape = (N, K) # int8 storage represents int4*2
A_shape = (M, K) # int8 storage represents int4*2
B_shape = (N, K) # int8 storage represents int4*2
A_shared_shape = (block_M, block_K)
B_shared_shape = (block_N, block_K)
C_shared_shape = (
Expand Down Expand Up @@ -107,9 +107,7 @@ def main(
C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope)
A_local = T.alloc_local((warp_rows * local_size_a), in_dtype)
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
C_local = T.alloc_local(
(warp_rows * warp_cols * local_size_c), accum_dtype
)
C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype)

thread_bindings = T.thread_binding(0, threads, "threadIdx.x")

Expand Down Expand Up @@ -180,7 +178,6 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
# src_code is the generated cuda source
assert src_code is not None


A = torch.randint(0, 4, (M, K), device="cuda", dtype=getattr(torch, in_dtype))
B = torch.randint(0, 4, (N, K), device="cuda", dtype=getattr(torch, in_dtype))
C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype))
Expand All @@ -196,9 +193,7 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
assert latency is not None

# Get Reference Result
ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to(
getattr(torch, accum_dtype)
)
ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to(getattr(torch, accum_dtype))

print(ref_c)
torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2)
Expand Down Expand Up @@ -245,7 +240,7 @@ def tl_matmul_weight_only_transform(
block_M = block_row_warps * warp_row_tiles
block_N = block_col_warps * warp_col_tiles
block_K = chunk

is_smooth_a = False
can_swizzle = block_K * DataType(in_dtype).bits == 512
apply_pad_a = not (is_smooth_a or can_swizzle)
Expand Down Expand Up @@ -291,6 +286,7 @@ def tl_matmul_weight_only_transform(
chunk=chunk,
transform_kind_b=transform_b,
)

@T.prim_func
def main(
A: T.Buffer(A_shape, in_dtype),
Expand All @@ -304,9 +300,7 @@ def main(
C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope)
A_local = T.alloc_local((warp_rows * local_size_a), in_dtype)
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
C_local = T.alloc_local(
(warp_rows * warp_cols * local_size_c), accum_dtype
)
C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype)

thread_bindings = T.thread_binding(0, threads, "threadIdx.x")

Expand Down Expand Up @@ -379,7 +373,7 @@ def assert_tl_matmul_weight_only_transform_correctness(M, N, K, in_dtype, out_dt
# src_code is the generated cuda source
assert src_code is not None
transform_b = 3

A = torch.randint(0, 4, (M, K), device="cuda", dtype=getattr(torch, in_dtype))
B = torch.randint(0, 4, (N, K), device="cuda", dtype=getattr(torch, in_dtype))
C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype))
Expand Down Expand Up @@ -408,9 +402,7 @@ def assert_tl_matmul_weight_only_transform_correctness(M, N, K, in_dtype, out_dt
assert latency is not None

# Get Reference Result
ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to(
getattr(torch, accum_dtype)
)
ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to(getattr(torch, accum_dtype))
print(C)
print(ref_c)
torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2)
Expand All @@ -420,6 +412,5 @@ def test_assert_tl_matmul_weight_only_transform():
assert_tl_matmul_weight_only_transform_correctness(128, 128, 128, "int8", "int32", "int32")



if __name__ == "__main__":
bitblas.testing.main()
16 changes: 7 additions & 9 deletions testing/python/tilelang/test_tilelang_macro_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,13 +709,9 @@ def main(
B_shared = T.alloc_shared(B_shared_shape, storage_dtype, scope=shared_scope)
C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope)
A_local = T.alloc_local((warp_rows * local_size_a), in_dtype)
B_local = T.alloc_local(
(warp_cols * local_size_b // num_elems_per_byte), storage_dtype
)
B_local = T.alloc_local((warp_cols * local_size_b // num_elems_per_byte), storage_dtype)
B_dequantize_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
C_local = T.alloc_local(
(warp_rows * warp_cols * local_size_c), accum_dtype
)
C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype)
reduced_accum_res = T.alloc_local(0, accum_dtype)
thread_bindings = T.thread_binding(0, threads, "threadIdx.x")
rk = T.thread_binding(0, reduce_k, "threadIdx.y")
Expand Down Expand Up @@ -773,9 +769,11 @@ def main(
)

for j in T.serial(warp_cols):
T.call_extern('handle', 'decode_i4u_to_f16',
T.address_of(B_local[j * mma_emitter.local_size_b // num_elems_per_byte]),
T.address_of(B_dequantize_local[j * mma_emitter.local_size_b]), 8)
T.call_extern(
'handle', 'decode_i4u_to_f16',
T.address_of(B_local[j * mma_emitter.local_size_b //
num_elems_per_byte]),
T.address_of(B_dequantize_local[j * mma_emitter.local_size_b]), 8)

mma_emitter.mma(A_local, B_dequantize_local, C_local)

Expand Down

0 comments on commit 8f7767b

Please sign in to comment.