Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FP8] Support FP8 MatrixCore Code gen and related test #29

Merged
merged 6 commits into from
Apr 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/tvm
Submodule tvm updated 933 files
13 changes: 13 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,17 @@ Some of the key features of BitBLAS include:
- BitBLAS first implemented $W_{INT2}A_{INT8}$ GEMV/GEMM in [BitNet-b1.58](https://arxiv.org/abs/2402.17764) with 8x/2x speedup over cuBLAS $W_{FP16}A_{FP16}$ on A100, please checkout [op_benchmark_a100_int2_scaling](https://github.com/microsoft/BitBLAS/blob/main/images/figures/op_benchmark_a100_int2_scaling.png) for detailed benchmark results. Please checkout [BitNet-b1.58 integration](https://github.com/microsoft/BitBLAS/blob/main/integration/BitNet) for the integration with the 3rdparty reproduced BitNet-b1.58 model.
- Support customizing mixed-precision DNN operations for your specific scenarios via the flexible DSL (TIR Script).

## Latest News

- 2024.04.19: BitBLAS is now open source! We are excited to announce that BitBLAS, a high-performance library for mixed-precision DNN model deployment, is now available to the public.
- 2024.04.30: BitBLAS now support

## Integration Example of FasterTransformer with BitBLAS
![FasterTransformer Integration](images/gif/FasterTransformer.gif)

## Benchmark Summary


## Integration Example of FasterTransformer with BitBLAS
![FasterTransformer Integration](images/gif/FasterTransformer.gif)

Expand Down Expand Up @@ -63,6 +74,8 @@ For more detailed information on benchmark sets with other formats (NF4/FP4) and
| INT8 | UINT4/INT4 | INT32 | FP32/INT32/FP16/INT8 | **√** | V100(SM_70)/A100(SM_80)/A6000(SM_86)/RTX 4090(SM_89) |
| INT8 | UINT2/INT2 | INT32 | FP32/INT32/FP16/INT8 | **√** | V100(SM_70)/A100(SM_80)/A6000(SM_86)/RTX 4090(SM_89) |
| INT8 | UINT1 | INT32 | FP32/INT32/FP16/INT8 | **√** | V100(SM_70)/A100(SM_80)/A6000(SM_86)/RTX 4090(SM_89) |
| FP8_E4M3 | FP8_E4M3 | FP32 | FP32/FP16 | **√** | RTX 4090(SM_89) |
| FP8_E5M2 | FP8_E5M2 | FP32 | FP32/FP16 | **√** | RTX 4090(SM_89) |

We are continuously expanding the support matrix. If you have any specific requirements, please feel free to open an issue or PR.

Expand Down
17 changes: 14 additions & 3 deletions python/bitblas/base/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,16 +135,27 @@ def var_wrapper(v):
else:
raise ValueError("Not supported type: ", type(func))

def map_numpy_type(intype):
typemap = {
'e4m3_float8': 'float8_e4m3fn',
'e5m2_float8': 'float8_e5m2',
}
if intype in typemap:
return typemap[intype]
else:
return intype

numpy_dtype = map_numpy_type(arg.dtype)
if distribution == "uniform":
profile_tensors.append(
tvm.nd.array(
np.random.rand(*[var_wrapper(i) for i in arg.shape]).astype(arg.dtype),
np.random.rand(*[var_wrapper(i) for i in arg.shape]).astype(numpy_dtype),
device=device,
))
elif distribution == "onefill":
profile_tensors.append(
tvm.nd.array(
np.ones([var_wrapper(i) for i in arg.shape]).astype(arg.dtype),
np.ones([var_wrapper(i) for i in arg.shape]).astype(numpy_dtype),
device=device,
))
else:
Expand Down Expand Up @@ -245,7 +256,7 @@ def tvm_callback_cuda_postproc(code, _):
try:
latency = cpresult.profile()
except Exception as e_mesg:
logger.debug("Evaluation with config failed: ", e_mesg)
logger.debug(f"Evaluation with config failed {e_mesg}")
continue
logger.info("Evaluation with config {}".format(config))
logger.info("Time cost of this config: {:.3f} ms".format(latency))
Expand Down
7 changes: 3 additions & 4 deletions python/bitblas/gpu/gemv.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,9 @@ def get_extent(sch: tir.Schedule, loop_rv: tir.schedule.LoopRV):


def get_bytes(dtype: Union[DataType, str]) -> int:
num = re.findall(r"\d+", dtype)
if len(num) != 1:
raise ValueError(f"Cannot get bytes from {dtype}")
return int(num[0]) // 8
if isinstance(dtype, str):
dtype = DataType(dtype)
return int(dtype.bits) // 8


def is_gemv(sch: tir.Schedule, block_info: BlockInfo) -> Optional[List[tir.Buffer]]:
Expand Down
21 changes: 15 additions & 6 deletions python/bitblas/gpu/matmul_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,7 @@ def get_tensorized_func_and_tags(
allow_gemv: bool = False,
) -> Tuple[tir.PrimFunc, Dict[str, Union[List[int], int]]]:
from tvm.tir.tensor_intrin.cuda import ( # pylint: disable=import-outside-toplevel
get_wmma_intrin_group,)
get_mma_intrin_group,)
"""
transform function to matmul if necessary (e.g. transform conv2d with im2col)
"""
Expand Down Expand Up @@ -607,14 +607,18 @@ def check_last_trait(region: List[Range]):

block_stmt = sch.get(main_block)
if target.kind.name == "cuda" and check_sm_version(target.arch) >= 70:
# TODO(lei): we should consider the dtype of the input a and b
# instead of assuming both a and b share the same dtype.
# As the tensorcore may supports e4m3_float8 * e5m2_float8
in_dtype, out_dtype = get_in_out_dtypes(block_stmt)
try:
_ = get_wmma_intrin_group(
in_dtype=in_dtype,
_ = get_mma_intrin_group(
a_dtype=in_dtype,
b_dtype=in_dtype,
out_dtype=out_dtype,
)
except Exception:
logger.debug("Cannot find the corresponding wmma intrin group")
logger.debug("Cannot find the corresponding mma intrin group")
return func, None

# reindex and transform functions
Expand Down Expand Up @@ -651,11 +655,16 @@ def get_propagate_map(trans: bool = True, dtype="float16", matrix_name="A", inde
ldmatrix_32x16_to_shared_16x32_layout_a, ldmatrix_32x16_to_shared_16x32_layout_b,
)

assert dtype in ["float16", "int8"], "Only support float16 for now"
assert dtype in [
"float16",
"int8",
"e4m3_float8",
"e5m2_float8",
], "Only support float16, int8, e4m3_float8, e5m2_float8"
if dtype == "float16":
ldmatrix_layout = ldmatrix_32x8_to_shared_16x16_layout
ldmatrix_layout_trans = ldmatrix_trans_32x8_to_shared_16x16_layout
elif dtype == "int8":
elif dtype in ["int8", "e4m3_float8", "e5m2_float8"]:
# int8 mma only support 32x16 to 16x32 layout
if matrix_name == "A" and trans is False:
ldmatrix_layout = ldmatrix_32x16_to_shared_16x32_layout_a
Expand Down
6 changes: 4 additions & 2 deletions python/bitblas/gpu/matmul_mma.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,8 @@ def store_output(block_outer, write_buffer_idx):
intrin_group = get_mma_intrin_group(
load_scope="shared.dyn",
store_scope="shared.dyn",
in_dtype=str(dtype_a),
a_dtype=str(dtype_a),
b_dtype=str(dtype_b),
out_dtype=str(dtype_c),
trans_a=is_transpose_a,
trans_b=is_transpose_b,
Expand Down Expand Up @@ -396,7 +397,8 @@ def check_has_dynamic(func: tir.PrimFunc):
intrin_group = get_mma_intrin_group(
load_scope=shared_scope,
store_scope=shared_scope if cache_write_required else "global",
in_dtype=intrin_info.in_dtype,
a_dtype=intrin_info.in_dtype,
b_dtype=intrin_info.in_dtype,
out_dtype=intrin_info.out_dtype,
trans_a=intrin_info.trans_a,
trans_b=intrin_info.trans_b,
Expand Down
9 changes: 6 additions & 3 deletions python/bitblas/gpu/matmul_mma_dequantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,8 @@ def check_weight_decode_info(weight_decode_info):
intrin_group = get_mma_intrin_group(
load_scope=shared_scope,
store_scope=shared_scope if cache_write_required else "global",
in_dtype=intrin_info.in_dtype,
a_dtype=intrin_info.in_dtype,
b_dtype=intrin_info.in_dtype,
out_dtype=intrin_info.out_dtype,
trans_a=intrin_info.trans_a,
trans_b=intrin_info.trans_b,
Expand Down Expand Up @@ -654,7 +655,8 @@ def check_weight_decode_info(weight_decode_info):
intrin_group = get_mma_intrin_group(
load_scope=shared_scope,
store_scope=shared_scope if cache_write_required else "global",
in_dtype=intrin_info.in_dtype,
a_dtype=intrin_info.in_dtype,
b_dtype=intrin_info.in_dtype,
out_dtype=intrin_info.out_dtype,
trans_a=intrin_info.trans_a,
trans_b=intrin_info.trans_b,
Expand Down Expand Up @@ -1143,7 +1145,8 @@ def check_weight_decode_info(weight_decode_info):
intrin_group = get_mma_intrin_group(
load_scope=shared_scope,
store_scope=shared_scope if cache_write_required else "global",
in_dtype=intrin_info.in_dtype,
a_dtype=intrin_info.in_dtype,
b_dtype=intrin_info.in_dtype,
out_dtype=intrin_info.out_dtype,
trans_a=intrin_info.trans_a,
trans_b=intrin_info.trans_b,
Expand Down
36 changes: 32 additions & 4 deletions python/bitblas/ops/general_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,24 @@

WORKSPACE_SIZE = 1024 * 1024 * 256

# TODO(lei): This should be improved into a general
# Method to get the consistent compute patterns.
NATIVE_COMPUTE_PATTERNS = [
# A_dtype, W_dtype
("float64", "float64"),
("float32", "float32"),
("float16", "float16"),
("int8", "int8"),
("e4m3_float8", "e4m3_float8"),
("e4m3_float8", "e5m2_float8"),
("e5m2_float8", "e4m3_float8"),
("e5m2_float8", "e5m2_float8"),
]


def is_native_compute(A_dtype, W_dtype) -> bool:
return (A_dtype, W_dtype) in NATIVE_COMPUTE_PATTERNS


class OPExecutorCPU:

Expand Down Expand Up @@ -150,8 +168,15 @@ def __post_init__(self):
if self.with_zeros is None:
object.__setattr__(self, "with_zeros", False)

if self.A_dtype == self.W_dtype and self.W_dtype in ["float16", "int8"]:
if self.A_dtype == self.W_dtype and self.W_dtype in [
"float16", "int8", "e4m3_float8", "e5m2_float8"
]:
object.__setattr__(self, "storage_dtype", self.W_dtype)
# TODO(lei): This is a limitation arose by pytorch
# Should be removed in the future.
if self.A_dtype in ["e4m3_float8", "e5m2_float8"]:
object.__setattr__(self, "propagate_a", TransformKind.NonTransform)
object.__setattr__(self, "propagate_b", TransformKind.NonTransform)


class Matmul(Operator):
Expand All @@ -176,6 +201,8 @@ class Matmul(Operator):
"nf4": ("nf", 4),
"fp8_e5m2": ("fp", 8),
"fp4_e2m1": ("fp", 4),
"e4m3_float8": ("fp", 8), # "e4m3_float8" is a trick for "float8_e4m3fn"
"e5m2_float8": ("fp", 8),
}

def __init__(
Expand Down Expand Up @@ -316,7 +343,7 @@ def _build_default_module(self, target: Target):
self._build_runtime_module(target)

def _select_implementation(self):
if self.A_dtype == self.W_dtype:
if is_native_compute(self.A_dtype, self.W_dtype):
return consistent_implementation(
M=self.M,
N=self.N,
Expand Down Expand Up @@ -446,8 +473,9 @@ def forward(self, A, W, scale=None, zeros=None, bias=None, output=None) -> Any:
args.append(bias)
args.append(output)

m = reduce(operator.mul, A.shape[:-1], 1)
args.append(m)
if self.dynamic_range is not None:
m = reduce(operator.mul, A.shape[:-1], 1)
args.append(m)

if self.lib is None:
self._forward_from_torch_func(*args)
Expand Down
4 changes: 2 additions & 2 deletions python/bitblas/ops/impl/ladder_permutate_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
def select_implementation(
M: int,
N: int,
datatype: Literal["float16", "int8"] = "float16",
datatype: Literal["float16", "int8", "e4m3_float8", "e5m2_float8"] = "float16",
dequantize_bits: int = -1,
storage_dtype: Literal["float16", "int8", "uint8", "int32", "uint32"] = "float16",
propagate_kind: Literal["A", "B"] = "B",
Expand All @@ -23,7 +23,7 @@ def select_implementation(
# This is trick to get the basic tile size for the current datatype
# as for nvidia tensorcore instruction, the basic tile size is 16x16/16x32 for float16/int8
l = r = 16 # noqa: E741
if datatype == "int8":
if datatype in ["int8", "e4m3_float8", "e5m2_float8"]:
l, r = 16, 32 # noqa: E741
intra_index_map, _ = get_propagate_map(
transpose_matrix, dtype=datatype, matrix_name=propagate_kind)
Expand Down
6 changes: 2 additions & 4 deletions python/bitblas/ops/impl/matmul_dequantize_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,6 @@ def decode_func(n, k):
else:
raise ValueError("Unsupported source_format: {}".format(source_format))



if not with_scaling:
return w

Expand Down Expand Up @@ -187,7 +185,7 @@ def matmul_nt_dequantize_b_propagate_b(
M = tvm.te.var("m")

l = r = 16 # noqa: E741
if in_dtype == "int8":
if in_dtype in ["int8", "e4m3_float8", "e5m2_float8"]:
l, r = 16, 32 # noqa: E741

_, inverse_indexmap = get_propagate_map(trans=True, dtype=in_dtype, matrix_name="B")
Expand Down Expand Up @@ -358,7 +356,7 @@ def matmul_nt_dequantize_b_propagate_a_propagate_b(
M = tvm.te.var("m")

l = r = 16 # noqa: E741
if in_dtype == "int8":
if in_dtype in ["int8", "e4m3_float8", "e5m2_float8"]:
l, r = 16, 32 # noqa: E741
_, inversed_index_map = get_propagate_map(trans=False, dtype=in_dtype, matrix_name="A")
A = te.placeholder((M // l, K // r, l, r), name="A", dtype=in_dtype)
Expand Down
6 changes: 3 additions & 3 deletions python/bitblas/ops/impl/matmul_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def matmul_nt_propagate_a(
if not isinstance(M, int):
M = tvm.te.var("m")
l = r = 16 # noqa: E741
if in_dtype == "int8":
if in_dtype in ["int8", "e4m3_float8", "e5m2_float8"]:
l, r = 16, 32 # noqa: E741

_, inversed_index_map = get_propagate_map(trans=False, dtype=in_dtype, matrix_name="A")
Expand Down Expand Up @@ -171,7 +171,7 @@ def matmul_nt_propagate_b(
if not isinstance(M, int):
M = tvm.te.var("m")
l = r = 16 # noqa: E741
if in_dtype == "int8":
if in_dtype in ["int8", "e4m3_float8", "e5m2_float8"]:
l, r = 16, 32 # noqa: E741

_, inversed_index_map = get_propagate_map(trans=True, dtype=in_dtype, matrix_name="B")
Expand Down Expand Up @@ -232,7 +232,7 @@ def matmul_nt_propagate_a_propagate_b(
if not isinstance(M, int):
M = tvm.te.var("m")
l = r = 16 # noqa: E741
if in_dtype == "int8":
if in_dtype in ["int8", "e4m3_float8", "e5m2_float8"]:
l, r = 16, 32 # noqa: E741

A = te.placeholder((M // l, K // r, l, r), name="A", dtype=in_dtype)
Expand Down
2 changes: 1 addition & 1 deletion python/bitblas/ops/impl/param_permutate_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def select_implementation(
# This is trick to get the basic tile size for the current datatype
# as for nvidia tensorcore instruction, the basic tile size is 16x16/16x32 for float16/int8
l = r = 16 # noqa: E741
if datatype == "int8":
if datatype in ["int8", "e4m3_float8", "e5m2_float8"]:
l, r = 16, 32 # noqa: E741
if group_size == -1:
group_size = N
Expand Down
14 changes: 13 additions & 1 deletion python/bitblas/ops/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,15 +220,27 @@ def var_warpper(v):
else:
raise RuntimeError("Not supported type: ", type(v))

def map_numpy_type(intype):
typemap = {
'e4m3_float8': 'float8_e4m3fn',
'e5m2_float8': 'float8_e5m2',
}
if intype in typemap:
return typemap[intype]
else:
return intype

profile_tensors = []
for param in func.params:
if param not in func.buffer_map:
# in case of dynamic symbolic may in params
continue
arg = func.buffer_map[param]
numpy_dtype = map_numpy_type(arg.dtype)
profile_tensors.append(
tvm.nd.array(
np.random.uniform(0, 1, [var_warpper(i) for i in arg.shape]).astype(arg.dtype),
np.random.uniform(0, 1,
[var_warpper(i) for i in arg.shape]).astype(numpy_dtype),
device=device,
))
self.profile_tensors = profile_tensors
Expand Down
3 changes: 2 additions & 1 deletion python/bitblas/relax/transform/weight_only_propagate.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,8 @@ def transform_matmul(self, g_var: GlobalVar, func: tir.PrimFunc, intrin_info):
intrin_group = get_mma_intrin_group(
load_scope="shared",
store_scope="shared",
in_dtype=intrin_info["in_dtype"],
a_dtype=intrin_info["in_dtype"],
b_dtype=intrin_info["in_dtype"],
out_dtype=intrin_info["out_dtype"],
trans_a=False,
trans_b=intrin_info["trans_b"],
Expand Down
Loading
Loading