Skip to content

Commit

Permalink
[Integration] Compress Gateup and QKV for bitnet integration (#144)
Browse files Browse the repository at this point in the history
* Refactor BatchMatMulEmitter and BatchMatMulSelector for improved readability and maintainability

* Refactor import statements for improved readability and maintainability

* Refactor import statements for improved readability and maintainability

* disable failure email for ci

* remove email notifications.

* move relax pass from testing to mlc_llm

* Refactor scripts with se check_eual_ref_scripts_with_emitter function

* Lint Fix

* Refactor scripts with se check_eual_ref_scripts_with_emitter function

* bug fix in test

* lint fix.

* test cuda i4 kernel

* Refactor copyright notice in i4matmul.hpp

* Refactor BitBLASLinear test module for improved readability and maintainability

* refactor test as version below python 3.9 cannot handle int32 overflow.

* format lint for test

* Refactor test_int4b_fp16_convert.py for improved readability and maintainability

* remove unused design file

* move tile device from package to base

* dummy impl for codegen

* Refactor file structure for ladder_permutate module

* Refactor backend class and fix typos in comments

* Deep refactor Lib related code.

* remove ci pull.

* LintFix

* refactor builder for whl build

* Refactor TIRWrapper.wrap() method to include an assertion for the optimized module

* Refactor lib_generator to set library and source paths

* lint fix

* BitNet vllm integration

* chore: update codespell to version 2.3.0

* Lintfix

* Bump version to 0.0.1.dev13

* lint fix

* disable fast decoding [u]int4xint8 by default.

* optimize from dict design in Hint

* Implement SplitK

* bitnet benchmark generation.

* Add benchmark script for BitNet integration

* AtomicAdd Support

* LintFix

* ci fix when 3rdparty tvm is initialized.

* bug fix for setup

* fix a bug in block reduce

* typo fix

* BUG Fix for block reduce.

* Lint fix

* Refactor block reduce schedule template

* transform branch from bitblas to bitblas_tl

* Fix subproject commit reference in 3rdparty/tvm

* chore: update submodule branch from bitblas to bitblas_tl

* force update config.cmake

* Bug fix

* Fix subproject commit reference in 3rdparty/cutlass

* chore: Add submodule for cutlass library

* update tl cutlass path

* Refactor BitBLASLinear test module for improved readability and maintainability

* format fix

* Copy CUTLASS to the package directory

* Refactor setup.py to include additional TVM header files

* lint fix

* bug fix

* Refactor BitBLASLinear test module for improved readability and maintainability

* Implement Matmul Benchmark Design

* chore: Update BitBLAS Matmul benchmark script

* lint fix

* Refactor BitBLASMatmulOpsBenchmark for improved readability and maintainability

* Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark run

* lint fix

* Benchmark bot test

* Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark run

* Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark run

* Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark run

* Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark run

* Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark run

* int8 test case

* Refactor compare_benchmark.py to handle missing benchmark results gracefully

* ci fix

* disable ci for test benchmark

* Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark run

* remove cli installation

* chore: Create virtual environment and install dependencies for benchmark

* chore: Update benchmark workflow to include comparison step

* Lint fix

* upodate tvm cmmit

* Imporve lower warp memory pass

* Bug fix

* Enhance to support warp schedule.

* Enhance LOP3 Instructions

* Enhance LOP3 Instructions

* add test for stage3 propagate

* implement propagate func

* Stage3 Ladder Permutate integration

* get_ladder_stage3_propagate

* comments benchmark scirpts as the setting is too big

* ci fix for benchmark

* lint fix

* chore: Update benchmark workflow to trigger on pull request comments

* Add LDMatrix Transform 3

* Support GPTQ Test

* Fuse BlockReduce Schedule

* Support mma propagate 3

* Support MMA Propagate Stage 3

* Lint Fix

* Merge block reduce for dequantze config.

* fix codeql

* chore: Update submodule reference to latest commit

* chore: Disable common subexpression elimination in TIR passes

* Lint Fix

* 4bit related lop3 updates.

* lint fix

* gptq test fix

* Fix for test

* lint fix

* lint fix

* typofix

* QuantCompress Test

* chore: Refactor quant_compress_impl.py for readability and maintainability

* Enhance docs to update latest works.

* Refactor weight executors in Matmul class for improved readability and maintainability

* Refactor weight executors in Matmul class for improved readability and maintainability

* Refactor weight executors in Matmul class for improved readability and maintainability

* removed legacy operator

* Refactor weight executors in Matmul class for improved readability and maintainability

* LintFix

* Fix GPTQ Repack with the latest weight transform

* lint fix

* bug fix for rescale dequantize

* test fix

* typo fix

* lint fix

* Set default weight propagate kind into LDMatrixTransform

* lint fix

* bug fix

* bug fix for test

* set default to stage3

* revert change

* lint fix

* case fix

* bug fix

* fix for legalize

* bug fix

* chore: Clear global operator cache before running tests

* revert optimize_stratety into SingleBatchDecodeOnly

* typofix

* update benchmark scripts

* chore: Refactor benchmark scripts and fix typos

* fix for testing

* lint fix

* fix import.

* typo

* operator benchmark

* optimize

* always with shared.dyn

* optimize cache.

* dsl fix

* tqdm

* chore: Add serialize_results method to benchmark_matmul_strategies.py

* fix performance issue for dynamic async copy

* chore: Refactor benchmark_matmul_strategies.py for improved performance and code readability

* bug fix

* update readme

* disable block reduce for int8

* bugfix for bitnet

* annotatte todo.

* lint fix

* regist fast_decode for int8xint4

* Refactor CUDA code to use sm architecture instead of compute architecture

* compress qkv and gate up for bitnet
  • Loading branch information
LeiWang1999 authored Aug 15, 2024
1 parent 60f3e5d commit 9e6fd67
Show file tree
Hide file tree
Showing 3 changed files with 255 additions and 12 deletions.
4 changes: 3 additions & 1 deletion integration/BitNet/maint/create_bitblas_ckpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def main():
output = model(input_ids)
print("original model output:", output)

model.quantize()
model.quantize(fuse_qkv=True, fuse_gateup=True)
print("original model generated text:")
print(generate_text(model, tokenizer, "Hi, ", max_length=100))

Expand All @@ -93,6 +93,8 @@ def main():
print("quant config:")
print(quant_config)
quant_config["checkpoint_format"] = "bitblas"
quant_config["fuse_qkv"] = True
quant_config["fuse_gateup"] = True

# save quant config
quant_config_path = os.path.join(saved_model_path, "quantize_config.json")
Expand Down
229 changes: 225 additions & 4 deletions integration/BitNet/modeling_bitnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,49 @@ def forward(self, x):
return x


class BitnetMLPFuseGateUp(nn.Module):

def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_up_proj = BitLinear(
self.hidden_size,
self.intermediate_size * 2,
bias=False,
weight_bits=config.weight_bits,
input_bits=config.input_bits,
)
self.down_proj = BitLinear(
self.intermediate_size,
self.hidden_size,
bias=False,
weight_bits=config.weight_bits,
input_bits=config.input_bits,
)
self.act_fn = ACT2FN[config.hidden_act]
self.ffn_layernorm = BitnetRMSNorm(self.intermediate_size, eps=config.rms_norm_eps)

@classmethod
def from_bit_mlp(cls, bit_mlp: BitnetMLP):
module = cls(bit_mlp.config)
# assign the weights
module.gate_up_proj.weight = nn.Parameter(
torch.cat([bit_mlp.gate_proj.weight, bit_mlp.up_proj.weight], dim=0))
module.down_proj = bit_mlp.down_proj
module.ffn_layernorm = bit_mlp.ffn_layernorm
return module

def forward(self, x):
gate_up = self.gate_up_proj(x)
gate, up = torch.chunk(gate_up, chunks=2, dim=-1)
x = self.act_fn(gate) * up
x = self.ffn_layernorm(x)
x = self.down_proj(x)
return x


def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
Expand Down Expand Up @@ -394,6 +437,153 @@ def forward(
return attn_output, attn_weights, past_key_value


class BitnetAttentionQKVFused(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""

def __init__(self, config: BitnetConfig, layer_idx: Optional[int] = None):
super().__init__()
self.config = config
self.layer_idx = layer_idx
if layer_idx is None:
logger.warning_once(
f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
"lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
"when creating this class.")

self.attention_dropout = config.attention_dropout
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.is_causal = True

if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads}).")

self.qkv_proj = BitLinear(
self.hidden_size,
self.num_heads * self.head_dim + (self.num_key_value_heads * self.head_dim) * 2,
bias=config.attention_bias,
weight_bits=config.weight_bits,
input_bits=config.input_bits,
)
self.o_proj = BitLinear(
self.hidden_size,
self.hidden_size,
bias=config.attention_bias,
weight_bits=config.weight_bits,
input_bits=config.input_bits,
)
self._init_rope()
self.inner_attn_ln = BitnetRMSNorm(self.hidden_size, eps=config.rms_norm_eps)

def _init_rope(self):
if self.config.rope_scaling is None:
self.rotary_emb = BitnetRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
base=self.rope_theta,
)
else:
raise NotImplementedError

@classmethod
def from_bit_attention(cls, bit_attention: BitnetAttention):
module = cls(bit_attention.config, bit_attention.layer_idx)
# assign the weights
module.qkv_proj.weight = nn.Parameter(
torch.cat([
bit_attention.q_proj.weight, bit_attention.k_proj.weight,
bit_attention.v_proj.weight
],
dim=0))
if bit_attention.q_proj.bias is not None and bit_attention.k_proj.bias is not None and bit_attention.v_proj.bias is not None:
module.qkv_proj.bias = nn.Parameter(
torch.cat([
bit_attention.q_proj.bias, bit_attention.k_proj.bias, bit_attention.v_proj.bias
],
dim=0))
module.o_proj = bit_attention.o_proj
module.inner_attn_ln = bit_attention.inner_attn_ln
if bit_attention.config.rope_scaling is None:
module.rotary_emb = bit_attention.rotary_emb
return module

def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
qkv_states = self.qkv_proj(hidden_states)
query_states, key_states, value_states = torch.split(
qkv_states, [
self.num_heads * self.head_dim, self.num_key_value_heads * self.head_dim,
self.num_key_value_heads * self.head_dim
],
dim=-1)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads,
self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads,
self.head_dim).transpose(1, 2)

past_key_value = getattr(self, "past_key_value", past_key_value)
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states,
self.layer_idx, cache_kwargs)

key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)

attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(
self.head_dim)

if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, :key_states.shape[-2]]
attn_weights = attn_weights + causal_mask

# upcast attention to fp32
attn_weights = nn.functional.softmax(
attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(
attn_weights, p=self.attention_dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states)

if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}")

attn_output = attn_output.transpose(1, 2).contiguous()

attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

attn_output = self.inner_attn_ln(attn_output)
attn_output = self.o_proj(attn_output)

if not output_attentions:
attn_weights = None

return attn_output, attn_weights, past_key_value


class BitnetFlashAttention2(BitnetAttention):
"""
Bitnet flash attention module. This module inherits from `BitnetAttention` as the weights of the module stays
Expand Down Expand Up @@ -1240,13 +1430,30 @@ def recursive_set(model, name, attr):
obj = getattr(obj, n)
setattr(obj, names[-1], attr)

def quantize(self):
def quantize(self, fuse_qkv=True, fuse_gateup=True):
for name, module in self.model.named_modules():
# if is bitnet layer
if fuse_qkv and isinstance(module, BitnetAttention):
# create quantized version of the layer
print("Replacing BitnetAttention", name)
bitnet_attenion_qkv_fused = BitnetAttentionQKVFused.from_bit_attention(module)
self.recursive_set(self.model, name, bitnet_attenion_qkv_fused)
if fuse_gateup and isinstance(module, BitnetMLP):
# create quantized version of the layer
print("Replacing BitnetMLP", name)
bitnet_mlp_fused = BitnetMLPFuseGateUp.from_bit_mlp(module)
self.recursive_set(self.model, name, bitnet_mlp_fused)
for name, module in self.model.named_modules():
# if is bitnet layer
if isinstance(module, BitLinear):
# create quantized version of the layer
print("Quantizing module", name)
bitblas_linear = BitLinearBitBLAS.from_bit_linear(module)
if name.endswith(".qkv_proj"):
bitblas_linear = BitLinearBitBLAS.from_bit_linear(module, weight_group=3)
elif name.endswith(".gate_up_proj"):
bitblas_linear = BitLinearBitBLAS.from_bit_linear(module, weight_group=2)
else:
bitblas_linear = BitLinearBitBLAS.from_bit_linear(module)
print("Replacing module", name, "with a quantized version")
self.recursive_set(self.model, name, bitblas_linear)
self.quantized = True
Expand Down Expand Up @@ -1300,20 +1507,34 @@ def from_quantized(
trust_remote_code=trust_remote_code,
**cached_file_kwargs,
)
# only load from remote instead of local
# TODO(lei): add local support
# load quantize config
quantize_file = cached_file(model_name_or_path, "quantize_config.json")
assert quantize_file is not None, "quantize config file not found"
import json

# get quantize format
with open(quantize_file, "r") as f:
quant_config = json.load(f)
checkpoint_format = quant_config["checkpoint_format"]
assert checkpoint_format in ["bitblas"], "quantize format not supported"
fuse_qkv = quant_config.get("fuse_qkv", True)
fuse_gateup = quant_config.get("fuse_gateup", True)

import accelerate
if checkpoint_format == "bitblas":
model = cls(config)
for name, module in model.named_modules():
# if is bitnet layer
if fuse_qkv and isinstance(module, BitnetAttention):
# create quantized version of the layer
print("Replacing BitnetAttention", name)
bitnet_attenion_qkv_fused = BitnetAttentionQKVFused.from_bit_attention(module)
model.recursive_set(model, name, bitnet_attenion_qkv_fused)
if fuse_gateup and isinstance(module, BitnetMLP):
# create quantized version of the layer
print("Replacing BitnetMLP", name)
bitnet_mlp_fused = BitnetMLPFuseGateUp.from_bit_mlp(module)
model.recursive_set(model, name, bitnet_mlp_fused)
for name, module in model.named_modules():
if isinstance(module, BitLinear):
# create quantized version of the layer
Expand Down
34 changes: 27 additions & 7 deletions integration/BitNet/utils_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,10 @@ def replace_weight_param_with_qweight(self):
self.format = "bitblas"

@classmethod
def from_bit_linear(cls, bitlinear):
def from_bit_linear(cls, bitlinear, weight_group=1):
bitblas_linear = cls(
bitlinear.in_features, bitlinear.out_features, weight_bits=1, input_bits=8)
sw, qweight = bitblas_linear.create_bitblas_weights(bitlinear.weight)
sw, qweight = bitblas_linear.create_bitblas_weights(bitlinear.weight, weight_group)
bitblas_linear.register_buffer("qweight", qweight)
bitblas_linear.register_buffer("sw", sw)
if bitlinear.bias is not None:
Expand All @@ -113,11 +113,31 @@ def from_bit_linear(cls, bitlinear):
bitblas_linear.bias = None
return bitblas_linear

def create_bitblas_weights(self, weight):
sw = 1 / weight.abs().mean().clamp(min=1e-5)
quant_weight = self.weight_quant(weight).detach()
quant_weight = self.bitblas_matmul.transform_weight(quant_weight)
qweight = nn.Parameter(quant_weight, requires_grad=False)
def create_bitblas_weights(self, weight, weight_group=1):
if weight_group:
hidden_size = weight.size(0)
group_size = hidden_size // weight_group

sw_list = []
qweight_list = []

for i in range(weight_group):
start_idx = i * group_size
end_idx = (i + 1) * group_size

sw = 1 / weight[start_idx:end_idx].abs().mean().clamp(min=1e-5)
sw_list.append(sw.repeat(group_size))

qweight = self.weight_quant(weight[start_idx:end_idx]).detach()
qweight_list.append(qweight)

sw = torch.cat(sw_list, dim=0)
qweight = torch.cat(qweight_list, dim=0)
else:
sw = 1 / weight.abs().mean().clamp(min=1e-5)
qweight = self.weight_quant(weight).detach()
qweight = self.bitblas_matmul.transform_weight(qweight)
qweight = nn.Parameter(qweight, requires_grad=False)
return sw, qweight

def post_process_weights(self):
Expand Down

0 comments on commit 9e6fd67

Please sign in to comment.