From 374065e69f0cf48a7bc8bf4e7a6aec73d763bfc9 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Mon, 19 Aug 2024 16:48:00 +0000 Subject: [PATCH] bug fix for scale only case --- bitblas/gpu/intrin/lop3.py | 8 ++++---- bitblas/gpu/matmul_mma_dequantize.py | 18 +++++++++--------- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/bitblas/gpu/intrin/lop3.py b/bitblas/gpu/intrin/lop3.py index aee3eac8b..b58aabc71 100644 --- a/bitblas/gpu/intrin/lop3.py +++ b/bitblas/gpu/intrin/lop3.py @@ -1379,7 +1379,7 @@ def fast_decode_desc( 1, ], dtype=target_dtype, - scope="global", + scope="local", ) Zeros = T.match_buffer( zeros, @@ -1387,7 +1387,7 @@ def fast_decode_desc( 1, ], dtype=target_dtype, - scope="global", + scope="local", ) with T.block("root"): T.reads(*get_dequantize_buffers_list( @@ -1447,7 +1447,7 @@ def fast_decode_impl( dtype=target_dtype, offset_factor=1, strides=[s0], - scope="global", + scope="local", ) Zeros = T.match_buffer( zeros, @@ -1457,7 +1457,7 @@ def fast_decode_impl( dtype=target_dtype, offset_factor=1, strides=[s1], - scope="global", + scope="local", ) with T.block("root"): T.reads(Compressed[0:n_storage_elems], Scale[0:1], Zeros[0:1]) diff --git a/bitblas/gpu/matmul_mma_dequantize.py b/bitblas/gpu/matmul_mma_dequantize.py index 7c40d3243..b2789da37 100644 --- a/bitblas/gpu/matmul_mma_dequantize.py +++ b/bitblas/gpu/matmul_mma_dequantize.py @@ -1571,16 +1571,16 @@ def get_idx(): sch.compute_at(block_shared_local_local, B_shared_vi, preserve_unit_loops=True) dequantize_block_local = block_shared_local - if is_qzeros: - if ("with_scaling" in weight_decode_info and weight_decode_info["with_scaling"]): - block_local_scales = sch.cache_read(dequantize_block_local, b_idx + 1, "local") - sch.compute_at(block_local_scales, B_shared_vi, preserve_unit_loops=True) - auto_inline_producers(sch, block_local_scales) - if ("with_zeros" in weight_decode_info and weight_decode_info["with_zeros"]): - block_local_zeros = sch.cache_read(dequantize_block_local, b_idx + 2, "local") - sch.compute_at(block_local_zeros, B_shared_vi, preserve_unit_loops=True) - auto_inline_producers(sch, block_local_zeros) + if ("with_scaling" in weight_decode_info and weight_decode_info["with_scaling"]): + block_local_scales = sch.cache_read(dequantize_block_local, b_idx + 1, "local") + sch.compute_at(block_local_scales, B_shared_vi, preserve_unit_loops=True) + auto_inline_producers(sch, block_local_scales) + + if ("with_zeros" in weight_decode_info and weight_decode_info["with_zeros"]): + block_local_zeros = sch.cache_read(dequantize_block_local, b_idx + 2, "local") + sch.compute_at(block_local_zeros, B_shared_vi, preserve_unit_loops=True) + auto_inline_producers(sch, block_local_zeros) for producer in weight_producers: with suppress(Exception):