-
Notifications
You must be signed in to change notification settings - Fork 34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Cannot use uint2 x float16 #24
Comments
hi @xzyaoi , thanks for your attention, absolutely we support fp16xuint2, a simple way to show more error message is to set 2024-04-25 08:55:20 [BitBLAS:DEBUG]: Cannot find the appropriate index map for tensorcore
2024-04-25 08:55:23 [BitBLAS:DEBUG]: [BitBLAS][Error] applying rule <bitblas.gpu.gemv.GEMV object at 0x7f11d5c9e490> failed
2024-04-25 08:55:45 [BitBLAS:DEBUG]: Cannot find the appropriate index map for tensorcore
2024-04-25 08:55:45 [BitBLAS:DEBUG]: Apply config {'block': [1], 'thread': [1], 'rstep': [1024], 'reduce_thread': [128], 'vectorize': {'A': 8, 'B_decode': 8}}
2024-04-25 08:55:45 [BitBLAS:DEBUG]: Apply config {'block': [4], 'thread': [4], 'rstep': [1024], 'reduce_thread': [32], 'vectorize': {'A': 8, 'B_decode': 8}}
2024-04-25 08:55:45 [BitBLAS:DEBUG]: Apply config {'block': [2], 'thread': [2], 'rstep': [1024], 'reduce_thread': [64], 'vectorize': {'A': 8, 'B_decode': 8}}
2024-04-25 08:55:45 [BitBLAS:DEBUG]: Apply config {'block': [8], 'thread': [8], 'rstep': [1024], 'reduce_thread': [16], 'vectorize': {'A': 8, 'B_decode': 8}}
2024-04-25 08:55:45 [BitBLAS:DEBUG]: Apply schedule failed: 'i2_to_f16_scale_zeros_quantized'
2024-04-25 08:55:45 [BitBLAS:DEBUG]: Apply config {'block': [16], 'thread': [16], 'rstep': [512], 'reduce_thread': [8], 'vectorize': {'A': 4, 'B_decode': 8}}
2024-04-25 08:55:45 [BitBLAS:DEBUG]: Apply schedule failed: 'i2_to_f16_scale_zeros_quantized'
2024-04-25 08:55:45 [BitBLAS:DEBUG]: Apply config {'block': [32], 'thread': [32], 'rstep': [256], 'reduce_thread': [4], 'vectorize': {'A': 2, 'B_decode': 8}}
2024-04-25 08:55:45 [BitBLAS:DEBUG]: Apply schedule failed: 'i2_to_f16_scale_zeros_quantized'
2024-04-25 08:55:45 [BitBLAS:DEBUG]: Apply config {'block': [64], 'thread': [64], 'rstep': [128], 'reduce_thread': [2], 'vectorize': {'B_decode': 8}}
2024-04-25 08:55:45 [BitBLAS:DEBUG]: Apply schedule failed: 'i2_to_f16_scale_zeros_quantized'
2024-04-25 08:55:45 [BitBLAS:DEBUG]: Apply config {'block': [128], 'thread': [128], 'rstep': [128], 'vectorize': {'B_decode': 8}}
2024-04-25 08:55:45 [BitBLAS:DEBUG]: Apply schedule failed: 'i2_to_f16_scale_zeros_quantized'
2024-04-25 08:55:45 [BitBLAS:DEBUG]: Apply schedule failed: 'i2_to_f16_scale_zeros_quantized'
2024-04-25 08:55:45 [BitBLAS:DEBUG]: Warning: block config [128] is not valid for matmul, skip.
2024-04-25 08:55:45 [BitBLAS:DEBUG]: Warning: block config [128] is not valid for matmul, skip.
2024-04-25 08:55:45 [BitBLAS:DEBUG]: Apply schedule failed: 'i2_to_f16_scale_zeros_quantized' the quantized zero type (qzeros) currently only support for uint4 case, you can use another zero types (rescale/original) to do inference currently, we will support it soon. |
hi @xzyaoi I just made a pr to fix the issue, you can check out the latest code from main branch, I've another ddl so I do not provide any correctness test currently. |
Thanks so much! @LeiWang1999 I can compile it and seems working, however the result seems to be different from reference. I have the following test example: import bitblas
import torch
import os
os.environ['NUMEXPR_MAX_THREADS'] = "32"
M = 1
N = 1024
K = 1024
GROUP_SIZE = 128
matmul_config = bitblas.MatmulConfig(
M=M, # M dimension
N=N, # N dimension
K=K, # K dimension
A_dtype="float16", # activation A dtype
W_dtype="uint2", # weight W dtype
accum_dtype="float16", # accumulation dtype
out_dtype="float16", # output dtype
layout="nt", # matrix layout, "nt" indicates the layout of A is non-transpose and the layout of W is transpose
with_bias=False, # bias
# configs for weight only quantization
group_size=128, # setting for grouped quantization
with_scaling=True, # setting for scaling factor
with_zeros=True, # setting for zeros
zeros_mode="quantized", # setting for how to calculating zeros
)
matmul = bitblas.Matmul(config=matmul_config)
scaling_shape = (1024, 1024//128)
zeros_shape = (1024, 1024//128)
# 4bit uint: 0~15
# 2bit uint: 0~3
# Create input matrices
input_tensor = torch.rand((1, K), dtype=torch.float16).cuda()
weight_tensor = torch.randint(0, 3, (N, K), dtype=torch.int8).cuda()
scaling = torch.rand(scaling_shape, dtype=torch.float16).cuda()
zeros = torch.rand(zeros_shape, dtype=torch.float16).cuda()
# Transform weight tensor to int4 data type
transformed = matmul.transform_weight(weight_tensor, zeros=zeros)
weight_tensor_transformed = transformed[0]
zeros_transformed = transformed[1]
# Perform mixed-precision matrix multiplication
output_tensor = matmul(input_tensor, weight_tensor_transformed, scale=scaling, zeros=zeros_transformed)
# Reference result using PyTorch matmul for comparison
ref_result = torch.matmul(input_tensor, weight_tensor.t().to(torch.float16))
# Assert that the results are close within a specified tolerance, note that the int4 randint value is a little bigger than the float16 value, so we set the atol to 1.0
print("Ref output:", ref_result)
print("BitBLAS output:", output_tensor)
torch.testing.assert_close(output_tensor, ref_result, rtol=1e-2, atol=1e-0) Is there anything that I might be doing wrong? Sorry for it if this is a silly question :(. For 4 bits it works perfectly. |
The problem lies in zeros, as you use zero types In your case, I think you should use zero_type "rescaling" or "original"? we have docs to describe the difference:
|
relevant test can be found: https://github.com/microsoft/BitBLAS/blob/main/testing/python/operators/test_general_matmul_ops.py#L134 you should also consider the scale and zeros in your ref matmul. |
@LeiWang1999 Thanks for your prompt reply! I just realized that I mis-interpret import torch
import bitblas
import torch.nn as nn
from auto_gptq.nn_modules.qlinear.qlinear_cuda_old import (
QuantLinear as CudaOldQuantLinear,
)
import os
from bitblas import Matmul
os.environ["NUMEXPR_MAX_THREADS"] = "16"
group_size = 1024
in_features = 1024
out_features = 1024
bitwidth = 2
def gen_quant(bitwidth, k, n, groupsize=-1):
maxq = 2**bitwidth
w = torch.randn((k, n), dtype=torch.half, device="cpu")
original_w = w.clone()
if groupsize == -1:
groupsize = k
if groupsize != -1:
w = w.reshape((-1, groupsize, n))
w = w.permute(1, 0, 2)
w = w.reshape((groupsize, -1))
s = torch.max(torch.abs(w), 0, keepdim=True)[0]
s *= 2 / maxq
# Quantize.
w = torch.round(w / s).int()
# Unsigned storage.
w += (maxq) // 2
w = torch.clamp(w, 0, maxq)
# Dequantize.
ref = (w - (maxq) // 2).half() * s
if groupsize != -1:
def reshape(w):
w = w.reshape((groupsize, -1, n))
w = w.permute(1, 0, 2)
w = w.reshape((k, n)).contiguous()
return w
ref = reshape(ref)
w = reshape(w)
s = s.reshape((-1, n)).contiguous()
linear = nn.Linear(k, n, bias=False)
linear.weight.data = ref.t()
return original_w, linear, s, (w - (maxq) // 2)
cuda_old_linear = CudaOldQuantLinear(
bits=bitwidth,
group_size=group_size,
infeatures=in_features,
outfeatures=out_features,
bias=False,
)
original_w, linear, s, qw = gen_quant(
bitwidth, in_features, out_features, group_size
)
zeros = torch.full((in_features // group_size, out_features), 3, dtype=torch.int32)
cuda_old_linear.pack(linear, s.T, zeros.T, g_idx=None)
bitblas_linear = bitblas.Linear(
in_features=in_features,
out_features=out_features,
bias=False,
A_dtype="float16", # activation A dtype
W_dtype=f"uint{bitwidth}", # weight W dtype
accum_dtype="float16", # accumulation dtype
out_dtype="float16", # output dtype
# configs for weight only quantization
group_size=group_size, # setting for grouped quantization
with_scaling=True, # setting for scaling factor
with_zeros=True, # setting for zeros
zeros_mode="quantized", # setting for how to calculating zeros
)
bitblas_linear.repack_from_gptq(cuda_old_linear)
m = 1 # Batch size
matmul_config = bitblas.MatmulConfig(
M=m,
N=out_features,
K=in_features,
# fast_decoding=True,
A_dtype="float16",
W_dtype=f"uint{bitwidth}",
accum_dtype="float16",
out_dtype="float16",
layout="nt",
with_bias=False,
group_size=group_size,
with_scaling=True,
with_zeros=True,
zeros_mode="quantized",
)
# for some reason, I got segmentation fault when using bitblas_linear(inp)
matmul = Matmul(matmul_config)
inp = torch.rand(m, in_features, dtype=torch.float16, device="cuda")
cuda_old_linear = cuda_old_linear.to("cuda")
bitblas_linear = bitblas_linear.to("cuda")
with torch.no_grad():
res_cuda_old = cuda_old_linear(inp)
res_bitblas = matmul(inp, bitblas_linear.qweight, bitblas_linear.scales, bitblas_linear.zeros)
print(f"CudaOldQuantLinear output: {res_cuda_old}")
print(f"BitBLAS output: {res_bitblas}")
torch.testing.assert_close(res_bitblas, res_cuda_old, rtol=1, atol=10) and the result is
Also for some reason I got segmentation fault when running the autogptq example in the QuickStart guide, so I created a new matmul operator here. The weird part is that most results are correct, but some values are not even close. I tried a real gptq-compressed model, and witnessed the same behaviour. Is there anything I might be doing wrong? Thanks for any pointers! |
oh I just noticed an mis-configured number in zeros (has updated the above code snippet), and once I set it correctly as |
Sorry for the long stall in this issue, I finally got a chance back to this and created a full reproduce script. It still fails my test case when I am using the latest version on PyPI (0.0.1dev5). Here's my test script: import os
import torch
import logging
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
from transformers import AutoTokenizer, TextGenerationPipeline
import safetensors as st
from safetensors.torch import save_file
from auto_gptq.nn_modules.qlinear.qlinear_cuda_old import (
QuantLinear as CudaOldQuantLinear,
)
import bitblas
from bitblas import Matmul
logging.basicConfig(
format="%(asctime)s %(levelname)s [%(name)s] %(message)s", level=logging.INFO, datefmt="%Y-%m-%d %H:%M:%S"
)
bitblas.set_log_level("DEBUG")
bitwidth = 2
pretrained_model_dir = "facebook/opt-125m"
quantized_model_dir = f".local/opt-125m-{bitwidth}bit"
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_dir, use_fast=True)
examples = [
tokenizer(
"auto-gptq is an easy-to-use model quantization library with user-friendly apis, based on GPTQ algorithm."
)
]
quantize_config = BaseQuantizeConfig(
bits=bitwidth, # quantize model to 2-bit
group_size=128, # it is recommended to set the value to 128,
sym=False,
)
model = AutoGPTQForCausalLM.from_pretrained(pretrained_model_dir, quantize_config)
model.quantize(examples)
# save quantized model using safetensors
model.save_quantized(quantized_model_dir, use_safetensors=True)
module_name = "model.decoder.layers.9.self_attn.q_proj"
with st.safe_open(os.path.join(quantized_model_dir, f"gptq_model-{bitwidth}bit-128g.safetensors"), "pt", device="cuda") as f:
keys = f.keys()
tensors = {key: f.get_tensor(key) for key in keys if module_name in key}
infeatures = 768
outfeatures = 768
cuda_old_linear = CudaOldQuantLinear(
bits=bitwidth,
group_size=128,
infeatures=infeatures,
outfeatures=outfeatures,
bias=False,
)
cuda_old_linear.qweight = tensors[f"{module_name}.qweight"]
cuda_old_linear.qzeros = tensors[f"{module_name}.qzeros"]
cuda_old_linear.scales = tensors[f"{module_name}.scales"]
cuda_old_linear.g_idx = tensors[f"{module_name}.g_idx"]
bitblas_linear = bitblas.Linear(
in_features=infeatures,
out_features=outfeatures,
bias=False,
A_dtype="float16",
W_dtype=f"uint{bitwidth}",
accum_dtype="float16",
out_dtype="float16",
group_size=128,
with_scaling=True,
with_zeros=True,
zeros_mode="quantized",
)
matmul_config = bitblas.MatmulConfig(
M=1,
N=outfeatures,
K=infeatures,
fast_decoding=True,
A_dtype="float16",
W_dtype=f"uint{bitwidth}",
accum_dtype="float16",
out_dtype="float16",
layout="nt",
with_bias=False,
group_size=128,
with_scaling=True,
with_zeros=True,
zeros_mode="quantized",
)
matmul = Matmul(matmul_config)
bitblas_linear.repack_from_gptq(cuda_old_linear)
print("repack done")
tensors = {
"qweight": bitblas_linear.qweight,
"scales": bitblas_linear.scales,
"qzeros": bitblas_linear.zeros,
}
save_file(tensors, os.path.join(quantized_model_dir,"bitblas.safetensors"))
with st.safe_open(os.path.join(quantized_model_dir, "bitblas.safetensors"), "pt", device="cuda") as f:
keys = f.keys()
tensors = {key: f.get_tensor(key) for key in keys}
bitblas_linear.qweight = tensors["qweight"]
bitblas_linear.scales = tensors["scales"]
bitblas_linear.zeros = tensors["qzeros"]
print("BitBLAS quantized weight: ", bitblas_linear.qweight.shape)
print("BitBLAS quantized scales: ", bitblas_linear.scales.shape)
print("BitBLAS quantized zeros: ", bitblas_linear.zeros.shape)
inp = torch.rand(1, infeatures, dtype=torch.float16, device="cuda")
cuda_old_linear = cuda_old_linear.to("cuda")
res_cuda_old = cuda_old_linear(inp)
print(f"CudaOldQuantLinear output: {res_cuda_old}")
res_bitblas = matmul(
inp,
bitblas_linear.qweight,
scale=bitblas_linear.scales,
zeros=bitblas_linear.zeros
)
print(f"BitBLAS output: {res_bitblas}")
torch.testing.assert_close(res_bitblas, res_cuda_old, rtol=1, atol=1) It basically compresses a model using auto-gptq, and then calls BitBLAS for a matmul. It works fine when I set
I am using a 3090, torch==2.1.2+cu121. Is there anything that I might go wrong? Thank you very much in advance for your support again! cc @LeiWang1999 |
Hi!
Thanks for this great project! I am trying to use it in my project where I have the model weights compressed to 2 bits, similar to AutoGPTQ approach.
However, when I was trying to create a uint2 matmul kernel, as below:
I got the following error:
Is it a known issue, or just unsupported yet? Any pointers on how to solve this issue would be appreciated!
Thanks again
Tasks
The text was updated successfully, but these errors were encountered: