Skip to content

Commit

Permalink
[Bug] Improve the Default Config Value and fix a Bug for TensorCore C…
Browse files Browse the repository at this point in the history
…onfig with Small shapes (#32)

* update bitblas

* Merge branch 'main' of https://github.com/microsoft/BitBLAS into main
  • Loading branch information
LeiWang1999 authored May 2, 2024
1 parent 2a26bcd commit a157dc4
Show file tree
Hide file tree
Showing 6 changed files with 70 additions and 49 deletions.
2 changes: 1 addition & 1 deletion 3rdparty/tvm
Submodule tvm updated 1 files
+2 −0 src/target/tag.cc
3 changes: 1 addition & 2 deletions python/bitblas/base/roller/policy/tensorcore.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,8 +258,7 @@ def _assign_block_size(self, node: PrimFuncNode, td: TileDict, block_size: int):
if tile[ax_m] < wmma_tile[ax_m] or tile[ax_n] < wmma_tile[ax_n]:
# allow pad, otherwise, we can not get a valid tile shape
return None
if np.prod(space) % warps != 0:
return None

factors = factorize(np.prod(space) // warps)

def _score(node, thread): # small is better
Expand Down
4 changes: 4 additions & 0 deletions python/bitblas/base/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,10 @@ def fast_tune(
policy = TensorCorePolicy(func=specilized_func, arch=arch, tags=tags)

configs = policy.emit_config(topk)

if len(configs) == 0:
raise ValueError("No valid config generated")

cpresults, best = apply_and_build(
func,
configs,
Expand Down
1 change: 1 addition & 0 deletions python/bitblas/cache/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ def _instantiate_and_add_operator(self, mapping, config, rt_mod, src_name, lib_n
def load_global_ops_cache(database_path=BITBLAS_DATABASE_PATH, target=None):
if target is None:
target = bitblas.auto_detect_nvidia_target()
logger.info(f"Loading operators from database {database_path} for target {target}")
global_operator_cache.load_from_database(database_path, target)
return global_operator_cache

Expand Down
105 changes: 61 additions & 44 deletions python/bitblas/ops/general_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,44 +94,24 @@ class MatmulConfig:
storage_dtype: str = "int8"

# weight transform related flags
fast_decoding: bool = True # enable fast decoding by default
propagate_a: TransformKind = TransformKind.NonTransform
propagate_b: TransformKind = TransformKind.NonTransform

def __post_init__(self):
# set M to default dynamic range if it is None
if self.M is None:
object.__setattr__(self, "M", [1, 16, 32, 64, 128, 256, 512, 1024])
if self.N is None:
raise ValueError("N should be specified currently.")
if self.K is None:
raise ValueError("K should be specified currently.")

# set M to tuple if it is list
# otherwise, M is not hashable
object.__setattr__(self, "M", tuple(self.M) if isinstance(self.M, list) else self.M)
if isinstance(self.propagate_a, bool):
object.__setattr__(
self,
"propagate_a",
(TransformKind.IntraWarpTransform
if self.propagate_a else TransformKind.NonTransform),
)
elif isinstance(self.propagate_a, int):
object.__setattr__(self, "propagate_a", TransformKind(self.propagate_a))

if isinstance(self.propagate_b, bool):
object.__setattr__(
self,
"propagate_b",
(TransformKind.IntraWarpTransform
if self.propagate_b else TransformKind.NonTransform),
)
elif isinstance(self.propagate_b, int):
object.__setattr__(self, "propagate_b", TransformKind(self.propagate_b))

# This is hack to legalize propagate_a and b
# TODO(lei): should be removed in the future when tc+br template is ready.
fast_decoding: Optional[bool] = None # enable fast decoding by default, if not specified, it is enabled by a rule.
propagate_a: Optional[TransformKind] = None # propagate_a is a flag to control the ladder permutation.
propagate_b: Optional[TransformKind] = None # propagate_b is a flag to control the ladder permutation


def __legalize_dynamic_symbolic(self, M):
return tuple(self.M) if isinstance(self.M, list) else self.M

def __legalize_propagate(self, propagate):
if isinstance(propagate, bool):
return (TransformKind.IntraWarpTransform
if propagate else TransformKind.NonTransform)
elif isinstance(propagate, int):
return TransformKind(propagate)

return propagate

def __initialize_propagate(self, propagate_a: Optional[TransformKind], propagate_b: Optional[TransformKind]):
MICRO_KERNEL_SIZE = 16
if isinstance(
self.M,
Expand All @@ -148,13 +128,54 @@ def __post_init__(self):
else:
object.__setattr__(self, "propagate_b", TransformKind.IntraWarpTransform)

if self.zeros_mode is None:
# set a and b value if is not None
if propagate_a is not None:
object.__setattr__(self, "propagate_a", propagate_a)
if propagate_b is not None:
object.__setattr__(self, "propagate_b", propagate_b)

# TODO(lei): This is a limitation arose by pytorch and llvm
# Should be removed in the future.
if self.A_dtype in ["e4m3_float8", "e5m2_float8"]:
object.__setattr__(self, "propagate_a", TransformKind.NonTransform)
object.__setattr__(self, "propagate_b", TransformKind.NonTransform)

def __initialize_zeros_mode(self, zeros_mode: Optional[str]):
if zeros_mode is None:
object.__setattr__(self, "zeros_mode", "original")

def __initialize_fast_decoding(self, fast_decoding: Optional[bool]):
if "int" not in self.W_dtype or self.W_dtype == self.A_dtype:
object.__setattr__(self, "fast_decoding", False)
else:
object.__setattr__(self, "fast_decoding", self.fast_decoding)
if fast_decoding is not None:
object.__setattr__(self, "fast_decoding", fast_decoding)

def __post_init__(self):
# set M to default dynamic range if it is None
if self.M is None:
object.__setattr__(self, "M", [1, 16, 32, 64, 128, 256, 512, 1024])
if self.N is None:
raise ValueError("N should be specified currently.")
if self.K is None:
raise ValueError("K should be specified currently.")

# set M to tuple if it is list
# otherwise, M is not hashable
object.__setattr__(self, "M", self.__legalize_dynamic_symbolic(self.M))

# set propagate_a and propagate_b to default value if it is None
object.__setattr__(self, "propagate_a", self.__legalize_propagate(self.propagate_a))
object.__setattr__(self, "propagate_b", self.__legalize_propagate(self.propagate_b))

# This is hack to legalize propagate_a and b
# TODO(lei): should be removed in the future when tc+br template is ready.
self.__initialize_propagate(self.propagate_a, self.propagate_b)

self.__initialize_zeros_mode(self.zeros_mode)

self.__initialize_fast_decoding(self.fast_decoding)

if self.with_bias is None:
object.__setattr__(self, "with_bias", False)
Expand All @@ -172,11 +193,6 @@ def __post_init__(self):
"float16", "int8", "e4m3_float8", "e5m2_float8"
]:
object.__setattr__(self, "storage_dtype", self.W_dtype)
# TODO(lei): This is a limitation arose by pytorch and llvm
# Should be removed in the future.
if self.A_dtype in ["e4m3_float8", "e5m2_float8"]:
object.__setattr__(self, "propagate_a", TransformKind.NonTransform)
object.__setattr__(self, "propagate_b", TransformKind.NonTransform)


class Matmul(Operator):
Expand Down Expand Up @@ -217,6 +233,7 @@ def __init__(
# to save compilation time
if target is None:
target = auto_detect_nvidia_target()
logger.info(f"Auto detected target: {target}")
assert (config.A_dtype
in self.BITBLAS_TRICK_DTYPE_MAP), f"Unsupported input dtype {config.A_dtype}"
source_format, bit = self.BITBLAS_TRICK_DTYPE_MAP[config.W_dtype]
Expand Down
4 changes: 2 additions & 2 deletions python/bitblas/utils/target_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
# Licensed under the MIT License.

import subprocess
import logging
from thefuzz import process
from tvm.target import Target
from tvm.target.tag import list_tags

import logging
logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -44,6 +44,7 @@ def check_target(best, default):
if check_target(best_match, "cuda"):
return best_match if score >= MATCH_THRESHOLD else "cuda"
else:
logger.info(f"Best match '{best_match}' is not a valid CUDA target, falling back to 'cuda'")
return "cuda"


Expand All @@ -65,5 +66,4 @@ def auto_detect_nvidia_target() -> str:
# Get the current GPU model and find the best matching target
gpu_model = get_gpu_model_from_nvidia_smi()
target = find_best_match(nvidia_tags, gpu_model) if gpu_model else "cuda"

return target

0 comments on commit a157dc4

Please sign in to comment.