Skip to content

Commit

Permalink
[Dev] Refactor the ops script implementation with SE (#78)
Browse files Browse the repository at this point in the history
* Refactor BatchMatMulEmitter and BatchMatMulSelector for improved readability and maintainability

* Refactor import statements for improved readability and maintainability

* Refactor import statements for improved readability and maintainability

* disable failure email for ci

* remove email notifications.

* move relax pass from testing to mlc_llm

* Refactor scripts with se check_eual_ref_scripts_with_emitter function

* Lint Fix
  • Loading branch information
LeiWang1999 authored Jul 6, 2024
1 parent bf41fc4 commit 49c41a4
Show file tree
Hide file tree
Showing 2 changed files with 388 additions and 18 deletions.
339 changes: 333 additions & 6 deletions bitblas/ops/impl/matmul_dequantize_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,326 @@ def emit(self):
)
return tvm.IRModule.from_expr(func)

# TODO: The following code should be refactored.
class MatMulNTDequantizeEmitter:

def __init__(
self,
M,
N,
K,
in_dtype="float16",
out_dtype="float16",
accum_dtype="float16",
bit=4,
storage_dtype="int8",
source_format="uint",
with_scaling=False,
with_zeros=False,
group_size=-1,
fast_decoding=False,
with_bias=False,
zeros_mode="original",
propagate_a: TransformKind = TransformKind.NonTransform,
propagate_b: TransformKind = TransformKind.NonTransform,
):
self.M = self._validate_dimension(M, "M")
self.N = N
self.K = K
self.in_dtype = in_dtype
self.out_dtype = out_dtype
self.accum_dtype = accum_dtype
self.bit = bit
self.storage_dtype = storage_dtype
self.source_format = source_format
self.with_scaling = with_scaling
self.with_zeros = with_zeros
self.group_size = group_size if group_size != -1 else K
self.fast_decoding = fast_decoding
self.with_bias = with_bias
self.zeros_mode = zeros_mode
self.propagate_a = self._legalize_transform_kind(propagate_a)
self.propagate_b = self._legalize_transform_kind(propagate_b)

self._validate_bit()
self._validate_layout()

@staticmethod
def _validate_dimension(dim, name):
if not isinstance(dim, int):
return tvm.te.var(name.lower())
return dim

def _validate_bit(self):
if self.bit not in [1, 2, 4, 8]:
raise ValueError(f"Unsupported bit: {self.bit}")

def _validate_layout(self):
# TODO: extend the dequantize operators into General Layout
pass

def _legalize_group_size(self):
if self.group_size == -1:
self.group_size = self.K

def _legalize_transform_kind(self, propagate):
if propagate is None:
return TransformKind.NonTransform
if isinstance(propagate, bool):
return (TransformKind.IntraWarpTransform if propagate else TransformKind.NonTransform)
elif isinstance(propagate, int):
return TransformKind(propagate)

def _create_placeholders(self):
storage_dtype = self.storage_dtype
storage_nbit = int("".join(c for c in storage_dtype if c.isdigit()))
in_dtype = self.in_dtype
bit = self.bit
l = r = 16 # noqa: E741
if in_dtype in ["int8", "e4m3_float8", "e5m2_float8"]:
l, r = 16, 32 # noqa: E741

A = te.placeholder((self.M, self.K), name="A", dtype=in_dtype)
B = te.placeholder((self.N, self.K // storage_nbit * bit), name="B", dtype=storage_dtype)
if self.propagate_a:
A = te.placeholder((self.M // l, self.K // r, l, r), name="A", dtype=in_dtype)
if self.propagate_b:
target_dtype = DataType(in_dtype)
scaling_factor = 1
if bit > 0 and bit < target_dtype.bits:
scaling_factor = ((target_dtype.bits // bit) * DataType(storage_dtype).bits //
target_dtype.bits)
qr = r * bit // storage_nbit
B = te.placeholder((self.N // l, (self.K // scaling_factor) // qr, l, qr),
name="B",
dtype=storage_dtype)

LUT = te.placeholder((1 << bit,), name="LUT", dtype=in_dtype)
Scale = te.placeholder((self.N, self.K // self.group_size), name="Scale", dtype=in_dtype)
Zeros = te.placeholder((self.N, self.K // self.group_size), name="Zeros", dtype=in_dtype)
QZeros = te.placeholder(((self.K // self.group_size), self.N // storage_nbit * bit),
name="QZeros",
dtype=self.storage_dtype)
Bias = te.placeholder((self.N,), name="Bias", dtype=in_dtype)
return A, B, LUT, Scale, Zeros, QZeros, Bias

def _propagate_input(self, tensor, transform_kind=TransformKind.NonTransform, matrix_name="A"):
if transform_kind == TransformKind.NonTransform:
return tensor
in_dtype = self.in_dtype
l = r = 16 # noqa: E741
if in_dtype in ["int8", "e4m3_float8", "e5m2_float8"]:
l, r = 16, 32 # noqa: E741
_, inversed_index_map = get_propagate_map(
trans=False, dtype=in_dtype, matrix_name=matrix_name)

def fcompute(i, j):
warp_i, warp_j = i % l, j % r
spatial_args = i // l, j // r
if transform_kind >= TransformKind.IntraWarpTransform:
warp_i, warp_j = inversed_index_map.map_indices([warp_i, warp_j])
new_index = (*spatial_args, warp_i, warp_j)
return tensor[new_index]

return te.compute(
(self.M, self.K),
fcompute,
name=f"{matrix_name}_reindex",
)

def _propagage_weight(self, tensor, transform_kind=TransformKind.NonTransform, matrix_name="B"):
if transform_kind == TransformKind.NonTransform:
return tensor
in_dtype = self.in_dtype
bit = self.bit
storage_dtype = self.storage_dtype
storage_nbit = int("".join(c for c in self.storage_dtype if c.isdigit()))

l = r = 16 # noqa: E741
if in_dtype in ["int8", "e4m3_float8", "e5m2_float8"]:
l, r = 16, 32 # noqa: E741
_, inversed_index_map = get_propagate_map(
trans=True, dtype=in_dtype, matrix_name=matrix_name)
target_dtype = DataType(in_dtype)
scaling_factor = 1
if bit > 0 and bit < target_dtype.bits:
scaling_factor = ((target_dtype.bits // bit) * DataType(storage_dtype).bits //
target_dtype.bits)
initial_indices = inversed_index_map.initial_indices
scaling_final_indices = inversed_index_map.map_indices(
initial_indices[:-1] + [initial_indices[-1] * scaling_factor])
scaling_final_indices = scaling_final_indices[:-1] + [
scaling_final_indices[-1] // scaling_factor
]
inversed_index_map = IndexMap(
initial_indices,
scaling_final_indices,
None,
)

qr = r * bit // storage_nbit

def fcompute(i, j):
warp_i, warp_j = i % l, j % qr
spatial_args = i // l, j // qr
if transform_kind >= TransformKind.IntraWarpTransform:
warp_i, warp_j = inversed_index_map.map_indices([warp_i, warp_j])
new_index = (*spatial_args, warp_i, warp_j)
return tensor[new_index]

return te.compute(
(self.N, self.K // storage_nbit * bit),
fcompute,
name=f"{matrix_name}_reindex",
)

def _decode_func(self, B, LUT, Scale, Zeros, QZeros):
bit = self.bit
in_dtype = self.in_dtype
storage_dtype = self.storage_dtype
storage_nbit = int("".join(c for c in storage_dtype if c.isdigit()))
storage_type = str("".join(c for c in storage_dtype if not c.isdigit()))
n_float_per_elem = storage_nbit // bit

# TODO: Move the decode function into a more general place
def decode(n, k):
w = None
if self.with_zeros and self.zeros_mode == "quantized":
qzeros_dequantize = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)(
bit,
QZeros[k, n // n_float_per_elem],
n % n_float_per_elem,
dtype=self.storage_dtype,
)
w = _tir_packed_to_unsigned_convert_with_zeros(storage_type, storage_nbit)(
bit,
B[n, k // n_float_per_elem],
k % n_float_per_elem,
qzeros_dequantize,
dtype=in_dtype,
)
elif self.source_format == "uint":
if bit == 8:
w = B[n, k].astype(in_dtype)
w = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)(
bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype)
elif self.source_format == "int":
if bit == 1:
w = _tir_packed_int_to_int_convert(storage_type, storage_nbit)(
bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype)
if bit == 8:
w = B[n, k].astype(in_dtype)
w = _tir_packed_to_signed_convert(storage_type, storage_nbit)(
bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype)
elif self.source_format == "fp":
w = _tir_u32_to_f4_to_f16(
bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype)
elif self.source_format == "fp_e4m3":
w = _tir_u8_to_f8_e4m3_to_f16(bit, B[n, k], dtype=in_dtype)
elif self.source_format == "nf":
index = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)(
bit,
B[n, k // n_float_per_elem],
k % n_float_per_elem,
dtype="int32",
)
w = LUT[index]
else:
raise ValueError(f"Unsupported source_format: {self.source_format}")

assert w is not None, "w is None"

group_size = self.group_size
zeros_mode = self.zeros_mode

if not self.with_scaling:
return w

if not self.with_zeros:
return w * Scale[n, k // group_size]

if zeros_mode == "original":
w = (w - Zeros[n, k // group_size]) * Scale[n, k // group_size]
elif zeros_mode == "rescale":
w = w * Scale[n, k // group_size] - Zeros[n, k // group_size]
elif zeros_mode == "quantized":
w = w * Scale[n, k // group_size]
else:
raise ValueError("Unsupported zeros_mode: {}".format(zeros_mode))

return w

return te.compute((self.N, self.K), decode, name="B_decode")

def _compute_matmul(self, A, B_decode):
k = te.reduce_axis((0, self.K), name="k")
C = te.compute(
(self.M, self.N),
lambda i, j: te.sum(
A[i, k].astype(self.accum_dtype) * B_decode[j, k].astype(self.accum_dtype), axis=k),
name="C",
)
return C

def _convert_dtype(self, tensor):
if self.accum_dtype != self.out_dtype:
return te.compute((self.M, self.N),
lambda i, j: tensor[i, j].astype(self.out_dtype),
name="D")
return tensor

def _apply_bias(self, tensor, Bias):
if self.with_bias:
return te.compute((self.M, self.N), lambda i, j: tensor[i, j] + Bias[j], name="E")
return tensor

def emit(self):
A, B, LUT, Scale, Zeros, QZeros, Bias = self._create_placeholders()
A_reindex = self._propagate_input(A, self.propagate_a, "A")
B_reindex = self._propagage_weight(B, self.propagate_b, "B")

B_decode = self._decode_func(B_reindex, LUT, Scale, Zeros, QZeros)
C = self._compute_matmul(A_reindex, B_decode)
D = self._convert_dtype(C)
last_output = self._apply_bias(D, Bias)

args = [A, B]
if self.source_format == "nf":
args.append(LUT)
if self.with_scaling:
args.append(Scale)
if self.with_zeros:
args.append(QZeros if self.zeros_mode == "quantized" else Zeros)
if self.with_bias:
args.append(Bias)
args.append(last_output)

func = te.create_prim_func(args).with_attr(
"dequantize_info",
{
"B_decode": {
"decode_block": "B_decode",
"fast_decoding": self.fast_decoding,
"source_format": {
"bits": self.bit,
"format": self.source_format,
},
"storage_dtype": self.storage_dtype,
"target_format": self.in_dtype,
"with_zeros": self.with_zeros,
"zeros_mode": self.zeros_mode,
"with_scaling": self.with_scaling,
"group_size": self.group_size,
}
},
)
if self.propagate_a:
func = func.with_attr("input_transform_kind", self.propagate_a.value)
if self.propagate_b:
func = func.with_attr("weight_transform_kind", self.propagate_b.value)
return tvm.IRModule.from_expr(func)


def matmul_nt_dequantize_b(
M,
N,
Expand Down Expand Up @@ -335,9 +655,12 @@ def decode_func(n, k):
A[i, k].astype(accum_dtype) * B_decode[j, k].astype(accum_dtype), axis=k),
name="C",
)
D = te.compute((M, N), lambda i, j: C[i, j].astype(out_dtype), name="D")

last_output = C
if accum_dtype != out_dtype:
D = te.compute((M, N), lambda i, j: C[i, j].astype(out_dtype), name="D")
last_output = D
args = [A, B]
last_output = D
if source_format == "nf":
args.append(LUT)
if with_scaling:
Expand Down Expand Up @@ -517,9 +840,11 @@ def decode_func(n, k):
A[i, k].astype(accum_dtype) * B_decode[j, k].astype(accum_dtype), axis=k),
name="C",
)
D = te.compute((M, N), lambda i, j: C[i, j].astype(out_dtype), name="D")
last_output = C
if accum_dtype != out_dtype:
D = te.compute((M, N), lambda i, j: C[i, j].astype(out_dtype), name="D")
last_output = D
args = [A, B]
last_output = D
if source_format == "nf":
args.append(LUT)
if with_scaling:
Expand Down Expand Up @@ -715,9 +1040,11 @@ def decode_func(n, k):
),
name="C",
)
D = te.compute((M, N), lambda i, j: C[i, j].astype(out_dtype), name="D")
last_output = C
if accum_dtype != out_dtype:
D = te.compute((M, N), lambda i, j: C[i, j].astype(out_dtype), name="D")
last_output = D
args = [A, B]
last_output = D
if source_format == "nf":
args.append(LUT)
if with_scaling:
Expand Down
Loading

0 comments on commit 49c41a4

Please sign in to comment.