Skip to content

Commit

Permalink
[Dev] Refactor Modeling BitNet to support vLLM quant linear (#84)
Browse files Browse the repository at this point in the history
* Refactor BatchMatMulEmitter and BatchMatMulSelector for improved readability and maintainability

* Refactor import statements for improved readability and maintainability

* Refactor import statements for improved readability and maintainability

* disable failure email for ci

* remove email notifications.

* move relax pass from testing to mlc_llm

* Refactor scripts with se check_eual_ref_scripts_with_emitter function

* Lint Fix

* Refactor scripts with se check_eual_ref_scripts_with_emitter function

* bug fix in test

* lint fix.

* test cuda i4 kernel

* Refactor copyright notice in i4matmul.hpp

* Refactor BitBLASLinear test module for improved readability and maintainability

* refactor test as version below python 3.9 cannot handle int32 overflow.

* format lint for test

* Refactor test_int4b_fp16_convert.py for improved readability and maintainability

* remove unused design file

* move tile device from package to base

* dummy impl for codegen

* Refactor file structure for ladder_permutate module

* Refactor backend class and fix typos in comments

* Deep refactor Lib related code.

* remove ci pull.

* LintFix

* refactor builder for whl build

* Refactor TIRWrapper.wrap() method to include an assertion for the optimized module

* Refactor lib_generator to set library and source paths

* lint fix

* BitNet vllm integration

* chore: update codespell to version 2.3.0

* Lintfix
  • Loading branch information
LeiWang1999 authored Jul 16, 2024
1 parent 66304d2 commit d2a86ea
Show file tree
Hide file tree
Showing 8 changed files with 592 additions and 174 deletions.
16 changes: 9 additions & 7 deletions bitblas/cache/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
logger = logging.getLogger(__name__)

BITBLAS_DATABASE_PATH = os.path.expanduser("~/.cache/bitblas")
BITBLAS_WRAPPED_SOURCE_NAME = "wrapper_source.cu"
BITBLAS_WRAPPED_COMPILED_NAME = "wrapper_compiled.so"


class OperatorCache:
Expand Down Expand Up @@ -107,17 +109,17 @@ def _save_operator_config_and_artifact(self, config, op_inst, config_path):
with open(optimized_file_path, "w") as optimized_file:
if op_inst.optimized_func is not None:
optimized_file.write(op_inst.optimized_func.script(show_meta=False))
if op_inst.wrapper.libpath is not None:
if op_inst.libpath is not None:
# copy lib name to the same directory as the artifact
srcpath = op_inst.wrapper.srcpath
srcpath = op_inst.srcpath
shutil.copy(
srcpath,
os.path.join(config_path, os.path.basename("wrapper_source.cu")),
os.path.join(config_path, os.path.basename(BITBLAS_WRAPPED_SOURCE_NAME)),
)
libpath = op_inst.wrapper.libpath
libpath = op_inst.libpath
shutil.copy(
libpath,
os.path.join(config_path, os.path.basename("wrapper_compiled.so")),
os.path.join(config_path, os.path.basename(BITBLAS_WRAPPED_COMPILED_NAME)),
)

def _determine_target_arch_str(self, target):
Expand All @@ -141,9 +143,9 @@ def _load_operator(self, config_path, target):
config = json.load(f)
elif file.endswith(".tar"):
rt_mod = tvm.runtime.load_module(full_path)
elif file == "wrapper_compiled.so":
elif file == BITBLAS_WRAPPED_COMPILED_NAME:
libpath = full_path
elif file == "wrapper_source.cu":
elif file == BITBLAS_WRAPPED_SOURCE_NAME:
srcpath = full_path

if mapping and config and rt_mod:
Expand Down
1 change: 1 addition & 0 deletions integration/BitNet/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
models/
110 changes: 110 additions & 0 deletions integration/BitNet/create_bitblas_ckpt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import torch
import bitblas
from modeling_bitnet import BitnetForCausalLM
from tokenization_bitnet import BitnetTokenizer
from transformers.utils.hub import cached_file
import os
from transformers import GenerationConfig
import time
import json

filepath = os.path.abspath(__file__)
dirpath = os.path.dirname(filepath)

torch.set_grad_enabled(False)
bitblas.set_log_level("INFO")

model_name_or_path = "BitBLASModel/open_llama_3b_1.58bits"
saved_model_path = os.path.join(dirpath, "models", f"{model_name_or_path}_bitblas")


def generate_text(model, tokenizer, prompt, max_length=100):
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.lm_head.weight.device)
# Generate cos and sin values
seq_length = input_ids.size(1)
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)

generation_config = GenerationConfig(
max_length=max_length,
do_sample=True,
top_k=50,
top_p=0.95,
num_return_sequences=1,
)

start_time = time.time()
output_ids = model.generate(input_ids, generation_config=generation_config)
end_time = time.time()

generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)

generation_time = end_time - start_time
num_tokens = len(output_ids[0])
tokens_per_second = num_tokens / generation_time

print(f"Generated {num_tokens} tokens in {generation_time:.2f} seconds")
print(f"Tokens per second: {tokens_per_second:.2f}")

return generated_text


def main():
model = (
BitnetForCausalLM.from_pretrained(
model_name_or_path,
use_flash_attention_2=True,
torch_dtype=torch.float16,
).cuda().half())
tokenizer = BitnetTokenizer.from_pretrained(model_name_or_path, use_fast=False)

# print("original model generated text:")
# print(generate_text(model, tokenizer, "Hi, ", max_length=100))
input_ids = torch.ones((1, 1), dtype=torch.long).cuda()
# naive model inference
output = model(input_ids)
print("original model output:", output)

model.quantize()
print("original model generated text:")
print(generate_text(model, tokenizer, "Hi, ", max_length=100))

model.save_pretrained(saved_model_path)

# load quant config
quant_config_path = cached_file(model_name_or_path, "quantize_config.json")
with open(quant_config_path, "r") as f:
quant_config = json.load(f)
print("quant config:")
print(quant_config)
quant_config["checkpoint_format"] = "bitblas"

# save quant config
quant_config_path = os.path.join(saved_model_path, "quantize_config.json")
with open(quant_config_path, "w") as f:
json.dump(quant_config, f)
print("quant config saved to:", quant_config_path)

# copy benchmark filed into saved model path
file_list = [
"configuration_bitnet.py",
"eval_utils.py",
"modeling_bitnet.py",
"tokenization_bitnet.py",
"utils_quant.py",
"README.md",
]
for file in file_list:
file_path = cached_file(model_name_or_path, file)
os.system(f"cp {file_path} {saved_model_path}")
# load quantized model
qmodel = BitnetForCausalLM.from_quantized(saved_model_path,).cuda().half()
print("quantized model generated text:")
print(generate_text(qmodel, tokenizer, "Hi, ", max_length=100))


if __name__ == '__main__':
main()
59 changes: 48 additions & 11 deletions integration/BitNet/eval_correctness.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,53 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import argparse
import torch

import bitblas
from modeling_bitnet import BitnetForCausalLM
from tokenization_bitnet import BitnetTokenizer
from transformers import GenerationConfig
import time

torch.set_grad_enabled(False)
bitblas.set_log_level("INFO")


def generate_text(model, tokenizer, prompt, max_length=100):
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.lm_head.weight.device)
# Generate cos and sin values
seq_length = input_ids.size(1)
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
# position_embeddings = model.embed_positions(position_ids)
# cos = position_embeddings[:, :, 0::2].cos()
# sin = position_embeddings[:, :, 1::2].sin()

generation_config = GenerationConfig(
max_length=max_length,
do_sample=True,
top_k=50,
top_p=0.95,
num_return_sequences=1,
)

start_time = time.time()
output_ids = model.generate(input_ids, generation_config=generation_config)
# output_ids = model.generate(input_ids, generation_config=generation_config, cos=cos, sin=sin)
end_time = time.time()

generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)

generation_time = end_time - start_time
num_tokens = len(output_ids[0])
tokens_per_second = num_tokens / generation_time

parser = argparse.ArgumentParser()
parser.add_argument('--hf_path', default='1bitLLM/bitnet_b1_58-3B', type=str)
print(f"Generated {num_tokens} tokens in {generation_time:.2f} seconds")
print(f"Tokens per second: {tokens_per_second:.2f}")

return generated_text


def profile(model, input_data):
import time

import numpy as np
model = model.cuda()
Expand All @@ -36,23 +70,26 @@ def get_runtime(num_repeats=1):
return np.mean(times)


model_path = '1bitLLM/bitnet_b1_58-3B'


def main():
model = BitnetForCausalLM.from_pretrained(
'1bitLLM/bitnet_b1_58-3B',
model_path,
use_flash_attention_2=True,
torch_dtype=torch.float16,
).cuda().half()
with torch.no_grad():
model._post_process_weights()

input_id = torch.ones(1, 1).long().cuda()

# test forward
tokenizer = BitnetTokenizer.from_pretrained(model_path, use_fast=False)
input_id = tokenizer("Hello")['input_ids']
input_id = torch.tensor(input_id).unsqueeze(0).cuda()
output = model(input_id)

# make sure the output is the same as the simulated output
print(output)

print(generate_text(model, tokenizer, "Hello", max_length=100))


if __name__ == '__main__':
main()
68 changes: 68 additions & 0 deletions integration/BitNet/load_from_quantized.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import torch
import bitblas
from modeling_bitnet import BitnetForCausalLM
from tokenization_bitnet import BitnetTokenizer
import os
from transformers import GenerationConfig
import time

filepath = os.path.abspath(__file__)
dirpath = os.path.dirname(filepath)

torch.set_grad_enabled(False)
bitblas.set_log_level("INFO")

model_name_or_path = "BitBLASModel/open_llama_3b_1.58bits"
saved_model_path = os.path.join(dirpath, "models", f"{model_name_or_path}_bitblas")


def generate_text(model, tokenizer, prompt, max_length=100):
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.lm_head.weight.device)
# Generate cos and sin values
seq_length = input_ids.size(1)
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)

generation_config = GenerationConfig(
max_length=max_length,
do_sample=True,
top_k=50,
top_p=0.95,
num_return_sequences=1,
)

start_time = time.time()
output_ids = model.generate(input_ids, generation_config=generation_config)
end_time = time.time()

generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)

generation_time = end_time - start_time
num_tokens = len(output_ids[0])
tokens_per_second = num_tokens / generation_time

print(f"Generated {num_tokens} tokens in {generation_time:.2f} seconds")
print(f"Tokens per second: {tokens_per_second:.2f}")

return generated_text


def main():
# load quantized model
qmodel = BitnetForCausalLM.from_quantized(saved_model_path,).cuda().half()
tokenizer = BitnetTokenizer.from_pretrained(model_name_or_path, use_fast=False)
# print("original model generated text:")
# print(generate_text(model, tokenizer, "Hi, ", max_length=100))
input_ids = torch.ones((1, 1), dtype=torch.long).cuda()
# naive model inference
output = qmodel(input_ids)
print("original model output:", output)
print("quantized model generated text:")
print(generate_text(qmodel, tokenizer, "Hi, ", max_length=100))


if __name__ == "__main__":
main()
Loading

0 comments on commit d2a86ea

Please sign in to comment.