diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 1fbdf19dd..511b95833 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -65,12 +65,3 @@ jobs: source bitblas_ci/bin/activate cd testing/python python -m pytest - - # Control notifications - notify: - runs-on: self-hosted - needs: [format-check, build-test] - if: failure() - steps: - - name: Notification - run: echo "Jobs failed, but no email will be sent." diff --git a/bitblas/ops/impl/matmul_dequantize_impl.py b/bitblas/ops/impl/matmul_dequantize_impl.py index 417eaacfc..1ed6b3404 100644 --- a/bitblas/ops/impl/matmul_dequantize_impl.py +++ b/bitblas/ops/impl/matmul_dequantize_impl.py @@ -15,6 +15,204 @@ _tir_packed_to_unsigned_convert_with_zeros, ) +# TODO: The following code should be refactored. +class MatMulNTDequantizeEmitter: + def __init__( + self, + M, + N, + K, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + bit=4, + storage_dtype="int8", + source_format="uint", + with_scaling=False, + with_zeros=False, + group_size=-1, + fast_decoding=False, + with_bias=False, + zeros_mode="original", + propagate_a: TransformKind = TransformKind.NonTransform, + propagate_b: TransformKind = TransformKind.NonTransform, + ): + self.M = self._validate_dimension(M, "M") + self.N = N + self.K = K + self.in_dtype = in_dtype + self.out_dtype = out_dtype + self.accum_dtype = accum_dtype + self.bit = bit + self.storage_dtype = storage_dtype + self.source_format = source_format + self.with_scaling = with_scaling + self.with_zeros = with_zeros + self.group_size = group_size if group_size != -1 else K + self.fast_decoding = fast_decoding + self.with_bias = with_bias + self.zeros_mode = zeros_mode + self.propagate_a = propagate_a + self.propagate_b = propagate_b + + self._validate_bit() + self._validate_layout() + + @staticmethod + def _validate_dimension(dim, name): + if not isinstance(dim, int): + return tvm.te.var(name.lower()) + return dim + + def _validate_bit(self): + if self.bit not in [1, 2, 4, 8]: + raise ValueError(f"Unsupported bit: {self.bit}") + + def _validate_layout(self): + if self.layout not in ["nt"]: + raise ValueError(f"Unsupported layout: {self.layout}") + + def _create_placeholders(self): + storage_nbit = int("".join(c for c in self.storage_dtype if c.isdigit())) + n_float_per_elem = storage_nbit // self.bit + + A = te.placeholder((self.M, self.K), name="A", dtype=self.in_dtype) + B = te.placeholder((self.N, self.K // storage_nbit * self.bit), name="B", dtype=self.storage_dtype) + LUT = te.placeholder((1 << self.bit,), name="LUT", dtype=self.in_dtype) + Scale = te.placeholder((self.N, self.K // self.group_size), name="Scale", dtype=self.in_dtype) + Zeros = te.placeholder((self.N, self.K // self.group_size), name="Zeros", dtype=self.in_dtype) + QZeros = te.placeholder(((self.K // self.group_size), self.N // storage_nbit * self.bit), + name="QZeros", + dtype=self.storage_dtype) + Bias = te.placeholder((self.N,), name="Bias", dtype=self.in_dtype) + return A, B, LUT, Scale, Zeros, QZeros, Bias, storage_nbit, n_float_per_elem + + def _decode_func(self, B, LUT, Scale, Zeros, QZeros, storage_nbit, n_float_per_elem): + w = None + def decode(n, k): + if self.with_zeros and self.zeros_mode == "quantized": + qzeros_dequantize = _tir_packed_to_unsigned_convert(self.storage_dtype, storage_nbit)( + self.bit, + QZeros[k, n // n_float_per_elem], + n % n_float_per_elem, + dtype=self.storage_dtype, + ) + w = _tir_packed_to_unsigned_convert_with_zeros(self.storage_dtype, storage_nbit)( + self.bit, + B[n, k // n_float_per_elem], + k % n_float_per_elem, + qzeros_dequantize, + dtype=self.in_dtype, + ) + elif self.source_format == "uint": + if self.bit == 8: + w = B[n, k].astype(self.in_dtype) + w = _tir_packed_to_unsigned_convert(self.storage_dtype, storage_nbit)( + self.bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=self.in_dtype) + elif self.source_format == "int": + if self.bit == 1: + w = _tir_packed_int_to_int_convert(self.storage_dtype, storage_nbit)( + self.bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=self.in_dtype) + if self.bit == 8: + w = B[n, k].astype(self.in_dtype) + w = _tir_packed_to_signed_convert(self.storage_dtype, storage_nbit)( + self.bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=self.in_dtype) + elif self.source_format == "fp": + w = _tir_u32_to_f4_to_f16( + self.bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=self.in_dtype) + elif self.source_format == "fp_e4m3": + w = _tir_u8_to_f8_e4m3_to_f16(self.bit, B[n, k], dtype=self.in_dtype) + elif self.source_format == "nf": + index = _tir_packed_to_unsigned_convert(self.storage_dtype, storage_nbit)( + self.bit, + B[n, k // n_float_per_elem], + k % n_float_per_elem, + dtype="int32", + ) + w = LUT[index] + else: + raise ValueError(f"Unsupported source_format: {self.source_format}") + + group_size = self.group_size + zeros_mode = self.zeros_mode + + if not self.with_scaling: + return w + + if not self.with_zeros: + return w * Scale[n, k // group_size] + + if zeros_mode == "original": + w = (w - Zeros[n, k // group_size]) * Scale[n, k // group_size] + elif zeros_mode == "rescale": + w = w * Scale[n, k // group_size] - Zeros[n, k // group_size] + elif zeros_mode == "quantized": + w = w * Scale[n, k // group_size] + else: + raise ValueError("Unsupported zeros_mode: {}".format(zeros_mode)) + + return w + + return te.compute((self.N, self.K), decode, name="B_decode") + + def _compute_matmul(self, A, B_decode): + k = te.reduce_axis((0, self.K), name="k") + C = te.compute( + (self.M, self.N), + lambda i, j: te.sum( + A[i, k].astype(self.accum_dtype) * B_decode[j, k].astype(self.accum_dtype), axis=k), + name="C", + ) + return C + + def _convert_dtype(self, tensor): + if self.accum_dtype != self.out_dtype: + return te.compute((self.M, self.N), lambda i, j: tensor[i, j].astype(self.out_dtype), name="D") + return tensor + + def _apply_bias(self, tensor, Bias): + if self.with_bias: + return te.compute((self.M, self.N), lambda i, j: tensor[i, j] + Bias[j], name="E") + return tensor + + def emit(self): + A, B, LUT, Scale, Zeros, QZeros, Bias, storage_nbit, n_float_per_elem = self._create_placeholders() + B_decode = self._decode_func(B, LUT, Scale, Zeros, QZeros, storage_nbit, n_float_per_elem) + C = self._compute_matmul(A, B_decode) + D = self._convert_dtype(C) + last_output = self._apply_bias(D, Bias) + + args = [A, B] + if self.source_format == "nf": + args.append(LUT) + if self.with_scaling: + args.append(Scale) + if self.with_zeros: + args.append(QZeros if self.zeros_mode == "quantized" else Zeros) + if self.with_bias: + args.append(Bias) + args.append(last_output) + + func = te.create_prim_func(args).with_attr( + "dequantize_info", + { + "B_decode": { + "decode_block": "B_decode", + "fast_decoding": self.fast_decoding, + "source_format": { + "bits": self.bit, + "format": self.source_format, + }, + "storage_dtype": self.storage_dtype, + "target_format": self.in_dtype, + "with_zeros": self.with_zeros, + "zeros_mode": self.zeros_mode, + "with_scaling": self.with_scaling, + "group_size": self.group_size, + } + }, + ) + return tvm.IRModule.from_expr(func) def matmul_nt_dequantize_b( M, diff --git a/testing/python/transform/test_weight_only_transform.py b/integration/mlc_llm/test_weight_only_transform.py similarity index 100% rename from testing/python/transform/test_weight_only_transform.py rename to integration/mlc_llm/test_weight_only_transform.py diff --git a/testing/python/operators/test_tir_script_emitter.py b/testing/python/operators/test_tir_script_emitter.py new file mode 100644 index 000000000..cec56b473 --- /dev/null +++ b/testing/python/operators/test_tir_script_emitter.py @@ -0,0 +1,68 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from bitblas.ops.impl.matmul_dequantize_impl import ( + MatMulNTDequantizeEmitter, + matmul_nt_dequantize_b, + matmul_nt_dequantize_b_propagate_b, + matmul_nt_dequantize_b_propagate_a_propagate_b, +) +from bitblas import tvm +import logging +from bitblas import set_log_level + +set_log_level(logging.DEBUG) + +def compare_tir_scripts_and_emitter( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + bit, + storage_dtype, + source_format, + with_scaling, + with_zeros, + group_size, + fast_decoding, + with_bias, + zeros_mode, +): + tir_script_func = matmul_nt_dequantize_b( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + bit, + storage_dtype, + source_format, + with_scaling, + with_zeros, + group_size, + fast_decoding, + with_bias, + zeros_mode, + ) + + emitter_func = MatMulNTDequantizeEmitter( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + bit, + storage_dtype, + source_format, + with_scaling, + with_zeros, + group_size, + fast_decoding, + with_bias, + zeros_mode, + ).emit() + + tvm.ir.assert_structural_equal(tir_script_func, emitter_func)