From 8f7767bca01e54998292a16edbd15d8f44576d18 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Fri, 1 Nov 2024 16:01:31 +0000 Subject: [PATCH] lint fix --- bitblas/tl/macro_generator.py | 22 ++++--------- .../tilelang/test_tilelang_gemm_s4_mma.py | 31 +++++++------------ .../tilelang/test_tilelang_macro_gemm.py | 16 +++++----- 3 files changed, 24 insertions(+), 45 deletions(-) diff --git a/bitblas/tl/macro_generator.py b/bitblas/tl/macro_generator.py index 25376f294..fd8ec43ae 100644 --- a/bitblas/tl/macro_generator.py +++ b/bitblas/tl/macro_generator.py @@ -473,6 +473,7 @@ def _warp_mma(A_local_buf, B_local_buf, C_local_buf): return _warp_mma(A_local_buf, B_local_buf, C_local_buf) + class INT4TensorCoreIntrinEmitter(TensorCoreIntrinEmitter): def mma(self, A_local_buf, B_local_buf, C_local_buf): @@ -533,9 +534,7 @@ def _warp_mma(A_local_buf, B_local_buf, C_local_buf): B_local_buf.data, j * local_size_b + lift(local_size_b) // 2, C_local_buf.data, - i * warp_cols * local_size_out - + j * local_size_out - + lift(local_size_out) // 2, + i * warp_cols * local_size_out + j * local_size_out + lift(local_size_out) // 2, T.bool(False), ) @@ -571,18 +570,14 @@ def _warp_mma(A_local_buf, B_local_buf, C_local_buf): B_local_buf.data, j * local_size_b + lift(local_size_b) // 2 + lift(local_size_b) // 4, C_local_buf.data, - i * warp_cols * local_size_out - + j * local_size_out - + lift(local_size_out) // 2, + i * warp_cols * local_size_out + j * local_size_out + lift(local_size_out) // 2, T.bool(False), ) return _warp_mma(A_local_buf, B_local_buf, C_local_buf) -class INT4TensorCoreIntrinEmitterWithLadderTransform( - TensorCoreIntrinEmitterWithLadderTransform -): +class INT4TensorCoreIntrinEmitterWithLadderTransform(TensorCoreIntrinEmitterWithLadderTransform): def mma(self, A_local_buf, B_local_buf, C_local_buf): @@ -643,9 +638,7 @@ def _warp_mma(A_local_buf, B_local_buf, C_local_buf): B_local_buf.data, j * local_size_b + lift(local_size_b) // 2, C_local_buf.data, - i * warp_cols * local_size_out - + j * local_size_out - + lift(local_size_out) // 2, + i * warp_cols * local_size_out + j * local_size_out + lift(local_size_out) // 2, T.bool(False), ) @@ -681,11 +674,8 @@ def _warp_mma(A_local_buf, B_local_buf, C_local_buf): B_local_buf.data, j * local_size_b + lift(local_size_b) // 2 + lift(local_size_b) // 4, C_local_buf.data, - i * warp_cols * local_size_out - + j * local_size_out - + lift(local_size_out) // 2, + i * warp_cols * local_size_out + j * local_size_out + lift(local_size_out) // 2, T.bool(False), ) - return _warp_mma(A_local_buf, B_local_buf, C_local_buf) diff --git a/testing/python/tilelang/test_tilelang_gemm_s4_mma.py b/testing/python/tilelang/test_tilelang_gemm_s4_mma.py index bda7334eb..37c210b91 100644 --- a/testing/python/tilelang/test_tilelang_gemm_s4_mma.py +++ b/testing/python/tilelang/test_tilelang_gemm_s4_mma.py @@ -9,8 +9,7 @@ from tvm import tl as TL import tvm.tl.language as T from bitblas.tl.utils import ( - make_swizzle_layout, -) + make_swizzle_layout,) from bitblas.tl.macro_generator import ( INT4TensorCoreIntrinEmitter, @@ -20,6 +19,7 @@ torch.manual_seed(0) + @simplify_prim_func def tl_matmul( M, @@ -61,8 +61,8 @@ def tl_matmul( block_N = block_col_warps * warp_col_tiles block_K = chunk - A_shape = (M, K) # int8 storage represents int4*2 - B_shape = (N, K) # int8 storage represents int4*2 + A_shape = (M, K) # int8 storage represents int4*2 + B_shape = (N, K) # int8 storage represents int4*2 A_shared_shape = (block_M, block_K) B_shared_shape = (block_N, block_K) C_shared_shape = ( @@ -107,9 +107,7 @@ def main( C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) - C_local = T.alloc_local( - (warp_rows * warp_cols * local_size_c), accum_dtype - ) + C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) thread_bindings = T.thread_binding(0, threads, "threadIdx.x") @@ -180,7 +178,6 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): # src_code is the generated cuda source assert src_code is not None - A = torch.randint(0, 4, (M, K), device="cuda", dtype=getattr(torch, in_dtype)) B = torch.randint(0, 4, (N, K), device="cuda", dtype=getattr(torch, in_dtype)) C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) @@ -196,9 +193,7 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): assert latency is not None # Get Reference Result - ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to( - getattr(torch, accum_dtype) - ) + ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to(getattr(torch, accum_dtype)) print(ref_c) torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) @@ -245,7 +240,7 @@ def tl_matmul_weight_only_transform( block_M = block_row_warps * warp_row_tiles block_N = block_col_warps * warp_col_tiles block_K = chunk - + is_smooth_a = False can_swizzle = block_K * DataType(in_dtype).bits == 512 apply_pad_a = not (is_smooth_a or can_swizzle) @@ -291,6 +286,7 @@ def tl_matmul_weight_only_transform( chunk=chunk, transform_kind_b=transform_b, ) + @T.prim_func def main( A: T.Buffer(A_shape, in_dtype), @@ -304,9 +300,7 @@ def main( C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) - C_local = T.alloc_local( - (warp_rows * warp_cols * local_size_c), accum_dtype - ) + C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) thread_bindings = T.thread_binding(0, threads, "threadIdx.x") @@ -379,7 +373,7 @@ def assert_tl_matmul_weight_only_transform_correctness(M, N, K, in_dtype, out_dt # src_code is the generated cuda source assert src_code is not None transform_b = 3 - + A = torch.randint(0, 4, (M, K), device="cuda", dtype=getattr(torch, in_dtype)) B = torch.randint(0, 4, (N, K), device="cuda", dtype=getattr(torch, in_dtype)) C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) @@ -408,9 +402,7 @@ def assert_tl_matmul_weight_only_transform_correctness(M, N, K, in_dtype, out_dt assert latency is not None # Get Reference Result - ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to( - getattr(torch, accum_dtype) - ) + ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to(getattr(torch, accum_dtype)) print(C) print(ref_c) torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) @@ -420,6 +412,5 @@ def test_assert_tl_matmul_weight_only_transform(): assert_tl_matmul_weight_only_transform_correctness(128, 128, 128, "int8", "int32", "int32") - if __name__ == "__main__": bitblas.testing.main() diff --git a/testing/python/tilelang/test_tilelang_macro_gemm.py b/testing/python/tilelang/test_tilelang_macro_gemm.py index 346888f4d..4c4cf8f59 100644 --- a/testing/python/tilelang/test_tilelang_macro_gemm.py +++ b/testing/python/tilelang/test_tilelang_macro_gemm.py @@ -709,13 +709,9 @@ def main( B_shared = T.alloc_shared(B_shared_shape, storage_dtype, scope=shared_scope) C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) - B_local = T.alloc_local( - (warp_cols * local_size_b // num_elems_per_byte), storage_dtype - ) + B_local = T.alloc_local((warp_cols * local_size_b // num_elems_per_byte), storage_dtype) B_dequantize_local = T.alloc_local((warp_cols * local_size_b), in_dtype) - C_local = T.alloc_local( - (warp_rows * warp_cols * local_size_c), accum_dtype - ) + C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) reduced_accum_res = T.alloc_local(0, accum_dtype) thread_bindings = T.thread_binding(0, threads, "threadIdx.x") rk = T.thread_binding(0, reduce_k, "threadIdx.y") @@ -773,9 +769,11 @@ def main( ) for j in T.serial(warp_cols): - T.call_extern('handle', 'decode_i4u_to_f16', - T.address_of(B_local[j * mma_emitter.local_size_b // num_elems_per_byte]), - T.address_of(B_dequantize_local[j * mma_emitter.local_size_b]), 8) + T.call_extern( + 'handle', 'decode_i4u_to_f16', + T.address_of(B_local[j * mma_emitter.local_size_b // + num_elems_per_byte]), + T.address_of(B_dequantize_local[j * mma_emitter.local_size_b]), 8) mma_emitter.mma(A_local, B_dequantize_local, C_local)