Skip to content

Commit

Permalink
[Dev] BUG Fix for bitnet integration (#141)
Browse files Browse the repository at this point in the history
* Refactor BatchMatMulEmitter and BatchMatMulSelector for improved readability and maintainability

* Refactor import statements for improved readability and maintainability

* Refactor import statements for improved readability and maintainability

* disable failure email for ci

* remove email notifications.

* move relax pass from testing to mlc_llm

* Refactor scripts with se check_eual_ref_scripts_with_emitter function

* Lint Fix

* Refactor scripts with se check_eual_ref_scripts_with_emitter function

* bug fix in test

* lint fix.

* test cuda i4 kernel

* Refactor copyright notice in i4matmul.hpp

* Refactor BitBLASLinear test module for improved readability and maintainability

* refactor test as version below python 3.9 cannot handle int32 overflow.

* format lint for test

* Refactor test_int4b_fp16_convert.py for improved readability and maintainability

* remove unused design file

* move tile device from package to base

* dummy impl for codegen

* Refactor file structure for ladder_permutate module

* Refactor backend class and fix typos in comments

* Deep refactor Lib related code.

* remove ci pull.

* LintFix

* refactor builder for whl build

* Refactor TIRWrapper.wrap() method to include an assertion for the optimized module

* Refactor lib_generator to set library and source paths

* lint fix

* BitNet vllm integration

* chore: update codespell to version 2.3.0

* Lintfix

* Bump version to 0.0.1.dev13

* lint fix

* disable fast decoding [u]int4xint8 by default.

* optimize from dict design in Hint

* Implement SplitK

* bitnet benchmark generation.

* Add benchmark script for BitNet integration

* AtomicAdd Support

* LintFix

* ci fix when 3rdparty tvm is initialized.

* bug fix for setup

* fix a bug in block reduce

* typo fix

* BUG Fix for block reduce.

* Lint fix

* Refactor block reduce schedule template

* transform branch from bitblas to bitblas_tl

* Fix subproject commit reference in 3rdparty/tvm

* chore: update submodule branch from bitblas to bitblas_tl

* force update config.cmake

* Bug fix

* Fix subproject commit reference in 3rdparty/cutlass

* chore: Add submodule for cutlass library

* update tl cutlass path

* Refactor BitBLASLinear test module for improved readability and maintainability

* format fix

* Copy CUTLASS to the package directory

* Refactor setup.py to include additional TVM header files

* lint fix

* bug fix

* Refactor BitBLASLinear test module for improved readability and maintainability

* Implement Matmul Benchmark Design

* chore: Update BitBLAS Matmul benchmark script

* lint fix

* Refactor BitBLASMatmulOpsBenchmark for improved readability and maintainability

* Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark run

* lint fix

* Benchmark bot test

* Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark run

* Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark run

* Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark run

* Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark run

* Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark run

* int8 test case

* Refactor compare_benchmark.py to handle missing benchmark results gracefully

* ci fix

* disable ci for test benchmark

* Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark run

* remove cli installation

* chore: Create virtual environment and install dependencies for benchmark

* chore: Update benchmark workflow to include comparison step

* Lint fix

* upodate tvm cmmit

* Imporve lower warp memory pass

* Bug fix

* Enhance to support warp schedule.

* Enhance LOP3 Instructions

* Enhance LOP3 Instructions

* add test for stage3 propagate

* implement propagate func

* Stage3 Ladder Permutate integration

* get_ladder_stage3_propagate

* comments benchmark scirpts as the setting is too big

* ci fix for benchmark

* lint fix

* chore: Update benchmark workflow to trigger on pull request comments

* Add LDMatrix Transform 3

* Support GPTQ Test

* Fuse BlockReduce Schedule

* Support mma propagate 3

* Support MMA Propagate Stage 3

* Lint Fix

* Merge block reduce for dequantze config.

* fix codeql

* chore: Update submodule reference to latest commit

* chore: Disable common subexpression elimination in TIR passes

* Lint Fix

* 4bit related lop3 updates.

* lint fix

* gptq test fix

* Fix for test

* lint fix

* lint fix

* typofix

* QuantCompress Test

* chore: Refactor quant_compress_impl.py for readability and maintainability

* Enhance docs to update latest works.

* Refactor weight executors in Matmul class for improved readability and maintainability

* Refactor weight executors in Matmul class for improved readability and maintainability

* Refactor weight executors in Matmul class for improved readability and maintainability

* removed legacy operator

* Refactor weight executors in Matmul class for improved readability and maintainability

* LintFix

* Fix GPTQ Repack with the latest weight transform

* lint fix

* bug fix for rescale dequantize

* test fix

* typo fix

* lint fix

* Set default weight propagate kind into LDMatrixTransform

* lint fix

* bug fix

* bug fix for test

* set default to stage3

* revert change

* lint fix

* case fix

* bug fix

* fix for legalize

* bug fix

* chore: Clear global operator cache before running tests

* revert optimize_stratety into SingleBatchDecodeOnly

* typofix

* update benchmark scripts

* chore: Refactor benchmark scripts and fix typos

* fix for testing

* lint fix

* fix import.

* typo

* operator benchmark

* optimize

* always with shared.dyn

* optimize cache.

* dsl fix

* tqdm

* chore: Add serialize_results method to benchmark_matmul_strategies.py

* fix performance issue for dynamic async copy

* chore: Refactor benchmark_matmul_strategies.py for improved performance and code readability

* bug fix

* update readme

* disable block reduce for int8

* bugfix for bitnet

* annotatte todo.

* lint fix
  • Loading branch information
LeiWang1999 authored Aug 13, 2024
1 parent fcf3f63 commit af697de
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 6 deletions.
2 changes: 2 additions & 0 deletions bitblas/ops/general_matmul/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ class MatmulConfig(OperatorConfig):
None # propagate_b is a flag to control the ladder permutation
)

# TODO: This is a temporary solution to legalize the dynamic symbolic.
# Maybe we should remove this in the future.
# optimize strategy, default is SingleBatchDecodeOnly
optimize_stratety: Union[int, OptimizeStrategy] = OptimizeStrategy.SingleBatchDecodeOnly

Expand Down
11 changes: 6 additions & 5 deletions integration/BitNet/eval_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,18 +72,19 @@ def get_runtime(num_repeats=1):
def main():
model = BitnetForCausalLM.from_pretrained(
model_path,
use_flash_attention_2=True,
use_flash_attention_2=False,
torch_dtype=torch.float16,
).cuda().half()
with torch.no_grad():
model._post_process_weights()

tokenizer = BitnetTokenizer.from_pretrained(model_path, use_fast=False)
input_id = tokenizer("Hello")['input_ids']
input_id = torch.tensor(input_id).unsqueeze(0).cuda()
output = model(input_id)
print(output)

print("original model generated text:")
print(generate_text(model, tokenizer, "Hello", max_length=100))

model.quantize()
print("quantized model generated text:")
print(generate_text(model, tokenizer, "Hello", max_length=100))


Expand Down
2 changes: 1 addition & 1 deletion integration/BitNet/maint/create_bitblas_ckpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def main():
model = (
BitnetForCausalLM.from_pretrained(
model_name_or_path,
use_flash_attention_2=True,
use_flash_attention_2=False,
torch_dtype=torch.float16,
).cuda().half())
tokenizer = BitnetTokenizer.from_pretrained(model_name_or_path, use_fast=False)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ fi
if [ -z "$SAVED_MODEL_DIR" ]; then
python ./maint/create_bitblas_ckpt.py --model_name_or_path $MODEL_DIR
else
if [ ! -d "$SAVED_MODEL_DIR" ]; then
mkdir -p $SAVED_MODEL_DIR
fi
python ./maint/create_bitblas_ckpt.py --model_name_or_path $MODEL_DIR --saved_model_path $SAVED_MODEL_DIR
fi

Expand Down
1 change: 1 addition & 0 deletions integration/BitNet/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
lm_eval==0.3.0
flash_attn
transformers==4.40

0 comments on commit af697de

Please sign in to comment.