Skip to content

Commit

Permalink
Refactor matmul implementation for splitk layout
Browse files Browse the repository at this point in the history
  • Loading branch information
LeiWang199 committed Jun 5, 2024
1 parent e06ce10 commit d67cc6d
Show file tree
Hide file tree
Showing 9 changed files with 54 additions and 59 deletions.
17 changes: 5 additions & 12 deletions python/bitblas/gpu/matmul_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from typing import List, Optional, Set, Union, Tuple, Dict
from tvm import tir
from tvm.ir import Range
from tvm.tir import IterVar, PrimExpr, Var, BufferRegion
from tvm.tir import IterVar, PrimExpr, Var, BufferRegion, IndexMap
from tvm.tir.analysis import undefined_vars
from tvm.tir.schedule.schedule import BlockRV
from ..base.analysis import (
Expand All @@ -17,12 +17,12 @@
get_reduction_blocks,
)
from tvm.target.target import Target
from tvm.tir import IndexMap, Var
from tvm.tir.stmt_functor import pre_order_visit
import logging

logger = logging.getLogger(__name__)


def collect_vars_from_expr(prim_expr):
vars = []

Expand Down Expand Up @@ -352,10 +352,7 @@ def is_common_reduce(var: Var) -> bool:

def has_common_reduce(var: Var) -> bool:
vars = collect_vars_from_expr(var)
for v in vars:
if is_common_reduce(v):
return True
return False
return any(is_common_reduce(v) for v in vars)

def check_last_trait(region: List[Range]):
axes = get_ordered_axes(region)
Expand Down Expand Up @@ -605,16 +602,12 @@ def is_common_reduce(var: Var) -> bool:

def has_common_reduce(var: Var) -> bool:
vars = collect_vars_from_expr(var)
for v in vars:
if is_common_reduce(v):
return True
return False

return any(is_common_reduce(v) for v in vars)

def check_last_trait(region: List[Range]):
axes = get_ordered_axes(region)
return has_common_reduce(axes[-1])


intrin_info: dict = {}
in_dtype, out_dtype = get_in_out_dtypes(block_stmt)
intrin_info["in_dtype"] = in_dtype
Expand Down
3 changes: 2 additions & 1 deletion python/bitblas/ops/general_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
def is_native_compute(A_dtype, W_dtype) -> bool:
return (A_dtype, W_dtype) in NATIVE_COMPUTE_PATTERNS


@dataclass(frozen=True)
class MatmulConfig:
M: Union[int, Tuple[int]] = None
Expand Down Expand Up @@ -497,7 +498,7 @@ def forward(self, A, W, scale=None, zeros=None, bias=None, output=None) -> Any:
if self.dynamic_range is not None:
m = reduce(operator.mul, A.shape[:-1], 1)
args.append(m)

stream = torch.cuda.current_stream()

if self.lib is None:
Expand Down
13 changes: 8 additions & 5 deletions python/bitblas/ops/general_matmul_splitk.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from tvm.target import Target
import operator
from functools import reduce
from typing import Any, Optional, Union
from typing import Any, Optional, Union
from .operator import TransformKind
from .impl.matmul_splitk_impl import select_implementation as consistent_implementation
from .impl.matmul_dequantize_splitk_impl import select_implementation as weight_dequantize_implementation
Expand All @@ -17,9 +17,10 @@

WORKSPACE_SIZE = 1024 * 1024 * 256


@dataclass(frozen=True)
class MatmulConfigWithSplitK(MatmulConfig):
k_split: int = 1 # split K dimension
k_split: int = 1 # split K dimension


class MatmulWithSplitK(Matmul):
Expand Down Expand Up @@ -158,8 +159,10 @@ def forward(self, A, W, scale=None, zeros=None, bias=None, output=None) -> Any:
args.append(self.lut)

if output is None:
output = torch.empty((self.k_split, ) +
A.shape[:-1] + (self.N,), dtype=self.torch_output_dtype, device=A.device)
output = torch.empty(
(self.k_split,) + A.shape[:-1] + (self.N,),
dtype=self.torch_output_dtype,
device=A.device)
if scale is not None:
args.append(scale)
if zeros is not None:
Expand All @@ -171,7 +174,7 @@ def forward(self, A, W, scale=None, zeros=None, bias=None, output=None) -> Any:
if self.dynamic_range is not None:
m = reduce(operator.mul, A.shape[:-1], 1)
args.append(m)

stream = torch.cuda.current_stream()

if self.lib is None:
Expand Down
27 changes: 15 additions & 12 deletions python/bitblas/ops/impl/batch_matmul_dequantize_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,10 @@
from tvm.tir import IndexMap
from bitblas.ops.operator import TransformKind
from bitblas.gpu.matmul_analysis import get_propagate_map
from bitblas.quantization import (
_tir_packed_int_to_int_convert,
_tir_packed_to_signed_convert,
_tir_packed_to_unsigned_convert,
_tir_u32_to_f4_to_f16,
_tir_u8_to_f8_e4m3_to_f16,
_tir_packed_to_unsigned_convert_with_zeros,
)
from bitblas.quantization import (_tir_packed_int_to_int_convert, _tir_packed_to_signed_convert,
_tir_packed_to_unsigned_convert, _tir_u32_to_f4_to_f16,
_tir_u8_to_f8_e4m3_to_f16)


def matmul_nt_dequantize_b(
Batch,
Expand Down Expand Up @@ -48,7 +44,6 @@ def matmul_nt_dequantize_b(
Scale = te.placeholder((Batch, N, K // group_size), name="Scale", dtype=in_dtype)
Bias = te.placeholder((N,), name="Bias", dtype=in_dtype)


def decode_func(b, n, k):
if source_format == "uint":
if bit == 8:
Expand Down Expand Up @@ -187,11 +182,16 @@ def matmul_nt_dequantize_b_propagate_b(
group_size = K
qr = r * bit // storage_nbit
A = te.placeholder((Batch, M, K), name="A", dtype=in_dtype)
B = te.placeholder((Batch, N // l, (K // scaling_factor) // qr, l, qr), name="B", dtype=storage_dtype)
B = te.placeholder((Batch, N // l, (K // scaling_factor) // qr, l, qr),
name="B",
dtype=storage_dtype)
LUT = te.placeholder((1 << bit,), name="LUT", dtype=in_dtype)
Scale = te.placeholder((Batch, N, K // group_size), name="Scale", dtype=in_dtype)
Zeros = te.placeholder((Batch, N, K // group_size), name="Zeros", dtype=in_dtype)
Bias = te.placeholder((Batch, N,), name="Bias", dtype=in_dtype)
Bias = te.placeholder((
Batch,
N,
), name="Bias", dtype=in_dtype)

def fcompute(b, i, j):
warp_i, warp_j = i % l, j % qr
Expand Down Expand Up @@ -223,7 +223,10 @@ def decode_func(b, n, k):
if bit == 1:
# Dequantize int1 to -1 and 1. Without this step, the values would be 0 and 1, identical to uint1.
w = _tir_packed_int_to_int_convert(storage_type, storage_nbit)(
bit, B_reindex[b, n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype)
bit,
B_reindex[b, n, k // n_float_per_elem],
k % n_float_per_elem,
dtype=in_dtype)
elif bit == 8:
# 8 bit does not need to be compressed
w = B_reindex[b, n, k].astype(in_dtype)
Expand Down
4 changes: 2 additions & 2 deletions python/bitblas/ops/impl/batch_matmul_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
# pre-transformed tir expression of matmul
import tvm
from tvm import te
from bitblas.gpu.matmul_analysis import get_propagate_map
from bitblas.ops.operator import TransformKind


Expand All @@ -27,7 +26,8 @@ def matmul_nt(
k = te.reduce_axis((0, K), name="k")
C = te.compute(
(Batch, M, N),
lambda b, i, j: te.sum(A[b, i, k].astype(accum_dtype) * B[b, j, k].astype(accum_dtype), axis=k),
lambda b, i, j: te.sum(
A[b, i, k].astype(accum_dtype) * B[b, j, k].astype(accum_dtype), axis=k),
name="C",
)
last_output = C
Expand Down
21 changes: 7 additions & 14 deletions python/bitblas/ops/impl/matmul_dequantize_splitk_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,11 @@
# Licensed under the MIT License.
# pre-transformed tir expression of matmul
import tvm
from tvm import te, DataType
from tvm.tir import IndexMap
from bitblas.ops.operator import TransformKind
from bitblas.gpu.matmul_analysis import get_propagate_map
from bitblas.quantization import (
_tir_packed_int_to_int_convert,
_tir_packed_to_signed_convert,
_tir_packed_to_unsigned_convert,
_tir_u32_to_f4_to_f16,
_tir_u8_to_f8_e4m3_to_f16,
_tir_packed_to_unsigned_convert_with_zeros,
)
from tvm import te
from bitblas.quantization import (_tir_packed_int_to_int_convert, _tir_packed_to_signed_convert,
_tir_packed_to_unsigned_convert, _tir_u32_to_f4_to_f16,
_tir_u8_to_f8_e4m3_to_f16)


def matmul_nt_dequantize_b(
SplitK,
Expand Down Expand Up @@ -48,7 +41,6 @@ def matmul_nt_dequantize_b(
Scale = te.placeholder((N, K // group_size), name="Scale", dtype=in_dtype)
Bias = te.placeholder((N,), name="Bias", dtype=in_dtype)


def decode_func(n, k):
if source_format == "uint":
if bit == 8:
Expand Down Expand Up @@ -98,7 +90,8 @@ def decode_func(n, k):
C = te.compute(
(SplitK, M, N),
lambda sk, i, j: te.sum(
A[i, sk * RK + k].astype(accum_dtype) * B_decode[j, sk * RK + k].astype(accum_dtype), axis=k),
A[i, sk * RK + k].astype(accum_dtype) * B_decode[j, sk * RK + k].astype(accum_dtype),
axis=k),
name="C",
)
D = te.compute((SplitK, M, N), lambda b, i, j: C[b, i, j].astype(out_dtype), name="D")
Expand Down
4 changes: 2 additions & 2 deletions python/bitblas/ops/impl/matmul_splitk_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
# pre-transformed tir expression of matmul
import tvm
from tvm import te
from bitblas.gpu.matmul_analysis import get_propagate_map
from bitblas.ops.operator import TransformKind


Expand All @@ -28,7 +27,8 @@ def matmul_nt(
k = te.reduce_axis((0, RK), name="k")
C = te.compute(
(SplitK, M, N),
lambda sk, i, j: te.sum(A[i, sk * RK + k].astype(accum_dtype) * B[j, sk * RK + k].astype(accum_dtype), axis=k),
lambda sk, i, j: te.sum(
A[i, sk * RK + k].astype(accum_dtype) * B[j, sk * RK + k].astype(accum_dtype), axis=k),
name="C",
)
last_output = C
Expand Down
3 changes: 2 additions & 1 deletion python/bitblas/ops/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ def _forward_from_prebuild_lib(self, *args, stream=0):
]
ctypes_args.append(ctypes.c_void_p(stream))
self.lib.call(*ctypes_args)

def call_lib(self, *args, stream=0):
self.lib.call(*args, ctypes.c_void_p(stream))

Expand Down Expand Up @@ -340,6 +340,7 @@ class OPExecutorCPU:
"""
A class to execute a sequence of operators on the CPU.
"""

def __init__(self, operators: Optional[List[Operator]] = None):
if operators is None:
operators = []
Expand Down
21 changes: 11 additions & 10 deletions testing/python/operators/test_general_matmul_splitk_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
import pytest
import bitblas
from bitblas.ops.general_matmul_splitk import MatmulWithSplitK, MatmulConfigWithSplitK
import logging
from bitblas import set_log_level


def get_codegen_result(ops):
Expand Down Expand Up @@ -75,17 +73,19 @@ def test_matmul_finetune(M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, layo
matmul.hardware_aware_finetune(topk=10)
assert get_codegen_result(matmul)


@pytest.mark.parametrize(
"SPlitK,M,N,K,A_dtype,W_dtype,accum_dtype,out_dtype,layout,with_bias,group_size,with_scaling,with_zeros,zeros_mode",
[
(1, 1, 4096, 12800, "float16", "float16", "float16", "float16", "nt", False, -1, False, False,
None),
(4, 1, 4096, 12800, "float16", "float16", "float16", "float16", "nt", False, -1, False, False,
None),
(1, 1, 4096, 12800, "float16", "float16", "float16", "float16", "nt", False, -1, False,
False, None),
(4, 1, 4096, 12800, "float16", "float16", "float16", "float16", "nt", False, -1, False,
False, None),
],
)
def test_matmul_torch_forward_consistent(SplitK, M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, layout, with_bias,
group_size, with_scaling, with_zeros, zeros_mode):
def test_matmul_torch_forward_consistent(SplitK, M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype,
layout, with_bias, group_size, with_scaling, with_zeros,
zeros_mode):
import torch
torch.random.manual_seed(0)
matmul_config = MatmulConfigWithSplitK(
Expand All @@ -111,11 +111,12 @@ def test_matmul_torch_forward_consistent(SplitK, M, N, K, A_dtype, W_dtype, accu
inputs = []
inputs.append(torch.rand(input_shape, dtype=torch.float16).cuda() - 0.5)
inputs.append(torch.rand(weight_shape, dtype=torch.float16).cuda() - 0.5)

output_bitblas = matmul.forward(*inputs)
output_torch = torch.matmul(inputs[0], inputs[1].t() if layout == "nt" else inputs[1])
output_torch = torch.matmul(inputs[0], inputs[1].t() if layout == "nt" else inputs[1])
torch.testing.assert_close(output_bitblas, output_torch, rtol=1e-2, atol=1e-1)


# fmt: on
if __name__ == "__main__":
bitblas.testing.main()

0 comments on commit d67cc6d

Please sign in to comment.