Skip to content

Commit

Permalink
bug fix for test
Browse files Browse the repository at this point in the history
  • Loading branch information
LeiWang1999 committed Aug 6, 2024
1 parent a30dae9 commit b913d26
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 4 deletions.
1 change: 1 addition & 0 deletions .github/workflows/benchmark.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ jobs:
uses: actions/checkout@v2
with:
fetch-depth: 0
ref: ${{ github.event.pull_request.head.ref }}

- name: Get PR branch commit ID
id: get_pr_commit
Expand Down
12 changes: 11 additions & 1 deletion bitblas/gpu/matmul_mma.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,17 @@ def store_output(block_outer, write_buffer_idx):

sch.tensorize(sch.get_loops(block_init_inner)[-2], intrin_group["init"])
sch.tensorize(sch.get_loops(block_read_reg_a)[-2], intrin_group["load_a"])
sch.tensorize(sch.get_loops(block_read_reg_b)[-2], intrin_group["load_b"])
weight_transform_kind = 0
if hasattr(func, "attrs") and "weight_transform_kind" in func.attrs:
weight_transform_kind = func.attrs["weight_transform_kind"]
if weight_transform_kind >= TransformKind.LDMatrixTransform:
fused = sch.fuse(sch.get_loops(block_read_reg_b)[-2:])
vec_len = get_coalesced_veclen(sch.get(block_read_reg_b))
f0, f1, f2 = sch.split(fused, factors=[None, 32, vec_len])
sch.bind(f1, "threadIdx.x")
sch.vectorize(f2)
else:
sch.tensorize(sch.get_loops(block_read_reg_b)[-2], intrin_group["load_b"])
sch.tensorize(sch.get_loops(block_inner)[-3], intrin_group["compute"])
sch.tensorize(sch.get_loops(block_write_reg)[-2], intrin_group["store"])

Expand Down
16 changes: 13 additions & 3 deletions bitblas/gpu/matmul_mma_dequantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,9 +628,19 @@ def get_idx():
i0, i1 = sch.split(i, factors=[None, b_lr[0]])
j0, j1 = sch.split(j, factors=[None, b_lr[1]])
sch.reorder(i0, j0, i1, j1)
bb = sch.blockize(i1)
sch.annotate(bb, ann_key="permuted_layout", ann_val=can_swizzle_b)
sch.tensorize(bb, intrin_group["load_b"])
weight_transform_kind = 0
if hasattr(func, "attrs") and "weight_transform_kind" in func.attrs:
weight_transform_kind = func.attrs["weight_transform_kind"]
if weight_transform_kind >= TransformKind.LDMatrixTransform:
fused = sch.fuse(i1, j1)
vec_len = get_coalesced_veclen(sch.get(B_mat))
f0, f1, f2 = sch.split(fused, factors=[None, warp_size, vec_len])
sch.bind(f1, "threadIdx.x")
sch.vectorize(f2)
else:
bb = sch.blockize(i1)
sch.annotate(bb, ann_key="permuted_layout", ann_val=can_swizzle_b)
sch.tensorize(bb, intrin_group["load_b"])

def tensorize_init_store_compute():
sch.tensorize(sch.get_loops(block_init_c_inner)[-2], intrin_group["init"])
Expand Down
3 changes: 3 additions & 0 deletions testing/python/operators/test_general_matmul_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def matmul_torch_forward(M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, layo
with_scaling=with_scaling,
with_zeros=with_zeros,
zeros_mode=zeros_mode,
propagate_a=False,
)
matmul = Matmul(config=matmul_config, enable_tuning=False)

Expand Down Expand Up @@ -194,6 +195,8 @@ def matmul_torch_forward(M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, layo
permuted_inputs.append(bias)
permuted_inputs.append(inputs[2])
matmul(*permuted_inputs[:-1], output=permuted_inputs[-1])
print(permuted_inputs[-1])
print(ref_result)
if zeros_mode == "rescale":
torch.testing.assert_close(permuted_inputs[-1], ref_result, rtol=1e2, atol=1e0)
else:
Expand Down

0 comments on commit b913d26

Please sign in to comment.