Skip to content

Commit

Permalink
[Dev] Refactor the weight transformation to support upcoming stage3 t…
Browse files Browse the repository at this point in the history
…ransform (#130)

* Refactor BatchMatMulEmitter and BatchMatMulSelector for improved readability and maintainability

* Refactor import statements for improved readability and maintainability

* Refactor import statements for improved readability and maintainability

* disable failure email for ci

* remove email notifications.

* move relax pass from testing to mlc_llm

* Refactor scripts with se check_eual_ref_scripts_with_emitter function

* Lint Fix

* Refactor scripts with se check_eual_ref_scripts_with_emitter function

* bug fix in test

* lint fix.

* test cuda i4 kernel

* Refactor copyright notice in i4matmul.hpp

* Refactor BitBLASLinear test module for improved readability and maintainability

* refactor test as version below python 3.9 cannot handle int32 overflow.

* format lint for test

* Refactor test_int4b_fp16_convert.py for improved readability and maintainability

* remove unused design file

* move tile device from package to base

* dummy impl for codegen

* Refactor file structure for ladder_permutate module

* Refactor backend class and fix typos in comments

* Deep refactor Lib related code.

* remove ci pull.

* LintFix

* refactor builder for whl build

* Refactor TIRWrapper.wrap() method to include an assertion for the optimized module

* Refactor lib_generator to set library and source paths

* lint fix

* BitNet vllm integration

* chore: update codespell to version 2.3.0

* Lintfix

* Bump version to 0.0.1.dev13

* lint fix

* disable fast decoding [u]int4xint8 by default.

* optimize from dict design in Hint

* Implement SplitK

* bitnet benchmark generation.

* Add benchmark script for BitNet integration

* AtomicAdd Support

* LintFix

* ci fix when 3rdparty tvm is initialized.

* bug fix for setup

* fix a bug in block reduce

* typo fix

* BUG Fix for block reduce.

* Lint fix

* Refactor block reduce schedule template

* transform branch from bitblas to bitblas_tl

* Fix subproject commit reference in 3rdparty/tvm

* chore: update submodule branch from bitblas to bitblas_tl

* force update config.cmake

* Bug fix

* Fix subproject commit reference in 3rdparty/cutlass

* chore: Add submodule for cutlass library

* update tl cutlass path

* Refactor BitBLASLinear test module for improved readability and maintainability

* format fix

* Copy CUTLASS to the package directory

* Refactor setup.py to include additional TVM header files

* lint fix

* bug fix

* Refactor BitBLASLinear test module for improved readability and maintainability

* Implement Matmul Benchmark Design

* chore: Update BitBLAS Matmul benchmark script

* lint fix

* Refactor BitBLASMatmulOpsBenchmark for improved readability and maintainability

* Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark run

* lint fix

* Benchmark bot test

* Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark run

* Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark run

* Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark run

* Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark run

* Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark run

* int8 test case

* Refactor compare_benchmark.py to handle missing benchmark results gracefully

* ci fix

* disable ci for test benchmark

* Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark run

* remove cli installation

* chore: Create virtual environment and install dependencies for benchmark

* chore: Update benchmark workflow to include comparison step

* Lint fix

* upodate tvm cmmit

* Imporve lower warp memory pass

* Bug fix

* Enhance to support warp schedule.

* Enhance LOP3 Instructions

* Enhance LOP3 Instructions

* add test for stage3 propagate

* implement propagate func

* Stage3 Ladder Permutate integration

* get_ladder_stage3_propagate

* comments benchmark scirpts as the setting is too big

* ci fix for benchmark

* lint fix

* chore: Update benchmark workflow to trigger on pull request comments

* Add LDMatrix Transform 3

* Support GPTQ Test

* Fuse BlockReduce Schedule

* Support mma propagate 3

* Support MMA Propagate Stage 3

* Lint Fix

* Merge block reduce for dequantze config.

* fix codeql

* chore: Update submodule reference to latest commit

* chore: Disable common subexpression elimination in TIR passes

* Lint Fix

* 4bit related lop3 updates.

* lint fix

* gptq test fix

* Fix for test

* lint fix

* lint fix

* typofix

* QuantCompress Test

* chore: Refactor quant_compress_impl.py for readability and maintainability

* Enhance docs to update latest works.

* Refactor weight executors in Matmul class for improved readability and maintainability

* Refactor weight executors in Matmul class for improved readability and maintainability

* Refactor weight executors in Matmul class for improved readability and maintainability

* removed legacy operator

* Refactor weight executors in Matmul class for improved readability and maintainability

* LintFix

* Fix GPTQ Repack with the latest weight transform

* lint fix

* bug fix for rescale dequantize

* test fix

* typo fix
  • Loading branch information
LeiWang1999 authored Aug 5, 2024
1 parent 4d218c1 commit 5d14d31
Show file tree
Hide file tree
Showing 13 changed files with 122 additions and 637 deletions.
25 changes: 24 additions & 1 deletion bitblas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,10 @@
from .utils import auto_detect_nvidia_target, apply_transform_on_input # noqa: F401
from .ops.general_matmul import MatmulConfig, Matmul # noqa: F401
from .ops.general_matmul_splitk import MatmulConfigWithSplitK, MatmulWithSplitK # noqa: F401
from .ops.matmul_dequantize import MatmulWeightOnlyDequantizeConfig, MatmulWeightOnlyDequantize # noqa: F401
from .module import Linear # noqa: F401

import warnings
import functools
import logging
from tqdm import tqdm

Expand Down Expand Up @@ -89,4 +90,26 @@ def _init_logger():

_init_logger()


def deprecated(reason):
"""
This is a decorator which can be used to mark functions as deprecated.
It will result in a warning being emitted when the function is used.
"""

def decorator(func):

@functools.wraps(func)
def new_func(*args, **kwargs):
warnings.warn(
f"Call to deprecated function {func.__name__} ({reason}).",
category=DeprecationWarning,
stacklevel=2)
return func(*args, **kwargs)

return new_func

return decorator


__version__ = "0.0.1.dev13"
9 changes: 5 additions & 4 deletions bitblas/gpu/matmul_mma_dequantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -2264,10 +2264,11 @@ def get_idx():
lop3_intrin_info["compute"],
)
# Assume the grouped K is the last dim of the scaling
grouped_k = sch.get(bf).reads[1].buffer.shape[-1]
# TODO(lei): This is a hack to get the loop extent
loop_extent = 8 if out_dtype == "float16" else 16
sch.unsafe_inject_call_argument(bf, -2, loop_extent * grouped_k)
if "with_scaling" in weight_decode_info and weight_decode_info["with_scaling"]:
grouped_k = sch.get(bf).reads[1].buffer.shape[-1]
# TODO(lei): This is a hack to get the loop extent
loop_extent = 8 if out_dtype == "float16" else 16
sch.unsafe_inject_call_argument(bf, -2, loop_extent * grouped_k)
import_source.append(lop3_intrin_info["c_source"])

def tensorize_init_store_compute():
Expand Down
21 changes: 20 additions & 1 deletion bitblas/module/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,24 @@ def unpack_qzeros(qzeros, bits):
return torch.bitwise_and(unpacked_zeros + 1, 2**bits - 1)


def unpack_qweight(qweight, bits):
qweight = qweight.view(torch.int8)
elems_per_int8 = 8 // bits
unpacked_weight = torch.zeros(
(qweight.shape[0], qweight.shape[1] * elems_per_int8),
dtype=torch.int8,
device=qweight.device,
requires_grad=False,
)
for col in range(unpacked_weight.shape[1]):
i = col % elems_per_int8
unpacked_weight[:, col] = (qweight[:, col // elems_per_int8] >> (bits * i))

# Follow the instruction in AutoGPTQ qlinear_cuda_old.py line 303
# NOTE: It appears that casting after the `unpacked_zeros + 1` is important.
return torch.bitwise_and(unpacked_weight, 2**bits - 1)


class Linear(nn.Module):
opt_M = [1, 16, 32, 64, 128, 256, 512]
STORAGE_DTYPE = "int8" # assume int8 storage
Expand Down Expand Up @@ -279,8 +297,9 @@ def load_and_transform_weight(
def repack_from_gptq(self, gptq_module):
# qweight in gptq old quant linear stored with (out_features, in_features), should be transposed.
qweight = gptq_module.qweight.T.contiguous().view(self.TORCH_STORAGE_DTYPE)
intweight = unpack_qweight(qweight, self.bits).contiguous()
if self.bitblas_matmul.weight_transform is not None:
qweight = self.bitblas_matmul.weight_transform(qweight.cpu()).cuda()
qweight = self.bitblas_matmul.weight_transform(intweight.cpu()).cuda()
self.qweight = qweight
# scales in gptq old quant linear stored with (in_features // group_size, out_features), should be transposed.
scales = gptq_module.scales.T.contiguous().view(self.torch_dtype)
Expand Down
3 changes: 1 addition & 2 deletions bitblas/ops/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from .operator import Operator, OperatorConfig # noqa: F401
from .matmul import Matmul, MatmulConfig # noqa: F401
from .matmul_dequantize import MatmulWeightOnlyDequantize, MatmulWeightOnlyDequantizeConfig # noqa: F401
from .general_matmul import Matmul, MatmulConfig # noqa: F401
from .ladder_permutate import LadderPermutate, LadderPermutateConfig # noqa: F401
from .lop3_permutate import LOP3Permutate, LOP3PermutateConfig # noqa: F401
from .quant_compress import QuantCompress, QuantCompressConfig # noqa: F401
49 changes: 32 additions & 17 deletions bitblas/ops/general_matmul/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from bitblas.utils.target_detector import auto_detect_nvidia_target
from dataclasses import dataclass
from ..ladder_permutate import LadderPermutate, LadderPermutateConfig
from ..quant_compress import QuantCompress, QuantCompressConfig
from ..lop3_permutate import LOP3Permutate, LOP3PermutateConfig
import logging
import torch
Expand Down Expand Up @@ -292,6 +293,7 @@ def dispatch_tir(self,
# create permutate_opertors
self.ladder_permutate_a = self._assign_ladder_permutate_a(target, enable_tuning)
self.ladder_permutate_b = self._assign_ladder_permutate_b(target, enable_tuning)
self.weight_compress = self._assign_weight_compress(target, enable_tuning)
self.lop3_permutate = self._assign_lop3_permutate(target, enable_tuning)
# create cpu weight executors
self.input_executors = self._create_input_executors()
Expand Down Expand Up @@ -338,11 +340,14 @@ def _assign_ladder_permutate_b(self, target: Target, enable_tuning: bool):
del enable_tuning

if self.propagate_b:
# weight transform should be done in the unpacked level
# otherwise the bit trick should be applied and that is
# too complex to be implemented in the ladder permutation.
ladder_permutate_config = LadderPermutateConfig(
M=self.N,
N=self.K,
datatype=self.A_dtype,
dequantize_bits=self.bit,
dequantize_bits=-1,
storage_dtype=self.storage_dtype,
propagate_kind="B",
transpose_matrix=self.layout == "nt",
Expand All @@ -354,6 +359,25 @@ def _assign_ladder_permutate_b(self, target: Target, enable_tuning: bool):
)
return None

def _assign_weight_compress(self, target: Target, enable_tuning: bool):
# unused variables
del target
del enable_tuning

require_compress: bool = self.bit in [1, 2, 4]
if require_compress:
quant_compress_config = QuantCompressConfig(
M=self.N,
N=self.K,
input_dtype=self.storage_dtype,
storage_dtype=self.storage_dtype,
dequantize_bits=self.bit)
return QuantCompress(
config=quant_compress_config,
target=tvm.target.Target("llvm"),
)
return None

def _assign_lop3_permutate(self, target: Target, enable_tuning: bool):
# unused variables
del target
Expand Down Expand Up @@ -381,10 +405,12 @@ def _create_input_executors(self):

def _create_weight_executors(self):
weight_executors = OPExecutorCPU()
if self.fast_decoding:
weight_executors.append(self.lop3_permutate)
if self.propagate_b is not TransformKind.NonTransform:
weight_executors.append(self.ladder_permutate_b)
if self.weight_compress is not None:
weight_executors.append(self.weight_compress)
if self.fast_decoding:
weight_executors.append(self.lop3_permutate)
return weight_executors

def _select_implementation(self):
Expand Down Expand Up @@ -452,10 +478,6 @@ def transform_weight(self, weight, scale=None, zeros=None, bias=None):
return self.weight_transform(weight.cpu()).cuda().contiguous()
return weight

from bitblas.quantization import general_compress
import torch
import numpy as np

source_format, bit = self.source_format, self.bit

# Process integer source format
Expand All @@ -464,20 +486,13 @@ def transform_weight(self, weight, scale=None, zeros=None, bias=None):
assert not self.with_zeros, "zeros should be False for int source format"
maxq = 2**(bit - 1)
# Clamp weight values to be within the quantizable range and adjust
weight = torch.clamp(weight, -maxq, maxq).int() + maxq
weight = torch.clamp(weight, -maxq, maxq).char() + maxq
elif source_format in ["fp_e5m2", "fp_e4m3"]:
weight = weight.view(torch.int8)
weight = weight.int()
else:
# For non-integer formats, simply convert weights to integers
weight = weight.int()

np_storage_dtype = getattr(np, self.storage_dtype)

weight = general_compress(
weight.cpu().numpy(), source_bits=bit, storage_dtype=np_storage_dtype)

weight = torch.from_numpy(weight).cuda().contiguous()
# And assume weight is in the range of [-128, 127] for int8
weight = weight.char()

# Apply an optional weight transformation if specified
if self.weight_transform is not None:
Expand Down
18 changes: 18 additions & 0 deletions bitblas/ops/ladder_permutate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from ..operator import Operator
from .ladder_permutate_impl import select_implementation
from dataclasses import dataclass
import torch


@dataclass(frozen=True)
Expand Down Expand Up @@ -57,6 +58,23 @@ def _select_implementation(self):
target_instruction=self.target_instruction,
)

def forward(self, inp, out=None):
if out is None:
out_shape, out_dtype = self.retrieve_output_shape()
out = torch.zeros(out_shape, dtype=out_dtype).to(inp.device)
self.torch_func(inp, out)
return out

def retrieve_output_shape(self):
"""
Retrieve the output shape of the operator
"""
func = self.prim_func
param = func.params[-1]
assert param in func.buffer_map, f"param {param} not in buffer_map"
arg = func.buffer_map[param]
return [int(i) for i in arg.shape], getattr(torch, arg.dtype)

@property
def M(self):
return self.config.M
Expand Down
18 changes: 15 additions & 3 deletions bitblas/ops/lop3_permutate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,23 @@ def _select_implementation(self):
dequantize_bits=self.dequantize_bits,
)

def forward(self, weight, res):
def forward(self, inp, out=None):
out_shape = inp.shape
out_dtype = inp.dtype
if out is None:
# lop3 transform does not change the shape of the input tensor
out = torch.zeros(out_shape, dtype=out_dtype)
# reinterpret the input tensor to int32 format
args = [arg.view(torch.int32) for arg in [weight, res]]
shape_2dim = self.retrieve_2d_weight_shape()
args = [arg.view(inp.dtype).view(shape_2dim).view(torch.int32) for arg in [inp, out]]
self.torch_func(*args)
return args[-1].view(weight.dtype)
return args[-1].view(out_dtype).view(out_shape)

def retrieve_2d_weight_shape(self):
storage_nbit = int("".join(c for c in self.storage_dtype if c.isdigit()))
elems_per_byte = storage_nbit // self.dequantize_bits
weight_shape = (self.M, self.N // elems_per_byte)
return weight_shape

@property
def M(self):
Expand Down
Loading

0 comments on commit 5d14d31

Please sign in to comment.