-
Notifications
You must be signed in to change notification settings - Fork 34
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Dev] Refactor Modeling BitNet to support vLLM quant linear (#84)
* Refactor BatchMatMulEmitter and BatchMatMulSelector for improved readability and maintainability * Refactor import statements for improved readability and maintainability * Refactor import statements for improved readability and maintainability * disable failure email for ci * remove email notifications. * move relax pass from testing to mlc_llm * Refactor scripts with se check_eual_ref_scripts_with_emitter function * Lint Fix * Refactor scripts with se check_eual_ref_scripts_with_emitter function * bug fix in test * lint fix. * test cuda i4 kernel * Refactor copyright notice in i4matmul.hpp * Refactor BitBLASLinear test module for improved readability and maintainability * refactor test as version below python 3.9 cannot handle int32 overflow. * format lint for test * Refactor test_int4b_fp16_convert.py for improved readability and maintainability * remove unused design file * move tile device from package to base * dummy impl for codegen * Refactor file structure for ladder_permutate module * Refactor backend class and fix typos in comments * Deep refactor Lib related code. * remove ci pull. * LintFix * refactor builder for whl build * Refactor TIRWrapper.wrap() method to include an assertion for the optimized module * Refactor lib_generator to set library and source paths * lint fix * BitNet vllm integration * chore: update codespell to version 2.3.0 * Lintfix
- Loading branch information
1 parent
66304d2
commit d2a86ea
Showing
8 changed files
with
592 additions
and
174 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
models/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
|
||
import torch | ||
import bitblas | ||
from modeling_bitnet import BitnetForCausalLM | ||
from tokenization_bitnet import BitnetTokenizer | ||
from transformers.utils.hub import cached_file | ||
import os | ||
from transformers import GenerationConfig | ||
import time | ||
import json | ||
|
||
filepath = os.path.abspath(__file__) | ||
dirpath = os.path.dirname(filepath) | ||
|
||
torch.set_grad_enabled(False) | ||
bitblas.set_log_level("INFO") | ||
|
||
model_name_or_path = "BitBLASModel/open_llama_3b_1.58bits" | ||
saved_model_path = os.path.join(dirpath, "models", f"{model_name_or_path}_bitblas") | ||
|
||
|
||
def generate_text(model, tokenizer, prompt, max_length=100): | ||
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.lm_head.weight.device) | ||
# Generate cos and sin values | ||
seq_length = input_ids.size(1) | ||
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) | ||
position_ids = position_ids.unsqueeze(0).expand_as(input_ids) | ||
|
||
generation_config = GenerationConfig( | ||
max_length=max_length, | ||
do_sample=True, | ||
top_k=50, | ||
top_p=0.95, | ||
num_return_sequences=1, | ||
) | ||
|
||
start_time = time.time() | ||
output_ids = model.generate(input_ids, generation_config=generation_config) | ||
end_time = time.time() | ||
|
||
generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True) | ||
|
||
generation_time = end_time - start_time | ||
num_tokens = len(output_ids[0]) | ||
tokens_per_second = num_tokens / generation_time | ||
|
||
print(f"Generated {num_tokens} tokens in {generation_time:.2f} seconds") | ||
print(f"Tokens per second: {tokens_per_second:.2f}") | ||
|
||
return generated_text | ||
|
||
|
||
def main(): | ||
model = ( | ||
BitnetForCausalLM.from_pretrained( | ||
model_name_or_path, | ||
use_flash_attention_2=True, | ||
torch_dtype=torch.float16, | ||
).cuda().half()) | ||
tokenizer = BitnetTokenizer.from_pretrained(model_name_or_path, use_fast=False) | ||
|
||
# print("original model generated text:") | ||
# print(generate_text(model, tokenizer, "Hi, ", max_length=100)) | ||
input_ids = torch.ones((1, 1), dtype=torch.long).cuda() | ||
# naive model inference | ||
output = model(input_ids) | ||
print("original model output:", output) | ||
|
||
model.quantize() | ||
print("original model generated text:") | ||
print(generate_text(model, tokenizer, "Hi, ", max_length=100)) | ||
|
||
model.save_pretrained(saved_model_path) | ||
|
||
# load quant config | ||
quant_config_path = cached_file(model_name_or_path, "quantize_config.json") | ||
with open(quant_config_path, "r") as f: | ||
quant_config = json.load(f) | ||
print("quant config:") | ||
print(quant_config) | ||
quant_config["checkpoint_format"] = "bitblas" | ||
|
||
# save quant config | ||
quant_config_path = os.path.join(saved_model_path, "quantize_config.json") | ||
with open(quant_config_path, "w") as f: | ||
json.dump(quant_config, f) | ||
print("quant config saved to:", quant_config_path) | ||
|
||
# copy benchmark filed into saved model path | ||
file_list = [ | ||
"configuration_bitnet.py", | ||
"eval_utils.py", | ||
"modeling_bitnet.py", | ||
"tokenization_bitnet.py", | ||
"utils_quant.py", | ||
"README.md", | ||
] | ||
for file in file_list: | ||
file_path = cached_file(model_name_or_path, file) | ||
os.system(f"cp {file_path} {saved_model_path}") | ||
# load quantized model | ||
qmodel = BitnetForCausalLM.from_quantized(saved_model_path,).cuda().half() | ||
print("quantized model generated text:") | ||
print(generate_text(qmodel, tokenizer, "Hi, ", max_length=100)) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
|
||
import torch | ||
import bitblas | ||
from modeling_bitnet import BitnetForCausalLM | ||
from tokenization_bitnet import BitnetTokenizer | ||
import os | ||
from transformers import GenerationConfig | ||
import time | ||
|
||
filepath = os.path.abspath(__file__) | ||
dirpath = os.path.dirname(filepath) | ||
|
||
torch.set_grad_enabled(False) | ||
bitblas.set_log_level("INFO") | ||
|
||
model_name_or_path = "BitBLASModel/open_llama_3b_1.58bits" | ||
saved_model_path = os.path.join(dirpath, "models", f"{model_name_or_path}_bitblas") | ||
|
||
|
||
def generate_text(model, tokenizer, prompt, max_length=100): | ||
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.lm_head.weight.device) | ||
# Generate cos and sin values | ||
seq_length = input_ids.size(1) | ||
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) | ||
position_ids = position_ids.unsqueeze(0).expand_as(input_ids) | ||
|
||
generation_config = GenerationConfig( | ||
max_length=max_length, | ||
do_sample=True, | ||
top_k=50, | ||
top_p=0.95, | ||
num_return_sequences=1, | ||
) | ||
|
||
start_time = time.time() | ||
output_ids = model.generate(input_ids, generation_config=generation_config) | ||
end_time = time.time() | ||
|
||
generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True) | ||
|
||
generation_time = end_time - start_time | ||
num_tokens = len(output_ids[0]) | ||
tokens_per_second = num_tokens / generation_time | ||
|
||
print(f"Generated {num_tokens} tokens in {generation_time:.2f} seconds") | ||
print(f"Tokens per second: {tokens_per_second:.2f}") | ||
|
||
return generated_text | ||
|
||
|
||
def main(): | ||
# load quantized model | ||
qmodel = BitnetForCausalLM.from_quantized(saved_model_path,).cuda().half() | ||
tokenizer = BitnetTokenizer.from_pretrained(model_name_or_path, use_fast=False) | ||
# print("original model generated text:") | ||
# print(generate_text(model, tokenizer, "Hi, ", max_length=100)) | ||
input_ids = torch.ones((1, 1), dtype=torch.long).cuda() | ||
# naive model inference | ||
output = qmodel(input_ids) | ||
print("original model output:", output) | ||
print("quantized model generated text:") | ||
print(generate_text(qmodel, tokenizer, "Hi, ", max_length=100)) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
Oops, something went wrong.