From acf168ea80562a5105e53a6e2ca7a2bb38e67627 Mon Sep 17 00:00:00 2001 From: Lei Wang <34334180+LeiWang1999@users.noreply.github.com> Date: Tue, 30 Apr 2024 17:32:52 +0800 Subject: [PATCH] [FP8] Improve tensor adapter to support fp8 conversion between torch and numpy (#30) * Add Str Parse library to requirements.txt and requirements-dev.txt * Support quantized zero types for uint2. * Support FP8 Codegen * Add support for e4m3_float8 and e5m2_float8 types in CUDA wrapper * Support FP8 * Fix data type limitation in MatmulConfig and LadderPermutateConfig * Fix storage_dtype assignment in MatmulConfig --- python/bitblas/ops/general_matmul.py | 2 +- python/bitblas/ops/ladder_permutate.py | 2 +- python/bitblas/ops/operator.py | 5 ++++ python/bitblas/utils/__init__.py | 2 +- python/bitblas/utils/tensor_adapter.py | 40 ++++++++++++++++++++++++++ 5 files changed, 48 insertions(+), 3 deletions(-) diff --git a/python/bitblas/ops/general_matmul.py b/python/bitblas/ops/general_matmul.py index 4a48fb901..ce8a8aef4 100644 --- a/python/bitblas/ops/general_matmul.py +++ b/python/bitblas/ops/general_matmul.py @@ -172,7 +172,7 @@ def __post_init__(self): "float16", "int8", "e4m3_float8", "e5m2_float8" ]: object.__setattr__(self, "storage_dtype", self.W_dtype) - # TODO(lei): This is a limitation arose by pytorch + # TODO(lei): This is a limitation arose by pytorch and llvm # Should be removed in the future. if self.A_dtype in ["e4m3_float8", "e5m2_float8"]: object.__setattr__(self, "propagate_a", TransformKind.NonTransform) diff --git a/python/bitblas/ops/ladder_permutate.py b/python/bitblas/ops/ladder_permutate.py index 06b521496..70999b09d 100644 --- a/python/bitblas/ops/ladder_permutate.py +++ b/python/bitblas/ops/ladder_permutate.py @@ -11,7 +11,7 @@ class LadderPermutateConfig: M: int N: int - datatype: Literal["float16", "int8"] = "float16" + datatype: Literal["int8", "e4m3_float8", "e5m2_float8"] = "float16" dequantize_bits: int = -1 storage_dtype: Literal["float16", "int8", "uint8", "int32", "uint32"] = "float16" propagate_kind: Literal["A", "B"] = "B" # "A" or "B" diff --git a/python/bitblas/ops/operator.py b/python/bitblas/ops/operator.py index 224726b6d..3bf97d440 100644 --- a/python/bitblas/ops/operator.py +++ b/python/bitblas/ops/operator.py @@ -270,7 +270,12 @@ def _forward_from_tvm_args(self, *args): _tvm_args = [self._tensor_adapter(arg, self.arch.device) for arg in args] self.rt_mod(*_tvm_args) + def _forward_from_tvm_nd_array(self, *args): + self.rt_mod(*args) + def _forward_from_torch_func(self, *args): + # torch func is not reliable as some datatypes they don't support + # like float8. self.torch_func(*args) return args[-1] diff --git a/python/bitblas/utils/__init__.py b/python/bitblas/utils/__init__.py index 35f933542..f9587964c 100644 --- a/python/bitblas/utils/__init__.py +++ b/python/bitblas/utils/__init__.py @@ -1,5 +1,5 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. from .post_process import match_global_kernel, tensor_replace_dp4a, tensor_remove_make_int4 # noqa: F401 -from .tensor_adapter import tvm_tensor_to_torch # noqa: F401 +from .tensor_adapter import tvm_tensor_to_torch, lazy_tvm_tensor_to_torch, lazy_torch_to_tvm_tensor # noqa: F401 from .target_detector import auto_detect_nvidia_target # noqa: F401 diff --git a/python/bitblas/utils/tensor_adapter.py b/python/bitblas/utils/tensor_adapter.py index 0f7eeefa9..55b80d138 100644 --- a/python/bitblas/utils/tensor_adapter.py +++ b/python/bitblas/utils/tensor_adapter.py @@ -3,8 +3,10 @@ import tvm from typing import Union from enum import IntEnum +import numpy as np import torch from torch.utils.dlpack import from_dlpack, to_dlpack +from math import prod from tvm.relay import TensorType from tvm._ffi.base import _LIB, c_str @@ -88,3 +90,41 @@ def tvm_tensor_to_torch(tensor: Union[tvm.te.Tensor, tvm.nd.NDArray]): return from_dlpack(tensor) else: raise RuntimeError("Not supported type: ", type(tensor)) + +def lazy_tvm_tensor_to_torch(tensor: Union[tvm.te.Tensor, tvm.nd.NDArray]): + # It additionally needs the ctypes type as torch type + def as_tensor(address, shape, elems_inbytes, torch_type): + arr = (ctypes.c_int8 * elems_inbytes).from_address( + address) + return torch.frombuffer(arr, dtype=torch_type).view(*shape) + + if isinstance(tensor, tvm.nd.NDArray): + np_array = tensor.asnumpy() + shape = np_array.shape + dtype = np_array.dtype + torch_dtype = getattr(torch, str(dtype)) + num_elems_inbytes = prod(shape) * np_array.itemsize + data_ptr = np_array.ctypes.data + tensor = as_tensor(data_ptr, shape, num_elems_inbytes, torch_dtype) + return tensor + else: + raise RuntimeError("Not supported type: ", type(tensor)) + +def lazy_torch_to_tvm_tensor(tensor): + # It additionally needs the ctypes type as torch type + def as_tensor(address, shape, elems_inbytes, numpy_type): + arr = (ctypes.c_int8 * elems_inbytes).from_address( + address) + return np.frombuffer(arr, dtype=numpy_type).reshape(shape) + + if isinstance(tensor, torch.Tensor): + data_ptr = tensor.data_ptr() + shape = tensor.shape + torch_dtype = tensor.dtype + numpy_dtype = str(torch_dtype).replace("torch.", "") + num_elems_inbytes = prod(shape) * tensor.itemsize + np_tensor = as_tensor(data_ptr, shape, num_elems_inbytes, numpy_dtype) + tvm_tensor = tvm.nd.array(np_tensor) + return tvm_tensor + else: + raise RuntimeError("Not supported type: ", type(tensor))