Skip to content

Commit

Permalink
bug fox
Browse files Browse the repository at this point in the history
  • Loading branch information
LeiWang1999 committed Aug 20, 2024
1 parent fc63748 commit 0ae3953
Showing 1 changed file with 26 additions and 28 deletions.
54 changes: 26 additions & 28 deletions bitblas/gpu/gemv_dequantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,20 +158,19 @@ def get_idx(weight_decode_info: Dict):

skip_blocks = [block_shared_local_B]

if "zeros_mode" in weight_decode_info and weight_decode_info["zeros_mode"] == "quantized":
if "with_scaling" in weight_decode_info and weight_decode_info["with_scaling"]:
block_local_scales = sch.cache_read(block_decode_B,
get_idx(weight_decode_info) + 1, "local")
sch.compute_at(block_local_scales, tx, preserve_unit_loops=True)
auto_inline_producers(sch, block_local_scales)
skip_blocks.append(block_local_scales)

if "with_zeros" in weight_decode_info and weight_decode_info["with_zeros"]:
block_local_zeros = sch.cache_read(block_decode_B,
get_idx(weight_decode_info) + 2, "local")
sch.compute_at(block_local_zeros, tx, preserve_unit_loops=True)
auto_inline_producers(sch, block_local_zeros)
skip_blocks.append(block_local_zeros)
if "with_scaling" in weight_decode_info and weight_decode_info["with_scaling"]:
block_local_scales = sch.cache_read(block_decode_B,
get_idx(weight_decode_info) + 1, "local")
sch.compute_at(block_local_scales, tx, preserve_unit_loops=True)
auto_inline_producers(sch, block_local_scales)
skip_blocks.append(block_local_scales)

if "with_zeros" in weight_decode_info and weight_decode_info["with_zeros"]:
block_local_zeros = sch.cache_read(block_decode_B,
get_idx(weight_decode_info) + 2, "local")
sch.compute_at(block_local_zeros, tx, preserve_unit_loops=True)
auto_inline_producers(sch, block_local_zeros)
skip_blocks.append(block_local_zeros)

auto_inline_producers(sch, block_decode_B, skip_blocks)

Expand Down Expand Up @@ -329,20 +328,19 @@ def get_idx(weight_decode_info: Dict):

skip_blocks = [block_shared_local_B]

if "zeros_mode" in weight_decode_info and weight_decode_info["zeros_mode"] == "quantized":
if "with_scaling" in weight_decode_info and weight_decode_info["with_scaling"]:
block_local_scales = sch.cache_read(block_decode_B,
get_idx(weight_decode_info) + 1, "local")
sch.compute_at(block_local_scales, tx, preserve_unit_loops=True)
auto_inline_producers(sch, block_local_scales)
skip_blocks.append(block_local_scales)

if "with_zeros" in weight_decode_info and weight_decode_info["with_zeros"]:
block_local_zeros = sch.cache_read(block_decode_B,
get_idx(weight_decode_info) + 2, "local")
sch.compute_at(block_local_zeros, tx, preserve_unit_loops=True)
auto_inline_producers(sch, block_local_zeros)
skip_blocks.append(block_local_zeros)
if "with_scaling" in weight_decode_info and weight_decode_info["with_scaling"]:
block_local_scales = sch.cache_read(block_decode_B,
get_idx(weight_decode_info) + 1, "local")
sch.compute_at(block_local_scales, tx, preserve_unit_loops=True)
auto_inline_producers(sch, block_local_scales)
skip_blocks.append(block_local_scales)

if "with_zeros" in weight_decode_info and weight_decode_info["with_zeros"]:
block_local_zeros = sch.cache_read(block_decode_B,
get_idx(weight_decode_info) + 2, "local")
sch.compute_at(block_local_zeros, tx, preserve_unit_loops=True)
auto_inline_producers(sch, block_local_zeros)
skip_blocks.append(block_local_zeros)

auto_inline_producers(sch, block_decode_B, skip_blocks)

Expand Down

0 comments on commit 0ae3953

Please sign in to comment.