From b913d263f7b60abc776d7be232542f80317d17b7 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Tue, 6 Aug 2024 07:04:11 +0000 Subject: [PATCH] bug fix for test --- .github/workflows/benchmark.yml | 1 + bitblas/gpu/matmul_mma.py | 12 +++++++++++- bitblas/gpu/matmul_mma_dequantize.py | 16 +++++++++++++--- .../python/operators/test_general_matmul_ops.py | 3 +++ 4 files changed, 28 insertions(+), 4 deletions(-) diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 235b8686d..013345f6f 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -62,6 +62,7 @@ jobs: uses: actions/checkout@v2 with: fetch-depth: 0 + ref: ${{ github.event.pull_request.head.ref }} - name: Get PR branch commit ID id: get_pr_commit diff --git a/bitblas/gpu/matmul_mma.py b/bitblas/gpu/matmul_mma.py index 8700e6580..5d92f99b1 100644 --- a/bitblas/gpu/matmul_mma.py +++ b/bitblas/gpu/matmul_mma.py @@ -315,7 +315,17 @@ def store_output(block_outer, write_buffer_idx): sch.tensorize(sch.get_loops(block_init_inner)[-2], intrin_group["init"]) sch.tensorize(sch.get_loops(block_read_reg_a)[-2], intrin_group["load_a"]) - sch.tensorize(sch.get_loops(block_read_reg_b)[-2], intrin_group["load_b"]) + weight_transform_kind = 0 + if hasattr(func, "attrs") and "weight_transform_kind" in func.attrs: + weight_transform_kind = func.attrs["weight_transform_kind"] + if weight_transform_kind >= TransformKind.LDMatrixTransform: + fused = sch.fuse(sch.get_loops(block_read_reg_b)[-2:]) + vec_len = get_coalesced_veclen(sch.get(block_read_reg_b)) + f0, f1, f2 = sch.split(fused, factors=[None, 32, vec_len]) + sch.bind(f1, "threadIdx.x") + sch.vectorize(f2) + else: + sch.tensorize(sch.get_loops(block_read_reg_b)[-2], intrin_group["load_b"]) sch.tensorize(sch.get_loops(block_inner)[-3], intrin_group["compute"]) sch.tensorize(sch.get_loops(block_write_reg)[-2], intrin_group["store"]) diff --git a/bitblas/gpu/matmul_mma_dequantize.py b/bitblas/gpu/matmul_mma_dequantize.py index 2033b8f75..7421cbd47 100644 --- a/bitblas/gpu/matmul_mma_dequantize.py +++ b/bitblas/gpu/matmul_mma_dequantize.py @@ -628,9 +628,19 @@ def get_idx(): i0, i1 = sch.split(i, factors=[None, b_lr[0]]) j0, j1 = sch.split(j, factors=[None, b_lr[1]]) sch.reorder(i0, j0, i1, j1) - bb = sch.blockize(i1) - sch.annotate(bb, ann_key="permuted_layout", ann_val=can_swizzle_b) - sch.tensorize(bb, intrin_group["load_b"]) + weight_transform_kind = 0 + if hasattr(func, "attrs") and "weight_transform_kind" in func.attrs: + weight_transform_kind = func.attrs["weight_transform_kind"] + if weight_transform_kind >= TransformKind.LDMatrixTransform: + fused = sch.fuse(i1, j1) + vec_len = get_coalesced_veclen(sch.get(B_mat)) + f0, f1, f2 = sch.split(fused, factors=[None, warp_size, vec_len]) + sch.bind(f1, "threadIdx.x") + sch.vectorize(f2) + else: + bb = sch.blockize(i1) + sch.annotate(bb, ann_key="permuted_layout", ann_val=can_swizzle_b) + sch.tensorize(bb, intrin_group["load_b"]) def tensorize_init_store_compute(): sch.tensorize(sch.get_loops(block_init_c_inner)[-2], intrin_group["init"]) diff --git a/testing/python/operators/test_general_matmul_ops.py b/testing/python/operators/test_general_matmul_ops.py index 354914d22..2d6890577 100644 --- a/testing/python/operators/test_general_matmul_ops.py +++ b/testing/python/operators/test_general_matmul_ops.py @@ -134,6 +134,7 @@ def matmul_torch_forward(M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, layo with_scaling=with_scaling, with_zeros=with_zeros, zeros_mode=zeros_mode, + propagate_a=False, ) matmul = Matmul(config=matmul_config, enable_tuning=False) @@ -194,6 +195,8 @@ def matmul_torch_forward(M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, layo permuted_inputs.append(bias) permuted_inputs.append(inputs[2]) matmul(*permuted_inputs[:-1], output=permuted_inputs[-1]) + print(permuted_inputs[-1]) + print(ref_result) if zeros_mode == "rescale": torch.testing.assert_close(permuted_inputs[-1], ref_result, rtol=1e2, atol=1e0) else: