Skip to content

Commit

Permalink
[Dev] Fix a bug in general matmul ops with zero (#79)
Browse files Browse the repository at this point in the history
* Refactor BatchMatMulEmitter and BatchMatMulSelector for improved readability and maintainability

* Refactor import statements for improved readability and maintainability

* Refactor import statements for improved readability and maintainability

* disable failure email for ci

* remove email notifications.

* move relax pass from testing to mlc_llm

* Refactor scripts with se check_eual_ref_scripts_with_emitter function

* Lint Fix

* Refactor scripts with se check_eual_ref_scripts_with_emitter function

* bug fix in test
  • Loading branch information
LeiWang1999 authored Jul 6, 2024
1 parent 49c41a4 commit 541839b
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 228 deletions.
207 changes: 3 additions & 204 deletions bitblas/ops/impl/matmul_dequantize_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,204 +15,6 @@
_tir_packed_to_unsigned_convert_with_zeros,
)

# TODO: The following code should be refactored.
class MatMulNTDequantizeEmitter:
def __init__(
self,
M,
N,
K,
in_dtype="float16",
out_dtype="float16",
accum_dtype="float16",
bit=4,
storage_dtype="int8",
source_format="uint",
with_scaling=False,
with_zeros=False,
group_size=-1,
fast_decoding=False,
with_bias=False,
zeros_mode="original",
propagate_a: TransformKind = TransformKind.NonTransform,
propagate_b: TransformKind = TransformKind.NonTransform,
):
self.M = self._validate_dimension(M, "M")
self.N = N
self.K = K
self.in_dtype = in_dtype
self.out_dtype = out_dtype
self.accum_dtype = accum_dtype
self.bit = bit
self.storage_dtype = storage_dtype
self.source_format = source_format
self.with_scaling = with_scaling
self.with_zeros = with_zeros
self.group_size = group_size if group_size != -1 else K
self.fast_decoding = fast_decoding
self.with_bias = with_bias
self.zeros_mode = zeros_mode
self.propagate_a = propagate_a
self.propagate_b = propagate_b

self._validate_bit()
self._validate_layout()

@staticmethod
def _validate_dimension(dim, name):
if not isinstance(dim, int):
return tvm.te.var(name.lower())
return dim

def _validate_bit(self):
if self.bit not in [1, 2, 4, 8]:
raise ValueError(f"Unsupported bit: {self.bit}")

def _validate_layout(self):
if self.layout not in ["nt"]:
raise ValueError(f"Unsupported layout: {self.layout}")

def _create_placeholders(self):
storage_nbit = int("".join(c for c in self.storage_dtype if c.isdigit()))
n_float_per_elem = storage_nbit // self.bit

A = te.placeholder((self.M, self.K), name="A", dtype=self.in_dtype)
B = te.placeholder((self.N, self.K // storage_nbit * self.bit), name="B", dtype=self.storage_dtype)
LUT = te.placeholder((1 << self.bit,), name="LUT", dtype=self.in_dtype)
Scale = te.placeholder((self.N, self.K // self.group_size), name="Scale", dtype=self.in_dtype)
Zeros = te.placeholder((self.N, self.K // self.group_size), name="Zeros", dtype=self.in_dtype)
QZeros = te.placeholder(((self.K // self.group_size), self.N // storage_nbit * self.bit),
name="QZeros",
dtype=self.storage_dtype)
Bias = te.placeholder((self.N,), name="Bias", dtype=self.in_dtype)
return A, B, LUT, Scale, Zeros, QZeros, Bias, storage_nbit, n_float_per_elem

def _decode_func(self, B, LUT, Scale, Zeros, QZeros, storage_nbit, n_float_per_elem):
w = None
def decode(n, k):
if self.with_zeros and self.zeros_mode == "quantized":
qzeros_dequantize = _tir_packed_to_unsigned_convert(self.storage_dtype, storage_nbit)(
self.bit,
QZeros[k, n // n_float_per_elem],
n % n_float_per_elem,
dtype=self.storage_dtype,
)
w = _tir_packed_to_unsigned_convert_with_zeros(self.storage_dtype, storage_nbit)(
self.bit,
B[n, k // n_float_per_elem],
k % n_float_per_elem,
qzeros_dequantize,
dtype=self.in_dtype,
)
elif self.source_format == "uint":
if self.bit == 8:
w = B[n, k].astype(self.in_dtype)
w = _tir_packed_to_unsigned_convert(self.storage_dtype, storage_nbit)(
self.bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=self.in_dtype)
elif self.source_format == "int":
if self.bit == 1:
w = _tir_packed_int_to_int_convert(self.storage_dtype, storage_nbit)(
self.bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=self.in_dtype)
if self.bit == 8:
w = B[n, k].astype(self.in_dtype)
w = _tir_packed_to_signed_convert(self.storage_dtype, storage_nbit)(
self.bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=self.in_dtype)
elif self.source_format == "fp":
w = _tir_u32_to_f4_to_f16(
self.bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=self.in_dtype)
elif self.source_format == "fp_e4m3":
w = _tir_u8_to_f8_e4m3_to_f16(self.bit, B[n, k], dtype=self.in_dtype)
elif self.source_format == "nf":
index = _tir_packed_to_unsigned_convert(self.storage_dtype, storage_nbit)(
self.bit,
B[n, k // n_float_per_elem],
k % n_float_per_elem,
dtype="int32",
)
w = LUT[index]
else:
raise ValueError(f"Unsupported source_format: {self.source_format}")

group_size = self.group_size
zeros_mode = self.zeros_mode

if not self.with_scaling:
return w

if not self.with_zeros:
return w * Scale[n, k // group_size]

if zeros_mode == "original":
w = (w - Zeros[n, k // group_size]) * Scale[n, k // group_size]
elif zeros_mode == "rescale":
w = w * Scale[n, k // group_size] - Zeros[n, k // group_size]
elif zeros_mode == "quantized":
w = w * Scale[n, k // group_size]
else:
raise ValueError("Unsupported zeros_mode: {}".format(zeros_mode))

return w

return te.compute((self.N, self.K), decode, name="B_decode")

def _compute_matmul(self, A, B_decode):
k = te.reduce_axis((0, self.K), name="k")
C = te.compute(
(self.M, self.N),
lambda i, j: te.sum(
A[i, k].astype(self.accum_dtype) * B_decode[j, k].astype(self.accum_dtype), axis=k),
name="C",
)
return C

def _convert_dtype(self, tensor):
if self.accum_dtype != self.out_dtype:
return te.compute((self.M, self.N), lambda i, j: tensor[i, j].astype(self.out_dtype), name="D")
return tensor

def _apply_bias(self, tensor, Bias):
if self.with_bias:
return te.compute((self.M, self.N), lambda i, j: tensor[i, j] + Bias[j], name="E")
return tensor

def emit(self):
A, B, LUT, Scale, Zeros, QZeros, Bias, storage_nbit, n_float_per_elem = self._create_placeholders()
B_decode = self._decode_func(B, LUT, Scale, Zeros, QZeros, storage_nbit, n_float_per_elem)
C = self._compute_matmul(A, B_decode)
D = self._convert_dtype(C)
last_output = self._apply_bias(D, Bias)

args = [A, B]
if self.source_format == "nf":
args.append(LUT)
if self.with_scaling:
args.append(Scale)
if self.with_zeros:
args.append(QZeros if self.zeros_mode == "quantized" else Zeros)
if self.with_bias:
args.append(Bias)
args.append(last_output)

func = te.create_prim_func(args).with_attr(
"dequantize_info",
{
"B_decode": {
"decode_block": "B_decode",
"fast_decoding": self.fast_decoding,
"source_format": {
"bits": self.bit,
"format": self.source_format,
},
"storage_dtype": self.storage_dtype,
"target_format": self.in_dtype,
"with_zeros": self.with_zeros,
"zeros_mode": self.zeros_mode,
"with_scaling": self.with_scaling,
"group_size": self.group_size,
}
},
)
return tvm.IRModule.from_expr(func)

# TODO: The following code should be refactored.
class MatMulNTDequantizeEmitter:
Expand Down Expand Up @@ -671,8 +473,7 @@ def decode_func(n, k):
else:
args.append(Zeros)
if with_bias:
E = te.compute((M, N), lambda i, j: D[i, j] + Bias[j], name="E")
last_output = E
last_output = te.compute((M, N), lambda i, j: last_output[i, j] + Bias[j], name="E")
args.append(Bias)
args.append(last_output)

Expand Down Expand Up @@ -852,8 +653,7 @@ def decode_func(n, k):
if with_zeros:
args.append(Zeros)
if with_bias:
E = te.compute((M, N), lambda i, j: D[i, j] + Bias[j], name="E")
last_output = E
last_output = te.compute((M, N), lambda i, j: last_output[i, j] + Bias[j], name="E")
args.append(Bias)
args.append(last_output)

Expand Down Expand Up @@ -1052,8 +852,7 @@ def decode_func(n, k):
if with_zeros:
args.append(Zeros)
if with_bias:
E = te.compute((M, N), lambda i, j: D[i, j] + Bias[j], name="E")
last_output = E
last_output = te.compute((M, N), lambda i, j: last_output[i, j] + Bias[j], name="E")
args.append(Bias)
args.append(last_output)

Expand Down
41 changes: 18 additions & 23 deletions testing/python/module/test_bitblas_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,7 @@
torch.manual_seed(0)
bitblas.set_log_level("DEBUG")

@pytest.mark.parametrize(
"m, in_features, out_features, bias",
[
(1, 1024, 1024, False),
(1, 1024, 1024, True),
(1024, 1024, 1024, True),
([1, 1024], 1024, 1024, True),
],
)
def test_correctness_consistent(m, in_features, out_features, bias):
def correctness_consistent(m, in_features, out_features, bias):
linear_torch = (nn.Linear(in_features, out_features, bias=bias).to(torch.float16).cuda())
linear_bitblas = BitBLASLinear(
in_features,
Expand Down Expand Up @@ -48,19 +39,13 @@ def test_correctness_consistent(m, in_features, out_features, bias):
torch.testing.assert_close(output_torch, output_bitblas, rtol=1e-1, atol=1e-2)


@pytest.mark.parametrize(
"m, in_features, out_features, bias, W_dtype, group_size, with_scaling, with_zeros, zeros_mode",
[
(1, 1024, 1024, False, "uint4", -1, False, False, None),
(1, 1024, 1024, False, "uint4", -1, False, False, None),
(1024, 1024, 1024, True, "uint4", -1, False, False, None),
(1, 1024, 1024, True, "uint2", -1, True, False, None),
(1, 1024, 1024, True, "uint2", 128, True, True, "original"),
(1024, 1024, 1024, True, "uint2", 128, True, True, "original"),
(1, 1024, 1024, True, "uint2", 128, True, True, "rescale"),
],
)
def test_correctness_weight_only_dequantize(
def test_correctness_consistent():
correctness_consistent(1, 1024, 1024, False)
correctness_consistent(1, 1024, 1024, True)
correctness_consistent(1024, 1024, 1024, True)
correctness_consistent([1, 1024], 1024, 1024, True)

def correctness_weight_only_dequantize(
m,
in_features,
out_features,
Expand Down Expand Up @@ -169,6 +154,16 @@ def test_correctness_weight_only_dequantize(
torch.testing.assert_close(output_bitblas, ref_result, rtol=1e0, atol=1e0)


def test_correctness_weight_only_dequantize():
correctness_weight_only_dequantize(1, 1024, 1024, False, "uint4", -1, False, False, None)
correctness_weight_only_dequantize(1, 1024, 1024, False, "uint4", -1, False, False, None)
correctness_weight_only_dequantize(1024, 1024, 1024, True, "uint4", -1, False, False, None)
correctness_weight_only_dequantize(1, 1024, 1024, True, "uint2", -1, True, False, None)
correctness_weight_only_dequantize(1, 1024, 1024, True, "uint2", 128, True, True, "original")
correctness_weight_only_dequantize(1024, 1024, 1024, True, "uint2", 128, True, True, "original")
correctness_weight_only_dequantize(1, 1024, 1024, True, "uint2", 128, True, True, "rescale")


def profile(model, input_data):
model = model.cuda()
model.eval()
Expand Down
2 changes: 1 addition & 1 deletion testing/python/operators/test_general_matmul_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def matmul_torch_forward(M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, layo
if with_bias:
permuted_inputs.append(bias)
permuted_inputs.append(inputs[2])
matmul(*permuted_inputs[:2], output=permuted_inputs[-1])
matmul(*permuted_inputs[:-1], output=permuted_inputs[-1])
if zeros_mode == "rescale":
torch.testing.assert_close(permuted_inputs[-1], ref_result, rtol=1e2, atol=1e0)
else:
Expand Down

0 comments on commit 541839b

Please sign in to comment.