Skip to content

Commit

Permalink
[Dev] Bump Version to dev0.8 and fix issue INT8xINT2 (#49)
Browse files Browse the repository at this point in the history
* improve e4m3 decoding.

* append fp16xint1

* Update submodule commit reference

* chore: Update shared memory scope for float32 output dtype

* BUGFIX: UINT8/INT8 Decoding

* feat: Add rasterization options for roller module

* Refactor tensorcore_legalization method to optimize tensor core usage

* feat: Add function to collect variables from expression, improve for splitk

* chore: Update typing import in __init__.py

* chore: Refactor CPU execution of operators

* Refactor matmul implementation for splitk layout

* Refactor matmul implementation for splitk layout

* Refactor matmul implementation for splitk layout

* chore: Update version to 0.0.1.dev8

---------

Co-authored-by: LeiWang199 <leiwang199>
  • Loading branch information
LeiWang1999 authored Jun 5, 2024
1 parent 99a744e commit 4cac65a
Show file tree
Hide file tree
Showing 9 changed files with 22 additions and 8 deletions.
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.0.1.dev7
0.0.1.dev8
3 changes: 2 additions & 1 deletion integration/BitNet/utils_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,6 @@ def native_forward(self, input):
return out

def forward_fp32_simulated(self, input):
print("input: ", input)
quant_input = self.activation_quant(input, self.input_bits).detach()
quant_weight = self.weight_quant(self.weight).detach()

Expand All @@ -139,6 +138,8 @@ def forward_fp32_simulated(self, input):
return out

def forward(self, input):
# return self.forward_fp32_simulated(input)

quant_input = self.activation_quant(input, self.input_bits).detach()
fp32_out = self.bitblas_matmul(quant_input, self.weight)
sw = self.sw
Expand Down
2 changes: 1 addition & 1 deletion python/bitblas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,4 +81,4 @@ def _init_logger():

_init_logger()

__version__ = "0.0.1.dev7"
__version__ = "0.0.1.dev8"
3 changes: 2 additions & 1 deletion python/bitblas/base/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import tempfile
import itertools
from tvm.ir.supply import GlobalVarSupply
from bitblas.utils import tensor_replace_dp4a, tensor_remove_make_int4
from bitblas.utils import tensor_replace_dp4a, tensor_remove_make_int4, tensor_remove_make_int2
import logging

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -205,6 +205,7 @@ def _build(context) -> str:
def tvm_callback_cuda_postproc(code, _):
code = tensor_replace_dp4a(code)
code = tensor_remove_make_int4(code)
code = tensor_remove_make_int2(code)
return code

with tvm.transform.PassContext(config={"tir.use_async_copy": True, **config.pass_context}):
Expand Down
3 changes: 2 additions & 1 deletion python/bitblas/ops/general_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from .impl.matmul_dequantize_impl import (
select_implementation as weight_dequantize_implementation,)
from .impl.matmul_impl import select_implementation as consistent_implementation
from ..base.utils import tensor_replace_dp4a, tensor_remove_make_int4
from ..base.utils import tensor_replace_dp4a, tensor_remove_make_int4, tensor_remove_make_int2
from bitblas.utils.target_detector import auto_detect_nvidia_target
from dataclasses import dataclass
from .ladder_permutate import LadderPermutate, LadderPermutateConfig
Expand Down Expand Up @@ -398,6 +398,7 @@ def _select_implementation(self):
def post_process(self, code: str) -> str:
code = tensor_replace_dp4a(code)
code = tensor_remove_make_int4(code)
code = tensor_remove_make_int2(code)
return code

def retrieve_weight_shape(self):
Expand Down
3 changes: 2 additions & 1 deletion python/bitblas/ops/matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import List, Union, Optional, Any, Tuple
from .operator import Operator, TransformKind
from .impl.matmul_impl import select_implementation
from bitblas.utils import tensor_replace_dp4a, tensor_remove_make_int4
from bitblas.utils import tensor_replace_dp4a, tensor_remove_make_int4, tensor_remove_make_int2
from dataclasses import dataclass
from .ladder_permutate import LadderPermutate, LadderPermutateConfig
import logging
Expand Down Expand Up @@ -189,6 +189,7 @@ def _select_implementation(self):
def post_process(self, code: str) -> str:
code = tensor_replace_dp4a(code)
code = tensor_remove_make_int4(code)
code = tensor_remove_make_int2(code)
return code

def _profile_latency_with_dynamic_range(self) -> List:
Expand Down
3 changes: 2 additions & 1 deletion python/bitblas/ops/matmul_dequantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Any, List, Literal, Optional, Tuple, Union
from .operator import Operator, TransformKind
from .impl.matmul_dequantize_impl import select_implementation
from ..base.utils import tensor_replace_dp4a, tensor_remove_make_int4
from ..base.utils import tensor_replace_dp4a, tensor_remove_make_int4, tensor_remove_make_int2
from bitblas.utils.tensor_adapter import tvm_tensor_to_torch
from dataclasses import dataclass
from .ladder_permutate import LadderPermutate, LadderPermutateConfig
Expand Down Expand Up @@ -234,6 +234,7 @@ def _select_implementation(self):
def post_process(self, code: str) -> str:
code = tensor_replace_dp4a(code)
code = tensor_remove_make_int4(code)
code = tensor_remove_make_int2(code)
return code

def retrieve_weight_shape(self):
Expand Down
2 changes: 1 addition & 1 deletion python/bitblas/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from .post_process import match_global_kernel, tensor_replace_dp4a, tensor_remove_make_int4 # noqa: F401
from .post_process import match_global_kernel, tensor_replace_dp4a, tensor_remove_make_int4, tensor_remove_make_int2 # noqa: F401
from .tensor_adapter import tvm_tensor_to_torch, lazy_tvm_tensor_to_torch, lazy_torch_to_tvm_tensor # noqa: F401
from .target_detector import get_all_nvidia_targets, auto_detect_nvidia_target # noqa: F401
9 changes: 9 additions & 0 deletions python/bitblas/utils/post_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,12 @@ def tensor_remove_make_int4(source: str) -> str:
"make_int4(0, 0, 0, 0)",
)
return source

def tensor_remove_make_int2(source: str) -> str:
# remove make_int4 with 16 signed char arguments
# TODO(lei): this is a stuff that should be fixed in the tvm in the future
source = source.replace(
"make_int2((signed char)0, (signed char)0, (signed char)0, (signed char)0, (signed char)0, (signed char)0, (signed char)0, (signed char)0)",
"make_int2(0, 0)",
)
return source

0 comments on commit 4cac65a

Please sign in to comment.