Skip to content

Commit

Permalink
bug fix for scale only case
Browse files Browse the repository at this point in the history
  • Loading branch information
LeiWang1999 committed Aug 19, 2024
1 parent 67ad761 commit 374065e
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 13 deletions.
8 changes: 4 additions & 4 deletions bitblas/gpu/intrin/lop3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1379,15 +1379,15 @@ def fast_decode_desc(
1,
],
dtype=target_dtype,
scope="global",
scope="local",
)
Zeros = T.match_buffer(
zeros,
[
1,
],
dtype=target_dtype,
scope="global",
scope="local",
)
with T.block("root"):
T.reads(*get_dequantize_buffers_list(
Expand Down Expand Up @@ -1447,7 +1447,7 @@ def fast_decode_impl(
dtype=target_dtype,
offset_factor=1,
strides=[s0],
scope="global",
scope="local",
)
Zeros = T.match_buffer(
zeros,
Expand All @@ -1457,7 +1457,7 @@ def fast_decode_impl(
dtype=target_dtype,
offset_factor=1,
strides=[s1],
scope="global",
scope="local",
)
with T.block("root"):
T.reads(Compressed[0:n_storage_elems], Scale[0:1], Zeros[0:1])
Expand Down
18 changes: 9 additions & 9 deletions bitblas/gpu/matmul_mma_dequantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -1571,16 +1571,16 @@ def get_idx():
sch.compute_at(block_shared_local_local, B_shared_vi, preserve_unit_loops=True)

dequantize_block_local = block_shared_local
if is_qzeros:
if ("with_scaling" in weight_decode_info and weight_decode_info["with_scaling"]):
block_local_scales = sch.cache_read(dequantize_block_local, b_idx + 1, "local")
sch.compute_at(block_local_scales, B_shared_vi, preserve_unit_loops=True)
auto_inline_producers(sch, block_local_scales)

if ("with_zeros" in weight_decode_info and weight_decode_info["with_zeros"]):
block_local_zeros = sch.cache_read(dequantize_block_local, b_idx + 2, "local")
sch.compute_at(block_local_zeros, B_shared_vi, preserve_unit_loops=True)
auto_inline_producers(sch, block_local_zeros)
if ("with_scaling" in weight_decode_info and weight_decode_info["with_scaling"]):
block_local_scales = sch.cache_read(dequantize_block_local, b_idx + 1, "local")
sch.compute_at(block_local_scales, B_shared_vi, preserve_unit_loops=True)
auto_inline_producers(sch, block_local_scales)

if ("with_zeros" in weight_decode_info and weight_decode_info["with_zeros"]):
block_local_zeros = sch.cache_read(dequantize_block_local, b_idx + 2, "local")
sch.compute_at(block_local_zeros, B_shared_vi, preserve_unit_loops=True)
auto_inline_producers(sch, block_local_zeros)

for producer in weight_producers:
with suppress(Exception):
Expand Down

0 comments on commit 374065e

Please sign in to comment.