Skip to content

Commit

Permalink
Refactor the gemv schedule to support batch
Browse files Browse the repository at this point in the history
  • Loading branch information
LeiWang1999 committed May 7, 2024
1 parent e939345 commit d1b9dbf
Showing 1 changed file with 92 additions and 38 deletions.
130 changes: 92 additions & 38 deletions python/bitblas/ops/general_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from typing import Any, List, Literal, Optional, Tuple, Union
from .operator import Operator, TransformKind
from .impl.matmul_dequantize_impl import (
select_implementation as weight_dequantize_implementation,)
select_implementation as weight_dequantize_implementation,
)
from .impl.matmul_impl import select_implementation as consistent_implementation
from ..base.utils import tensor_replace_dp4a, tensor_remove_make_int4
from bitblas.utils.target_detector import auto_detect_nvidia_target
Expand Down Expand Up @@ -94,35 +95,51 @@ class MatmulConfig:
storage_dtype: str = "int8"

# weight transform related flags
fast_decoding: Optional[bool] = None # enable fast decoding by default, if not specified, it is enabled by a rule.
propagate_a: Optional[TransformKind] = None # propagate_a is a flag to control the ladder permutation.
propagate_b: Optional[TransformKind] = None # propagate_b is a flag to control the ladder permutation
fast_decoding: Optional[bool] = (
None # enable fast decoding by default, if not specified, it is enabled by a rule.
)
propagate_a: Optional[TransformKind] = (
None # propagate_a is a flag to control the ladder permutation.
)
propagate_b: Optional[TransformKind] = (
None # propagate_b is a flag to control the ladder permutation
)


def __legalize_dynamic_symbolic(self, M):
return tuple(self.M) if isinstance(self.M, list) else self.M

def __legalize_propagate(self, propagate):
if isinstance(propagate, bool):
return (TransformKind.IntraWarpTransform
if propagate else TransformKind.NonTransform)
return (
TransformKind.IntraWarpTransform
if propagate
else TransformKind.NonTransform
)
elif isinstance(propagate, int):
return TransformKind(propagate)

return propagate

def __initialize_propagate(self, propagate_a: Optional[TransformKind], propagate_b: Optional[TransformKind]):
def __initialize_propagate(
self, propagate_a: Optional[TransformKind], propagate_b: Optional[TransformKind]
):
MICRO_KERNEL_SIZE = 16
if isinstance(
self.M,
int) and (self.M % MICRO_KERNEL_SIZE) == 0 and (self.K % MICRO_KERNEL_SIZE) == 0:
if (
isinstance(self.M, int)
and (self.M % MICRO_KERNEL_SIZE) == 0
and (self.K % MICRO_KERNEL_SIZE) == 0
):
object.__setattr__(self, "propagate_a", TransformKind.IntraWarpTransform)
else:
object.__setattr__(self, "propagate_a", TransformKind.NonTransform)

if self.M == 1 or (
self.N % MICRO_KERNEL_SIZE) != 0 or (self.K % MICRO_KERNEL_SIZE) != 0 or isinstance(
self.M, Tuple) or (self.with_zeros and self.zeros_mode == "quantized"):
if (
self.M == 1
or (self.N % MICRO_KERNEL_SIZE) != 0
or (self.K % MICRO_KERNEL_SIZE) != 0
or isinstance(self.M, Tuple)
or (self.with_zeros and self.zeros_mode == "quantized")
):
object.__setattr__(self, "propagate_a", TransformKind.NonTransform)
object.__setattr__(self, "propagate_b", TransformKind.NonTransform)
else:
Expand All @@ -133,7 +150,7 @@ def __initialize_propagate(self, propagate_a: Optional[TransformKind], propagate
object.__setattr__(self, "propagate_a", propagate_a)
if propagate_b is not None:
object.__setattr__(self, "propagate_b", propagate_b)

# TODO(lei): This is a limitation arose by pytorch and llvm
# Should be removed in the future.
if self.A_dtype in ["e4m3_float8", "e5m2_float8"]:
Expand All @@ -145,7 +162,11 @@ def __initialize_zeros_mode(self, zeros_mode: Optional[str]):
object.__setattr__(self, "zeros_mode", "original")

def __initialize_fast_decoding(self, fast_decoding: Optional[bool]):
if "int" not in self.W_dtype or self.W_dtype == self.A_dtype:
if (
"int" not in self.W_dtype
or "nf" not in self.W_dtype
or self.W_dtype == self.A_dtype
):
object.__setattr__(self, "fast_decoding", False)
else:
object.__setattr__(self, "fast_decoding", True)
Expand All @@ -164,10 +185,14 @@ def __post_init__(self):
# set M to tuple if it is list
# otherwise, M is not hashable
object.__setattr__(self, "M", self.__legalize_dynamic_symbolic(self.M))

# set propagate_a and propagate_b to default value if it is None
object.__setattr__(self, "propagate_a", self.__legalize_propagate(self.propagate_a))
object.__setattr__(self, "propagate_b", self.__legalize_propagate(self.propagate_b))
object.__setattr__(
self, "propagate_a", self.__legalize_propagate(self.propagate_a)
)
object.__setattr__(
self, "propagate_b", self.__legalize_propagate(self.propagate_b)
)

# This is hack to legalize propagate_a and b
# TODO(lei): should be removed in the future when tc+br template is ready.
Expand All @@ -190,7 +215,10 @@ def __post_init__(self):
object.__setattr__(self, "with_zeros", False)

if self.A_dtype == self.W_dtype and self.W_dtype in [
"float16", "int8", "e4m3_float8", "e5m2_float8"
"float16",
"int8",
"e4m3_float8",
"e5m2_float8",
]:
object.__setattr__(self, "storage_dtype", self.W_dtype)

Expand Down Expand Up @@ -234,8 +262,9 @@ def __init__(
if target is None:
target = auto_detect_nvidia_target()
logger.info(f"Auto detected target: {target}")
assert (config.A_dtype
in self.BITBLAS_TRICK_DTYPE_MAP), f"Unsupported input dtype {config.A_dtype}"
assert (
config.A_dtype in self.BITBLAS_TRICK_DTYPE_MAP
), f"Unsupported input dtype {config.A_dtype}"
source_format, bit = self.BITBLAS_TRICK_DTYPE_MAP[config.W_dtype]

self.source_format = source_format
Expand All @@ -256,7 +285,8 @@ def __init__(
if isinstance(self.M, Tuple):
self.dynamic_range = {"m": self.M}
self.prim_func_mod["main"] = self.prim_func_mod["main"].with_attrs(
{"opt_shapes": self.dynamic_range})
{"opt_shapes": self.dynamic_range}
)
else:
self.dynamic_range = None

Expand Down Expand Up @@ -303,6 +333,7 @@ def __init__(
self.ladder_permutate_b = None

if self.fast_decoding:
assert self.source_format in ["int", "uint"]
lop3_permutate_config = LOP3PermutateConfig(
M=self.N,
N=self.K,
Expand Down Expand Up @@ -335,13 +366,27 @@ def __init__(
self.hardware_aware_finetune()

if source_format == "nf":
self.lut = torch.Tensor(([
-1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453,
-0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0.0,
0.07958029955625534, 0.16093020141124725, 0.24611230194568634, 0.33791524171829224,
0.44070982933044434, 0.5626170039176941, 0.7229568362236023, 1.0
]),
dtype=getattr(torch, self.A_dtype)).cuda()
self.lut = torch.tensor(
[
-1.0,
-0.6961928009986877,
-0.5250730514526367,
-0.39491748809814453,
-0.28444138169288635,
-0.18477343022823334,
-0.09105003625154495,
0.0,
0.07958029955625534,
0.16093020141124725,
0.24611230194568634,
0.33791524171829224,
0.44070982933044434,
0.5626170039176941,
0.7229568362236023,
1.0,
],
dtype=getattr(torch, self.A_dtype),
).cuda()
else:
self.lut = None

Expand All @@ -350,7 +395,9 @@ def __init__(

def _build_default_module(self, target: Target):
try:
self.optimized_func = self.apply_default_schedule(self.prim_func_mod, target)
self.optimized_func = self.apply_default_schedule(
self.prim_func_mod, target
)
except Exception:
self.optimized_func = None
logger.warning(
Expand Down Expand Up @@ -401,7 +448,9 @@ def post_process(self, code: str) -> str:
return code

def retrieve_weight_shape(self):
return [int(i) for i in self.prim_func.buffer_map[self.prim_func.params[1]].shape]
return [
int(i) for i in self.prim_func.buffer_map[self.prim_func.params[1]].shape
]

def transform_weight(self, weight, scale=None, zeros=None, bias=None):
"""
Expand Down Expand Up @@ -433,7 +482,7 @@ def transform_weight(self, weight, scale=None, zeros=None, bias=None):
if source_format == "int":
assert not self.with_scaling, "scale should be False for int source format"
assert not self.with_zeros, "zeros should be False for int source format"
maxq = 2**(bit - 1)
maxq = 2 ** (bit - 1)
# Clamp weight values to be within the quantizable range and adjust
weight = torch.clamp(weight, -maxq, maxq).int() + maxq
else:
Expand All @@ -443,7 +492,8 @@ def transform_weight(self, weight, scale=None, zeros=None, bias=None):
np_storage_dtype = getattr(np, self.storage_dtype)

weight = general_compress(
weight.cpu().numpy(), source_bits=bit, storage_dtype=np_storage_dtype)
weight.cpu().numpy(), source_bits=bit, storage_dtype=np_storage_dtype
)

weight = torch.from_numpy(weight).cuda().contiguous()

Expand All @@ -469,20 +519,24 @@ def transform_input(self, input_tensor):
raise ValueError(
f"Input size {input_tensor.numel()} is larger than the workspace size {WORKSPACE_SIZE}, please increase the workspace size."
)
self.ladder_permutate_a._forward_from_prebuild_lib(input_tensor, self.workspace)
self.ladder_permutate_a._forward_from_prebuild_lib(
input_tensor, self.workspace
)
return self.workspace
return input_tensor

def forward(self, A, W, scale=None, zeros=None, bias=None, output=None) -> Any:
args = []
args.append(self.transform_input(A))
args.append(W)

if self.lut is not None:
args.append(self.lut)
args.append(W)

if output is None:
output = torch.empty(
A.shape[:-1] + (self.N,), dtype=self.torch_output_dtype, device=A.device)
A.shape[:-1] + (self.N,), dtype=self.torch_output_dtype, device=A.device
)
if scale is not None:
args.append(scale)
if zeros is not None:
Expand Down

0 comments on commit d1b9dbf

Please sign in to comment.