From d2a86ea3fbfd3ff0061ccc9f5969affa3ffc799c Mon Sep 17 00:00:00 2001 From: Lei Wang <34334180+LeiWang1999@users.noreply.github.com> Date: Tue, 16 Jul 2024 23:43:06 +0800 Subject: [PATCH] [Dev] Refactor Modeling BitNet to support vLLM quant linear (#84) * Refactor BatchMatMulEmitter and BatchMatMulSelector for improved readability and maintainability * Refactor import statements for improved readability and maintainability * Refactor import statements for improved readability and maintainability * disable failure email for ci * remove email notifications. * move relax pass from testing to mlc_llm * Refactor scripts with se check_eual_ref_scripts_with_emitter function * Lint Fix * Refactor scripts with se check_eual_ref_scripts_with_emitter function * bug fix in test * lint fix. * test cuda i4 kernel * Refactor copyright notice in i4matmul.hpp * Refactor BitBLASLinear test module for improved readability and maintainability * refactor test as version below python 3.9 cannot handle int32 overflow. * format lint for test * Refactor test_int4b_fp16_convert.py for improved readability and maintainability * remove unused design file * move tile device from package to base * dummy impl for codegen * Refactor file structure for ladder_permutate module * Refactor backend class and fix typos in comments * Deep refactor Lib related code. * remove ci pull. * LintFix * refactor builder for whl build * Refactor TIRWrapper.wrap() method to include an assertion for the optimized module * Refactor lib_generator to set library and source paths * lint fix * BitNet vllm integration * chore: update codespell to version 2.3.0 * Lintfix --- bitblas/cache/operator.py | 16 +- integration/BitNet/.gitignore | 1 + integration/BitNet/create_bitblas_ckpt.py | 110 ++++++ integration/BitNet/eval_correctness.py | 59 ++- integration/BitNet/load_from_quantized.py | 68 ++++ integration/BitNet/modeling_bitnet.py | 416 +++++++++++++++------- integration/BitNet/utils_quant.py | 94 +++-- requirements-dev.txt | 2 +- 8 files changed, 592 insertions(+), 174 deletions(-) create mode 100644 integration/BitNet/.gitignore create mode 100644 integration/BitNet/create_bitblas_ckpt.py create mode 100644 integration/BitNet/load_from_quantized.py diff --git a/bitblas/cache/operator.py b/bitblas/cache/operator.py index 295630f5d..cbb2e0437 100644 --- a/bitblas/cache/operator.py +++ b/bitblas/cache/operator.py @@ -15,6 +15,8 @@ logger = logging.getLogger(__name__) BITBLAS_DATABASE_PATH = os.path.expanduser("~/.cache/bitblas") +BITBLAS_WRAPPED_SOURCE_NAME = "wrapper_source.cu" +BITBLAS_WRAPPED_COMPILED_NAME = "wrapper_compiled.so" class OperatorCache: @@ -107,17 +109,17 @@ def _save_operator_config_and_artifact(self, config, op_inst, config_path): with open(optimized_file_path, "w") as optimized_file: if op_inst.optimized_func is not None: optimized_file.write(op_inst.optimized_func.script(show_meta=False)) - if op_inst.wrapper.libpath is not None: + if op_inst.libpath is not None: # copy lib name to the same directory as the artifact - srcpath = op_inst.wrapper.srcpath + srcpath = op_inst.srcpath shutil.copy( srcpath, - os.path.join(config_path, os.path.basename("wrapper_source.cu")), + os.path.join(config_path, os.path.basename(BITBLAS_WRAPPED_SOURCE_NAME)), ) - libpath = op_inst.wrapper.libpath + libpath = op_inst.libpath shutil.copy( libpath, - os.path.join(config_path, os.path.basename("wrapper_compiled.so")), + os.path.join(config_path, os.path.basename(BITBLAS_WRAPPED_COMPILED_NAME)), ) def _determine_target_arch_str(self, target): @@ -141,9 +143,9 @@ def _load_operator(self, config_path, target): config = json.load(f) elif file.endswith(".tar"): rt_mod = tvm.runtime.load_module(full_path) - elif file == "wrapper_compiled.so": + elif file == BITBLAS_WRAPPED_COMPILED_NAME: libpath = full_path - elif file == "wrapper_source.cu": + elif file == BITBLAS_WRAPPED_SOURCE_NAME: srcpath = full_path if mapping and config and rt_mod: diff --git a/integration/BitNet/.gitignore b/integration/BitNet/.gitignore new file mode 100644 index 000000000..6ea887496 --- /dev/null +++ b/integration/BitNet/.gitignore @@ -0,0 +1 @@ +models/ \ No newline at end of file diff --git a/integration/BitNet/create_bitblas_ckpt.py b/integration/BitNet/create_bitblas_ckpt.py new file mode 100644 index 000000000..d443b2e20 --- /dev/null +++ b/integration/BitNet/create_bitblas_ckpt.py @@ -0,0 +1,110 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import torch +import bitblas +from modeling_bitnet import BitnetForCausalLM +from tokenization_bitnet import BitnetTokenizer +from transformers.utils.hub import cached_file +import os +from transformers import GenerationConfig +import time +import json + +filepath = os.path.abspath(__file__) +dirpath = os.path.dirname(filepath) + +torch.set_grad_enabled(False) +bitblas.set_log_level("INFO") + +model_name_or_path = "BitBLASModel/open_llama_3b_1.58bits" +saved_model_path = os.path.join(dirpath, "models", f"{model_name_or_path}_bitblas") + + +def generate_text(model, tokenizer, prompt, max_length=100): + input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.lm_head.weight.device) + # Generate cos and sin values + seq_length = input_ids.size(1) + position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) + position_ids = position_ids.unsqueeze(0).expand_as(input_ids) + + generation_config = GenerationConfig( + max_length=max_length, + do_sample=True, + top_k=50, + top_p=0.95, + num_return_sequences=1, + ) + + start_time = time.time() + output_ids = model.generate(input_ids, generation_config=generation_config) + end_time = time.time() + + generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True) + + generation_time = end_time - start_time + num_tokens = len(output_ids[0]) + tokens_per_second = num_tokens / generation_time + + print(f"Generated {num_tokens} tokens in {generation_time:.2f} seconds") + print(f"Tokens per second: {tokens_per_second:.2f}") + + return generated_text + + +def main(): + model = ( + BitnetForCausalLM.from_pretrained( + model_name_or_path, + use_flash_attention_2=True, + torch_dtype=torch.float16, + ).cuda().half()) + tokenizer = BitnetTokenizer.from_pretrained(model_name_or_path, use_fast=False) + + # print("original model generated text:") + # print(generate_text(model, tokenizer, "Hi, ", max_length=100)) + input_ids = torch.ones((1, 1), dtype=torch.long).cuda() + # naive model inference + output = model(input_ids) + print("original model output:", output) + + model.quantize() + print("original model generated text:") + print(generate_text(model, tokenizer, "Hi, ", max_length=100)) + + model.save_pretrained(saved_model_path) + + # load quant config + quant_config_path = cached_file(model_name_or_path, "quantize_config.json") + with open(quant_config_path, "r") as f: + quant_config = json.load(f) + print("quant config:") + print(quant_config) + quant_config["checkpoint_format"] = "bitblas" + + # save quant config + quant_config_path = os.path.join(saved_model_path, "quantize_config.json") + with open(quant_config_path, "w") as f: + json.dump(quant_config, f) + print("quant config saved to:", quant_config_path) + + # copy benchmark filed into saved model path + file_list = [ + "configuration_bitnet.py", + "eval_utils.py", + "modeling_bitnet.py", + "tokenization_bitnet.py", + "utils_quant.py", + "README.md", + ] + for file in file_list: + file_path = cached_file(model_name_or_path, file) + os.system(f"cp {file_path} {saved_model_path}") + # load quantized model + qmodel = BitnetForCausalLM.from_quantized(saved_model_path,).cuda().half() + print("quantized model generated text:") + print(generate_text(qmodel, tokenizer, "Hi, ", max_length=100)) + + +if __name__ == '__main__': + main() diff --git a/integration/BitNet/eval_correctness.py b/integration/BitNet/eval_correctness.py index 578715da4..cef89313d 100644 --- a/integration/BitNet/eval_correctness.py +++ b/integration/BitNet/eval_correctness.py @@ -1,19 +1,53 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import argparse import torch - +import bitblas from modeling_bitnet import BitnetForCausalLM +from tokenization_bitnet import BitnetTokenizer +from transformers import GenerationConfig +import time torch.set_grad_enabled(False) +bitblas.set_log_level("INFO") + + +def generate_text(model, tokenizer, prompt, max_length=100): + input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.lm_head.weight.device) + # Generate cos and sin values + seq_length = input_ids.size(1) + position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) + position_ids = position_ids.unsqueeze(0).expand_as(input_ids) + # position_embeddings = model.embed_positions(position_ids) + # cos = position_embeddings[:, :, 0::2].cos() + # sin = position_embeddings[:, :, 1::2].sin() + + generation_config = GenerationConfig( + max_length=max_length, + do_sample=True, + top_k=50, + top_p=0.95, + num_return_sequences=1, + ) + + start_time = time.time() + output_ids = model.generate(input_ids, generation_config=generation_config) + # output_ids = model.generate(input_ids, generation_config=generation_config, cos=cos, sin=sin) + end_time = time.time() + + generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True) + + generation_time = end_time - start_time + num_tokens = len(output_ids[0]) + tokens_per_second = num_tokens / generation_time -parser = argparse.ArgumentParser() -parser.add_argument('--hf_path', default='1bitLLM/bitnet_b1_58-3B', type=str) + print(f"Generated {num_tokens} tokens in {generation_time:.2f} seconds") + print(f"Tokens per second: {tokens_per_second:.2f}") + + return generated_text def profile(model, input_data): - import time import numpy as np model = model.cuda() @@ -36,23 +70,26 @@ def get_runtime(num_repeats=1): return np.mean(times) +model_path = '1bitLLM/bitnet_b1_58-3B' + + def main(): model = BitnetForCausalLM.from_pretrained( - '1bitLLM/bitnet_b1_58-3B', + model_path, use_flash_attention_2=True, torch_dtype=torch.float16, ).cuda().half() with torch.no_grad(): model._post_process_weights() - input_id = torch.ones(1, 1).long().cuda() - - # test forward + tokenizer = BitnetTokenizer.from_pretrained(model_path, use_fast=False) + input_id = tokenizer("Hello")['input_ids'] + input_id = torch.tensor(input_id).unsqueeze(0).cuda() output = model(input_id) - - # make sure the output is the same as the simulated output print(output) + print(generate_text(model, tokenizer, "Hello", max_length=100)) + if __name__ == '__main__': main() diff --git a/integration/BitNet/load_from_quantized.py b/integration/BitNet/load_from_quantized.py new file mode 100644 index 000000000..acea3bd0a --- /dev/null +++ b/integration/BitNet/load_from_quantized.py @@ -0,0 +1,68 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import torch +import bitblas +from modeling_bitnet import BitnetForCausalLM +from tokenization_bitnet import BitnetTokenizer +import os +from transformers import GenerationConfig +import time + +filepath = os.path.abspath(__file__) +dirpath = os.path.dirname(filepath) + +torch.set_grad_enabled(False) +bitblas.set_log_level("INFO") + +model_name_or_path = "BitBLASModel/open_llama_3b_1.58bits" +saved_model_path = os.path.join(dirpath, "models", f"{model_name_or_path}_bitblas") + + +def generate_text(model, tokenizer, prompt, max_length=100): + input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.lm_head.weight.device) + # Generate cos and sin values + seq_length = input_ids.size(1) + position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) + position_ids = position_ids.unsqueeze(0).expand_as(input_ids) + + generation_config = GenerationConfig( + max_length=max_length, + do_sample=True, + top_k=50, + top_p=0.95, + num_return_sequences=1, + ) + + start_time = time.time() + output_ids = model.generate(input_ids, generation_config=generation_config) + end_time = time.time() + + generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True) + + generation_time = end_time - start_time + num_tokens = len(output_ids[0]) + tokens_per_second = num_tokens / generation_time + + print(f"Generated {num_tokens} tokens in {generation_time:.2f} seconds") + print(f"Tokens per second: {tokens_per_second:.2f}") + + return generated_text + + +def main(): + # load quantized model + qmodel = BitnetForCausalLM.from_quantized(saved_model_path,).cuda().half() + tokenizer = BitnetTokenizer.from_pretrained(model_name_or_path, use_fast=False) + # print("original model generated text:") + # print(generate_text(model, tokenizer, "Hi, ", max_length=100)) + input_ids = torch.ones((1, 1), dtype=torch.long).cuda() + # naive model inference + output = qmodel(input_ids) + print("original model output:", output) + print("quantized model generated text:") + print(generate_text(qmodel, tokenizer, "Hi, ", max_length=100)) + + +if __name__ == "__main__": + main() diff --git a/integration/BitNet/modeling_bitnet.py b/integration/BitNet/modeling_bitnet.py index 11be4059f..e4e1d88ea 100644 --- a/integration/BitNet/modeling_bitnet.py +++ b/integration/BitNet/modeling_bitnet.py @@ -31,7 +31,6 @@ from transformers.activations import ACT2FN from transformers.cache_utils import Cache, DynamicCache, StaticCache -from transformers.modeling_attn_mask_utils import AttentionMaskConverter from transformers.modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -49,14 +48,27 @@ replace_return_docstrings, ) from configuration_bitnet import BitnetConfig -from utils_quant import BitLinear - +from utils_quant import BitLinear, BitLinearBitBLAS +from transformers.utils.hub import cached_file if is_flash_attn_2_available(): from flash_attn import flash_attn_func, flash_attn_varlen_func from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa: F401 +def find_layers(module, layers=None, name=""): + if not layers: + layers = [nn.Linear] + for layer in layers: + if isinstance(module, layer): + return {name: module} + res = {} + for name1, child in module.named_children(): + res.update( + find_layers(child, layers=layers, name=name + "." + name1 if name != "" else name1)) + return res + + logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "BitnetConfig" @@ -75,6 +87,7 @@ def _get_unpad_data(attention_mask): class BitnetRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): """ BitnetRMSNorm is equivalent to T5LayerNorm @@ -95,23 +108,34 @@ def forward(self, hidden_states): class BitnetRotaryEmbedding(nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + + def __init__(self, + dim, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0): super().__init__() self.scaling_factor = scaling_factor self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) + inv_freq = 1.0 / ( + self.base + **(torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) self.register_buffer("inv_freq", inv_freq) # For BC we register cos and sin cached self.max_seq_len_cached = max_position_embeddings - t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) + t = torch.arange( + self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) t = t / self.scaling_factor freqs = torch.outer(t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("_cos_cached", emb.cos().to(torch.get_default_dtype()), persistent=False) - self.register_buffer("_sin_cached", emb.sin().to(torch.get_default_dtype()), persistent=False) + self.register_buffer( + "_cos_cached", emb.cos().to(torch.get_default_dtype()), persistent=False) + self.register_buffer( + "_sin_cached", emb.sin().to(torch.get_default_dtype()), persistent=False) @property def sin_cached(self): @@ -132,12 +156,14 @@ def cos_cached(self): @torch.no_grad() def forward(self, x, position_ids): # x: [bs, num_attention_heads, seq_len, head_size] - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + inv_freq_expanded = self.inv_freq[None, :, + None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() # Force float32 since bfloat16 loses precision on long contexts # See https://github.com/huggingface/transformers/pull/29285 device_type = x.device.type - device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + device_type = device_type if isinstance(device_type, + str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) @@ -148,8 +174,8 @@ def forward(self, x, position_ids): def rotate_half(x): """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] + x1 = x[..., :x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2:] return torch.cat((-x2, x1), dim=-1) @@ -181,22 +207,32 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): class BitnetMLP(nn.Module): + def __init__(self, config): super().__init__() self.config = config self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size self.gate_proj = BitLinear( - self.hidden_size, self.intermediate_size, bias=False, - weight_bits=config.weight_bits, input_bits=config.input_bits, + self.hidden_size, + self.intermediate_size, + bias=False, + weight_bits=config.weight_bits, + input_bits=config.input_bits, ) self.up_proj = BitLinear( - self.hidden_size, self.intermediate_size, bias=False, - weight_bits=config.weight_bits, input_bits=config.input_bits, + self.hidden_size, + self.intermediate_size, + bias=False, + weight_bits=config.weight_bits, + input_bits=config.input_bits, ) self.down_proj = BitLinear( - self.intermediate_size, self.hidden_size, bias=False, - weight_bits=config.weight_bits, input_bits=config.input_bits, + self.intermediate_size, + self.hidden_size, + bias=False, + weight_bits=config.weight_bits, + input_bits=config.input_bits, ) self.act_fn = ACT2FN[config.hidden_act] self.ffn_layernorm = BitnetRMSNorm(self.intermediate_size, eps=config.rms_norm_eps) @@ -216,7 +252,8 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, + head_dim) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) @@ -231,8 +268,7 @@ def __init__(self, config: BitnetConfig, layer_idx: Optional[int] = None): logger.warning_once( f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " - "when creating this class." - ) + "when creating this class.") self.attention_dropout = config.attention_dropout self.hidden_size = config.hidden_size @@ -247,24 +283,35 @@ def __init__(self, config: BitnetConfig, layer_idx: Optional[int] = None): if (self.head_dim * self.num_heads) != self.hidden_size: raise ValueError( f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads})." - ) + f" and `num_heads`: {self.num_heads}).") self.q_proj = BitLinear( - self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias, - weight_bits=config.weight_bits, input_bits=config.input_bits, + self.hidden_size, + self.num_heads * self.head_dim, + bias=config.attention_bias, + weight_bits=config.weight_bits, + input_bits=config.input_bits, ) self.k_proj = BitLinear( - self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias, - weight_bits=config.weight_bits, input_bits=config.input_bits, + self.hidden_size, + self.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + weight_bits=config.weight_bits, + input_bits=config.input_bits, ) self.v_proj = BitLinear( - self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias, - weight_bits=config.weight_bits, input_bits=config.input_bits, + self.hidden_size, + self.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + weight_bits=config.weight_bits, + input_bits=config.input_bits, ) self.o_proj = BitLinear( - self.hidden_size, self.hidden_size, bias=config.attention_bias, - weight_bits=config.weight_bits, input_bits=config.input_bits, + self.hidden_size, + self.hidden_size, + bias=config.attention_bias, + weight_bits=config.weight_bits, + input_bits=config.input_bits, ) self._init_rope() self.inner_attn_ln = BitnetRMSNorm(self.hidden_size, eps=config.rms_norm_eps) @@ -297,8 +344,10 @@ def forward( value_states = self.v_proj(hidden_states) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, + self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, + self.head_dim).transpose(1, 2) past_key_value = getattr(self, "past_key_value", past_key_value) cos, sin = self.rotary_emb(value_states, position_ids) @@ -307,27 +356,30 @@ def forward( if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_value.update(key_states, value_states, + self.layer_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt( + self.head_dim) if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + causal_mask = attention_mask[:, :, :, :key_states.shape[-2]] attn_weights = attn_weights + causal_mask # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_weights = nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout( + attn_weights, p=self.attention_dropout, training=self.training) attn_output = torch.matmul(attn_weights, value_states) if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): raise ValueError( f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) + f" {attn_output.size()}") attn_output = attn_output.transpose(1, 2).contiguous() @@ -353,7 +405,7 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() @@ -380,8 +432,10 @@ def forward( # batch_size x seq_length x head_dim x hidden_dim # therefore we just need to keep the original shape query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, + self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, + self.head_dim).transpose(1, 2) cos, sin = self.rotary_emb(value_states, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) @@ -391,7 +445,8 @@ def forward( if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_value.update(key_states, value_states, + self.layer_idx, cache_kwargs) # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache # to be able to avoid many of these transpose/reshape/view. @@ -420,16 +475,14 @@ def forward( logger.warning_once( f"The input hidden states seems to be silently casted in float32, this might be related to" f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) + f" {target_dtype}.") query_states = query_states.to(target_dtype) key_states = key_states.to(target_dtype) value_states = value_states.to(target_dtype) attn_output = self._flash_attention_forward( - query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate - ) + query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() attn_output = self.inner_attn_ln(attn_output) @@ -440,9 +493,14 @@ def forward( return attn_output, attn_weights, past_key_value - def _flash_attention_forward( - self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None - ): + def _flash_attention_forward(self, + query_states, + key_states, + value_states, + attention_mask, + query_length, + dropout=0.0, + softmax_scale=None): """ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token first unpad the input, then computes the attention scores and pad the final attention scores. @@ -472,8 +530,7 @@ def _flash_attention_forward( if attention_mask is not None: batch_size = query_states.shape[0] query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( - query_states, key_states, value_states, attention_mask, query_length - ) + query_states, key_states, value_states, attention_mask, query_length) cu_seqlens_q, cu_seqlens_k = cu_seq_lens max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens @@ -494,8 +551,12 @@ def _flash_attention_forward( attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) else: attn_output = flash_attn_func( - query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal - ) + query_states, + key_states, + value_states, + dropout, + softmax_scale=softmax_scale, + causal=causal) return attn_output @@ -504,29 +565,27 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape key_layer = index_first_axis( - key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k - ) + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k) value_layer = index_first_axis( - value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k - ) + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k) if query_length == kv_seq_len: query_layer = index_first_axis( - query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k - ) + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k) cu_seqlens_q = cu_seqlens_k max_seqlen_in_batch_q = max_seqlen_in_batch_k indices_q = indices_k elif query_length == 1: max_seqlen_in_batch_q = 1 cu_seqlens_q = torch.arange( - batch_size + 1, dtype=torch.int32, device=query_layer.device - ) # There is a memcpy here, that is very bad. + batch_size + 1, dtype=torch.int32, + device=query_layer.device) # There is a memcpy here, that is very bad. indices_q = cu_seqlens_q[:-1] query_layer = query_layer.squeeze(1) else: # The -q_len: slice assumes left padding. attention_mask = attention_mask[:, -query_length:] - query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input( + query_layer, attention_mask) return ( query_layer, @@ -538,7 +597,6 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query ) - LLAMA_ATTENTION_CLASSES = { "eager": BitnetAttention, "flash_attention_2": BitnetFlashAttention2, @@ -546,11 +604,13 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query class BitnetDecoderLayer(nn.Module): + def __init__(self, config: BitnetConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation]( + config=config, layer_idx=layer_idx) self.mlp = BitnetMLP(config) self.input_layernorm = BitnetRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -583,8 +643,8 @@ def forward( """ if "padding_mask" in kwargs: warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`", + stacklevel=2) residual = hidden_states @@ -676,8 +736,7 @@ def _setup_cache(self, cache_cls, max_batch_size, max_cache_len: Optional[int] = else: dtype = layer.self_attn.o_proj.weight.dtype layer.self_attn.past_key_value = cache_cls( - self.config, max_batch_size, max_cache_len, device=device, dtype=dtype - ) + self.config, max_batch_size, max_cache_len, device=device, dtype=dtype) def _reset_cache(self): for layer in self.model.layers: @@ -776,9 +835,9 @@ def __init__(self, config: BitnetConfig): self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) - self.layers = nn.ModuleList( - [BitnetDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] - ) + self.layers = nn.ModuleList([ + BitnetDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers) + ]) self.norm = BitnetRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.gradient_checkpointing = False @@ -807,8 +866,8 @@ def forward( ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) + output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict @@ -827,17 +886,17 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) past_seen_tokens = 0 - if use_cache: # kept for BC (cache positions) - if not isinstance(past_key_values, StaticCache): - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_seen_tokens = past_key_values.get_seq_length() + if use_cache and not isinstance(past_key_values, StaticCache): + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_seen_tokens = past_key_values.get_seq_length() if cache_position is None: if isinstance(past_key_values, StaticCache): raise ValueError("cache_position is a required argument when using StaticCache.") cache_position = torch.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device - ) + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1], + device=inputs_embeds.device) if position_ids is None: position_ids = cache_position.unsqueeze(0) @@ -895,10 +954,11 @@ def forward( next_cache = None if use_cache: next_cache = ( - next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache - ) + next_decoder_cache.to_legacy_cache() + if isinstance(next_decoder_cache, Cache) else next_decoder_cache) if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] + if v is not None) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache, @@ -923,10 +983,13 @@ def _update_causal_mask(self, attention_mask, input_tensor, cache_position): target_length = self.config.max_position_embeddings else: # dynamic cache target_length = ( - attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else cache_position[-1] + 1 - ) + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) else cache_position[-1] + 1) - causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) + causal_mask = torch.full((sequence_length, target_length), + fill_value=min_dtype, + dtype=dtype, + device=device) if sequence_length != 1: causal_mask = torch.triu(causal_mask, diagonal=1) causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) @@ -935,8 +998,10 @@ def _update_causal_mask(self, attention_mask, input_tensor, cache_position): causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit if attention_mask.dim() == 2: mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0) - causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype) + padding_mask = causal_mask[..., :mask_length].eq( + 0.0) * attention_mask[:, None, None, :].eq(0.0) + causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill( + padding_mask, min_dtype) elif attention_mask.dim() == 4: # backwards compatibility: we allow passing a 4D attention mask shorter than the input length with # cache. In that case, the 4D attention mask attends to the newest tokens only. @@ -946,9 +1011,8 @@ def _update_causal_mask(self, attention_mask, input_tensor, cache_position): offset = 0 mask_shape = attention_mask.shape mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype - causal_mask[ - : mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3] - ] = mask_slice + causal_mask[:mask_shape[0], :mask_shape[1], + offset:mask_shape[2] + offset, :mask_shape[3]] = mask_slice return causal_mask @@ -961,7 +1025,7 @@ def __init__(self, config): self.model = BitnetModel(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - + self.quantized = False # Initialize weights and apply final processing self.post_init() @@ -1026,8 +1090,8 @@ def forward( ```""" output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) + output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) @@ -1073,9 +1137,13 @@ def forward( attentions=outputs.attentions, ) - def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, **kwargs - ): + def prepare_inputs_for_generation(self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + **kwargs): # With static cache, the `past_key_values` is None # TODO joao: standardize interface for the different Cache classes and remove of this if has_static_cache = False @@ -1086,13 +1154,13 @@ def prepare_inputs_for_generation( past_length = 0 if past_key_values is not None: if isinstance(past_key_values, Cache): - past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() + past_length = cache_position[ + 0] if cache_position is not None else past_key_values.get_seq_length() max_cache_length = ( torch.tensor(past_key_values.get_max_length(), device=input_ids.device) - if past_key_values.get_max_length() is not None - else None - ) - cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length) + if past_key_values.get_max_length() is not None else None) + cache_length = past_length if max_cache_length is None else torch.min( + max_cache_length, past_length) # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects else: cache_length = past_length = past_key_values[0][0].shape[2] @@ -1103,7 +1171,7 @@ def prepare_inputs_for_generation( # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as # input) if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: - input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length):] # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard # input_ids based on the past_length. elif past_length < input_ids.shape[1]: @@ -1111,11 +1179,8 @@ def prepare_inputs_for_generation( # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. - if ( - max_cache_length is not None - and attention_mask is not None - and cache_length + input_ids.shape[1] > max_cache_length - ): + if (max_cache_length is not None and attention_mask is not None and + cache_length + input_ids.shape[1] > max_cache_length): attention_mask = attention_mask[:, -max_cache_length:] position_ids = kwargs.get("position_ids", None) @@ -1124,7 +1189,7 @@ def prepare_inputs_for_generation( position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] + position_ids = position_ids[:, -input_ids.shape[1]:] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: @@ -1137,40 +1202,134 @@ def prepare_inputs_for_generation( input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] if cache_position is None: - cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) + cache_position = torch.arange( + past_length, past_length + input_length, device=input_ids.device) else: cache_position = cache_position[-input_length:] if has_static_cache: past_key_values = None - model_inputs.update( - { - "position_ids": position_ids, - "cache_position": cache_position, - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "attention_mask": attention_mask, - } - ) + model_inputs.update({ + "position_ids": position_ids, + "cache_position": cache_position, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + }) return model_inputs @staticmethod def _reorder_cache(past_key_values, beam_idx): reordered_past = () for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) + reordered_past += (tuple( + past_state.index_select(0, beam_idx.to(past_state.device)) + for past_state in layer_past),) return reordered_past - + @staticmethod + def recursive_set(model, name, attr): + ''' + set layers.25.mlp.up_proj to attr + ''' + + names = name.split('.') + obj = model + for n in names[:-1]: + obj = getattr(obj, n) + setattr(obj, names[-1], attr) + + def quantize(self): + for name, module in self.model.named_modules(): + # if is bitnet layer + if isinstance(module, BitLinear): + # create quantized version of the layer + print("Quantizing module", name) + bitblas_linear = BitLinearBitBLAS.from_bit_linear(module) + print("Replacing module", name, "with a quantized version") + self.recursive_set(self.model, name, bitblas_linear) + self.quantized = True + def _post_process_weights(self): for name, module in self.model.named_modules(): if hasattr(module, "post_process_weights"): print("Post processing weights for module", name) module.post_process_weights() + def _replace_weight_param_with_qweight(self): + for name, module in self.model.named_modules(): + if hasattr(module, "replace_weight_param_with_qweight"): + print("Replacing weight param with qweight for module", name) + module.replace_weight_param_with_qweight() + + @classmethod + def from_quantized( + cls, + model_name_or_path: Optional[str], + trust_remote_code: bool = False, + **kwargs, + ): + """load quantized model from local disk""" + # Parameters related to loading from Hugging Face Hub + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", False) + use_auth_token = kwargs.pop("use_auth_token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", "") + commit_hash = kwargs.pop("_commit_hash", None) + + cached_file_kwargs = { + "cache_dir": cache_dir, + "force_download": force_download, + "proxies": proxies, + "resume_download": resume_download, + "local_files_only": local_files_only, + "use_auth_token": use_auth_token, + "revision": revision, + "subfolder": subfolder, + "_raise_exceptions_for_missing_entries": False, + "_commit_hash": commit_hash, + } + # == step1: prepare configs and file names == # + config: BitnetConfig = BitnetConfig.from_pretrained( + model_name_or_path, + trust_remote_code=trust_remote_code, + **cached_file_kwargs, + ) + # only load from remote instead of local + # TODO(lei): add local support + quantize_file = cached_file(model_name_or_path, "quantize_config.json") + assert quantize_file is not None, "quantize config file not found" + import json + # get quantize format + with open(quantize_file, "r") as f: + quant_config = json.load(f) + checkpoint_format = quant_config["checkpoint_format"] + assert checkpoint_format in ["bitblas"], "quantize format not supported" + + import accelerate + if checkpoint_format == "bitblas": + model = cls(config) + for name, module in model.named_modules(): + if isinstance(module, BitLinear): + # create quantized version of the layer + print("Quantizing module", name) + bitblas_linear = BitLinearBitBLAS.from_bit_linear(module) + print("Replacing module", name, "with a quantized version") + model.recursive_set(model, name, bitblas_linear) + accelerate.utils.modeling.load_checkpoint_in_model( + model, + checkpoint=model_name_or_path, + offload_state_dict=True, + offload_buffers=True, + ) + return model + + @add_start_docstrings( """ The LLaMa Model transformer with a sequence classification head on top (linear layer). @@ -1187,6 +1346,7 @@ def _post_process_weights(self): LLAMA_START_DOCSTRING, ) class BitnetForSequenceClassification(BitnetPreTrainedModel): + def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels @@ -1250,7 +1410,8 @@ def forward( else: if input_ids is not None: # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility - sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = torch.eq(input_ids, + self.config.pad_token_id).int().argmax(-1) - 1 sequence_lengths = sequence_lengths % input_ids.shape[-1] sequence_lengths = sequence_lengths.to(logits.device) else: @@ -1264,7 +1425,8 @@ def forward( if self.config.problem_type is None: if self.num_labels == 1: self.config.problem_type = "regression" - elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + elif self.num_labels > 1 and (labels.dtype == torch.long or + labels.dtype == torch.int): self.config.problem_type = "single_label_classification" else: self.config.problem_type = "multi_label_classification" @@ -1390,4 +1552,4 @@ def forward( end_logits=end_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, - ) \ No newline at end of file + ) diff --git a/integration/BitNet/utils_quant.py b/integration/BitNet/utils_quant.py index 121649387..d9cc25ae7 100644 --- a/integration/BitNet/utils_quant.py +++ b/integration/BitNet/utils_quant.py @@ -6,14 +6,12 @@ import torch from torch import nn -import bitblas from bitblas.cache import global_operator_cache, get_database_path from bitblas import Matmul, MatmulConfig from bitblas import auto_detect_nvidia_target from logging import getLogger logger = getLogger(__name__) -bitblas.set_log_level("INFO") BITBLAS_TARGET = auto_detect_nvidia_target() BITBLAS_DATABASE_PATH = get_database_path() @@ -36,14 +34,22 @@ def activation_quant(x, num_bits=8): return result.type(dtype) -# BitBLAS BitLinear -class BitLinear(nn.Linear): +class BitLinearBitBLAS(nn.Module): - def __init__(self, *kargs, weight_bits=1, input_bits=8, **kwargs): - super(BitLinear, self).__init__(*kargs, **kwargs) + def __init__( + self, + in_features: int, + out_features: int, + weight_bits=1, + input_bits=8, + **kwargs, + ): + super().__init__() """ RMSNorm is placed outside BitLinear """ + self.in_features = in_features + self.out_features = out_features self.weight_bits = weight_bits self.input_bits = input_bits matmul_config = MatmulConfig( @@ -64,6 +70,7 @@ def __init__(self, *kargs, weight_bits=1, input_bits=8, **kwargs): ENABLE_TUNING = True self.bitblas_matmul = self._get_or_create_bitblas_operator(matmul_config, ENABLE_TUNING) + self.format = "bitnet" self.Qp = 2**(self.input_bits - 1) - 1 def _get_or_create_bitblas_operator(self, config, enable_tuning): @@ -86,14 +93,46 @@ def _get_or_create_bitblas_operator(self, config, enable_tuning): print("BitBLAS Operator found in global_operator_cache.") return bitblas_matmul + def replace_weight_param_with_qweight(self): + if hasattr(self, "weight"): + del self.weight + quant_weight = torch.empty(self.bitblas_matmul.retrieve_weight_shape()) + self.qweight = nn.Parameter(quant_weight, requires_grad=False) + self.format = "bitblas" + + @classmethod + def from_bit_linear(cls, bitlinear): + bitblas_linear = cls( + bitlinear.in_features, bitlinear.out_features, weight_bits=1, input_bits=8) + sw, qweight = bitblas_linear.create_bitblas_weights(bitlinear.weight) + bitblas_linear.register_buffer("qweight", qweight) + bitblas_linear.register_buffer("sw", sw) + if bitlinear.bias is not None: + bitblas_linear.register_buffer("bias", bitlinear.bias) + else: + bitblas_linear.bias = None + return bitblas_linear + + def create_bitblas_weights(self, weight): + sw = 1 / weight.abs().mean().clamp(min=1e-5) + quant_weight = self.weight_quant(weight).detach() + quant_weight = self.bitblas_matmul.transform_weight(quant_weight) + qweight = nn.Parameter(quant_weight, requires_grad=False) + return sw, qweight + def post_process_weights(self): sw = 1 / self.weight.abs().mean().clamp(min=1e-5) self.sw = sw quant_weight = self.weight_quant(self.weight).detach() quant_weight = self.bitblas_matmul.transform_weight(quant_weight) - self.weight = nn.Parameter(quant_weight, requires_grad=False) - - def weight_quant(self, weight): + # remove self.weight and replace it with quant_weight + if hasattr(self, "weight"): + del self.weight + self.qweight = nn.Parameter(quant_weight, requires_grad=False) + self.format = "bitblas" + + @staticmethod + def weight_quant(weight): weight = weight.float() s = 1 / weight.abs().mean().clamp(min=1e-5) result = (weight * s).round().clamp(-1, 1) @@ -139,9 +178,8 @@ def forward_fp32_simulated(self, input): def forward(self, input): # return self.forward_fp32_simulated(input) - quant_input = self.activation_quant(input, self.input_bits).detach() - fp32_out = self.bitblas_matmul(quant_input, self.weight) + fp32_out = self.bitblas_matmul(quant_input, self.qweight) sw = self.sw Qp = self.Qp si = Qp / input.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5) @@ -154,25 +192,25 @@ def forward(self, input): return out -# # Naive BitLinear from HuggingFace -# class BitLinear(nn.Linear): +# Naive BitLinear from HuggingFace +class BitLinear(nn.Linear): -# def __init__(self, *kargs, weight_bits=1, input_bits=8, **kwargs): -# super(BitLinear, self).__init__(*kargs, **kwargs) -# """ -# RMSNorm is placed outside BitLinear -# """ -# self.weight_bits = weight_bits -# self.input_bits = input_bits + def __init__(self, *kargs, weight_bits=1, input_bits=8, **kwargs): + super(BitLinear, self).__init__(*kargs, **kwargs) + """ + RMSNorm is placed outside BitLinear + """ + self.weight_bits = weight_bits + self.input_bits = input_bits -# def forward(self, input): + def forward(self, input): -# quant_input = input + (activation_quant(input, self.input_bits) - input).detach() -# quant_weight = self.weight + (weight_quant(self.weight, self.weight_bits) - -# self.weight).detach() + quant_input = input + (activation_quant(input, self.input_bits) - input).detach() + quant_weight = self.weight + (weight_quant(self.weight, self.weight_bits) - + self.weight).detach() -# out = nn.functional.linear(quant_input, quant_weight) -# if not self.bias is None: -# out += self.bias.view(1, -1).expand_as(out) + out = nn.functional.linear(quant_input, quant_weight) + if self.bias is not None: + out += self.bias.view(1, -1).expand_as(out) -# return out + return out diff --git a/requirements-dev.txt b/requirements-dev.txt index 085de6a4f..99c101afb 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -3,7 +3,7 @@ yapf==0.32.0 toml==0.10.2 tomli==2.0.1 ruff==0.1.5 -codespell==2.2.6 +codespell==2.3.0 cffi cpplint