From 9e6fd67ea0463538c1ffab2b12ac7f5abfdee438 Mon Sep 17 00:00:00 2001 From: Lei Wang <34334180+LeiWang1999@users.noreply.github.com> Date: Thu, 15 Aug 2024 13:33:38 +0800 Subject: [PATCH] [Integration] Compress Gateup and QKV for bitnet integration (#144) * Refactor BatchMatMulEmitter and BatchMatMulSelector for improved readability and maintainability * Refactor import statements for improved readability and maintainability * Refactor import statements for improved readability and maintainability * disable failure email for ci * remove email notifications. * move relax pass from testing to mlc_llm * Refactor scripts with se check_eual_ref_scripts_with_emitter function * Lint Fix * Refactor scripts with se check_eual_ref_scripts_with_emitter function * bug fix in test * lint fix. * test cuda i4 kernel * Refactor copyright notice in i4matmul.hpp * Refactor BitBLASLinear test module for improved readability and maintainability * refactor test as version below python 3.9 cannot handle int32 overflow. * format lint for test * Refactor test_int4b_fp16_convert.py for improved readability and maintainability * remove unused design file * move tile device from package to base * dummy impl for codegen * Refactor file structure for ladder_permutate module * Refactor backend class and fix typos in comments * Deep refactor Lib related code. * remove ci pull. * LintFix * refactor builder for whl build * Refactor TIRWrapper.wrap() method to include an assertion for the optimized module * Refactor lib_generator to set library and source paths * lint fix * BitNet vllm integration * chore: update codespell to version 2.3.0 * Lintfix * Bump version to 0.0.1.dev13 * lint fix * disable fast decoding [u]int4xint8 by default. * optimize from dict design in Hint * Implement SplitK * bitnet benchmark generation. * Add benchmark script for BitNet integration * AtomicAdd Support * LintFix * ci fix when 3rdparty tvm is initialized. * bug fix for setup * fix a bug in block reduce * typo fix * BUG Fix for block reduce. * Lint fix * Refactor block reduce schedule template * transform branch from bitblas to bitblas_tl * Fix subproject commit reference in 3rdparty/tvm * chore: update submodule branch from bitblas to bitblas_tl * force update config.cmake * Bug fix * Fix subproject commit reference in 3rdparty/cutlass * chore: Add submodule for cutlass library * update tl cutlass path * Refactor BitBLASLinear test module for improved readability and maintainability * format fix * Copy CUTLASS to the package directory * Refactor setup.py to include additional TVM header files * lint fix * bug fix * Refactor BitBLASLinear test module for improved readability and maintainability * Implement Matmul Benchmark Design * chore: Update BitBLAS Matmul benchmark script * lint fix * Refactor BitBLASMatmulOpsBenchmark for improved readability and maintainability * Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark run * lint fix * Benchmark bot test * Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark run * Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark run * Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark run * Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark run * Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark run * int8 test case * Refactor compare_benchmark.py to handle missing benchmark results gracefully * ci fix * disable ci for test benchmark * Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark run * remove cli installation * chore: Create virtual environment and install dependencies for benchmark * chore: Update benchmark workflow to include comparison step * Lint fix * upodate tvm cmmit * Imporve lower warp memory pass * Bug fix * Enhance to support warp schedule. * Enhance LOP3 Instructions * Enhance LOP3 Instructions * add test for stage3 propagate * implement propagate func * Stage3 Ladder Permutate integration * get_ladder_stage3_propagate * comments benchmark scirpts as the setting is too big * ci fix for benchmark * lint fix * chore: Update benchmark workflow to trigger on pull request comments * Add LDMatrix Transform 3 * Support GPTQ Test * Fuse BlockReduce Schedule * Support mma propagate 3 * Support MMA Propagate Stage 3 * Lint Fix * Merge block reduce for dequantze config. * fix codeql * chore: Update submodule reference to latest commit * chore: Disable common subexpression elimination in TIR passes * Lint Fix * 4bit related lop3 updates. * lint fix * gptq test fix * Fix for test * lint fix * lint fix * typofix * QuantCompress Test * chore: Refactor quant_compress_impl.py for readability and maintainability * Enhance docs to update latest works. * Refactor weight executors in Matmul class for improved readability and maintainability * Refactor weight executors in Matmul class for improved readability and maintainability * Refactor weight executors in Matmul class for improved readability and maintainability * removed legacy operator * Refactor weight executors in Matmul class for improved readability and maintainability * LintFix * Fix GPTQ Repack with the latest weight transform * lint fix * bug fix for rescale dequantize * test fix * typo fix * lint fix * Set default weight propagate kind into LDMatrixTransform * lint fix * bug fix * bug fix for test * set default to stage3 * revert change * lint fix * case fix * bug fix * fix for legalize * bug fix * chore: Clear global operator cache before running tests * revert optimize_stratety into SingleBatchDecodeOnly * typofix * update benchmark scripts * chore: Refactor benchmark scripts and fix typos * fix for testing * lint fix * fix import. * typo * operator benchmark * optimize * always with shared.dyn * optimize cache. * dsl fix * tqdm * chore: Add serialize_results method to benchmark_matmul_strategies.py * fix performance issue for dynamic async copy * chore: Refactor benchmark_matmul_strategies.py for improved performance and code readability * bug fix * update readme * disable block reduce for int8 * bugfix for bitnet * annotatte todo. * lint fix * regist fast_decode for int8xint4 * Refactor CUDA code to use sm architecture instead of compute architecture * compress qkv and gate up for bitnet --- .../BitNet/maint/create_bitblas_ckpt.py | 4 +- integration/BitNet/modeling_bitnet.py | 229 +++++++++++++++++- integration/BitNet/utils_quant.py | 34 ++- 3 files changed, 255 insertions(+), 12 deletions(-) diff --git a/integration/BitNet/maint/create_bitblas_ckpt.py b/integration/BitNet/maint/create_bitblas_ckpt.py index 6ddb04cba..4f0555430 100644 --- a/integration/BitNet/maint/create_bitblas_ckpt.py +++ b/integration/BitNet/maint/create_bitblas_ckpt.py @@ -80,7 +80,7 @@ def main(): output = model(input_ids) print("original model output:", output) - model.quantize() + model.quantize(fuse_qkv=True, fuse_gateup=True) print("original model generated text:") print(generate_text(model, tokenizer, "Hi, ", max_length=100)) @@ -93,6 +93,8 @@ def main(): print("quant config:") print(quant_config) quant_config["checkpoint_format"] = "bitblas" + quant_config["fuse_qkv"] = True + quant_config["fuse_gateup"] = True # save quant config quant_config_path = os.path.join(saved_model_path, "quantize_config.json") diff --git a/integration/BitNet/modeling_bitnet.py b/integration/BitNet/modeling_bitnet.py index e4e1d88ea..22a985ce0 100644 --- a/integration/BitNet/modeling_bitnet.py +++ b/integration/BitNet/modeling_bitnet.py @@ -244,6 +244,49 @@ def forward(self, x): return x +class BitnetMLPFuseGateUp(nn.Module): + + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_up_proj = BitLinear( + self.hidden_size, + self.intermediate_size * 2, + bias=False, + weight_bits=config.weight_bits, + input_bits=config.input_bits, + ) + self.down_proj = BitLinear( + self.intermediate_size, + self.hidden_size, + bias=False, + weight_bits=config.weight_bits, + input_bits=config.input_bits, + ) + self.act_fn = ACT2FN[config.hidden_act] + self.ffn_layernorm = BitnetRMSNorm(self.intermediate_size, eps=config.rms_norm_eps) + + @classmethod + def from_bit_mlp(cls, bit_mlp: BitnetMLP): + module = cls(bit_mlp.config) + # assign the weights + module.gate_up_proj.weight = nn.Parameter( + torch.cat([bit_mlp.gate_proj.weight, bit_mlp.up_proj.weight], dim=0)) + module.down_proj = bit_mlp.down_proj + module.ffn_layernorm = bit_mlp.ffn_layernorm + return module + + def forward(self, x): + gate_up = self.gate_up_proj(x) + gate, up = torch.chunk(gate_up, chunks=2, dim=-1) + x = self.act_fn(gate) * up + x = self.ffn_layernorm(x) + x = self.down_proj(x) + return x + + def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, @@ -394,6 +437,153 @@ def forward( return attn_output, attn_weights, past_key_value +class BitnetAttentionQKVFused(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: BitnetConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class.") + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads}).") + + self.qkv_proj = BitLinear( + self.hidden_size, + self.num_heads * self.head_dim + (self.num_key_value_heads * self.head_dim) * 2, + bias=config.attention_bias, + weight_bits=config.weight_bits, + input_bits=config.input_bits, + ) + self.o_proj = BitLinear( + self.hidden_size, + self.hidden_size, + bias=config.attention_bias, + weight_bits=config.weight_bits, + input_bits=config.input_bits, + ) + self._init_rope() + self.inner_attn_ln = BitnetRMSNorm(self.hidden_size, eps=config.rms_norm_eps) + + def _init_rope(self): + if self.config.rope_scaling is None: + self.rotary_emb = BitnetRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + else: + raise NotImplementedError + + @classmethod + def from_bit_attention(cls, bit_attention: BitnetAttention): + module = cls(bit_attention.config, bit_attention.layer_idx) + # assign the weights + module.qkv_proj.weight = nn.Parameter( + torch.cat([ + bit_attention.q_proj.weight, bit_attention.k_proj.weight, + bit_attention.v_proj.weight + ], + dim=0)) + if bit_attention.q_proj.bias is not None and bit_attention.k_proj.bias is not None and bit_attention.v_proj.bias is not None: + module.qkv_proj.bias = nn.Parameter( + torch.cat([ + bit_attention.q_proj.bias, bit_attention.k_proj.bias, bit_attention.v_proj.bias + ], + dim=0)) + module.o_proj = bit_attention.o_proj + module.inner_attn_ln = bit_attention.inner_attn_ln + if bit_attention.config.rope_scaling is None: + module.rotary_emb = bit_attention.rotary_emb + return module + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + qkv_states = self.qkv_proj(hidden_states) + query_states, key_states, value_states = torch.split( + qkv_states, [ + self.num_heads * self.head_dim, self.num_key_value_heads * self.head_dim, + self.num_key_value_heads * self.head_dim + ], + dim=-1) + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, + self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, + self.head_dim).transpose(1, 2) + + past_key_value = getattr(self, "past_key_value", past_key_value) + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, + self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt( + self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, :key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout( + attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}") + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.inner_attn_ln(attn_output) + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + class BitnetFlashAttention2(BitnetAttention): """ Bitnet flash attention module. This module inherits from `BitnetAttention` as the weights of the module stays @@ -1240,13 +1430,30 @@ def recursive_set(model, name, attr): obj = getattr(obj, n) setattr(obj, names[-1], attr) - def quantize(self): + def quantize(self, fuse_qkv=True, fuse_gateup=True): + for name, module in self.model.named_modules(): + # if is bitnet layer + if fuse_qkv and isinstance(module, BitnetAttention): + # create quantized version of the layer + print("Replacing BitnetAttention", name) + bitnet_attenion_qkv_fused = BitnetAttentionQKVFused.from_bit_attention(module) + self.recursive_set(self.model, name, bitnet_attenion_qkv_fused) + if fuse_gateup and isinstance(module, BitnetMLP): + # create quantized version of the layer + print("Replacing BitnetMLP", name) + bitnet_mlp_fused = BitnetMLPFuseGateUp.from_bit_mlp(module) + self.recursive_set(self.model, name, bitnet_mlp_fused) for name, module in self.model.named_modules(): # if is bitnet layer if isinstance(module, BitLinear): # create quantized version of the layer print("Quantizing module", name) - bitblas_linear = BitLinearBitBLAS.from_bit_linear(module) + if name.endswith(".qkv_proj"): + bitblas_linear = BitLinearBitBLAS.from_bit_linear(module, weight_group=3) + elif name.endswith(".gate_up_proj"): + bitblas_linear = BitLinearBitBLAS.from_bit_linear(module, weight_group=2) + else: + bitblas_linear = BitLinearBitBLAS.from_bit_linear(module) print("Replacing module", name, "with a quantized version") self.recursive_set(self.model, name, bitblas_linear) self.quantized = True @@ -1300,20 +1507,34 @@ def from_quantized( trust_remote_code=trust_remote_code, **cached_file_kwargs, ) - # only load from remote instead of local - # TODO(lei): add local support + # load quantize config quantize_file = cached_file(model_name_or_path, "quantize_config.json") assert quantize_file is not None, "quantize config file not found" import json + # get quantize format with open(quantize_file, "r") as f: quant_config = json.load(f) checkpoint_format = quant_config["checkpoint_format"] assert checkpoint_format in ["bitblas"], "quantize format not supported" + fuse_qkv = quant_config.get("fuse_qkv", True) + fuse_gateup = quant_config.get("fuse_gateup", True) import accelerate if checkpoint_format == "bitblas": model = cls(config) + for name, module in model.named_modules(): + # if is bitnet layer + if fuse_qkv and isinstance(module, BitnetAttention): + # create quantized version of the layer + print("Replacing BitnetAttention", name) + bitnet_attenion_qkv_fused = BitnetAttentionQKVFused.from_bit_attention(module) + model.recursive_set(model, name, bitnet_attenion_qkv_fused) + if fuse_gateup and isinstance(module, BitnetMLP): + # create quantized version of the layer + print("Replacing BitnetMLP", name) + bitnet_mlp_fused = BitnetMLPFuseGateUp.from_bit_mlp(module) + model.recursive_set(model, name, bitnet_mlp_fused) for name, module in model.named_modules(): if isinstance(module, BitLinear): # create quantized version of the layer diff --git a/integration/BitNet/utils_quant.py b/integration/BitNet/utils_quant.py index cb0c0f50b..3da74c213 100644 --- a/integration/BitNet/utils_quant.py +++ b/integration/BitNet/utils_quant.py @@ -101,10 +101,10 @@ def replace_weight_param_with_qweight(self): self.format = "bitblas" @classmethod - def from_bit_linear(cls, bitlinear): + def from_bit_linear(cls, bitlinear, weight_group=1): bitblas_linear = cls( bitlinear.in_features, bitlinear.out_features, weight_bits=1, input_bits=8) - sw, qweight = bitblas_linear.create_bitblas_weights(bitlinear.weight) + sw, qweight = bitblas_linear.create_bitblas_weights(bitlinear.weight, weight_group) bitblas_linear.register_buffer("qweight", qweight) bitblas_linear.register_buffer("sw", sw) if bitlinear.bias is not None: @@ -113,11 +113,31 @@ def from_bit_linear(cls, bitlinear): bitblas_linear.bias = None return bitblas_linear - def create_bitblas_weights(self, weight): - sw = 1 / weight.abs().mean().clamp(min=1e-5) - quant_weight = self.weight_quant(weight).detach() - quant_weight = self.bitblas_matmul.transform_weight(quant_weight) - qweight = nn.Parameter(quant_weight, requires_grad=False) + def create_bitblas_weights(self, weight, weight_group=1): + if weight_group: + hidden_size = weight.size(0) + group_size = hidden_size // weight_group + + sw_list = [] + qweight_list = [] + + for i in range(weight_group): + start_idx = i * group_size + end_idx = (i + 1) * group_size + + sw = 1 / weight[start_idx:end_idx].abs().mean().clamp(min=1e-5) + sw_list.append(sw.repeat(group_size)) + + qweight = self.weight_quant(weight[start_idx:end_idx]).detach() + qweight_list.append(qweight) + + sw = torch.cat(sw_list, dim=0) + qweight = torch.cat(qweight_list, dim=0) + else: + sw = 1 / weight.abs().mean().clamp(min=1e-5) + qweight = self.weight_quant(weight).detach() + qweight = self.bitblas_matmul.transform_weight(qweight) + qweight = nn.Parameter(qweight, requires_grad=False) return sw, qweight def post_process_weights(self):