Skip to content

Commit

Permalink
[Integration] Upload tutorial for making a bitnet ckpt for vLLM (#135)
Browse files Browse the repository at this point in the history
* fix install with absolute path

* efficient inference with torch compile

* update vllm ckpt tutorial for bitnet
  • Loading branch information
LeiWang1999 authored Aug 9, 2024
1 parent c6cc01e commit 7c6bccf
Show file tree
Hide file tree
Showing 12 changed files with 913 additions and 10 deletions.
2 changes: 1 addition & 1 deletion install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ fi

echo "Download and extraction completed successfully."

LLVM_CONFIG_PATH="${EXTRACT_PATH}/$(basename ${FILE_NAME} .tar.xz)/bin/llvm-config"
LLVM_CONFIG_PATH="$(realpath ${EXTRACT_PATH}/$(basename ${FILE_NAME} .tar.xz)/bin/llvm-config)"
echo "LLVM config path: $LLVM_CONFIG_PATH"

# clone and build tvm
Expand Down
33 changes: 33 additions & 0 deletions integration/BitNet/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,41 @@
license: mit
---

## Latest News

- 08/09/2024 ✨: We provide a more efficient implementation for bitnet with vLLM, which should use special model checkpoints, to make the ckpt, please reach [].

This is a BitBLAS Implementation for the reproduced 1.58bit model from [1bitLLM/bitnet_b1_58-3B](https://huggingface.co/1bitLLM/bitnet_b1_58-3B). We replaced the original simulated Int8x3bit Quantized Inference Kernel with BitBLAS INT8xINT2 Kernel. We also evaluated the model's correctness and performance through `eval_correctness.py` and `benchmark_inference_latency.py`.

## Make Checkpoints for vLLM

We provide two scripts to make the checkpoints for vLLM. The first script is `generate_bitnet_model_native_format.sh`, which is used to make a checkpoint with fp16 uncompressed metaadta, the main difference with the original checkpoint is the `quant_config.json`, which allow vLLM to load the model and execute with a quant extension.

```bash
# move to the integration directory
cd /root/to/BitBLAS/integration/BitNet
# make the checkpoint
./maint/generate_bitnet_model_native_format.sh
# the output ckpy will be saved in the `./models/bitnet_b1_58-3B` directory
```

The second script is `generate_bitnet_model_bitblas_format.sh`, which is used to make a checkpoint with BitBLAS compressed metadata, which can avoid the online dequantize sage for the profiling of vLLM, which lead to more efficient memory utilization.

```bash
./maint/generate_bitnet_model_bitblas_format.sh ./models/bitnet_3B_1.58bit ./models/bitnet_3B_1.58bit_bitblas
# the output ckpy will be saved in the `./models/bitnet_b1_58-3B_bitblas` directory
```

Finnaly, you can use the ckpt in vLLM with:

```bash
cd vllm_workspace
# inference with the ckpt with fp16 uncompressed metadata
python3 inference_with_native_format.py
# inference with the ckpt with BitBLAS compressed metadata
python3 inference_with_bitblas_format.py
```

## BitBLAS Results

### Performance
Expand Down
4 changes: 0 additions & 4 deletions integration/BitNet/eval_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@ def generate_text(model, tokenizer, prompt, max_length=100):
seq_length = input_ids.size(1)
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
# position_embeddings = model.embed_positions(position_ids)
# cos = position_embeddings[:, :, 0::2].cos()
# sin = position_embeddings[:, :, 1::2].sin()

generation_config = GenerationConfig(
max_length=max_length,
Expand All @@ -32,7 +29,6 @@ def generate_text(model, tokenizer, prompt, max_length=100):

start_time = time.time()
output_ids = model.generate(input_ids, generation_config=generation_config)
# output_ids = model.generate(input_ids, generation_config=generation_config, cos=cos, sin=sin)
end_time = time.time()

generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import argparse
import torch
import bitblas
from modeling_bitnet import BitnetForCausalLM
Expand All @@ -17,8 +18,13 @@
torch.set_grad_enabled(False)
bitblas.set_log_level("INFO")

model_name_or_path = "BitBLASModel/open_llama_3b_1.58bits"
saved_model_path = os.path.join(dirpath, "models", f"{model_name_or_path}_bitblas")
parser = argparse.ArgumentParser()
parser.add_argument("--model_name_or_path", type=str, default="BitBLASModel/open_llama_3b_1.58bits")
parser.add_argument("--saved_model_path", type=str, default=None)
args = parser.parse_args()

model_name_or_path = args.model_name_or_path
saved_model_path = os.path.join(dirpath, "models", f"{model_name_or_path}_bitblas") if args.saved_model_path is None else args.saved_model_path


def generate_text(model, tokenizer, prompt, max_length=100):
Expand Down
27 changes: 27 additions & 0 deletions integration/BitNet/maint/generate_bitnet_model_bitblas_format.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#!/bin/bash

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

# retrieve the native model input and saved model directory
MODEL_DIR=$1
SAVED_MODEL_DIR=$2

# check if the model directory exists
if [ ! -d "$MODEL_DIR" ]; then
echo "Model directory does not exist!"
exit 1
fi

# if the saved model directory does not exist, create it
# if SAVED_MODEL_DIR is not provided, we do not pass it to the script
if [ -z "$SAVED_MODEL_DIR" ]; then
python ./maint/create_bitblas_ckpt.py --model_name_or_path $MODEL_DIR
else
python ./maint/create_bitblas_ckpt.py --model_name_or_path $MODEL_DIR --saved_model_path $SAVED_MODEL_DIR
fi

# get the realpath of the saved model directory
SAVED_MODEL_DIR=$(realpath $SAVED_MODEL_DIR)

echo "Model has been converted and save to $SAVED_MODEL_DIR"
27 changes: 27 additions & 0 deletions integration/BitNet/maint/generate_bitnet_model_native_format.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#!/bin/bash

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

# require git lfs
if ! command -v git-lfs &> /dev/null; then
echo "Please install git-lfs first by running 'sudo apt install git-lfs'"
exit 1
fi

mkdir -p models

cd models

# download the model
git clone https://huggingface.co/1bitLLM/bitnet_b1_58-3B bitnet_3B_1.58bits --depth 1

# copy quantized config into the model directory
cp ../maint/quant_config.json bitnet_3B_1.58bits

# get the realpath of the model directory
MODEL_DIR=$(realpath bitnet_3B_1.58bits)

cd ..

echo "Model has been converted and save to $MODEL_DIR"
10 changes: 10 additions & 0 deletions integration/BitNet/maint/quant_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
{
"bits": 2,
"desc_act": false,
"static_groups": false,
"sym": true,
"lm_head": false,
"model_name_or_path": "1bitLLM/bitnet_b1_58-3B",
"quant_method": "bitnet",
"checkpoint_format": "bitnet"
}
13 changes: 10 additions & 3 deletions integration/BitNet/utils_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def weight_quant(weight):
result = (weight * s).round().clamp(-1, 1)
return result.type(torch.int8)

@torch.compile
def activation_quant(self, x, num_bits=8):
x = x.float()
Qn = -(2**(num_bits - 1))
Expand All @@ -146,6 +147,13 @@ def activation_quant(self, x, num_bits=8):
result = (x * s).round().clamp(Qn, Qp)
return result.type(torch.int8)

@torch.compile
def post_quant_process(self, input, si, sw):
out = input / si
out = out / sw
out = out.half()
return out

# for the correctness evaluation.
def native_forward(self, input):
quant_input = (input + (activation_quant(input, self.input_bits) - input).detach())
Expand Down Expand Up @@ -184,9 +192,8 @@ def forward(self, input):
Qp = self.Qp
si = Qp / input.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5)
# if / (si * sw) it will inf in some cases
out = fp32_out / si
out = out / sw
out = out.half()
out = self.post_quant_process(fp32_out, si, sw)

if self.bias is not None:
out += self.bias.view(1, -1).expand_as(out)
return out
Expand Down
Loading

0 comments on commit 7c6bccf

Please sign in to comment.