Skip to content

Commit

Permalink
BUGFIX: UINT8/INT8 Decoding
Browse files Browse the repository at this point in the history
  • Loading branch information
LeiWang199 committed Jun 2, 2024
1 parent 9122ff7 commit b508acc
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 19 deletions.
1 change: 1 addition & 0 deletions python/bitblas/base/roller/hint.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np
from .rasterization import *


class TensorCoreExtraConfig:
"""
This class is used to store extra information for tensorcore
Expand Down
13 changes: 10 additions & 3 deletions python/bitblas/gpu/intrin/lop3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1483,7 +1483,11 @@ def fast_decode_impl(
TensorIntrin.register(
LOP3_FAST_DECODE_INT2_TO_INT8_TO_INT8_L16_INTRIN,
*get_fast_decode_intrin(
source_bit=2, source_format="int", storage_dtype="int8", target_dtype="int8", loops_extent=16),
source_bit=2,
source_format="int",
storage_dtype="int8",
target_dtype="int8",
loops_extent=16),
)

LOP3_FAST_DECODE_UINT1_TO_INT8_TO_INT8_L16_INTRIN = ("lop3_fast_decode_u1_to_int8_to_i8_l16_")
Expand All @@ -1497,10 +1501,13 @@ def fast_decode_impl(
TensorIntrin.register(
LOP3_FAST_DECODE_INT1_TO_INT8_TO_INT8_L16_INTRIN,
*get_fast_decode_intrin(
source_bit=1, source_format="int", storage_dtype="int8", target_dtype="int8", loops_extent=16),
source_bit=1,
source_format="int",
storage_dtype="int8",
target_dtype="int8",
loops_extent=16),
)


LOP3_FAST_DECODE_INT4_TO_INT8_TO_FP16_L8_INTRIN = ("lop3_fast_decode_i4_to_int8_to_f16_l8_")
TensorIntrin.register(
LOP3_FAST_DECODE_INT4_TO_INT8_TO_FP16_L8_INTRIN,
Expand Down
13 changes: 11 additions & 2 deletions python/bitblas/ops/general_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,18 @@ def __initialize_zeros_mode(self, zeros_mode: Optional[str]):
object.__setattr__(self, "zeros_mode", "original")

def __initialize_fast_decoding(self, fast_decoding: Optional[bool]):

def is_not_fast_decoding_supported():
conditions = []
conditions.append("int" not in self.W_dtype)
conditions.append(self.W_dtype == self.A_dtype)
# int8,uint8 also do not implement and also do not require fast decoding
conditions.append(self.W_dtype in ["int8", "uint8"])
return any(conditions)

if fast_decoding is not None:
object.__setattr__(self, "fast_decoding", fast_decoding)
elif ("int" not in self.W_dtype or self.W_dtype == self.A_dtype):
elif is_not_fast_decoding_supported():
object.__setattr__(self, "fast_decoding", False)
else:
object.__setattr__(self, "fast_decoding", True)
Expand Down Expand Up @@ -450,7 +459,7 @@ def transform_weight(self, weight, scale=None, zeros=None, bias=None):
source_format, bit = self.source_format, self.bit

# Process integer source format
if source_format == "int":
if source_format == "int" and bit < 8:
assert not self.with_scaling, "scale should be False for int source format"
assert not self.with_zeros, "zeros should be False for int source format"
maxq = 2**(bit - 1)
Expand Down
52 changes: 38 additions & 14 deletions python/bitblas/ops/impl/matmul_dequantize_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def matmul_nt_dequantize_b(
with_bias=False,
zeros_mode="original",
):
assert bit in [1, 2, 4, 8], "Unsupported bit: {}".format(bit)
if not isinstance(M, int):
M = tvm.te.var("m")

Expand Down Expand Up @@ -78,13 +79,20 @@ def decode_func(n, k):
dtype=in_dtype,
)
elif source_format == "uint":
w = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)(
bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype)
if bit == 8:
# 8 bit does not need to be compressed
w = B[n, k].astype(in_dtype)
else:
w = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)(
bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype)
elif source_format == "int":
if bit == 1:
# Dequantize int1 to -1 and 1. Without this step, the values would be 0 and 1, identical to uint1.
w = _tir_packed_int_to_int_convert(storage_type, storage_nbit)(
bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype)
elif bit == 8:
# 8 bit does not need to be compressed
w = B[n, k].astype(in_dtype)
else:
w = _tir_packed_to_signed_convert(storage_type, storage_nbit)(
bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype)
Expand Down Expand Up @@ -187,6 +195,7 @@ def matmul_nt_dequantize_b_propagate_b(
zeros_mode="original",
transform_kind: TransformKind = TransformKind.IntraWarpTransform,
):
assert bit in [1, 2, 4, 8], "Unsupported bit: {}".format(bit)
if not isinstance(M, int):
M = tvm.te.var("m")

Expand Down Expand Up @@ -241,17 +250,24 @@ def fcompute(i, j):

def decode_func(n, k):
if source_format == "uint":
w = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)(
bit,
B_reindex[n, k // n_float_per_elem],
k % n_float_per_elem,
dtype=in_dtype,
)
if bit == 8:
# 8 bit does not need to be compressed
w = B_reindex[n, k].astype(in_dtype)
else:
w = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)(
bit,
B_reindex[n, k // n_float_per_elem],
k % n_float_per_elem,
dtype=in_dtype,
)
elif source_format == "int":
if bit == 1:
# Dequantize int1 to -1 and 1. Without this step, the values would be 0 and 1, identical to uint1.
w = _tir_packed_int_to_int_convert(storage_type, storage_nbit)(
bit, B_reindex[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype)
elif bit == 8:
# 8 bit does not need to be compressed
w = B_reindex[n, k].astype(in_dtype)
else:
w = _tir_packed_to_signed_convert(storage_type, storage_nbit)(
bit,
Expand Down Expand Up @@ -360,6 +376,7 @@ def matmul_nt_dequantize_b_propagate_a_propagate_b(
transform_kind_input: TransformKind = TransformKind.IntraWarpTransform,
transform_kind_weight: TransformKind = TransformKind.IntraWarpTransform,
):
assert bit in [1, 2, 4, 8], "Unsupported bit: {}".format(bit)
if not isinstance(M, int):
M = tvm.te.var("m")

Expand Down Expand Up @@ -429,17 +446,24 @@ def fcompute(i, j):

def decode_func(n, k):
if source_format == "uint":
w = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)(
bit,
B_reindex[n, k // n_float_per_elem],
k % n_float_per_elem,
dtype=in_dtype,
)
if bit == 8:
# 8 bit does not need to be compressed
w = B_reindex[n, k].astype(in_dtype)
else:
w = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)(
bit,
B_reindex[n, k // n_float_per_elem],
k % n_float_per_elem,
dtype=in_dtype,
)
elif source_format == "int":
# Dequantize int1 to -1 and 1. Without this step, the values would be 0 and 1, identical to uint1.
if bit == 1:
w = _tir_packed_int_to_int_convert(storage_type, storage_nbit)(
bit, B_reindex[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype)
elif bit == 8:
# 8 bit does not need to be compressed
w = B_reindex[n, k].astype(in_dtype)
else:
w = _tir_packed_to_signed_convert(storage_type, storage_nbit)(
bit,
Expand Down

0 comments on commit b508acc

Please sign in to comment.