diff --git a/VERSION b/VERSION index 407ab24ea..9eac5e019 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -0.0.1.dev7 \ No newline at end of file +0.0.1.dev8 \ No newline at end of file diff --git a/integration/BitNet/utils_quant.py b/integration/BitNet/utils_quant.py index 06a8dc119..121649387 100644 --- a/integration/BitNet/utils_quant.py +++ b/integration/BitNet/utils_quant.py @@ -119,7 +119,6 @@ def native_forward(self, input): return out def forward_fp32_simulated(self, input): - print("input: ", input) quant_input = self.activation_quant(input, self.input_bits).detach() quant_weight = self.weight_quant(self.weight).detach() @@ -139,6 +138,8 @@ def forward_fp32_simulated(self, input): return out def forward(self, input): + # return self.forward_fp32_simulated(input) + quant_input = self.activation_quant(input, self.input_bits).detach() fp32_out = self.bitblas_matmul(quant_input, self.weight) sw = self.sw diff --git a/python/bitblas/__init__.py b/python/bitblas/__init__.py index 3f806bfde..3bd32875e 100644 --- a/python/bitblas/__init__.py +++ b/python/bitblas/__init__.py @@ -81,4 +81,4 @@ def _init_logger(): _init_logger() -__version__ = "0.0.1.dev7" +__version__ = "0.0.1.dev8" diff --git a/python/bitblas/base/utils.py b/python/bitblas/base/utils.py index 23a817f78..7da309dd5 100644 --- a/python/bitblas/base/utils.py +++ b/python/bitblas/base/utils.py @@ -19,7 +19,7 @@ import tempfile import itertools from tvm.ir.supply import GlobalVarSupply -from bitblas.utils import tensor_replace_dp4a, tensor_remove_make_int4 +from bitblas.utils import tensor_replace_dp4a, tensor_remove_make_int4, tensor_remove_make_int2 import logging logger = logging.getLogger(__name__) @@ -205,6 +205,7 @@ def _build(context) -> str: def tvm_callback_cuda_postproc(code, _): code = tensor_replace_dp4a(code) code = tensor_remove_make_int4(code) + code = tensor_remove_make_int2(code) return code with tvm.transform.PassContext(config={"tir.use_async_copy": True, **config.pass_context}): diff --git a/python/bitblas/ops/general_matmul.py b/python/bitblas/ops/general_matmul.py index 9fe7d1345..af2da3f02 100644 --- a/python/bitblas/ops/general_matmul.py +++ b/python/bitblas/ops/general_matmul.py @@ -10,7 +10,7 @@ from .impl.matmul_dequantize_impl import ( select_implementation as weight_dequantize_implementation,) from .impl.matmul_impl import select_implementation as consistent_implementation -from ..base.utils import tensor_replace_dp4a, tensor_remove_make_int4 +from ..base.utils import tensor_replace_dp4a, tensor_remove_make_int4, tensor_remove_make_int2 from bitblas.utils.target_detector import auto_detect_nvidia_target from dataclasses import dataclass from .ladder_permutate import LadderPermutate, LadderPermutateConfig @@ -398,6 +398,7 @@ def _select_implementation(self): def post_process(self, code: str) -> str: code = tensor_replace_dp4a(code) code = tensor_remove_make_int4(code) + code = tensor_remove_make_int2(code) return code def retrieve_weight_shape(self): diff --git a/python/bitblas/ops/matmul.py b/python/bitblas/ops/matmul.py index 59729a426..7783c4972 100644 --- a/python/bitblas/ops/matmul.py +++ b/python/bitblas/ops/matmul.py @@ -7,7 +7,7 @@ from typing import List, Union, Optional, Any, Tuple from .operator import Operator, TransformKind from .impl.matmul_impl import select_implementation -from bitblas.utils import tensor_replace_dp4a, tensor_remove_make_int4 +from bitblas.utils import tensor_replace_dp4a, tensor_remove_make_int4, tensor_remove_make_int2 from dataclasses import dataclass from .ladder_permutate import LadderPermutate, LadderPermutateConfig import logging @@ -189,6 +189,7 @@ def _select_implementation(self): def post_process(self, code: str) -> str: code = tensor_replace_dp4a(code) code = tensor_remove_make_int4(code) + code = tensor_remove_make_int2(code) return code def _profile_latency_with_dynamic_range(self) -> List: diff --git a/python/bitblas/ops/matmul_dequantize.py b/python/bitblas/ops/matmul_dequantize.py index d1dc35c94..25c68b121 100644 --- a/python/bitblas/ops/matmul_dequantize.py +++ b/python/bitblas/ops/matmul_dequantize.py @@ -6,7 +6,7 @@ from typing import Any, List, Literal, Optional, Tuple, Union from .operator import Operator, TransformKind from .impl.matmul_dequantize_impl import select_implementation -from ..base.utils import tensor_replace_dp4a, tensor_remove_make_int4 +from ..base.utils import tensor_replace_dp4a, tensor_remove_make_int4, tensor_remove_make_int2 from bitblas.utils.tensor_adapter import tvm_tensor_to_torch from dataclasses import dataclass from .ladder_permutate import LadderPermutate, LadderPermutateConfig @@ -234,6 +234,7 @@ def _select_implementation(self): def post_process(self, code: str) -> str: code = tensor_replace_dp4a(code) code = tensor_remove_make_int4(code) + code = tensor_remove_make_int2(code) return code def retrieve_weight_shape(self): diff --git a/python/bitblas/utils/__init__.py b/python/bitblas/utils/__init__.py index 416d6b1f2..00bddc2a5 100644 --- a/python/bitblas/utils/__init__.py +++ b/python/bitblas/utils/__init__.py @@ -1,5 +1,5 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from .post_process import match_global_kernel, tensor_replace_dp4a, tensor_remove_make_int4 # noqa: F401 +from .post_process import match_global_kernel, tensor_replace_dp4a, tensor_remove_make_int4, tensor_remove_make_int2 # noqa: F401 from .tensor_adapter import tvm_tensor_to_torch, lazy_tvm_tensor_to_torch, lazy_torch_to_tvm_tensor # noqa: F401 from .target_detector import get_all_nvidia_targets, auto_detect_nvidia_target # noqa: F401 diff --git a/python/bitblas/utils/post_process.py b/python/bitblas/utils/post_process.py index e4fe5f95f..cabee6be1 100644 --- a/python/bitblas/utils/post_process.py +++ b/python/bitblas/utils/post_process.py @@ -27,3 +27,12 @@ def tensor_remove_make_int4(source: str) -> str: "make_int4(0, 0, 0, 0)", ) return source + +def tensor_remove_make_int2(source: str) -> str: + # remove make_int4 with 16 signed char arguments + # TODO(lei): this is a stuff that should be fixed in the tvm in the future + source = source.replace( + "make_int2((signed char)0, (signed char)0, (signed char)0, (signed char)0, (signed char)0, (signed char)0, (signed char)0, (signed char)0)", + "make_int2(0, 0)", + ) + return source