Skip to content

Commit

Permalink
[Dev] FIx a bug of repack AutoGPTQ quantized parameters (#57)
Browse files Browse the repository at this point in the history
* [Dev] Issue#24: FIx a bug of repack AutoGPTQ quantized parameters

* Remove .test as test files can be put into ./debug/

---------

Co-authored-by: Lei Wang <[email protected]>
  • Loading branch information
tzj-fxz and LeiWang1999 authored Jun 15, 2024
1 parent c090df6 commit d589a79
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 6 deletions.
14 changes: 8 additions & 6 deletions python/bitblas/module/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,13 @@ def unpack_qzeros(qzeros, bits):
device=qzeros.device,
requires_grad=False,
)

for col in range(unpacked_zeros.shape[1]):
i = col % elems_per_int32
unpacked_zeros[:, col] = (qzeros[:, col // elems_per_int32] >> (bits * i)) & 0xF
unpacked_zeros[:, col] = (qzeros[:, col // elems_per_int32] >> (bits * i))

return unpacked_zeros + 1
# Follow the instruction in AutoGPTQ qlinear_cuda_old.py line 303
# NOTE: It appears that casting after the `unpacked_zeros + 1` is important.
return torch.bitwise_and(unpacked_zeros + 1, 2**bits - 1)


class Linear(nn.Module):
Expand Down Expand Up @@ -232,18 +233,19 @@ def forward(self, A, output=None):
A = A.half()
# can be lifted to post init.
self.init_params()

if output is None:
output = torch.empty(
A.shape[:-1] + (self.out_features,), dtype=A.dtype, device=A.device)
m = ctypes.c_int32(reduce(operator.mul, A.shape[:-1], 1))
A = self.bitblas_matmul.transform_input(A)
stream = torch.cuda.current_stream()

A_void = ctypes.c_void_p(A.data_ptr())
stream_handle = ctypes.c_void_p(stream.cuda_stream)
# m is the product of the last n - 1 dimensions of A
self.bitblas_matmul.lib.call(A_void, *self.q_params, ctypes.c_void_p(output.data_ptr()), m, stream_handle)
self.bitblas_matmul.lib.call(A_void, *self.q_params, ctypes.c_void_p(output.data_ptr()), m,
stream_handle)

return output

Expand Down
120 changes: 120 additions & 0 deletions testing/python/operators/test_matmul_dequantize_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,7 @@ def test_matmul_dequantize_torch_forward(M, N, K, in_dtype, out_dtype, accum_dty
if with_scaling:
if group_size == -1:
group_size = K
# Note that scaling is default to all 1
permuted_inputs.append(torch.ones([N, K // group_size], dtype=torch.float16).cuda())
if with_zeros:
permuted_inputs.append(torch.ones([N, K // group_size], dtype=torch.float16).cuda() * zeros)
Expand All @@ -515,6 +516,125 @@ def test_matmul_dequantize_torch_forward(M, N, K, in_dtype, out_dtype, accum_dty
torch.testing.assert_close(permuted_inputs[-1], ref_result, rtol=1e-0, atol=1e-1)


@pytest.mark.parametrize(
"M,N,K,in_dtype,out_dtype,accum_dtype,bit,storage_dtype,source_format,with_scaling,with_zeros,group_size,fast_decoding,with_bias,propagate_a,propagate_b,layout,zeros_mode",
[
(1, 768, 768, "float16", "float16", "float16", 2, "int8", "uint", True, False, 128, False,
False, False, False, "nt", "quantized"),
(1, 768, 768, "float16", "float16", "float16", 4, "int8", "uint", True, True, 128, False,
False, False, False, "nt", "quantized"),
],
)
def test_matmul_dequantize_torch_forward_with_asym_quantized_zeros(M, N, K, in_dtype, out_dtype, accum_dtype, bit,
storage_dtype, source_format, with_scaling, with_zeros,
group_size, fast_decoding, with_bias, propagate_a,
propagate_b, layout, zeros_mode):
import torch
import numpy as np
torch.random.manual_seed(0)
from bitblas.quantization.utils import general_compress
matmul_config = MatmulWeightOnlyDequantizeConfig(
M=M,
N=N,
K=K,
in_dtype=in_dtype,
out_dtype=out_dtype,
accum_dtype=accum_dtype,
bit=bit,
storage_dtype=storage_dtype,
source_format=source_format,
with_scaling=with_scaling,
with_zeros=with_zeros,
group_size=group_size,
fast_decoding=fast_decoding,
with_bias=with_bias,
propagate_a=propagate_a,
propagate_b=propagate_b,
layout=layout,
zeros_mode=zeros_mode)
matmul = MatmulWeightOnlyDequantize(
config=matmul_config,
target=target,
)
if not isinstance(M, int):
M = int(32)
# matmul.hardware_aware_finetune(topk=20)
input_shape = (M, K)
weight_shape = (N, K) if layout == "nt" else (K, N)
output_shape = (M, N)
scaling_shape = (N, K // group_size)
zeros_shape = (K // group_size, N)

input_A = torch.rand(input_shape, dtype=torch.float16).cuda() - 0.5
max_quantization = 2 ** (bit - 1)
scaling_matrix = torch.rand(scaling_shape, dtype=torch.float16).cuda()
zeros_matrix = torch.randint(0, max_quantization, zeros_shape, dtype=torch.int8).cuda()
bias = torch.rand((output_shape[-1],), dtype=torch.float16).cuda()

if source_format == "uint":
input_W = torch.randint(0, max_quantization, weight_shape, dtype=torch.int8).cuda()
elif source_format == "int":
input_W = torch.randint(-max_quantization, max_quantization, weight_shape, dtype=torch.int8).cuda()
else:
raise NotImplementedError

# Now begin bitblas matmul
input_W_int = input_W.cpu().numpy().astype(np.int8)
if source_format == "int":
input_W_int = input_W_int + max_quantization
qw_np = general_compress(input_W_int, source_bits=bit, storage_dtype=np.int8)
qw_torch = torch.from_numpy(qw_np).cuda()

permuted_inputs = []
# input and weight
if matmul.input_transform is not None:
permuted_inputs.append(matmul.input_transform(input_A.cpu()).cuda())
else:
permuted_inputs.append(input_A)
if matmul.weight_transform is not None:
permuted_inputs.append(matmul.weight_transform(qw_torch.cpu()).cuda())
else:
permuted_inputs.append(qw_torch)
# scale
if with_scaling:
if group_size == -1:
group_size = K
permuted_inputs.append(scaling_matrix)
# zeros
if with_zeros:
if zeros_mode == "quantized":
original_zeros = zeros_matrix
qzeros = general_compress(
original_zeros.cpu().numpy(), source_bits=bit, storage_dtype=np.int8)
permuted_inputs.append(torch.from_numpy(qzeros).cuda())
else:
raise NotImplementedError
# bias
if with_bias:
permuted_inputs.append(bias)
# output
permuted_inputs.append(torch.rand(output_shape, dtype=torch.float16).cuda())
matmul(*permuted_inputs)
bitblas_result = permuted_inputs[-1]

# Now begin torch matmul
if with_scaling and with_zeros and zeros_mode == "quantized":
rescaling_tensor = torch.zeros_like(input_W, dtype=torch.float16).cuda()
for i in range(K // group_size):
for j in range(group_size):
rescaling_tensor[:, i * group_size + j] = (
input_W[:, i * group_size + j].to(torch.float16) - zeros_matrix[i, :]
) * scaling_matrix[:, i]
elif with_scaling:
rescaling_tensor = torch.zeros_like(input_W, dtype=torch.float16).cuda()
for i in range(K // group_size):
for j in range(group_size):
rescaling_tensor[:, i * group_size + j] = input_W[:, i * group_size + j].to(torch.float16) * scaling_matrix[:, i]
ref_result = torch.matmul(input_A, rescaling_tensor.t().to(torch.float16))

torch.testing.assert_close(bitblas_result, ref_result, rtol=1e-1, atol=1e-1)


@pytest.mark.parametrize(
"M,N,K,in_dtype,out_dtype,accum_dtype,bit,storage_dtype,source_format,with_scaling,with_zeros,group_size,fast_decoding,with_bias,layout,zeros_mode",
[
Expand Down

0 comments on commit d589a79

Please sign in to comment.