Skip to content

Commit

Permalink
[Dev][TL] Add TL BaseScheduler and Library Generator (#200)
Browse files Browse the repository at this point in the history
* Refactor tilelang dequantize module and add matmul_blocked_weight_only function

* remove un-implemented code.

* Implement BaseScheduler to wrap some related items.

* lint fix

* test skip

* Refactor tilelang dequantize module and add matmul_blocked_weight_only function

* test fix
  • Loading branch information
LeiWang1999 authored Sep 29, 2024
1 parent cd41b4e commit 69350cb
Show file tree
Hide file tree
Showing 41 changed files with 954 additions and 475 deletions.
20 changes: 17 additions & 3 deletions bitblas/builder/lib_generator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from bitblas.base.arch import TileDevice
import ctypes
import os
import os.path as osp
import tempfile
import subprocess
import logging
Expand All @@ -26,7 +27,7 @@ def update_lib_code(self, lib_code: str):
def load_lib(self):
return ctypes.CDLL(self.libpath)

def compile_lib(self, timeout: float = None):
def compile_lib(self, timeout: float = None, with_tl: bool = False):
arch = self.arch
src = tempfile.NamedTemporaryFile(mode="w", suffix=".cu", delete=False)
compute_version = arch.compute_capability
Expand All @@ -45,9 +46,22 @@ def compile_lib(self, timeout: float = None):
"-lcuda",
"-gencode",
f"arch=compute_{compute_version},code=sm_{compute_version}",
"-o",
libpath,
]
if with_tl:
tvm_root = osp.join(osp.dirname(__file__), "../../../3rdparty/tvm")
tl_template_path = osp.abspath(osp.join(tvm_root, "src/tl"))
if "TL_CUTLASS_PATH" in os.environ:
cutlass_path = os.environ["TL_CUTLASS_PATH"]
else:
cutlass_path = osp.abspath(osp.join(tvm_root, "3rdparty/cutlass/include"))

command += [
"-I" + tl_template_path,
"-I" + cutlass_path,
]
command += ["-diag-suppress=20013"]
command += ["-o", libpath]

src.write(self.lib_code)
src.flush()
try:
Expand Down
1 change: 1 addition & 0 deletions bitblas/builder/wrapper/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from .tir import TIRWrapper # noqa: F401
from .tl import TLWrapper # noqa: F401
16 changes: 16 additions & 0 deletions bitblas/builder/wrapper/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,22 @@
# Licensed under the MIT License.
from abc import ABC, abstractmethod

PREDEF_ARRTIBUTE_SET_DYNAMIC_MEMORY = """
cudaFuncSetAttribute({}, cudaFuncAttributeMaxDynamicSharedMemorySize, {});
"""

PREDEF_INIT_FUNC = """
extern "C" void init() {{
{}
}}
"""

PREDEF_HOST_FUNC = """
extern "C" void call({}) {{
{}
}}
"""


class BaseWrapper(ABC):

Expand Down
37 changes: 11 additions & 26 deletions bitblas/builder/wrapper/tir.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,26 +9,11 @@
import re
import logging

from .base import BaseWrapper
from .base import (BaseWrapper, PREDEF_ARRTIBUTE_SET_DYNAMIC_MEMORY, PREDEF_INIT_FUNC,
PREDEF_HOST_FUNC)

logger = logging.getLogger(__name__)

PREDEF_ARRTIBUTE_SET_DYNAMIC_MEMORY = """
cudaFuncSetAttribute({}, cudaFuncAttributeMaxDynamicSharedMemorySize, {});
"""

PREDEF_INIT_FUNC = """
extern "C" void init() {{
{}
}}
"""

PREDEF_HOST_FUNC = """
extern "C" void call({}) {{
{}
}}
"""


class TIRCUDASourceWrapper(object):
_TYPE_MAP = {
Expand All @@ -48,8 +33,8 @@ class TIRCUDASourceWrapper(object):
"uchar": "uint8_t",
}

def __init__(self, optimized_mod: IRModule, source: str, arch: TileDevice):
self.mod = optimized_mod
def __init__(self, scheduled_ir_module: IRModule, source: str, arch: TileDevice):
self.mod = scheduled_ir_module
self.arch = arch
self.source = source
self.function_name: Optional[str] = None
Expand Down Expand Up @@ -190,8 +175,8 @@ def prim_func(self):

class TIRCUDASourceWrapperWithDynamic(TIRCUDASourceWrapper):

def __init__(self, optimized_mod: IRModule, source: str, arch: TileDevice):
super().__init__(optimized_mod, source, arch)
def __init__(self, scheduled_ir_module: IRModule, source: str, arch: TileDevice):
super().__init__(scheduled_ir_module, source, arch)

def get_cuda_init_func(self):
# Initialize an empty string to accumulate CUDA function calls for setting dynamic shared memory
Expand Down Expand Up @@ -387,16 +372,16 @@ class TIRWrapper(BaseWrapper):

def __init__(self, arch: TileDevice):
super().__init__()
self.optimized_mod = None
self.scheduled_ir_module = None
self.arch = arch
self.lib = None

def assign_optimized_module(self, optimized_mod: IRModule):
self.optimized_mod = optimized_mod
def assign_optimized_module(self, scheduled_ir_module: IRModule):
self.scheduled_ir_module = scheduled_ir_module

# Get Scheduled Rt Module and return source to be compiled
def wrap(self, c_source: str, is_dynamic: bool = False):
assert self.optimized_mod is not None, "Please assign optimized module first."
assert self.scheduled_ir_module is not None, "Please assign optimized module first."
wrapper_class = TIRCUDASourceWrapper if not is_dynamic else TIRCUDASourceWrapperWithDynamic
wrapper = wrapper_class(self.optimized_mod, c_source, self.arch)
wrapper = wrapper_class(self.scheduled_ir_module, c_source, self.arch)
return wrapper.lib_code
193 changes: 193 additions & 0 deletions bitblas/builder/wrapper/tl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from bitblas import tvm
from typing import Optional, List, Dict, Union
from tvm import IRModule
from bitblas.base.arch import TileDevice
from bitblas.utils import match_global_kernel
from bitblas.utils.rtmod_analysis import get_annotated_device_mod
import re
import logging

from .base import (BaseWrapper, PREDEF_ARRTIBUTE_SET_DYNAMIC_MEMORY, PREDEF_INIT_FUNC,
PREDEF_HOST_FUNC)

logger = logging.getLogger(__name__)


class TLCUDASourceWrapper(object):
_TYPE_MAP = {
"float32": "float",
"float16": "half_t",
"bfloat16": "__nv_bfloat16",
"e4m3_float8": "__nv_fp8_e4m3",
"e5m2_float8": "__nv_fp8_e5m2",
"float64": "double",
"int64": "int64_t",
"int32": "int",
"uint32": "unsigned int",
"bool": "int8_t",
"int8": "int8_t",
"uint8": "uint8_t",
"int16": "int16_t",
"uchar": "uint8_t",
}

def __init__(self, scheduled_ir_module: IRModule, source: str, arch: TileDevice):
self.mod = scheduled_ir_module
self.arch = arch
self.source = source
self.function_name: Optional[str] = None
self.dynamic_smem_buf: Optional[int] = None
self.block_info: Union[List[int], Dict] = [1, 1, 1]
self.grid_info: Union[List[int], Dict] = [1, 1, 1]
self.parse_source_information()
self.srcpath: Optional[str] = None
self.libpath: Optional[str] = None
self.lib_code: Optional[str] = self.update_lib_code(source)

def parse_source_information(self):
device_mod = get_annotated_device_mod(self.mod, self.arch.target, backend="tl")
assert (len(device_mod.functions) == 1
), "Only support one function in the module for static shape kernel."
for g_var, func in device_mod.functions.items():
self.function_name = g_var.name_hint
attrs = func.attrs
if "dyn_shared_memory_buf" in attrs:
self.dynamic_smem_buf = int(attrs["dyn_shared_memory_buf"])
if "thread_extent" in attrs:
thread_extent = attrs["thread_extent"]
for tag, extent in thread_extent.items():
if "threadIdx" in tag:
self.block_info["xyz".index(tag[-1])] = extent
elif "blockIdx" in tag:
self.grid_info["xyz".index(tag[-1])] = extent

def get_dynamic_symbolic_set(self, prim_func):
# Determine the set of dynamic symbols used in the function
dynamic_symbolic_set = set()
for param in prim_func.params:
buffer = prim_func.buffer_map[param]
for dim in buffer.shape:
if isinstance(dim, tvm.tir.Var):
dynamic_symbolic_set.add(dim.name)
return dynamic_symbolic_set

def get_cuda_init_func(self):
# Initialize an empty string for the CUDA function call
call_str = """"""
# If dynamic shared memory buffer is specified, prepare the cudaFuncSetAttribute call
if self.dynamic_smem_buf is not None:
call_str = (
PREDEF_ARRTIBUTE_SET_DYNAMIC_MEMORY.format(self.function_name,
self.dynamic_smem_buf))
# Format the initialization function using the call_str
init_funcs = PREDEF_INIT_FUNC.format(call_str)
return init_funcs

def update_lib_code(self, code: str):
# Update the library code with the given code string
self.lib_code = code
# Find the index of the global kernel function in the code
index = match_global_kernel(code)
# Extract the declaration of the function starting from the found index
declaration = code[index:].split(";")[0]

function_name = self.function_name
# Get the CUDA initialization function
init_func = self.get_cuda_init_func()

# Locate the opening brace of the function to insert arguments
index = code.index("{", index)
function_args = []
# Populate the function arguments from the primary function's parameters and buffers
for param in self.prim_func.params:
buffer = self.prim_func.buffer_map[param]
function_args.append({
"name": buffer.name,
"type": self._TYPE_MAP[buffer.dtype] + "* __restrict__",
})

dynamic_symbolic_set = self.get_dynamic_symbolic_set(self.prim_func)
# Add dynamic symbolic parameters as integers to the function arguments
for dyn_sym in dynamic_symbolic_set:
function_args.append({"name": dyn_sym, "type": "int"})

function_args.append({"name": "stream=cudaStreamDefault", "type": "cudaStream_t"},)
# Format the function arguments for declaration
def_args = ", ".join([f"{arg['type']} {arg['name']}" for arg in function_args])

def func_call_args(s, function_args):
# Extract the function call arguments matching the function definition
pattern = r"[,\s]*(?:\w+\s*\*+\s*__restrict__\s+)?(\w+)"
matches = re.findall(pattern, s)
call_args = []
for match in matches:
for arg in function_args:
if arg["name"] == match:
call_args.append(match)
return call_args

call_args = ", ".join(func_call_args(declaration, function_args))
block_info, grid_info = self.block_info, self.grid_info

def legalize_c(p):
# Convert TIR expressions to legal C expressions
# Directly convert to string since the special case handling
# does not alter the string representation for `tvm.tir.Var` and `IntImm`.
# Replace Python's floor division operator with C's division operator
if isinstance(p, tvm.tir.IntImm):
p = int(p)
return str(p).replace("//", "/")

# Prepare the block and grid dimensions for the CUDA kernel launch
block_str = "dim3({}, {}, {})".format(
legalize_c(block_info[0]),
legalize_c(block_info[1]),
legalize_c(block_info[2]),
)
grid_str = "dim3({}, {}, {})".format(
legalize_c(grid_info[0]), legalize_c(grid_info[1]), legalize_c(grid_info[2]))
# Determine the shared memory size, defaulting to 0 if not specified
smem_str = 0 if self.dynamic_smem_buf is None else self.dynamic_smem_buf
# Format the CUDA kernel launch string
if len(dynamic_symbolic_set) != 0:
call_str = "if ({} == 0) return; \n\t\t".format(list(dynamic_symbolic_set)[0])
else:
call_str = ""
call_str += "{}<<<{}, {}, {}, stream>>>({});".format(function_name, grid_str, block_str,
smem_str, call_args)
# Create the host function wrapper for the CUDA kernel
host_func = PREDEF_HOST_FUNC.format(def_args, call_str)
# Combine the source, initialization function, and host function to form the complete library code
lib_code = self.source + init_func + host_func
return lib_code

@property
def prim_func(self):
if len(self.mod.get_global_vars()) == 1:
return self.mod[self.mod.get_global_vars()[0]]
elif "main" in self.mod:
return self.mod["main"]
else:
raise ValueError("Unable to determine primary function.")


class TLWrapper(BaseWrapper):

def __init__(self, arch: TileDevice):
super().__init__()
self.scheduled_ir_module = None
self.arch = arch
self.lib = None

def assign_optimized_module(self, scheduled_ir_module: IRModule):
self.scheduled_ir_module = scheduled_ir_module

# Get Scheduled Rt Module and return source to be compiled
def wrap(self, c_source: str, is_dynamic: bool = False):
assert is_dynamic is False, "Dynamic kernel is not supported in TLWrapper."
assert self.scheduled_ir_module is not None, "Please assign optimized module first."
wrapper_class = TLCUDASourceWrapper
wrapper = wrapper_class(self.scheduled_ir_module, c_source, self.arch)
return wrapper.lib_code
4 changes: 2 additions & 2 deletions bitblas/cache/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,8 @@ def _save_operator_config_and_artifact(self, config, op_inst, config_path):
# For writing optimized.py file
optimized_file_path = os.path.join(config_path, "optimized.py")
with open(optimized_file_path, "w") as optimized_file:
if op_inst.optimized_mod is not None:
optimized_file.write(op_inst.optimized_mod.script(show_meta=False))
if op_inst.scheduled_ir_module is not None:
optimized_file.write(op_inst.scheduled_ir_module.script(show_meta=False))
if op_inst.libpath is not None:
# copy lib name to the same directory as the artifact
srcpath = op_inst.srcpath
Expand Down
2 changes: 1 addition & 1 deletion bitblas/gpu/matmul_mma.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from tvm import tir, DataType
from tvm.target import Target

from ..ops.operator import TransformKind
from ..ops.common import TransformKind
from ..base.roller import Hint
from ..base.roller.rasterization import NoRasterization
from ..base import analysis
Expand Down
2 changes: 1 addition & 1 deletion bitblas/gpu/matmul_mma_dequantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from tvm import tir, DataType
from tvm.target import Target

from ..ops.operator import TransformKind
from ..ops.common import TransformKind
from ..base.roller.hint import Hint, IntrinInfo
from ..base.roller.rasterization import NoRasterization
from ..base import analysis
Expand Down
Loading

0 comments on commit 69350cb

Please sign in to comment.