Skip to content

Commit

Permalink
[Dev][TL] Hardware Aware Tuning Examples with TL (#201)
Browse files Browse the repository at this point in the history
* Refactor tilelang dequantize module and add matmul_blocked_weight_only function

* remove un-implemented code.

* Implement BaseScheduler to wrap some related items.

* lint fix

* test skip

* Refactor tilelang dequantize module and add matmul_blocked_weight_only function

* test fix

* hardware tuning demo
  • Loading branch information
LeiWang1999 authored Sep 29, 2024
1 parent 69350cb commit 5af67f7
Show file tree
Hide file tree
Showing 9 changed files with 490 additions and 49 deletions.
2 changes: 1 addition & 1 deletion 3rdparty/tvm
Submodule tvm updated 1 files
+7 −3 python/tvm/tl/engine.py
2 changes: 1 addition & 1 deletion bitblas/base/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def _apply_schedule(f, c):
sch = None
return sch

with ThreadPoolExecutor(max_workers=4) as scheduler:
with ThreadPoolExecutor(max_workers=max_workers) as scheduler:
futures = {scheduler.submit(_apply_schedule, func, config) for config in configs}
for future in as_completed(futures, timeout=timeout):
_sched.append(future.result())
Expand Down
5 changes: 5 additions & 0 deletions bitblas/ops/base_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from dataclasses import dataclass, field
from tvm.tir.transform import Simplify
from abc import ABC, abstractmethod
from bitblas.base.arch import TileDevice


@dataclass
Expand All @@ -20,6 +21,10 @@ def Simplify(stmt: Union[PrimFunc, IRModule]):
else:
raise ValueError(f"Unsupported type: {type(stmt)}")

def get_hardware_aware_configs(self, arch: TileDevice = None):
raise NotImplementedError(
f"{self.__class__.__name__} does not support hardware-aware tuning for {arch}")

def activate_simplify(self):
self._enable_simplify = True
return self
Expand Down
8 changes: 6 additions & 2 deletions bitblas/ops/general_matmul/tilelang/dense/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

from .matmul import (
from .matmul_simt import (
MatmulFineGrainSIMTScheduler, # noqa: F401
)

from .matmul_tensorcore import (
matmul_blocked, # noqa: F401
matmul_macro_tensorcore, # noqa: F401
matmul_macro_tensorcore_weight_propagation_level_ldmatrix # noqa: F401
)

from .matmul import (
from .matmul_tensorcore import (
MatmulScheduler, # noqa: F401
MatmulFineGrainScheduler, # noqa: F401
MatmulWeightPropagationScheduler, # noqa: F401
Expand Down
62 changes: 62 additions & 0 deletions bitblas/ops/general_matmul/tilelang/dense/matmul_simt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from bitblas import tvm as tvm
from tvm import DataType
import tvm.tl.language as T
from typing import Optional
from bitblas.tl.utils import (
get_mma_micro_size,
make_swizzle_layout,
)

from bitblas.ops.base_scheduler import BaseScheduler

from dataclasses import dataclass


@dataclass
class MatmulFineGrainSIMTScheduler(BaseScheduler):
# Fine-grained matrix multiplication scheduler
# Allows for more detailed configuration.

# Operation Configuration
M: Optional[int] = None
N: Optional[int] = None
K: Optional[int] = None
in_dtype: str = "float16"
out_dtype: str = "float16"
trans_A: bool = False
trans_B: bool = True
accum_dtype: str = "float16"

# Tensor Core Warp Configuration
block_row_warps: int = 2
block_col_warps: int = 2
warp_row_tiles: int = 32
warp_col_tiles: int = 32
chunk: int = 32 # Usually determines the K-dimension split size

# Tiling and Other Optimization Parameters
num_stages: int = 2
enable_rasterization: bool = False

def with_default_config(self):
raise NotImplementedError

def apply_config(
self,
):

# M, N, K = self.M, self.N, self.K
# trans_A, trans_B = self.trans_A, self.trans_B
# in_dtype, out_dtype, accum_dtype = self.in_dtype, self.out_dtype, self.accum_dtype

raise NotImplementedError


def __post_init__(self):
# Validate the matrix transpose settings
assert self.trans_A is False, "Currently only support Matrix A not transposed"
assert self.trans_B is True, "Currently only support Matrix B transposed"

return
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import itertools
from bitblas import tvm as tvm
from tvm import DataType
import tvm.tl.language as T
Expand All @@ -15,7 +16,7 @@
)
from bitblas.ops.common import TransformKind
from bitblas.ops.base_scheduler import BaseScheduler

from bitblas.base.arch import CUDA
from dataclasses import dataclass


Expand All @@ -40,6 +41,22 @@ class MatmulScheduler(BaseScheduler):
threads: int = 128
enable_rasterization: bool = False # Enhance L2 Locality

def get_configs_sm80(self):
num_stages = 2
configs = [
{'block_M': 128, 'block_N': 256, 'block_K': 32, 'threads': 128},
{'block_M': 256, 'block_N': 128, 'block_K': 32, 'threads': 128},
{'block_M': 128, 'block_N': 128, 'block_K': 32, 'threads': 128},
]
configs = [{**c, 'num_stages': num_stages} for c in configs]
return configs

def get_hardware_aware_configs(self, arch: CUDA = None):
# TODO(lei): implement only for SM80 Currently
sm_version: int = int(arch.sm_partition)
assert sm_version is not None, "Please provide a valid CUDA Arch"
return self.get_configs_sm80()

def with_default_config(self):
block_M = getattr(self, "block_M", 64)
block_N = getattr(self, "block_N", 64)
Expand Down
121 changes: 77 additions & 44 deletions bitblas/ops/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@
from tvm.contrib.dlpack import to_pytorch_func
import bitblas
import ctypes
from typing import (List, Dict, Any, Optional, Tuple, Literal, Callable)
from typing import List, Dict, Any, Optional, Tuple, Literal, Callable
import numpy as np
from bitblas.base import fast_tune, fast_tune_with_dynamic_range
from bitblas.tl.tuner import apply_and_build as tl_apply_and_build
from copy import deepcopy
from bitblas.ops.base_scheduler import BaseScheduler
from bitblas.base.arch import get_arch, TileDevice
Expand All @@ -38,6 +39,7 @@
@dataclass(frozen=True)
class OperatorConfig:
"""Base class for operator configurations. Used for typing."""

pass


Expand All @@ -55,7 +57,7 @@ def is_valid_config(self, config: OperatorConfig):

@abstractmethod
def generate(self, hint: Hint = None) -> str:
'''Generate the kernel name based on the config and hint'''
"""Generate the kernel name based on the config and hint"""
pass


Expand All @@ -73,18 +75,20 @@ def generate(self, hint: Hint = None) -> str:
return self.DEFAULT_PREFIX

def is_valid_config(self, config: OperatorConfig) -> bool:
# hint is not used
# config is not used
assert config is not None
return True


class Operator(object):

def __init__(self,
name,
config: OperatorConfig,
target: Target = None,
backend: Literal["tir", "tl"] = "tir"):
def __init__(
self,
name,
config: OperatorConfig,
target: Target = None,
backend: Literal["tir", "tl"] = "tir",
):
if isinstance(target, str):
target = Target(target)
self.name = name
Expand Down Expand Up @@ -169,7 +173,7 @@ def tvm_callback_cuda_postproc(code, _):
config={
"tir.use_async_copy": True,
"tir.disable_cse_tir": True,
**(self.pass_context if self.pass_context else {})
**(self.pass_context if self.pass_context else {}),
}):
if self.is_tir_backend():
rt_mod = tvm.build(self.scheduled_ir_module, target=target)
Expand All @@ -183,9 +187,12 @@ def tvm_callback_cuda_postproc(code, _):
raise ValueError(f"Unsupported backend: {self.backend}")
except Exception: # noqa: F841
logger.debug(
BUILD_RUNTIME_LIBRARY_FAILED_MESSAGE.format(self.__class__.__name__, target,
"optimized",
"Failed to build optimized module"))
BUILD_RUNTIME_LIBRARY_FAILED_MESSAGE.format(
self.__class__.__name__,
target,
"optimized",
"Failed to build optimized module",
))
else:
# For non-CUDA platforms or when no optimized function is available, build with the primary function
rt_mod = tvm.build(self.prim_func, target=target, name=self.name)
Expand Down Expand Up @@ -248,10 +255,12 @@ def _build_default_module(self, target: Target):
scheduled_mod = self.apply_default_schedule(self.ir_module, target)
elif self.is_tilelang_backend():
scheduled_mod = self.scheduler_with_default(self.scheduler)
assert len(scheduled_mod.get_global_vars()) == 1, (
"The optimized module should only have one global variable for default schedule.")
assert "main" in scheduled_mod, (
"The optimized module should have a function named 'main' for default schedule.")
assert (
len(scheduled_mod.get_global_vars()) == 1
), "The optimized module should only have one global variable for default schedule."
assert (
"main" in scheduled_mod
), "The optimized module should have a function named 'main' for default schedule."
default_kernal_name = self.kernel_name_generator.generate()
func = scheduled_mod["main"].with_attr("global_symbol", default_kernal_name)
scheduled_ir_module = tvm.IRModule({default_kernal_name: func})
Expand All @@ -267,54 +276,77 @@ def _build_default_module(self, target: Target):
def post_process(self, code: str) -> str:
return code

def apply_fast_tuning(self,
func: PrimFunc,
target: Target,
topk: int = 20,
parallel_build=True) -> Tuple[IRModule, Hint]:
_, best = fast_tune(func, target, topk=topk, parallel_build=parallel_build)
# annotate the best pass context
# TODO(lei): actually we should remove this by enable pass through
# annotation in the func's attribute.
self.pass_context = best.config.pass_context
return ((best.sch.mod, best.config) if best is not None else (None, None))
def get_tl_tuning_config(self):
assert self.is_tilelang_backend(), "Only support tilelang backend"
return self.scheduler.get_hardware_aware_configs(self.arch)

def apply_fast_tuning(
self,
func_or_scheduler: PrimFunc,
target: Target,
topk: int = 20,
parallel_build=True,
) -> Tuple[IRModule, Hint]:
if self.is_tir_backend():
_, best = fast_tune(func_or_scheduler, target, topk=topk, parallel_build=parallel_build)
# annotate the best pass context
# TODO(lei): actually we should remove this by enable pass through
# annotation in the func's attribute.
self.pass_context = best.config.pass_context
return (best.sch.mod, best.config) if best is not None else (None, None)
elif self.is_tilelang_backend():
# Finetune the schedule
tuning_configs = self.get_tl_tuning_config()
_, best = tl_apply_and_build(
func_or_scheduler, tuning_configs, arch=self.arch, parallel_build=False)
# Return the best Config as Hint
return (best.sch.mod, best.config) if best is not None else (None, None)

def apply_fast_tuning_with_dynamic_range(
self,
func: PrimFunc,
func_or_scheduler: PrimFunc,
target: Target,
topk: int = 20,
dynamic_range: Dict[str, List[int]] = None,
):
scheduled_ir_module = fast_tune_with_dynamic_range(
func,
func_or_scheduler,
target,
topk=topk,
parallel_build=True,
dynamic_range=dynamic_range,
kernel_name_generator=self.kernel_name_generator)
kernel_name_generator=self.kernel_name_generator,
)
if scheduled_ir_module is not None:
return scheduled_ir_module
return None

def hardware_aware_finetune(self,
topk: int = 20,
target: Optional[tvm.target.Target] = None,
parallel_build=True):
def hardware_aware_finetune(
self,
topk: int = 20,
target: Optional[tvm.target.Target] = None,
parallel_build=True,
):
if target is None:
target = self.target
dynamic_range = self.dynamic_range
func = self.prim_func
if dynamic_range is not None:
self.scheduled_ir_module = self.apply_fast_tuning_with_dynamic_range(
func, target, topk, dynamic_range)
if self.is_tir_backend():
func = self.prim_func
self.scheduled_ir_module = self.apply_fast_tuning_with_dynamic_range(
func, target, topk, dynamic_range)
elif self.is_tilelang_backend():
raise NotImplementedError("Not support dynamic range for tilelang backend")
else:
func_or_scheduler = (self.prim_func if self.is_tir_backend() else self.scheduler)
scheduled_mod, best_hint = self.apply_fast_tuning(
func, target, topk, parallel_build=parallel_build)
assert len(scheduled_mod.get_global_vars()) == 1, (
"The optimized module should only have one global variable for default schedule.")
assert "main" in scheduled_mod, (
"The optimized module should have a function named 'main' for default schedule.")
func_or_scheduler, target, topk, parallel_build=parallel_build)
assert (
len(scheduled_mod.get_global_vars()) == 1
), "The optimized module should only have one global variable for default schedule."
assert (
"main" in scheduled_mod
), "The optimized module should have a function named 'main' for default schedule."
default_kernal_name = self.kernel_name_generator.generate(best_hint)
func = scheduled_mod["main"].with_attr("global_symbol", default_kernal_name)
scheduled_ir_module = tvm.IRModule({default_kernal_name: func})
Expand All @@ -341,8 +373,9 @@ def var_warpper(v):
for i in func.attrs["opt_shapes"][v.name]:
avg_shape += i.value
avg_shape = avg_shape // len(func.attrs["opt_shapes"][v.name])
_info_message = f"Doesn't provide dynamic symbolic constrains for {v.name} when do benchmarking, "\
f"use average shape {avg_shape}"
_info_message = (
f"Doesn't provide dynamic symbolic constrains for {v.name} when do benchmarking, "
f"use average shape {avg_shape}")
logger.info(_info_message)
return avg_shape
else:
Expand Down
Loading

0 comments on commit 5af67f7

Please sign in to comment.