Skip to content

Commit

Permalink
[Dev][BitNET] Implement INT4xINT2 GEMM (#233)
Browse files Browse the repository at this point in the history
* Merge TL Update

* submodule update

* Re-implement macro with sub function.

* lint fix

* Refactor tensor core memory allocation in MatmulFineGrainScheduler

- Adjusted the local fragment sizes for tensor core memory allocation in the MatmulFineGrainScheduler class.
- Updated the allocation sizes for A_local, B_local, and C_local variables based on the new fragment sizes.
- The changes ensure efficient memory utilization and improve performance.

Refactor tensor core memory allocation in MatmulDequantizeFineGrainedScheduler

- Modified the fragment sizes for tensor core memory allocation in the MatmulDequantizeFineGrainedScheduler class.
- Updated the allocation sizes for A_frag, B_frag, and C_frag variables based on the new fragment sizes.
- The changes optimize memory usage and enhance the efficiency of the dequantization process.

Refactor tensor core memory allocation in MatmulDequantizeWeightPropagationScheduler

- Adjusted the fragment sizes for tensor core memory allocation in the MatmulDequantizeWeightPropagationScheduler class.
- Updated the allocation sizes for A_frag, B_frag, B_dequantize_frag, and C_frag variables based on the new fragment sizes.
- The changes improve memory utilization and optimize the weight propagation process.

* Implement int4 tensorcore

* lint fix

* support uint2->uint4 fast dequantize

* Support int4 tensorcore decoding

* lint fix
  • Loading branch information
LeiWang1999 authored Nov 3, 2024
1 parent e94f65d commit 451b466
Show file tree
Hide file tree
Showing 13 changed files with 1,921 additions and 1 deletion.
2 changes: 1 addition & 1 deletion 3rdparty/tvm
41 changes: 41 additions & 0 deletions bitblas/gpu/intrin/lop3.py
Original file line number Diff line number Diff line change
Expand Up @@ -978,6 +978,47 @@
}
"""

decode_i2s_to_i4s = r"""
template <typename T1, typename T2, bool isSigned>
__device__ void decode_i2b_to_i4s(T1 *_i2b, T2 *_i4s, const int N = 16)
{
uint *i4s = reinterpret_cast<uint *>(_i4s);
uint *i2b = reinterpret_cast<uint *>(_i2b);
// First, we extract the i4s and construct an intermediate i8 number.
static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint BOTTOM_MASK = 0x33333333; // 0xf -> 0b1111 select 0,2,4,6,8,10,12
static constexpr uint I4b_TO_I8s_MAGIC_NUM = 0x00000000; // 0
static constexpr uint MEDIAN_NUM = isSigned ? 0x33333333 : 0x00000000;
#pragma unroll
for (int i = 0; i < (N / 8); i++)
{
// Extract elt_01 - (i4s & 0x000f000f) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(i4s[i])
: "r"(i2b[i / 2] >> (2 * (i % 2))), "n"(BOTTOM_MASK), "n"(I4b_TO_I8s_MAGIC_NUM), "n"(immLut));
if constexpr (isSigned)
{
// TODO(lei): uint4 sub should be enhanced.
// 0x03 0x03 0x03 0x03
i4s[i] = (((i4s[i] << 1) | i4s[i]) << 1) | i4s[i];
}
}
}
template <typename T1, typename T2>
__device__ void decode_i2s_to_i4s(T1 *_i4s, T2 *B_local_decode, const int N = 16)
{
decode_i2b_to_i4s<T1, T2, true>(_i4s, B_local_decode, N);
}
template <typename T1, typename T2>
__device__ void decode_i2u_to_i4s(T1 *_i4u, T2 *B_local_decode, const int N = 16)
{
decode_i2b_to_i4s<T1, T2, false>(_i4u, B_local_decode, N);
}
"""


def get_fast_decode_intrin(
source_bit=4,
Expand Down
2 changes: 2 additions & 0 deletions bitblas/ops/lop3_permutate/lop3_permutate_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@ def interleave_weight_int8_1b(A: T.Buffer((N, QK), storage_dtype), B: T.Buffer((
return interleave_weight_f16_1b
elif target_dtype == "int8" and bits == 1:
return interleave_weight_int8_1b
elif target_dtype == "int4" and bits == 2:
pass

return interleave_weight

Expand Down
21 changes: 21 additions & 0 deletions bitblas/tl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,3 +133,24 @@ def transform_func(i, j):
return [new_warp_i, new_warp_j]

return T.Layout(shape, transform_func)


def index_to_coordinates(index, shape):
'''
General Implementation of:
vjj = index % (micro_size_k // num_elems_per_byte)
coordinates[-1] = index % shape[-1];
vii = index // (micro_size_k // num_elems_per_byte) % micro_size_y
index = index // shape[-1]; coordinates[-2] = index % shape[-2];
vj = index // (micro_size_k // num_elems_per_byte * micro_size_y) % block_K // (micro_size_k // num_elems_per_byte)
index = index // shape[-2]; coordinates[-3] = index % shape[-3];
vi = index // (micro_size_k // num_elems_per_byte * micro_size_y * (block_K // (micro_size_k // num_elems_per_byte))) % block_N // micro_size_y
index = index // shape[-3]; coordinates[-4] = index % shape[-4];
'''
coordinates = []
dims = len(shape)
for i in range(dims):
coordinates.append(index % shape[dims - i - 1])
index = index // shape[dims - i - 1]
coordinates.reverse()
return coordinates
275 changes: 275 additions & 0 deletions integration/BitNet/int4_kernel/tl_int4xint2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,275 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import torch
import torch.backends
from bitblas import tvm as tvm
import bitblas.testing
from tvm import DataType
from tvm import tl as TL
import tvm.tl.language as T
from bitblas.tl.utils import (make_swizzle_layout, index_to_coordinates)
from bitblas.gpu.intrin.lop3 import decode_i2s_to_i4s

from bitblas.tl.macro_generator import (
INT4TensorCoreIntrinEmitter,)
from bitblas.ops.base_scheduler import simplify_prim_func

torch.manual_seed(0)


@simplify_prim_func
def tl_matmul(
M,
N,
K,
in_dtype,
out_dtype,
accum_dtype,
fast_decoding=True,
):
assert in_dtype in [
"float16",
"int8",
], "Currently only float16 and int8 are supported"
assert out_dtype in [
"float16",
"float32",
"int32",
], "Currently only float16, float32 and int32 are supported"

K = K // 2

micro_size_x = micro_size_y = micro_size_k = 16

if accum_dtype == "int32":
micro_size_k = 32

num_elems_per_byte = 2
MAX_TRANSACTION_SIZE_IN_BITS = 128
local_size = MAX_TRANSACTION_SIZE_IN_BITS // DataType(in_dtype).bits
local_size_compressed = local_size // num_elems_per_byte

# This is a debug config
block_row_warps = 2
block_col_warps = 2
warp_row_tiles = 64
warp_col_tiles = 64
chunk = 32 if in_dtype == "float16" else 64
shared_scope = "shared.dyn"
storage_dtype = "int8"

# Pipeline Stage
stage = 2

block_M = block_row_warps * warp_row_tiles
block_N = block_col_warps * warp_col_tiles
block_K = chunk

A_shape = (M, K) # int8 storage represents int4*2
B_shape = (N, K // num_elems_per_byte) # int8 storage represents int4*2
A_shared_shape = (block_M, block_K)
B_shared_shape = (block_N, block_K // num_elems_per_byte)
B_dequantize_shared_shape = (block_N, block_K)
C_shared_shape = (
block_M // micro_size_x,
block_N // micro_size_y,
micro_size_x,
micro_size_y,
)

warp_size = 32
threads = warp_size * (block_row_warps * block_col_warps)
fragement_size_a = (micro_size_x * micro_size_k) // warp_size
fragement_size_b = (micro_size_y * micro_size_k) // warp_size
fragement_size_c = (micro_size_x * micro_size_y) // warp_size
warp_rows = warp_row_tiles // micro_size_x
warp_cols = warp_col_tiles // micro_size_y

# MMA Wrapper to Auto Generate Code for MMA
mma_emitter = INT4TensorCoreIntrinEmitter(
a_dtype=in_dtype,
b_dtype=in_dtype,
accum_dtype=accum_dtype,
a_transposed=False,
b_transposed=True,
block_row_warps=block_row_warps,
block_col_warps=block_col_warps,
warp_row_tiles=warp_row_tiles,
warp_col_tiles=warp_col_tiles,
chunk=chunk,
)

@T.prim_func
def main(
A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, storage_dtype),
C: T.Buffer((M, N), out_dtype),
):
with T.Kernel(
T.ceildiv(N, block_N),
T.ceildiv(M, block_M),
threads=threads,
prelude=decode_i2s_to_i4s) as (bx, by):

A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
B_shared = T.alloc_shared(B_shared_shape, storage_dtype, scope=shared_scope)
B_dequantize_shared = T.alloc_shared(
B_dequantize_shared_shape, in_dtype, scope=shared_scope)
C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope)
A_frag = T.alloc_local((warp_rows * fragement_size_a), in_dtype)
B_frag = T.alloc_local((warp_cols * fragement_size_b), in_dtype)
C_frag = T.alloc_local((warp_rows * warp_cols * fragement_size_c), accum_dtype)

B_local = T.alloc_local([local_size_compressed], storage_dtype)
B_dequantize_local = T.alloc_local([local_size], in_dtype)

thread_bindings = T.thread_binding(0, threads, "threadIdx.x")

T.annotate_layout({
A_shared: make_swizzle_layout(A_shared),
B_dequantize_shared: make_swizzle_layout(B_dequantize_shared),
})

# Improve L2 Cache
T.use_swizzle(panel_size=10)

T.clear(C_frag)

for ko in T.Pipelined((K // block_K), num_stages=stage):

# Load A into shared memory
for i, k in T.Parallel(block_M, block_K):
A_shared[i, k] = A[by * block_M + i, ko * block_K + k]

# Load B into shared memory
for j, k in T.Parallel(block_N, block_K // num_elems_per_byte):
B_shared[j, k] = B[bx * block_N + j, ko * (block_K // num_elems_per_byte) + k]

for i in T.serial(block_N * block_K // num_elems_per_byte //
(threads * local_size_compressed)):
for v in T.vectorized(0, local_size_compressed):
index = (
i * threads * local_size_compressed +
thread_bindings * local_size_compressed + v)
vi, vj = index_to_coordinates(index, B_shared_shape)
B_local[v] = B_shared[vi, vj]

if fast_decoding:
T.call_extern('handle', 'decode_i2u_to_i4s', T.address_of(B_local[0]),
T.address_of(B_dequantize_local[0]), 32)
else:
for v in T.serial(0, local_size):
int2x2_value = (B_local[v // 2] >> ((v % 2) * 4)) & 0x0F

int4_0 = (int2x2_value >> 0) & 0x03
int4_1 = (int2x2_value >> 2) & 0x03

B_dequantize_local[v] = (int4_1 << 4) | int4_0

for v in T.vectorized(0, local_size):
index = i * threads * local_size + thread_bindings * local_size + v
vi, vj = index_to_coordinates(index, B_dequantize_shared_shape)
B_dequantize_shared[vi, vj] = B_dequantize_local[v]

for ki in T.serial(0, (block_K // micro_size_k)):

# Load A into fragment
mma_emitter.ldmatrix_a(
A_frag,
A_shared,
ki,
thread_bindings=thread_bindings,
)

# Load B into fragment
mma_emitter.ldmatrix_b(
B_frag,
B_dequantize_shared,
ki,
thread_bindings=thread_bindings,
)

# Perform Matrix Multiplication
mma_emitter.mma(A_frag, B_frag, C_frag)

# Perform STMatrix
mma_emitter.stmatrix(
C_frag,
C_shared,
thread_bindings=thread_bindings,
)

# Store shared into global
for i, j in T.Parallel(block_M, block_N):
C[by * block_M + i, bx * block_N + j] = C_shared[
i // micro_size_x,
j // micro_size_y,
i % micro_size_x,
j % micro_size_y,
]

return main


def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype, fast_decoding=True):
matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype, fast_decoding)
print(matmul)
mod, params = TL.lower(matmul)
src_code = mod.imported_modules[0].get_source()
# src_code is the generated cuda source
assert src_code is not None
print(src_code)
# A = torch.ones(M, K, device="cuda", dtype=getattr(torch, in_dtype))
# B = torch.ones(N, K, device="cuda", dtype=getattr(torch, in_dtype))
A = torch.randint(0, 4, (M, K), device="cuda", dtype=getattr(torch, in_dtype))
B = torch.randint(0, 2, (N, K), device="cuda", dtype=getattr(torch, in_dtype))

lop3_permutate_config = bitblas.ops.LOP3PermutateConfig(
M=N,
N=K,
datatype="int4",
dequantize_bits=2,
storage_dtype="int8",
)
lop3_permutate = bitblas.ops.LOP3Permutate(
config=lop3_permutate_config,
target=tvm.target.Target("llvm"),
)

C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype))

compressed_A = (A[:, ::2] & 0x0F) + ((A[:, 1::2] & 0x0F) << 4)
compressed_B = (B[:, ::4] & 0x03) + ((B[:, 1::4] & 0x03) << 2) + ((B[:, 2::4] & 0x03) << 4) + (
(B[:, 3::4] & 0x03) << 6)

mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer)
print(f"{compressed_B=}")
lop3_compressed_B = lop3_permutate(compressed_B.cpu()).cuda()
print(f"{lop3_compressed_B=}")
mod(compressed_A, lop3_compressed_B, C)
print(C)
latency = mod.do_bench(mod.func, warmup=25, profiler="tvm")
print(latency)
# Ensure that the latency is not None
assert latency is not None

# Get Reference Result
ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to(getattr(torch, accum_dtype))

print(ref_c)
torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2)


def test_assert_tl_matmul():
assert_tl_matmul_correctness(128, 128, 128, "float16", "float16", "float16")
assert_tl_matmul_correctness(128, 256, 256, "float16", "float32", "float32")


if __name__ == "__main__":
# bitblas.testing.main()
# assert_tl_matmul_correctness(128, 128, 128, "float16", "float16", "float16")
# assert_tl_matmul_correctness(16384, 16384, 16384, "int8", "int32", "int32")
# assert_tl_matmul_correctness(16384, 16384, 16384, "int8", "int32", "int32")
# assert_tl_matmul_correctness(128, 128, 128, "int8", "int32", "int32")
assert_tl_matmul_correctness(16384, 16384, 16384, "int8", "int32", "int32")
Loading

0 comments on commit 451b466

Please sign in to comment.