Skip to content

Commit

Permalink
[Fix] Fix scale and zero scopes for scale only template (#147)
Browse files Browse the repository at this point in the history
* Refactor BatchMatMulEmitter and BatchMatMulSelector for improved readability and maintainability

* Refactor import statements for improved readability and maintainability

* Refactor import statements for improved readability and maintainability

* disable failure email for ci

* remove email notifications.

* move relax pass from testing to mlc_llm

* Refactor scripts with se check_eual_ref_scripts_with_emitter function

* Lint Fix

* Refactor scripts with se check_eual_ref_scripts_with_emitter function

* bug fix in test

* lint fix.

* test cuda i4 kernel

* Refactor copyright notice in i4matmul.hpp

* Refactor BitBLASLinear test module for improved readability and maintainability

* refactor test as version below python 3.9 cannot handle int32 overflow.

* format lint for test

* Refactor test_int4b_fp16_convert.py for improved readability and maintainability

* remove unused design file

* move tile device from package to base

* dummy impl for codegen

* Refactor file structure for ladder_permutate module

* Refactor backend class and fix typos in comments

* Deep refactor Lib related code.

* remove ci pull.

* LintFix

* refactor builder for whl build

* Refactor TIRWrapper.wrap() method to include an assertion for the optimized module

* Refactor lib_generator to set library and source paths

* lint fix

* BitNet vllm integration

* chore: update codespell to version 2.3.0

* Lintfix

* Bump version to 0.0.1.dev13

* lint fix

* disable fast decoding [u]int4xint8 by default.

* optimize from dict design in Hint

* Implement SplitK

* bitnet benchmark generation.

* Add benchmark script for BitNet integration

* AtomicAdd Support

* LintFix

* ci fix when 3rdparty tvm is initialized.

* bug fix for setup

* fix a bug in block reduce

* typo fix

* BUG Fix for block reduce.

* Lint fix

* Refactor block reduce schedule template

* transform branch from bitblas to bitblas_tl

* Fix subproject commit reference in 3rdparty/tvm

* chore: update submodule branch from bitblas to bitblas_tl

* force update config.cmake

* Bug fix

* Fix subproject commit reference in 3rdparty/cutlass

* chore: Add submodule for cutlass library

* update tl cutlass path

* Refactor BitBLASLinear test module for improved readability and maintainability

* format fix

* Copy CUTLASS to the package directory

* Refactor setup.py to include additional TVM header files

* lint fix

* bug fix

* Refactor BitBLASLinear test module for improved readability and maintainability

* Implement Matmul Benchmark Design

* chore: Update BitBLAS Matmul benchmark script

* lint fix

* Refactor BitBLASMatmulOpsBenchmark for improved readability and maintainability

* Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark run

* lint fix

* Benchmark bot test

* Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark run

* Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark run

* Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark run

* Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark run

* Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark run

* int8 test case

* Refactor compare_benchmark.py to handle missing benchmark results gracefully

* ci fix

* disable ci for test benchmark

* Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark run

* remove cli installation

* chore: Create virtual environment and install dependencies for benchmark

* chore: Update benchmark workflow to include comparison step

* Lint fix

* upodate tvm cmmit

* Imporve lower warp memory pass

* Bug fix

* Enhance to support warp schedule.

* Enhance LOP3 Instructions

* Enhance LOP3 Instructions

* add test for stage3 propagate

* implement propagate func

* Stage3 Ladder Permutate integration

* get_ladder_stage3_propagate

* comments benchmark scirpts as the setting is too big

* ci fix for benchmark

* lint fix

* chore: Update benchmark workflow to trigger on pull request comments

* Add LDMatrix Transform 3

* Support GPTQ Test

* Fuse BlockReduce Schedule

* Support mma propagate 3

* Support MMA Propagate Stage 3

* Lint Fix

* Merge block reduce for dequantze config.

* fix codeql

* chore: Update submodule reference to latest commit

* chore: Disable common subexpression elimination in TIR passes

* Lint Fix

* 4bit related lop3 updates.

* lint fix

* gptq test fix

* Fix for test

* lint fix

* lint fix

* typofix

* QuantCompress Test

* chore: Refactor quant_compress_impl.py for readability and maintainability

* Enhance docs to update latest works.

* Refactor weight executors in Matmul class for improved readability and maintainability

* Refactor weight executors in Matmul class for improved readability and maintainability

* Refactor weight executors in Matmul class for improved readability and maintainability

* removed legacy operator

* Refactor weight executors in Matmul class for improved readability and maintainability

* LintFix

* Fix GPTQ Repack with the latest weight transform

* lint fix

* bug fix for rescale dequantize

* test fix

* typo fix

* lint fix

* Set default weight propagate kind into LDMatrixTransform

* lint fix

* bug fix

* bug fix for test

* set default to stage3

* revert change

* lint fix

* case fix

* bug fix

* fix for legalize

* bug fix

* chore: Clear global operator cache before running tests

* revert optimize_stratety into SingleBatchDecodeOnly

* typofix

* update benchmark scripts

* chore: Refactor benchmark scripts and fix typos

* fix for testing

* lint fix

* fix import.

* typo

* operator benchmark

* optimize

* always with shared.dyn

* optimize cache.

* dsl fix

* tqdm

* chore: Add serialize_results method to benchmark_matmul_strategies.py

* fix performance issue for dynamic async copy

* chore: Refactor benchmark_matmul_strategies.py for improved performance and code readability

* bug fix

* update readme

* disable block reduce for int8

* bugfix for bitnet

* annotatte todo.

* lint fix

* regist fast_decode for int8xint4

* Refactor CUDA code to use sm architecture instead of compute architecture

* compress qkv and gate up for bitnet

* improve elementwise schedule

* Refactor BitNet model checkpoint generation scripts

* cross thread reduce for tl

* fix scale only lop3 tensorize instructions.

* bug fix for scale only case

* fix scale for warp memory dequantize

* lint fix

* bug fox

* format
  • Loading branch information
LeiWang1999 authored Aug 20, 2024
1 parent 01c7a80 commit ef28a5d
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 59 deletions.
54 changes: 26 additions & 28 deletions bitblas/gpu/gemv_dequantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,20 +158,19 @@ def get_idx(weight_decode_info: Dict):

skip_blocks = [block_shared_local_B]

if "zeros_mode" in weight_decode_info and weight_decode_info["zeros_mode"] == "quantized":
if "with_scaling" in weight_decode_info and weight_decode_info["with_scaling"]:
block_local_scales = sch.cache_read(block_decode_B,
get_idx(weight_decode_info) + 1, "local")
sch.compute_at(block_local_scales, tx, preserve_unit_loops=True)
auto_inline_producers(sch, block_local_scales)
skip_blocks.append(block_local_scales)

if "with_zeros" in weight_decode_info and weight_decode_info["with_zeros"]:
block_local_zeros = sch.cache_read(block_decode_B,
get_idx(weight_decode_info) + 2, "local")
sch.compute_at(block_local_zeros, tx, preserve_unit_loops=True)
auto_inline_producers(sch, block_local_zeros)
skip_blocks.append(block_local_zeros)
if "with_scaling" in weight_decode_info and weight_decode_info["with_scaling"]:
block_local_scales = sch.cache_read(block_decode_B,
get_idx(weight_decode_info) + 1, "local")
sch.compute_at(block_local_scales, tx, preserve_unit_loops=True)
auto_inline_producers(sch, block_local_scales)
skip_blocks.append(block_local_scales)

if "with_zeros" in weight_decode_info and weight_decode_info["with_zeros"]:
block_local_zeros = sch.cache_read(block_decode_B,
get_idx(weight_decode_info) + 2, "local")
sch.compute_at(block_local_zeros, tx, preserve_unit_loops=True)
auto_inline_producers(sch, block_local_zeros)
skip_blocks.append(block_local_zeros)

auto_inline_producers(sch, block_decode_B, skip_blocks)

Expand Down Expand Up @@ -329,20 +328,19 @@ def get_idx(weight_decode_info: Dict):

skip_blocks = [block_shared_local_B]

if "zeros_mode" in weight_decode_info and weight_decode_info["zeros_mode"] == "quantized":
if "with_scaling" in weight_decode_info and weight_decode_info["with_scaling"]:
block_local_scales = sch.cache_read(block_decode_B,
get_idx(weight_decode_info) + 1, "local")
sch.compute_at(block_local_scales, tx, preserve_unit_loops=True)
auto_inline_producers(sch, block_local_scales)
skip_blocks.append(block_local_scales)

if "with_zeros" in weight_decode_info and weight_decode_info["with_zeros"]:
block_local_zeros = sch.cache_read(block_decode_B,
get_idx(weight_decode_info) + 2, "local")
sch.compute_at(block_local_zeros, tx, preserve_unit_loops=True)
auto_inline_producers(sch, block_local_zeros)
skip_blocks.append(block_local_zeros)
if "with_scaling" in weight_decode_info and weight_decode_info["with_scaling"]:
block_local_scales = sch.cache_read(block_decode_B,
get_idx(weight_decode_info) + 1, "local")
sch.compute_at(block_local_scales, tx, preserve_unit_loops=True)
auto_inline_producers(sch, block_local_scales)
skip_blocks.append(block_local_scales)

if "with_zeros" in weight_decode_info and weight_decode_info["with_zeros"]:
block_local_zeros = sch.cache_read(block_decode_B,
get_idx(weight_decode_info) + 2, "local")
sch.compute_at(block_local_zeros, tx, preserve_unit_loops=True)
auto_inline_producers(sch, block_local_zeros)
skip_blocks.append(block_local_zeros)

auto_inline_producers(sch, block_decode_B, skip_blocks)

Expand Down
23 changes: 13 additions & 10 deletions bitblas/gpu/intrin/lop3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1028,6 +1028,9 @@ def get_fast_decode_intrin(
else:
raise ValueError("Unsupported source_format: {}".format(source_format))

# As warp memory require scale from scatter memory address, we need to pass the scale as an offset
scale_zero_scope = "local" if storage_scope == "local" else "global"

def get_func_arguments(Quant, Dequant, Scale=None, Zeros=None):
args = [Quant.access_ptr("r"), Dequant.access_ptr("w")]
if Scale is not None:
Expand Down Expand Up @@ -1127,7 +1130,7 @@ def fast_decode_desc(compressed: T.handle, decompressed: T.handle, scale: T.hand
1,
],
dtype=target_dtype,
scope="global",
scope=scale_zero_scope,
)
with T.block("root"):
T.reads(Compressed[0:n_storage_elems], Scale[0:1])
Expand Down Expand Up @@ -1173,7 +1176,7 @@ def fast_decode_impl(compressed: T.handle, decompressed: T.handle, scale: T.hand
dtype=target_dtype,
offset_factor=1,
strides=[s0],
scope="global",
scope=scale_zero_scope,
)
with T.block("root"):
T.reads(Compressed[0:n_storage_elems], Scale[0:1])
Expand Down Expand Up @@ -1237,15 +1240,15 @@ def fast_decode_desc(
1,
],
dtype=target_dtype,
scope=storage_scope,
scope=scale_zero_scope,
)
Zeros = T.match_buffer(
zeros,
[
1,
],
dtype=storage_dtype,
scope=storage_scope,
scope=scale_zero_scope,
)
with T.block("root"):
T.reads(*get_dequantize_buffers_list(
Expand Down Expand Up @@ -1306,7 +1309,7 @@ def fast_decode_impl(
dtype=target_dtype,
offset_factor=1,
strides=[s0],
scope=storage_scope,
scope=scale_zero_scope,
)
Zeros = T.match_buffer(
zeros,
Expand All @@ -1316,7 +1319,7 @@ def fast_decode_impl(
dtype=storage_dtype,
offset_factor=1,
strides=[s1],
scope=storage_scope,
scope=scale_zero_scope,
)
with T.block("root"):
T.reads(Compressed[0:n_storage_elems], Scale[0:1], Zeros[0:1])
Expand Down Expand Up @@ -1379,15 +1382,15 @@ def fast_decode_desc(
1,
],
dtype=target_dtype,
scope="global",
scope=scale_zero_scope,
)
Zeros = T.match_buffer(
zeros,
[
1,
],
dtype=target_dtype,
scope="global",
scope=scale_zero_scope,
)
with T.block("root"):
T.reads(*get_dequantize_buffers_list(
Expand Down Expand Up @@ -1447,7 +1450,7 @@ def fast_decode_impl(
dtype=target_dtype,
offset_factor=1,
strides=[s0],
scope="global",
scope=scale_zero_scope,
)
Zeros = T.match_buffer(
zeros,
Expand All @@ -1457,7 +1460,7 @@ def fast_decode_impl(
dtype=target_dtype,
offset_factor=1,
strides=[s1],
scope="global",
scope=scale_zero_scope,
)
with T.block("root"):
T.reads(Compressed[0:n_storage_elems], Scale[0:1], Zeros[0:1])
Expand Down
39 changes: 18 additions & 21 deletions bitblas/gpu/matmul_mma_dequantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,18 +495,16 @@ def get_idx():
sch.compute_at(block_shared_local_local, B_shared_vi, preserve_unit_loops=True)

dequantize_block_local = block_shared_local
if ("zeros_mode" in weight_decode_info and
weight_decode_info["zeros_mode"] == "quantized"):
if ("with_scaling" in weight_decode_info and weight_decode_info["with_scaling"]):
block_local_scales = sch.cache_read(dequantize_block_local, b_idx + 1, "local")
sch.compute_at(block_local_scales, B_shared_vi, preserve_unit_loops=True)
# pop the scale block
auto_inline_producers(sch, block_local_scales)
if ("with_scaling" in weight_decode_info and weight_decode_info["with_scaling"]):
block_local_scales = sch.cache_read(dequantize_block_local, b_idx + 1, "local")
sch.compute_at(block_local_scales, B_shared_vi, preserve_unit_loops=True)
# pop the scale block
auto_inline_producers(sch, block_local_scales)

if ("with_zeros" in weight_decode_info and weight_decode_info["with_zeros"]):
block_local_zeros = sch.cache_read(dequantize_block_local, b_idx + 2, "local")
sch.compute_at(block_local_zeros, B_shared_vi, preserve_unit_loops=True)
auto_inline_producers(sch, block_local_zeros)
if ("with_zeros" in weight_decode_info and weight_decode_info["with_zeros"]):
block_local_zeros = sch.cache_read(dequantize_block_local, b_idx + 2, "local")
sch.compute_at(block_local_zeros, B_shared_vi, preserve_unit_loops=True)
auto_inline_producers(sch, block_local_zeros)

for producer in weight_producers:
with suppress(Exception):
Expand Down Expand Up @@ -1571,16 +1569,16 @@ def get_idx():
sch.compute_at(block_shared_local_local, B_shared_vi, preserve_unit_loops=True)

dequantize_block_local = block_shared_local
if is_qzeros:
if ("with_scaling" in weight_decode_info and weight_decode_info["with_scaling"]):
block_local_scales = sch.cache_read(dequantize_block_local, b_idx + 1, "local")
sch.compute_at(block_local_scales, B_shared_vi, preserve_unit_loops=True)
auto_inline_producers(sch, block_local_scales)

if ("with_zeros" in weight_decode_info and weight_decode_info["with_zeros"]):
block_local_zeros = sch.cache_read(dequantize_block_local, b_idx + 2, "local")
sch.compute_at(block_local_zeros, B_shared_vi, preserve_unit_loops=True)
auto_inline_producers(sch, block_local_zeros)
if ("with_scaling" in weight_decode_info and weight_decode_info["with_scaling"]):
block_local_scales = sch.cache_read(dequantize_block_local, b_idx + 1, "local")
sch.compute_at(block_local_scales, B_shared_vi, preserve_unit_loops=True)
auto_inline_producers(sch, block_local_scales)

if ("with_zeros" in weight_decode_info and weight_decode_info["with_zeros"]):
block_local_zeros = sch.cache_read(dequantize_block_local, b_idx + 2, "local")
sch.compute_at(block_local_zeros, B_shared_vi, preserve_unit_loops=True)
auto_inline_producers(sch, block_local_zeros)

for producer in weight_producers:
with suppress(Exception):
Expand Down Expand Up @@ -2176,7 +2174,6 @@ def get_idx():
warp_size,
reduce_k,
)

return B_dequantized_mat

B_dequantized_mat = warp_memory_dequantize()
Expand Down

0 comments on commit ef28a5d

Please sign in to comment.