diff --git a/python/bitblas/ops/general_matmul.py b/python/bitblas/ops/general_matmul.py index 4d3a4c51b..cc1d9f075 100644 --- a/python/bitblas/ops/general_matmul.py +++ b/python/bitblas/ops/general_matmul.py @@ -8,7 +8,8 @@ from typing import Any, List, Literal, Optional, Tuple, Union from .operator import Operator, TransformKind from .impl.matmul_dequantize_impl import ( - select_implementation as weight_dequantize_implementation,) + select_implementation as weight_dequantize_implementation, +) from .impl.matmul_impl import select_implementation as consistent_implementation from ..base.utils import tensor_replace_dp4a, tensor_remove_make_int4 from bitblas.utils.target_detector import auto_detect_nvidia_target @@ -94,35 +95,51 @@ class MatmulConfig: storage_dtype: str = "int8" # weight transform related flags - fast_decoding: Optional[bool] = None # enable fast decoding by default, if not specified, it is enabled by a rule. - propagate_a: Optional[TransformKind] = None # propagate_a is a flag to control the ladder permutation. - propagate_b: Optional[TransformKind] = None # propagate_b is a flag to control the ladder permutation + fast_decoding: Optional[bool] = ( + None # enable fast decoding by default, if not specified, it is enabled by a rule. + ) + propagate_a: Optional[TransformKind] = ( + None # propagate_a is a flag to control the ladder permutation. + ) + propagate_b: Optional[TransformKind] = ( + None # propagate_b is a flag to control the ladder permutation + ) - def __legalize_dynamic_symbolic(self, M): return tuple(self.M) if isinstance(self.M, list) else self.M - + def __legalize_propagate(self, propagate): if isinstance(propagate, bool): - return (TransformKind.IntraWarpTransform - if propagate else TransformKind.NonTransform) + return ( + TransformKind.IntraWarpTransform + if propagate + else TransformKind.NonTransform + ) elif isinstance(propagate, int): return TransformKind(propagate) return propagate - def __initialize_propagate(self, propagate_a: Optional[TransformKind], propagate_b: Optional[TransformKind]): + def __initialize_propagate( + self, propagate_a: Optional[TransformKind], propagate_b: Optional[TransformKind] + ): MICRO_KERNEL_SIZE = 16 - if isinstance( - self.M, - int) and (self.M % MICRO_KERNEL_SIZE) == 0 and (self.K % MICRO_KERNEL_SIZE) == 0: + if ( + isinstance(self.M, int) + and (self.M % MICRO_KERNEL_SIZE) == 0 + and (self.K % MICRO_KERNEL_SIZE) == 0 + ): object.__setattr__(self, "propagate_a", TransformKind.IntraWarpTransform) else: object.__setattr__(self, "propagate_a", TransformKind.NonTransform) - if self.M == 1 or ( - self.N % MICRO_KERNEL_SIZE) != 0 or (self.K % MICRO_KERNEL_SIZE) != 0 or isinstance( - self.M, Tuple) or (self.with_zeros and self.zeros_mode == "quantized"): + if ( + self.M == 1 + or (self.N % MICRO_KERNEL_SIZE) != 0 + or (self.K % MICRO_KERNEL_SIZE) != 0 + or isinstance(self.M, Tuple) + or (self.with_zeros and self.zeros_mode == "quantized") + ): object.__setattr__(self, "propagate_a", TransformKind.NonTransform) object.__setattr__(self, "propagate_b", TransformKind.NonTransform) else: @@ -133,7 +150,7 @@ def __initialize_propagate(self, propagate_a: Optional[TransformKind], propagate object.__setattr__(self, "propagate_a", propagate_a) if propagate_b is not None: object.__setattr__(self, "propagate_b", propagate_b) - + # TODO(lei): This is a limitation arose by pytorch and llvm # Should be removed in the future. if self.A_dtype in ["e4m3_float8", "e5m2_float8"]: @@ -145,7 +162,11 @@ def __initialize_zeros_mode(self, zeros_mode: Optional[str]): object.__setattr__(self, "zeros_mode", "original") def __initialize_fast_decoding(self, fast_decoding: Optional[bool]): - if "int" not in self.W_dtype or self.W_dtype == self.A_dtype: + if ( + "int" not in self.W_dtype + or "nf" not in self.W_dtype + or self.W_dtype == self.A_dtype + ): object.__setattr__(self, "fast_decoding", False) else: object.__setattr__(self, "fast_decoding", True) @@ -164,10 +185,14 @@ def __post_init__(self): # set M to tuple if it is list # otherwise, M is not hashable object.__setattr__(self, "M", self.__legalize_dynamic_symbolic(self.M)) - + # set propagate_a and propagate_b to default value if it is None - object.__setattr__(self, "propagate_a", self.__legalize_propagate(self.propagate_a)) - object.__setattr__(self, "propagate_b", self.__legalize_propagate(self.propagate_b)) + object.__setattr__( + self, "propagate_a", self.__legalize_propagate(self.propagate_a) + ) + object.__setattr__( + self, "propagate_b", self.__legalize_propagate(self.propagate_b) + ) # This is hack to legalize propagate_a and b # TODO(lei): should be removed in the future when tc+br template is ready. @@ -190,7 +215,10 @@ def __post_init__(self): object.__setattr__(self, "with_zeros", False) if self.A_dtype == self.W_dtype and self.W_dtype in [ - "float16", "int8", "e4m3_float8", "e5m2_float8" + "float16", + "int8", + "e4m3_float8", + "e5m2_float8", ]: object.__setattr__(self, "storage_dtype", self.W_dtype) @@ -234,8 +262,9 @@ def __init__( if target is None: target = auto_detect_nvidia_target() logger.info(f"Auto detected target: {target}") - assert (config.A_dtype - in self.BITBLAS_TRICK_DTYPE_MAP), f"Unsupported input dtype {config.A_dtype}" + assert ( + config.A_dtype in self.BITBLAS_TRICK_DTYPE_MAP + ), f"Unsupported input dtype {config.A_dtype}" source_format, bit = self.BITBLAS_TRICK_DTYPE_MAP[config.W_dtype] self.source_format = source_format @@ -256,7 +285,8 @@ def __init__( if isinstance(self.M, Tuple): self.dynamic_range = {"m": self.M} self.prim_func_mod["main"] = self.prim_func_mod["main"].with_attrs( - {"opt_shapes": self.dynamic_range}) + {"opt_shapes": self.dynamic_range} + ) else: self.dynamic_range = None @@ -303,6 +333,7 @@ def __init__( self.ladder_permutate_b = None if self.fast_decoding: + assert self.source_format in ["int", "uint"] lop3_permutate_config = LOP3PermutateConfig( M=self.N, N=self.K, @@ -335,13 +366,27 @@ def __init__( self.hardware_aware_finetune() if source_format == "nf": - self.lut = torch.Tensor(([ - -1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453, - -0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0.0, - 0.07958029955625534, 0.16093020141124725, 0.24611230194568634, 0.33791524171829224, - 0.44070982933044434, 0.5626170039176941, 0.7229568362236023, 1.0 - ]), - dtype=getattr(torch, self.A_dtype)).cuda() + self.lut = torch.tensor( + [ + -1.0, + -0.6961928009986877, + -0.5250730514526367, + -0.39491748809814453, + -0.28444138169288635, + -0.18477343022823334, + -0.09105003625154495, + 0.0, + 0.07958029955625534, + 0.16093020141124725, + 0.24611230194568634, + 0.33791524171829224, + 0.44070982933044434, + 0.5626170039176941, + 0.7229568362236023, + 1.0, + ], + dtype=getattr(torch, self.A_dtype), + ).cuda() else: self.lut = None @@ -350,7 +395,9 @@ def __init__( def _build_default_module(self, target: Target): try: - self.optimized_func = self.apply_default_schedule(self.prim_func_mod, target) + self.optimized_func = self.apply_default_schedule( + self.prim_func_mod, target + ) except Exception: self.optimized_func = None logger.warning( @@ -401,7 +448,9 @@ def post_process(self, code: str) -> str: return code def retrieve_weight_shape(self): - return [int(i) for i in self.prim_func.buffer_map[self.prim_func.params[1]].shape] + return [ + int(i) for i in self.prim_func.buffer_map[self.prim_func.params[1]].shape + ] def transform_weight(self, weight, scale=None, zeros=None, bias=None): """ @@ -433,7 +482,7 @@ def transform_weight(self, weight, scale=None, zeros=None, bias=None): if source_format == "int": assert not self.with_scaling, "scale should be False for int source format" assert not self.with_zeros, "zeros should be False for int source format" - maxq = 2**(bit - 1) + maxq = 2 ** (bit - 1) # Clamp weight values to be within the quantizable range and adjust weight = torch.clamp(weight, -maxq, maxq).int() + maxq else: @@ -443,7 +492,8 @@ def transform_weight(self, weight, scale=None, zeros=None, bias=None): np_storage_dtype = getattr(np, self.storage_dtype) weight = general_compress( - weight.cpu().numpy(), source_bits=bit, storage_dtype=np_storage_dtype) + weight.cpu().numpy(), source_bits=bit, storage_dtype=np_storage_dtype + ) weight = torch.from_numpy(weight).cuda().contiguous() @@ -469,20 +519,24 @@ def transform_input(self, input_tensor): raise ValueError( f"Input size {input_tensor.numel()} is larger than the workspace size {WORKSPACE_SIZE}, please increase the workspace size." ) - self.ladder_permutate_a._forward_from_prebuild_lib(input_tensor, self.workspace) + self.ladder_permutate_a._forward_from_prebuild_lib( + input_tensor, self.workspace + ) return self.workspace return input_tensor def forward(self, A, W, scale=None, zeros=None, bias=None, output=None) -> Any: args = [] args.append(self.transform_input(A)) + args.append(W) + if self.lut is not None: args.append(self.lut) - args.append(W) if output is None: output = torch.empty( - A.shape[:-1] + (self.N,), dtype=self.torch_output_dtype, device=A.device) + A.shape[:-1] + (self.N,), dtype=self.torch_output_dtype, device=A.device + ) if scale is not None: args.append(scale) if zeros is not None: