From 451b4660611a814171104bad81fb5baedfdf4645 Mon Sep 17 00:00:00 2001 From: Lei Wang <34334180+LeiWang1999@users.noreply.github.com> Date: Sun, 3 Nov 2024 13:56:11 +0800 Subject: [PATCH] [Dev][BitNET] Implement INT4xINT2 GEMM (#233) * Merge TL Update * submodule update * Re-implement macro with sub function. * lint fix * Refactor tensor core memory allocation in MatmulFineGrainScheduler - Adjusted the local fragment sizes for tensor core memory allocation in the MatmulFineGrainScheduler class. - Updated the allocation sizes for A_local, B_local, and C_local variables based on the new fragment sizes. - The changes ensure efficient memory utilization and improve performance. Refactor tensor core memory allocation in MatmulDequantizeFineGrainedScheduler - Modified the fragment sizes for tensor core memory allocation in the MatmulDequantizeFineGrainedScheduler class. - Updated the allocation sizes for A_frag, B_frag, and C_frag variables based on the new fragment sizes. - The changes optimize memory usage and enhance the efficiency of the dequantization process. Refactor tensor core memory allocation in MatmulDequantizeWeightPropagationScheduler - Adjusted the fragment sizes for tensor core memory allocation in the MatmulDequantizeWeightPropagationScheduler class. - Updated the allocation sizes for A_frag, B_frag, B_dequantize_frag, and C_frag variables based on the new fragment sizes. - The changes improve memory utilization and optimize the weight propagation process. * Implement int4 tensorcore * lint fix * support uint2->uint4 fast dequantize * Support int4 tensorcore decoding * lint fix --- 3rdparty/tvm | 2 +- bitblas/gpu/intrin/lop3.py | 41 +++ .../ops/lop3_permutate/lop3_permutate_impl.py | 2 + bitblas/tl/utils.py | 21 ++ .../BitNet/int4_kernel/tl_int4xint2.py | 275 +++++++++++++++ .../tl_int4xint2_ladder_weight_only.py | 331 ++++++++++++++++++ .../BitNet/int4_kernel/tl_int4xint4.py | 211 +++++++++++ .../tl_int4xint4_ladder_weight_only.py | 238 +++++++++++++ .../BitNet/int4_kernel/tl_int8xint8.py | 223 ++++++++++++ .../tl_int8xint8_ladder_weight_only.py | 255 ++++++++++++++ .../cpp/lop3_type_conversion/CMakeLists.txt | 1 + .../lop3_type_conversion/fast_decoding.hpp | 87 +++++ .../lowprecision_to_int4.cu | 235 +++++++++++++ 13 files changed, 1921 insertions(+), 1 deletion(-) create mode 100644 integration/BitNet/int4_kernel/tl_int4xint2.py create mode 100644 integration/BitNet/int4_kernel/tl_int4xint2_ladder_weight_only.py create mode 100644 integration/BitNet/int4_kernel/tl_int4xint4.py create mode 100644 integration/BitNet/int4_kernel/tl_int4xint4_ladder_weight_only.py create mode 100644 integration/BitNet/int4_kernel/tl_int8xint8.py create mode 100644 integration/BitNet/int4_kernel/tl_int8xint8_ladder_weight_only.py create mode 100644 testing/cpp/lop3_type_conversion/lowprecision_to_int4.cu diff --git a/3rdparty/tvm b/3rdparty/tvm index 71fe7ce82..be013f6d5 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 71fe7ce827396b98a3169343c3744e788a82566c +Subproject commit be013f6d5e623e1787351aac897e270970e33ada diff --git a/bitblas/gpu/intrin/lop3.py b/bitblas/gpu/intrin/lop3.py index 8d60c7651..75f4b1757 100644 --- a/bitblas/gpu/intrin/lop3.py +++ b/bitblas/gpu/intrin/lop3.py @@ -978,6 +978,47 @@ } """ +decode_i2s_to_i4s = r""" +template +__device__ void decode_i2b_to_i4s(T1 *_i2b, T2 *_i4s, const int N = 16) +{ + uint *i4s = reinterpret_cast(_i4s); + uint *i2b = reinterpret_cast(_i2b); + // First, we extract the i4s and construct an intermediate i8 number. + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x33333333; // 0xf -> 0b1111 select 0,2,4,6,8,10,12 + static constexpr uint I4b_TO_I8s_MAGIC_NUM = 0x00000000; // 0 + static constexpr uint MEDIAN_NUM = isSigned ? 0x33333333 : 0x00000000; + +#pragma unroll + for (int i = 0; i < (N / 8); i++) + { + // Extract elt_01 - (i4s & 0x000f000f) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(i4s[i]) + : "r"(i2b[i / 2] >> (2 * (i % 2))), "n"(BOTTOM_MASK), "n"(I4b_TO_I8s_MAGIC_NUM), "n"(immLut)); + if constexpr (isSigned) + { + // TODO(lei): uint4 sub should be enhanced. + // 0x03 0x03 0x03 0x03 + i4s[i] = (((i4s[i] << 1) | i4s[i]) << 1) | i4s[i]; + } + } +} + +template +__device__ void decode_i2s_to_i4s(T1 *_i4s, T2 *B_local_decode, const int N = 16) +{ + decode_i2b_to_i4s(_i4s, B_local_decode, N); +} + +template +__device__ void decode_i2u_to_i4s(T1 *_i4u, T2 *B_local_decode, const int N = 16) +{ + decode_i2b_to_i4s(_i4u, B_local_decode, N); +} +""" + def get_fast_decode_intrin( source_bit=4, diff --git a/bitblas/ops/lop3_permutate/lop3_permutate_impl.py b/bitblas/ops/lop3_permutate/lop3_permutate_impl.py index 07d8f4f0c..94ddd13c6 100644 --- a/bitblas/ops/lop3_permutate/lop3_permutate_impl.py +++ b/bitblas/ops/lop3_permutate/lop3_permutate_impl.py @@ -126,6 +126,8 @@ def interleave_weight_int8_1b(A: T.Buffer((N, QK), storage_dtype), B: T.Buffer(( return interleave_weight_f16_1b elif target_dtype == "int8" and bits == 1: return interleave_weight_int8_1b + elif target_dtype == "int4" and bits == 2: + pass return interleave_weight diff --git a/bitblas/tl/utils.py b/bitblas/tl/utils.py index 2c88bec64..18f0d3274 100644 --- a/bitblas/tl/utils.py +++ b/bitblas/tl/utils.py @@ -133,3 +133,24 @@ def transform_func(i, j): return [new_warp_i, new_warp_j] return T.Layout(shape, transform_func) + + +def index_to_coordinates(index, shape): + ''' + General Implementation of: + vjj = index % (micro_size_k // num_elems_per_byte) + coordinates[-1] = index % shape[-1]; + vii = index // (micro_size_k // num_elems_per_byte) % micro_size_y + index = index // shape[-1]; coordinates[-2] = index % shape[-2]; + vj = index // (micro_size_k // num_elems_per_byte * micro_size_y) % block_K // (micro_size_k // num_elems_per_byte) + index = index // shape[-2]; coordinates[-3] = index % shape[-3]; + vi = index // (micro_size_k // num_elems_per_byte * micro_size_y * (block_K // (micro_size_k // num_elems_per_byte))) % block_N // micro_size_y + index = index // shape[-3]; coordinates[-4] = index % shape[-4]; + ''' + coordinates = [] + dims = len(shape) + for i in range(dims): + coordinates.append(index % shape[dims - i - 1]) + index = index // shape[dims - i - 1] + coordinates.reverse() + return coordinates diff --git a/integration/BitNet/int4_kernel/tl_int4xint2.py b/integration/BitNet/int4_kernel/tl_int4xint2.py new file mode 100644 index 000000000..16797501a --- /dev/null +++ b/integration/BitNet/int4_kernel/tl_int4xint2.py @@ -0,0 +1,275 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import torch +import torch.backends +from bitblas import tvm as tvm +import bitblas.testing +from tvm import DataType +from tvm import tl as TL +import tvm.tl.language as T +from bitblas.tl.utils import (make_swizzle_layout, index_to_coordinates) +from bitblas.gpu.intrin.lop3 import decode_i2s_to_i4s + +from bitblas.tl.macro_generator import ( + INT4TensorCoreIntrinEmitter,) +from bitblas.ops.base_scheduler import simplify_prim_func + +torch.manual_seed(0) + + +@simplify_prim_func +def tl_matmul( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + fast_decoding=True, +): + assert in_dtype in [ + "float16", + "int8", + ], "Currently only float16 and int8 are supported" + assert out_dtype in [ + "float16", + "float32", + "int32", + ], "Currently only float16, float32 and int32 are supported" + + K = K // 2 + + micro_size_x = micro_size_y = micro_size_k = 16 + + if accum_dtype == "int32": + micro_size_k = 32 + + num_elems_per_byte = 2 + MAX_TRANSACTION_SIZE_IN_BITS = 128 + local_size = MAX_TRANSACTION_SIZE_IN_BITS // DataType(in_dtype).bits + local_size_compressed = local_size // num_elems_per_byte + + # This is a debug config + block_row_warps = 2 + block_col_warps = 2 + warp_row_tiles = 64 + warp_col_tiles = 64 + chunk = 32 if in_dtype == "float16" else 64 + shared_scope = "shared.dyn" + storage_dtype = "int8" + + # Pipeline Stage + stage = 2 + + block_M = block_row_warps * warp_row_tiles + block_N = block_col_warps * warp_col_tiles + block_K = chunk + + A_shape = (M, K) # int8 storage represents int4*2 + B_shape = (N, K // num_elems_per_byte) # int8 storage represents int4*2 + A_shared_shape = (block_M, block_K) + B_shared_shape = (block_N, block_K // num_elems_per_byte) + B_dequantize_shared_shape = (block_N, block_K) + C_shared_shape = ( + block_M // micro_size_x, + block_N // micro_size_y, + micro_size_x, + micro_size_y, + ) + + warp_size = 32 + threads = warp_size * (block_row_warps * block_col_warps) + fragement_size_a = (micro_size_x * micro_size_k) // warp_size + fragement_size_b = (micro_size_y * micro_size_k) // warp_size + fragement_size_c = (micro_size_x * micro_size_y) // warp_size + warp_rows = warp_row_tiles // micro_size_x + warp_cols = warp_col_tiles // micro_size_y + + # MMA Wrapper to Auto Generate Code for MMA + mma_emitter = INT4TensorCoreIntrinEmitter( + a_dtype=in_dtype, + b_dtype=in_dtype, + accum_dtype=accum_dtype, + a_transposed=False, + b_transposed=True, + block_row_warps=block_row_warps, + block_col_warps=block_col_warps, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=chunk, + ) + + @T.prim_func + def main( + A: T.Buffer(A_shape, in_dtype), + B: T.Buffer(B_shape, storage_dtype), + C: T.Buffer((M, N), out_dtype), + ): + with T.Kernel( + T.ceildiv(N, block_N), + T.ceildiv(M, block_M), + threads=threads, + prelude=decode_i2s_to_i4s) as (bx, by): + + A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) + B_shared = T.alloc_shared(B_shared_shape, storage_dtype, scope=shared_scope) + B_dequantize_shared = T.alloc_shared( + B_dequantize_shared_shape, in_dtype, scope=shared_scope) + C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) + A_frag = T.alloc_local((warp_rows * fragement_size_a), in_dtype) + B_frag = T.alloc_local((warp_cols * fragement_size_b), in_dtype) + C_frag = T.alloc_local((warp_rows * warp_cols * fragement_size_c), accum_dtype) + + B_local = T.alloc_local([local_size_compressed], storage_dtype) + B_dequantize_local = T.alloc_local([local_size], in_dtype) + + thread_bindings = T.thread_binding(0, threads, "threadIdx.x") + + T.annotate_layout({ + A_shared: make_swizzle_layout(A_shared), + B_dequantize_shared: make_swizzle_layout(B_dequantize_shared), + }) + + # Improve L2 Cache + T.use_swizzle(panel_size=10) + + T.clear(C_frag) + + for ko in T.Pipelined((K // block_K), num_stages=stage): + + # Load A into shared memory + for i, k in T.Parallel(block_M, block_K): + A_shared[i, k] = A[by * block_M + i, ko * block_K + k] + + # Load B into shared memory + for j, k in T.Parallel(block_N, block_K // num_elems_per_byte): + B_shared[j, k] = B[bx * block_N + j, ko * (block_K // num_elems_per_byte) + k] + + for i in T.serial(block_N * block_K // num_elems_per_byte // + (threads * local_size_compressed)): + for v in T.vectorized(0, local_size_compressed): + index = ( + i * threads * local_size_compressed + + thread_bindings * local_size_compressed + v) + vi, vj = index_to_coordinates(index, B_shared_shape) + B_local[v] = B_shared[vi, vj] + + if fast_decoding: + T.call_extern('handle', 'decode_i2u_to_i4s', T.address_of(B_local[0]), + T.address_of(B_dequantize_local[0]), 32) + else: + for v in T.serial(0, local_size): + int2x2_value = (B_local[v // 2] >> ((v % 2) * 4)) & 0x0F + + int4_0 = (int2x2_value >> 0) & 0x03 + int4_1 = (int2x2_value >> 2) & 0x03 + + B_dequantize_local[v] = (int4_1 << 4) | int4_0 + + for v in T.vectorized(0, local_size): + index = i * threads * local_size + thread_bindings * local_size + v + vi, vj = index_to_coordinates(index, B_dequantize_shared_shape) + B_dequantize_shared[vi, vj] = B_dequantize_local[v] + + for ki in T.serial(0, (block_K // micro_size_k)): + + # Load A into fragment + mma_emitter.ldmatrix_a( + A_frag, + A_shared, + ki, + thread_bindings=thread_bindings, + ) + + # Load B into fragment + mma_emitter.ldmatrix_b( + B_frag, + B_dequantize_shared, + ki, + thread_bindings=thread_bindings, + ) + + # Perform Matrix Multiplication + mma_emitter.mma(A_frag, B_frag, C_frag) + + # Perform STMatrix + mma_emitter.stmatrix( + C_frag, + C_shared, + thread_bindings=thread_bindings, + ) + + # Store shared into global + for i, j in T.Parallel(block_M, block_N): + C[by * block_M + i, bx * block_N + j] = C_shared[ + i // micro_size_x, + j // micro_size_y, + i % micro_size_x, + j % micro_size_y, + ] + + return main + + +def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype, fast_decoding=True): + matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype, fast_decoding) + print(matmul) + mod, params = TL.lower(matmul) + src_code = mod.imported_modules[0].get_source() + # src_code is the generated cuda source + assert src_code is not None + print(src_code) + # A = torch.ones(M, K, device="cuda", dtype=getattr(torch, in_dtype)) + # B = torch.ones(N, K, device="cuda", dtype=getattr(torch, in_dtype)) + A = torch.randint(0, 4, (M, K), device="cuda", dtype=getattr(torch, in_dtype)) + B = torch.randint(0, 2, (N, K), device="cuda", dtype=getattr(torch, in_dtype)) + + lop3_permutate_config = bitblas.ops.LOP3PermutateConfig( + M=N, + N=K, + datatype="int4", + dequantize_bits=2, + storage_dtype="int8", + ) + lop3_permutate = bitblas.ops.LOP3Permutate( + config=lop3_permutate_config, + target=tvm.target.Target("llvm"), + ) + + C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) + + compressed_A = (A[:, ::2] & 0x0F) + ((A[:, 1::2] & 0x0F) << 4) + compressed_B = (B[:, ::4] & 0x03) + ((B[:, 1::4] & 0x03) << 2) + ((B[:, 2::4] & 0x03) << 4) + ( + (B[:, 3::4] & 0x03) << 6) + + mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer) + print(f"{compressed_B=}") + lop3_compressed_B = lop3_permutate(compressed_B.cpu()).cuda() + print(f"{lop3_compressed_B=}") + mod(compressed_A, lop3_compressed_B, C) + print(C) + latency = mod.do_bench(mod.func, warmup=25, profiler="tvm") + print(latency) + # Ensure that the latency is not None + assert latency is not None + + # Get Reference Result + ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to(getattr(torch, accum_dtype)) + + print(ref_c) + torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) + + +def test_assert_tl_matmul(): + assert_tl_matmul_correctness(128, 128, 128, "float16", "float16", "float16") + assert_tl_matmul_correctness(128, 256, 256, "float16", "float32", "float32") + + +if __name__ == "__main__": + # bitblas.testing.main() + # assert_tl_matmul_correctness(128, 128, 128, "float16", "float16", "float16") + # assert_tl_matmul_correctness(16384, 16384, 16384, "int8", "int32", "int32") + # assert_tl_matmul_correctness(16384, 16384, 16384, "int8", "int32", "int32") + # assert_tl_matmul_correctness(128, 128, 128, "int8", "int32", "int32") + assert_tl_matmul_correctness(16384, 16384, 16384, "int8", "int32", "int32") diff --git a/integration/BitNet/int4_kernel/tl_int4xint2_ladder_weight_only.py b/integration/BitNet/int4_kernel/tl_int4xint2_ladder_weight_only.py new file mode 100644 index 000000000..d44717e7f --- /dev/null +++ b/integration/BitNet/int4_kernel/tl_int4xint2_ladder_weight_only.py @@ -0,0 +1,331 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import torch +import torch.backends +from bitblas import tvm as tvm +import bitblas.testing +from tvm import DataType +from tvm import tl as TL +import tvm.tl.language as T +from bitblas.tl.utils import make_swizzle_layout, index_to_coordinates +from bitblas.tl.macro_generator import ( + INT4TensorCoreIntrinEmitterWithLadderTransform,) +from bitblas.gpu.intrin.lop3 import decode_i2s_to_i4s +from bitblas.ops.base_scheduler import simplify_prim_func + +torch.manual_seed(0) + + +@simplify_prim_func +def tl_matmul( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + fast_decoding=True, +): + K = K // 2 + assert in_dtype in [ + "float16", + "int8", + ], "Currently only float16 and int8 are supported" + assert out_dtype in [ + "float16", + "float32", + "int32", + ], "Currently only float16, float32 and int32 are supported" + + micro_size_x = micro_size_y = micro_size_k = 16 + + if out_dtype == "int32": + micro_size_k = 32 + + num_elems_per_byte = 2 + MAX_TRANSACTION_SIZE_IN_BITS = 128 + local_size = MAX_TRANSACTION_SIZE_IN_BITS // DataType(in_dtype).bits + local_size_compressed = local_size // num_elems_per_byte + + transform_b = 3 + + # This is a debug config + block_row_warps = 2 + block_col_warps = 2 + warp_row_tiles = 64 + warp_col_tiles = 64 + chunk = 32 if in_dtype == "float16" else 64 + shared_scope = "shared.dyn" + # shared_scope = "shared" + storage_dtype = "int8" + + # Pipeline Stage + stage = 2 + + block_M = block_row_warps * warp_row_tiles + block_N = block_col_warps * warp_col_tiles + block_K = chunk + + is_smooth_a = False + can_swizzle = block_K * DataType(in_dtype).bits == 512 + apply_pad_a = not (is_smooth_a or can_swizzle) + pad_factor = 8 + + A_shape = (M, K) + B_shape = (N // micro_size_y, K // micro_size_k, micro_size_y, + micro_size_k // num_elems_per_byte) + A_shared_shape = ( + block_M, + (block_K + pad_factor) if apply_pad_a else block_K, + ) + B_shared_shape = ( + block_N // micro_size_y, + block_K // micro_size_k, + micro_size_y, + micro_size_k // num_elems_per_byte, + ) + B_dequantize_shared_shape = ( + block_N // micro_size_y, + block_K // micro_size_k, + micro_size_y, + micro_size_k, + ) + C_shared_shape = ( + block_M // micro_size_x, + block_N // micro_size_y, + micro_size_x, + micro_size_y, + ) + warp_size = 32 + threads = warp_size * (block_row_warps * block_col_warps) + fragement_size_a = (micro_size_x * micro_size_k) // warp_size + fragement_size_b = (micro_size_y * micro_size_k) // warp_size + fragement_size_c = (micro_size_x * micro_size_y) // warp_size + warp_rows = warp_row_tiles // micro_size_x + warp_cols = warp_col_tiles // micro_size_y + + # MMA Wrapper to Auto Generate Code for MMA + mma_emitter = INT4TensorCoreIntrinEmitterWithLadderTransform( + a_dtype=in_dtype, + b_dtype=in_dtype, + accum_dtype=accum_dtype, + a_transposed=False, + b_transposed=True, + block_row_warps=block_row_warps, + block_col_warps=block_col_warps, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=chunk, + transform_kind_b=transform_b, + ) + + vec_load_qb = 16 + if block_N * (block_K) // num_elems_per_byte // threads < vec_load_qb: + vec_load_qb = block_N * (block_K) // num_elems_per_byte // threads + + @T.prim_func + def main( + A: T.Buffer(A_shape, in_dtype), + B: T.Buffer(B_shape, storage_dtype), + C: T.Buffer((M, N), out_dtype), + ): + with T.Kernel( + T.ceildiv(N, block_N), + T.ceildiv(M, block_M), + threads=threads, + prelude=decode_i2s_to_i4s) as (bx, by): + + A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) + B_shared = T.alloc_shared(B_shared_shape, storage_dtype, scope=shared_scope) + B_dequantize_shared = T.alloc_shared( + B_dequantize_shared_shape, in_dtype, scope=shared_scope) + C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) + A_frag = T.alloc_local((warp_rows * fragement_size_a), in_dtype) + B_frag = T.alloc_local((warp_cols * fragement_size_b), in_dtype) + C_frag = T.alloc_local((warp_rows * warp_cols * fragement_size_c), accum_dtype) + B_local = T.alloc_local([local_size_compressed], storage_dtype) + B_dequantize_local = T.alloc_local([local_size], in_dtype) + + thread_bindings = T.thread_binding(0, threads, "threadIdx.x") + + T.annotate_layout({ + A_shared: make_swizzle_layout(A_shared), + # B_dequantize_shared: make_swizzle_layout(B_dequantize_shared, is_smooth=True), + }) + + # Improve L2 Cache + T.use_swizzle(panel_size=10) + + T.clear(C_frag) + + for ko in T.Pipelined((K // block_K), num_stages=stage): + + # Load A into shared memory + for i, k in T.Parallel(block_M, block_K): + A_shared[i, k] = A[by * block_M + i, ko * block_K + k] + + # Load B into shared memory + # TODO(lei): Layout Inference Pass is not efficient to handle the four dims int8 load + for i in T.serial(block_N * block_K // num_elems_per_byte // + (threads * vec_load_qb)): + for v in T.vectorized(0, vec_load_qb): + t = thread_bindings + idx = i * threads * vec_load_qb + threads * vec_load_qb + t * vec_load_qb + v + vj, vk, vjj, vkk = index_to_coordinates(idx, B_shared_shape) + B_shared[vj, vk, vjj, + vkk] = B[bx * (block_N // micro_size_y) + vj, + ko * (block_K // micro_size_k) + vk, vjj, vkk] + + for i in T.serial(block_N * block_K // num_elems_per_byte // + (threads * local_size_compressed)): + for v in T.vectorized(0, local_size_compressed): + index = ( + i * threads * local_size_compressed + + thread_bindings * local_size_compressed + v) + vi, vj, vii, vjj = index_to_coordinates(index, B_shared_shape) + B_local[v] = B_shared[vi, vj, vii, vjj] + + if fast_decoding: + # Simulated dequantization + T.call_extern('handle', 'decode_i2u_to_i4s', T.address_of(B_local[0]), + T.address_of(B_dequantize_local[0]), 32) + else: + for v in T.serial(0, local_size): + int2x2_value = (B_local[v // 2] >> ((v % 2) * 4)) & 0x0F + + int4_0 = (int2x2_value >> 0) & 0x03 + int4_1 = (int2x2_value >> 2) & 0x03 + + B_dequantize_local[v] = (int4_1 << 4) | int4_0 + + for v in T.vectorized(0, local_size): + index = i * threads * local_size + thread_bindings * local_size + v + vi, vj, vii, vjj = index_to_coordinates(index, B_dequantize_shared_shape) + B_dequantize_shared[vi, vj, vii, vjj] = B_dequantize_local[v] + + for ki in T.serial(0, (block_K // micro_size_k)): + + # Load A into fragment + mma_emitter.ldmatrix_a( + A_frag, + A_shared, + ki, + thread_bindings=thread_bindings, + ) + + # Load B into fragment + mma_emitter.ldmatrix_b( + B_frag, + B_dequantize_shared, + ki, + thread_bindings=thread_bindings, + ) + + # Perform Matrix Multiplication + mma_emitter.mma(A_frag, B_frag, C_frag) + + # Perform STMatrix + mma_emitter.stmatrix( + C_frag, + C_shared, + thread_bindings=thread_bindings, + ) + + # Store shared into global + for i, j in T.Parallel(block_M, block_N): + C[by * block_M + i, bx * block_N + j] = C_shared[ + i // micro_size_x, + j // micro_size_y, + i % micro_size_x, + j % micro_size_y, + ] + + return main + + +def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype, fast_decoding=True): + matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype, fast_decoding) + print(matmul) + mod, params = TL.lower(matmul) + src_code = mod.imported_modules[0].get_source() + # src_code is the generated cuda source + print(src_code) + assert src_code is not None + transform_b = 3 + + # A = torch.ones(M, K, device="cuda", dtype=getattr(torch, in_dtype)) + # B = torch.ones(N, K, device="cuda", dtype=getattr(torch, in_dtype)) + A = torch.randint(0, 4, (M, K), device="cuda", dtype=getattr(torch, in_dtype)) + B = torch.randint(0, 4, (N, K), device="cuda", dtype=getattr(torch, in_dtype)) + C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) + compressed_A = (A[:, ::2] & 0x0F) + ((A[:, 1::2] & 0x0F) << 4) + compressed_B = (B[:, ::2] & 0x0F) + ((B[:, 1::2] & 0x0F) << 4) + + lop3_permutate_config = bitblas.ops.LOP3PermutateConfig( + M=N, + N=K, + datatype="int4", + dequantize_bits=2, + storage_dtype="int8", + ) + lop3_permutate = bitblas.ops.LOP3Permutate( + config=lop3_permutate_config, + target=tvm.target.Target("llvm"), + ) + + ladder_permutate_config = bitblas.ops.LadderPermutateConfig( + M=N, + N=(K // 2), + datatype="int8", + storage_dtype="int8", + transform_kind=transform_b, + transpose_matrix=True, + ) + + ladder_permutate = bitblas.ops.LadderPermutate(ladder_permutate_config) + + mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer) + compressed_B_ladder = ladder_permutate(compressed_B.cpu()).cuda() + ladder_shape = compressed_B_ladder.shape + int2_shape = (ladder_shape[:-1] + (ladder_shape[-1] // 2,)) + int2_tensor = torch.zeros(int2_shape, device="cuda", dtype=torch.int8) + for i in range(int2_tensor.shape[-1]): + int2_tensor[..., i] = (compressed_B_ladder[..., 2 * i] & 0x03) | ( + (compressed_B_ladder[..., 2 * i] >> 4) & 0x03) << 2 | ( + (compressed_B_ladder[..., 2 * i + 1] & 0x03) << 4) | ( + (compressed_B_ladder[..., 2 * i + 1] >> 4) << 6) + + raw_tensor_shape = int2_tensor.shape + print(f"{raw_tensor_shape=}") + if fast_decoding: + lop3_compressed_B = lop3_permutate(int2_tensor.cpu()).cuda() + lop3_compressed_B = lop3_compressed_B.view(raw_tensor_shape) + else: + lop3_compressed_B = int2_tensor + + mod(compressed_A, lop3_compressed_B, C) + + latency = mod.do_bench(mod.func, warmup=25) + print(f"Latency: {latency}") + # Ensure that the latency is not None + assert latency is not None + + # Get Reference Result + ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to(getattr(torch, accum_dtype)) + print(C) + print(ref_c) + torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) + + +def test_assert_tl_matmul(): + assert_tl_matmul_correctness(128, 128, 128, "float16", "float16", "float16") + assert_tl_matmul_correctness(128, 256, 256, "float16", "float32", "float32") + + +if __name__ == "__main__": + # bitblas.testing.main() + # assert_tl_matmul_correctness(128, 128, 128, "float16", "float16", "float16") + # assert_tl_matmul_correctness(128, 128, 128, "int8", "int32", "int32") + # assert_tl_matmul_correctness(16384, 16384, 16384, "int8", "int32", "int32") + assert_tl_matmul_correctness(128, 128, 128, "int8", "int32", "int32", False) diff --git a/integration/BitNet/int4_kernel/tl_int4xint4.py b/integration/BitNet/int4_kernel/tl_int4xint4.py new file mode 100644 index 000000000..5b040db89 --- /dev/null +++ b/integration/BitNet/int4_kernel/tl_int4xint4.py @@ -0,0 +1,211 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import torch +import torch.backends +from bitblas import tvm as tvm +from tvm import tl as TL +import tvm.tl.language as T +from bitblas.tl.utils import ( + make_swizzle_layout,) + +from bitblas.tl.macro_generator import ( + INT4TensorCoreIntrinEmitter,) +from bitblas.ops.base_scheduler import simplify_prim_func + +torch.manual_seed(0) + + +@simplify_prim_func +def tl_matmul( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, +): + assert in_dtype in [ + "float16", + "int8", + ], "Currently only float16 and int8 are supported" + assert out_dtype in [ + "float16", + "float32", + "int32", + ], "Currently only float16, float32 and int32 are supported" + + K = K // 2 + + micro_size_x = micro_size_y = micro_size_k = 16 + + if accum_dtype == "int32": + micro_size_k = 32 + + # This is a debug config + block_row_warps = 2 + block_col_warps = 2 + warp_row_tiles = 64 + warp_col_tiles = 64 + chunk = 32 if in_dtype == "float16" else 64 + shared_scope = "shared.dyn" + + # Pipeline Stage + stage = 2 + + block_M = block_row_warps * warp_row_tiles + block_N = block_col_warps * warp_col_tiles + block_K = chunk + + A_shape = (M, K) # int8 storage represents int4*2 + B_shape = (N, K) # int8 storage represents int4*2 + A_shared_shape = (block_M, block_K) + B_shared_shape = (block_N, block_K) + C_shared_shape = ( + block_M // micro_size_x, + block_N // micro_size_y, + micro_size_x, + micro_size_y, + ) + + warp_size = 32 + threads = warp_size * (block_row_warps * block_col_warps) + local_size_a = (micro_size_x * micro_size_k) // warp_size + local_size_b = (micro_size_y * micro_size_k) // warp_size + local_size_c = (micro_size_x * micro_size_y) // warp_size + warp_rows = warp_row_tiles // micro_size_x + warp_cols = warp_col_tiles // micro_size_y + + # MMA Wrapper to Auto Generate Code for MMA + mma_emitter = INT4TensorCoreIntrinEmitter( + a_dtype=in_dtype, + b_dtype=in_dtype, + accum_dtype=accum_dtype, + a_transposed=False, + b_transposed=True, + block_row_warps=block_row_warps, + block_col_warps=block_col_warps, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=chunk, + ) + + @T.prim_func + def main( + A: T.Buffer(A_shape, in_dtype), + B: T.Buffer(B_shape, in_dtype), + C: T.Buffer((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + + A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) + B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) + C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) + A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) + B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) + C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) + + thread_bindings = T.thread_binding(0, threads, "threadIdx.x") + + T.annotate_layout({ + A_shared: make_swizzle_layout(A_shared), + B_shared: make_swizzle_layout(B_shared), + }) + + # Improve L2 Cache + T.use_swizzle(panel_size=10) + + T.clear(C_local) + + for ko in T.Pipelined((K // block_K), num_stages=stage): + + # Load A into shared memory + for i, k in T.Parallel(block_M, block_K): + A_shared[i, k] = A[by * block_M + i, ko * block_K + k] + + # Load B into shared memory + for j, k in T.Parallel(block_N, block_K): + B_shared[j, k] = B[bx * block_N + j, ko * block_K + k] + + for ki in T.serial(0, (block_K // micro_size_k)): + + # Load A into fragment + mma_emitter.ldmatrix_a( + A_local, + A_shared, + ki, + thread_bindings=thread_bindings, + ) + + # Load B into fragment + mma_emitter.ldmatrix_b( + B_local, + B_shared, + ki, + thread_bindings=thread_bindings, + ) + + # Perform Matrix Multiplication + mma_emitter.mma(A_local, B_local, C_local) + + # Perform STMatrix + mma_emitter.stmatrix( + C_local, + C_shared, + thread_bindings=thread_bindings, + ) + + # Store shared into global + for i, j in T.Parallel(block_M, block_N): + C[by * block_M + i, bx * block_N + j] = C_shared[ + i // micro_size_x, + j // micro_size_y, + i % micro_size_x, + j % micro_size_y, + ] + + return main + + +def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): + matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype) + mod, params = TL.lower(matmul) + src_code = mod.imported_modules[0].get_source() + # src_code is the generated cuda source + assert src_code is not None + # print(src_code) + # A = torch.ones(M, K, device="cuda", dtype=getattr(torch, in_dtype)) + # B = torch.ones(N, K, device="cuda", dtype=getattr(torch, in_dtype)) + A = torch.randint(0, 4, (M, K), device="cuda", dtype=getattr(torch, in_dtype)) + B = torch.randint(0, 4, (N, K), device="cuda", dtype=getattr(torch, in_dtype)) + C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) + + compressed_A = (A[:, ::2] & 0x0F) + ((A[:, 1::2] & 0x0F) << 4) + compressed_B = (B[:, ::2] & 0x0F) + ((B[:, 1::2] & 0x0F) << 4) + mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer) + mod(compressed_A, compressed_B, C) + print(C) + latency = mod.do_bench(mod.func, warmup=25, profiler="tvm") + print(latency) + # Ensure that the latency is not None + assert latency is not None + + # Get Reference Result + ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to(getattr(torch, accum_dtype)) + + print(ref_c) + torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) + + +def test_assert_tl_matmul(): + assert_tl_matmul_correctness(128, 128, 128, "float16", "float16", "float16") + assert_tl_matmul_correctness(128, 256, 256, "float16", "float32", "float32") + + +if __name__ == "__main__": + # bitblas.testing.main() + # assert_tl_matmul_correctness(128, 128, 128, "float16", "float16", "float16") + # assert_tl_matmul_correctness(16384, 16384, 16384, "int8", "int32", "int32") + # assert_tl_matmul_correctness(16384, 16384, 16384, "int8", "int32", "int32") + assert_tl_matmul_correctness(16384, 16384, 16384, "int8", "int32", "int32") + # assert_tl_matmul_correctness(128, 128, 128, "int8", "int32", "int32") diff --git a/integration/BitNet/int4_kernel/tl_int4xint4_ladder_weight_only.py b/integration/BitNet/int4_kernel/tl_int4xint4_ladder_weight_only.py new file mode 100644 index 000000000..1603698b2 --- /dev/null +++ b/integration/BitNet/int4_kernel/tl_int4xint4_ladder_weight_only.py @@ -0,0 +1,238 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import torch +import torch.backends +from bitblas import tvm as tvm +import bitblas.testing +from tvm import DataType +from tvm import tl as TL +import tvm.tl.language as T +from bitblas.tl.utils import make_swizzle_layout +from bitblas.tl.macro_generator import ( + INT4TensorCoreIntrinEmitterWithLadderTransform,) +from bitblas.ops.base_scheduler import simplify_prim_func + +torch.manual_seed(0) + + +@simplify_prim_func +def tl_matmul( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, +): + K = K // 2 + assert in_dtype in [ + "float16", + "int8", + ], "Currently only float16 and int8 are supported" + assert out_dtype in [ + "float16", + "float32", + "int32", + ], "Currently only float16, float32 and int32 are supported" + + micro_size_x = micro_size_y = micro_size_k = 16 + + if out_dtype == "int32": + micro_size_k = 32 + + transform_b = 3 + + # This is a debug config + block_row_warps = 2 + block_col_warps = 2 + warp_row_tiles = 64 + warp_col_tiles = 64 + chunk = 32 if in_dtype == "float16" else 64 + shared_scope = "shared.dyn" + + # Pipeline Stage + stage = 2 + + block_M = block_row_warps * warp_row_tiles + block_N = block_col_warps * warp_col_tiles + block_K = chunk + + is_smooth_a = False + can_swizzle = block_K * DataType(in_dtype).bits == 512 + apply_pad_a = not (is_smooth_a or can_swizzle) + pad_factor = 8 + + A_shape = (M, K) + B_shape = (N // micro_size_y, K // micro_size_k, micro_size_y, micro_size_k) + A_shared_shape = ( + block_M, + (block_K + pad_factor) if apply_pad_a else block_K, + ) + B_shared_shape = ( + block_N // micro_size_y, + block_K // micro_size_k, + micro_size_y, + micro_size_k, + ) + C_shared_shape = ( + block_M // micro_size_x, + block_N // micro_size_y, + micro_size_x, + micro_size_y, + ) + warp_size = 32 + threads = warp_size * (block_row_warps * block_col_warps) + local_size_a = (micro_size_x * micro_size_k) // warp_size + local_size_b = (micro_size_y * micro_size_k) // warp_size + local_size_c = (micro_size_x * micro_size_y) // warp_size + warp_rows = warp_row_tiles // micro_size_x + warp_cols = warp_col_tiles // micro_size_y + + # MMA Wrapper to Auto Generate Code for MMA + mma_emitter = INT4TensorCoreIntrinEmitterWithLadderTransform( + a_dtype=in_dtype, + b_dtype=in_dtype, + accum_dtype=accum_dtype, + a_transposed=False, + b_transposed=True, + block_row_warps=block_row_warps, + block_col_warps=block_col_warps, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=chunk, + transform_kind_b=transform_b, + ) + + @T.prim_func + def main( + A: T.Buffer(A_shape, in_dtype), + B: T.Buffer(B_shape, in_dtype), + C: T.Buffer((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + + A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) + B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) + C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) + A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) + B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) + C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) + + thread_bindings = T.thread_binding(0, threads, "threadIdx.x") + + T.annotate_layout({ + A_shared: make_swizzle_layout(A_shared), + # B_shared: make_swizzle_layout(B_shared), + }) + + # Improve L2 Cache + T.use_swizzle(panel_size=10) + + T.clear(C_local) + + for ko in T.Pipelined((K // block_K), num_stages=stage): + + # Load A into shared memory + for i, k in T.Parallel(block_M, block_K): + A_shared[i, k] = A[by * block_M + i, ko * block_K + k] + + # Load B into shared memory + for j, k, jj, kk in T.Parallel(block_N // micro_size_y, block_K // micro_size_k, + micro_size_y, micro_size_k): + B_shared[j, k, jj, kk] = B[bx * (block_N // micro_size_y) + j, + ko * (block_K // micro_size_k) + k, jj, kk] + + for ki in T.serial(0, (block_K // micro_size_k)): + + # Load A into fragment + mma_emitter.ldmatrix_a( + A_local, + A_shared, + ki, + thread_bindings=thread_bindings, + ) + + # Load B into fragment + mma_emitter.ldmatrix_b( + B_local, + B_shared, + ki, + thread_bindings=thread_bindings, + ) + + # Perform Matrix Multiplication + mma_emitter.mma(A_local, B_local, C_local) + + # Perform STMatrix + mma_emitter.stmatrix( + C_local, + C_shared, + thread_bindings=thread_bindings, + ) + + # Store shared into global + for i, j in T.Parallel(block_M, block_N): + C[by * block_M + i, bx * block_N + j] = C_shared[ + i // micro_size_x, + j // micro_size_y, + i % micro_size_x, + j % micro_size_y, + ] + + return main + + +def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): + matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype) + mod, params = TL.lower(matmul) + src_code = mod.imported_modules[0].get_source() + # src_code is the generated cuda source + assert src_code is not None + transform_b = 3 + + A = torch.randint(0, 4, (M, K), device="cuda", dtype=getattr(torch, in_dtype)) + B = torch.randint(0, 4, (N, K), device="cuda", dtype=getattr(torch, in_dtype)) + C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) + compressed_A = (A[:, ::2] & 0x0F) + ((A[:, 1::2] & 0x0F) << 4) + compressed_B = (B[:, ::2] & 0x0F) + ((B[:, 1::2] & 0x0F) << 4) + + ladder_permutate_config = bitblas.ops.LadderPermutateConfig( + M=N, + N=(K // 2), + datatype="int8", + storage_dtype="int8", + transform_kind=transform_b, + transpose_matrix=True, + ) + + ladder_permutate = bitblas.ops.LadderPermutate(ladder_permutate_config) + + mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer) + LB = ladder_permutate(compressed_B.cpu()).cuda() + + mod(compressed_A, LB, C) + + latency = mod.do_bench(mod.func, warmup=25) + print(f"Latency: {latency}") + # Ensure that the latency is not None + assert latency is not None + + # Get Reference Result + ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to(getattr(torch, accum_dtype)) + print(C) + print(ref_c) + torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) + + +def test_assert_tl_matmul(): + assert_tl_matmul_correctness(128, 128, 128, "float16", "float16", "float16") + assert_tl_matmul_correctness(128, 256, 256, "float16", "float32", "float32") + + +if __name__ == "__main__": + # bitblas.testing.main() + # assert_tl_matmul_correctness(128, 128, 128, "float16", "float16", "float16") + # assert_tl_matmul_correctness(128, 128, 128, "int8", "int32", "int32") + # assert_tl_matmul_correctness(16384, 16384, 16384, "int8", "int32", "int32") + assert_tl_matmul_correctness(128, 128, 128, "int8", "int32", "int32") diff --git a/integration/BitNet/int4_kernel/tl_int8xint8.py b/integration/BitNet/int4_kernel/tl_int8xint8.py new file mode 100644 index 000000000..e809c673e --- /dev/null +++ b/integration/BitNet/int4_kernel/tl_int8xint8.py @@ -0,0 +1,223 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import torch +import torch.backends +from bitblas import tvm as tvm +from tvm import DataType +from tvm import tl as TL +import tvm.tl.language as T +from bitblas.tl.utils import get_swizzle_layout +from bitblas.tl.macro_generator import ( + TensorCoreIntrinEmitter,) +from bitblas.ops.base_scheduler import simplify_prim_func + +torch.manual_seed(0) + + +def make_swizzle_layout(shared_buf): + dtype = shared_buf.dtype + shape = shared_buf.shape + + can_swizzle = shape[-1] * DataType(dtype).bits == 512 + if not can_swizzle: + return T.Layout(shape, lambda *args: args) + + def transform_func(i, j): + new_warp_i, new_warp_j = get_swizzle_layout(i, j, shape[-1], dtype) + return [new_warp_i, new_warp_j] + + return T.Layout(shape, transform_func) + + +@simplify_prim_func +def tl_matmul( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, +): + assert in_dtype in [ + "float16", + "int8", + ], "Currently only float16 and int8 are supported" + assert out_dtype in [ + "float16", + "float32", + "int32", + ], "Currently only float16, float32 and int32 are supported" + + micro_size_x = micro_size_y = micro_size_k = 16 + + if out_dtype == "int32": + micro_size_k = 32 + + # This is a debug config + block_row_warps = 2 + block_col_warps = 2 + warp_row_tiles = 64 + warp_col_tiles = 64 + chunk = 32 if in_dtype == "float16" else 64 + shared_scope = "shared.dyn" + + # Pipeline Stage + stage = 2 + + block_M = block_row_warps * warp_row_tiles + block_N = block_col_warps * warp_col_tiles + block_K = chunk + + A_shape = (M, K) + B_shape = (N, K) + A_shared_shape = (block_M, block_K) + B_shared_shape = (block_N, block_K) + C_shared_shape = ( + block_M // micro_size_x, + block_N // micro_size_y, + micro_size_x, + micro_size_y, + ) + + warp_size = 32 + threads = warp_size * (block_row_warps * block_col_warps) + local_size_a = (micro_size_x * micro_size_k) // warp_size + local_size_b = (micro_size_y * micro_size_k) // warp_size + local_size_c = (micro_size_x * micro_size_y) // warp_size + warp_rows = warp_row_tiles // micro_size_x + warp_cols = warp_col_tiles // micro_size_y + + # MMA Wrapper to Auto Generate Code for MMA + mma_emitter = TensorCoreIntrinEmitter( + a_dtype=in_dtype, + b_dtype=in_dtype, + accum_dtype=accum_dtype, + a_transposed=False, + b_transposed=True, + block_row_warps=block_row_warps, + block_col_warps=block_col_warps, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=chunk, + ) + + @T.prim_func + def main( + A: T.Buffer(A_shape, in_dtype), + B: T.Buffer(B_shape, in_dtype), + C: T.Buffer((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + + A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) + B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) + C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) + A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) + B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) + C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) + + thread_bindings = T.thread_binding(0, threads, "threadIdx.x") + + T.annotate_layout({ + A_shared: make_swizzle_layout(A_shared), + B_shared: make_swizzle_layout(B_shared), + }) + + # Improve L2 Cache + T.use_swizzle(panel_size=10) + + T.clear(C_local) + + for ko in T.Pipelined((K // block_K), num_stages=stage): + + # Load A into shared memory + for i, k in T.Parallel(block_M, block_K): + A_shared[i, k] = A[by * block_M + i, ko * block_K + k] + + # Load B into shared memory + for j, k in T.Parallel(block_N, block_K): + B_shared[j, k] = B[bx * block_N + j, ko * block_K + k] + + for ki in T.serial(0, (block_K // micro_size_k)): + + # Load A into fragment + mma_emitter.ldmatrix_a( + A_local, + A_shared, + ki, + thread_bindings=thread_bindings, + ) + + # Load B into fragment + mma_emitter.ldmatrix_b( + B_local, + B_shared, + ki, + thread_bindings=thread_bindings, + ) + + # Perform Matrix Multiplication + mma_emitter.mma(A_local, B_local, C_local) + + # Perform STMatrix + mma_emitter.stmatrix( + C_local, + C_shared, + thread_bindings=thread_bindings, + ) + + # Store shared into global + for i, j in T.Parallel(block_M, block_N): + C[by * block_M + i, bx * block_N + j] = C_shared[ + i // micro_size_x, + j // micro_size_y, + i % micro_size_x, + j % micro_size_y, + ] + + return main + + +def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): + matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype) + mod, params = TL.lower(matmul) + src_code = mod.imported_modules[0].get_source() + # src_code is the generated cuda source + assert src_code is not None + print(src_code) + if in_dtype == "int8": + A = torch.randint(-7, 7, (M, K), device="cuda", dtype=torch.int8) + B = torch.randint(-7, 7, (N, K), device="cuda", dtype=torch.int8) + else: + A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype)) + B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) + + C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) + + mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer) + + mod(A, B, C) + + latency = mod.do_bench(mod.func, warmup=25) + print(f"Latency: {latency}") + # Ensure that the latency is not None + assert latency is not None + + # Get Reference Result + ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to(getattr(torch, accum_dtype)) + print(C) + print(ref_c) + torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) + + +def test_assert_tl_matmul(): + assert_tl_matmul_correctness(128, 128, 128, "float16", "float16", "float16") + assert_tl_matmul_correctness(128, 256, 256, "float16", "float32", "float32") + + +if __name__ == "__main__": + # bitblas.testing.main() + # assert_tl_matmul_correctness(128, 128, 128, "float16", "float16", "float16") + # assert_tl_matmul_correctness(128, 128, 128, "int8", "int32", "int32") + assert_tl_matmul_correctness(16384, 16384, 16384, "int8", "int32", "int32") diff --git a/integration/BitNet/int4_kernel/tl_int8xint8_ladder_weight_only.py b/integration/BitNet/int4_kernel/tl_int8xint8_ladder_weight_only.py new file mode 100644 index 000000000..733441f2f --- /dev/null +++ b/integration/BitNet/int4_kernel/tl_int8xint8_ladder_weight_only.py @@ -0,0 +1,255 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import torch +import torch.backends +from bitblas import tvm as tvm +import bitblas.testing +from tvm import DataType +from tvm import tl as TL +import tvm.tl.language as T +from bitblas.tl.utils import get_swizzle_layout +from bitblas.tl.macro_generator import ( + TensorCoreIntrinEmitterWithLadderTransform,) +from bitblas.ops.base_scheduler import simplify_prim_func + +torch.manual_seed(0) + + +def make_swizzle_layout(shared_buf): + dtype = shared_buf.dtype + shape = shared_buf.shape + + can_swizzle = shape[-1] * DataType(dtype).bits == 512 + if not can_swizzle: + return T.Layout(shape, lambda *args: args) + + def transform_func(i, j): + new_warp_i, new_warp_j = get_swizzle_layout(i, j, shape[-1], dtype) + return [new_warp_i, new_warp_j] + + return T.Layout(shape, transform_func) + + +@simplify_prim_func +def tl_matmul( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, +): + assert in_dtype in [ + "float16", + "int8", + ], "Currently only float16 and int8 are supported" + assert out_dtype in [ + "float16", + "float32", + "int32", + ], "Currently only float16, float32 and int32 are supported" + + micro_size_x = micro_size_y = micro_size_k = 16 + + if out_dtype == "int32": + micro_size_k = 32 + + transform_b = 3 + + # This is a debug config + block_row_warps = 2 + block_col_warps = 2 + warp_row_tiles = 64 + warp_col_tiles = 64 + chunk = 32 if in_dtype == "float16" else 64 + shared_scope = "shared.dyn" + + # Pipeline Stage + stage = 2 + + block_M = block_row_warps * warp_row_tiles + block_N = block_col_warps * warp_col_tiles + block_K = chunk + + is_smooth_a = False + can_swizzle = block_K * DataType(in_dtype).bits == 512 + apply_pad_a = not (is_smooth_a or can_swizzle) + pad_factor = 8 + + A_shape = (M, K) + B_shape = (N // micro_size_y, K // micro_size_k, micro_size_y, micro_size_k) + A_shared_shape = ( + block_M, + (block_K + pad_factor) if apply_pad_a else block_K, + ) + B_shared_shape = ( + block_N // micro_size_y, + block_K // micro_size_k, + micro_size_y, + micro_size_k, + ) + C_shared_shape = ( + block_M // micro_size_x, + block_N // micro_size_y, + micro_size_x, + micro_size_y, + ) + warp_size = 32 + threads = warp_size * (block_row_warps * block_col_warps) + local_size_a = (micro_size_x * micro_size_k) // warp_size + local_size_b = (micro_size_y * micro_size_k) // warp_size + local_size_c = (micro_size_x * micro_size_y) // warp_size + warp_rows = warp_row_tiles // micro_size_x + warp_cols = warp_col_tiles // micro_size_y + + # MMA Wrapper to Auto Generate Code for MMA + mma_emitter = TensorCoreIntrinEmitterWithLadderTransform( + a_dtype=in_dtype, + b_dtype=in_dtype, + accum_dtype=accum_dtype, + a_transposed=False, + b_transposed=True, + block_row_warps=block_row_warps, + block_col_warps=block_col_warps, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=chunk, + transform_kind_b=transform_b, + ) + + @T.prim_func + def main( + A: T.Buffer(A_shape, in_dtype), + B: T.Buffer(B_shape, in_dtype), + C: T.Buffer((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + + A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) + B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) + C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) + A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) + B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) + C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) + + thread_bindings = T.thread_binding(0, threads, "threadIdx.x") + + T.annotate_layout({ + A_shared: make_swizzle_layout(A_shared), + B_shared: make_swizzle_layout(B_shared), + }) + + # Improve L2 Cache + T.use_swizzle(panel_size=10) + + T.clear(C_local) + + for ko in T.Pipelined((K // block_K), num_stages=stage): + + # Load A into shared memory + for i, k in T.Parallel(block_M, block_K): + A_shared[i, k] = A[by * block_M + i, ko * block_K + k] + + # Load B into shared memory + for j, k, jj, kk in T.Parallel(block_N // micro_size_y, block_K // micro_size_k, + micro_size_y, micro_size_k): + B_shared[j, k, jj, kk] = B[bx * (block_N // micro_size_y) + j, + ko * (block_K // micro_size_k) + k, jj, kk] + + for ki in T.serial(0, (block_K // micro_size_k)): + + # Load A into fragment + mma_emitter.ldmatrix_a( + A_local, + A_shared, + ki, + thread_bindings=thread_bindings, + ) + + # Load B into fragment + mma_emitter.ldmatrix_b( + B_local, + B_shared, + ki, + thread_bindings=thread_bindings, + ) + + # Perform Matrix Multiplication + mma_emitter.mma(A_local, B_local, C_local) + + # Perform STMatrix + mma_emitter.stmatrix( + C_local, + C_shared, + thread_bindings=thread_bindings, + ) + + # Store shared into global + for i, j in T.Parallel(block_M, block_N): + C[by * block_M + i, bx * block_N + j] = C_shared[ + i // micro_size_x, + j // micro_size_y, + i % micro_size_x, + j % micro_size_y, + ] + + return main + + +def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): + matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype) + mod, params = TL.lower(matmul) + src_code = mod.imported_modules[0].get_source() + # src_code is the generated cuda source + assert src_code is not None + print(src_code) + transform_b = 3 + + if in_dtype == "int8": + A = torch.randint(-7, 7, (M, K), device="cuda", dtype=torch.int8) + B = torch.randint(-7, 7, (N, K), device="cuda", dtype=torch.int8) + else: + A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype)) + B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) + + C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) + + ladder_permutate_config = bitblas.ops.LadderPermutateConfig( + M=N, + N=K, + datatype="int8", + storage_dtype="int8", + transform_kind=transform_b, + transpose_matrix=True, + ) + + ladder_permutate = bitblas.ops.LadderPermutate(ladder_permutate_config) + + mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer) + LB = ladder_permutate(B.cpu()).cuda() + + mod(A, LB, C) + + latency = mod.do_bench(mod.func, warmup=25) + print(f"Latency: {latency}") + # Ensure that the latency is not None + assert latency is not None + + # Get Reference Result + ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to(getattr(torch, accum_dtype)) + print(C) + print(ref_c) + torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) + + +def test_assert_tl_matmul(): + assert_tl_matmul_correctness(128, 128, 128, "float16", "float16", "float16") + assert_tl_matmul_correctness(128, 256, 256, "float16", "float32", "float32") + + +if __name__ == "__main__": + # bitblas.testing.main() + # assert_tl_matmul_correctness(128, 128, 128, "float16", "float16", "float16") + # assert_tl_matmul_correctness(128, 128, 128, "int8", "int32", "int32") + assert_tl_matmul_correctness(16384, 16384, 16384, "int8", "int32", "int32") diff --git a/testing/cpp/lop3_type_conversion/CMakeLists.txt b/testing/cpp/lop3_type_conversion/CMakeLists.txt index 61903faf4..8b104ca47 100644 --- a/testing/cpp/lop3_type_conversion/CMakeLists.txt +++ b/testing/cpp/lop3_type_conversion/CMakeLists.txt @@ -10,3 +10,4 @@ endfunction(ADD_CUDA_TEST_EXECUTABLE) ADD_CUDA_TEST_EXECUTABLE(lowprecision_to_float16) ADD_CUDA_TEST_EXECUTABLE(lowprecision_to_int8) +ADD_CUDA_TEST_EXECUTABLE(lowprecision_to_int4) diff --git a/testing/cpp/lop3_type_conversion/fast_decoding.hpp b/testing/cpp/lop3_type_conversion/fast_decoding.hpp index 6d5b6335a..e6f8b2923 100644 --- a/testing/cpp/lop3_type_conversion/fast_decoding.hpp +++ b/testing/cpp/lop3_type_conversion/fast_decoding.hpp @@ -797,3 +797,90 @@ __device__ void decode_i1u_to_i8s(T1 *_i1u, T2 *B_local_decode, const int N = 16 { decode_i1b_to_i8s(_i1u, B_local_decode, N); } + + +void general_interleave_int4(int8_t *origin_arr, int8_t *interleaved, const int nbit, size_t size_in_bytes, bool verbose = false) +{ + // For int4 example + // i2s {e15,e14,e13,e12,e11,e10,e9,e8,e7,e6,e5,e4,e3,e2,e1,e0} + // |-----8b-----||-----8b-----||-----8b-----||-----8b-----| + // 0b00110011 0b00110011 0b00110011 0b00110011 + // interleave {e15,e7,e14,e6,e13,e5,e12,e4,e11,e3,e10,e2,e9,e1,e8,e0} + + size_t size = size_in_bytes / sizeof(int32_t); + int32_t *int32_origin = (int32_t *)origin_arr; + int32_t *int32_interleaved = (int32_t *)interleaved; + + constexpr int bits_stride = 4; + int elems_per_group = bits_stride / nbit; + int mask = (1 << nbit) - 1; + int num_groups = 32 / bits_stride; + + for (int idx = 0; idx < size; ++idx) + { + int32_t current_value = int32_origin[idx]; + int32_t new_value = 0; + for (int i = 0; i < num_groups; ++i) + { + for (int j = 0; j < elems_per_group; ++j) + { + int offset = i * elems_per_group + j; + int shift = (offset % num_groups) * bits_stride + (offset / num_groups) * nbit; + int group_value = (current_value >> (nbit * (i * elems_per_group + j))) & mask; + new_value |= group_value << shift; + if (verbose) + printf("put %d to %d\n", offset, shift); + } + } + if (nbit == 1) + { + throw std::runtime_error("Not implemented"); + } + else + int32_interleaved[idx] = new_value; + } + + // Convert back to int8_t if needed + memcpy(interleaved, int32_interleaved, size * sizeof(int32_t)); +} + + +template +__device__ void decode_i2b_to_i4s(T1 *_i2b, T2 *_i4s, const int N = 16) +{ + uint *i4s = reinterpret_cast(_i4s); + uint *i2b = reinterpret_cast(_i2b); + // First, we extract the i4s and construct an intermediate i8 number. + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x33333333; // 0xf -> 0b1111 select 0,2,4,6,8,10,12 + static constexpr uint I4b_TO_I8s_MAGIC_NUM = 0x00000000; // 0 + static constexpr uint MEDIAN_NUM = isSigned ? 0x33333333 : 0x00000000; + +#pragma unroll + for (int i = 0; i < (N / 8); i++) + { + // Extract elt_01 - (i4s & 0x000f000f) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(i4s[i]) + : "r"(i2b[i / 2] >> (2 * (i % 2))), "n"(BOTTOM_MASK), "n"(I4b_TO_I8s_MAGIC_NUM), "n"(immLut)); + if constexpr (isSigned) + { + // TODO(lei): uint4 sub should be enhanced. + // 0x03 0x03 0x03 0x03 + i4s[i] = (((i4s[i] << 1) | i4s[i]) << 1) | i4s[i]; + } + } +} + +template +__device__ void decode_i2s_to_i4s(T1 *_i4s, T2 *B_local_decode, const int N = 16) +{ + decode_i2b_to_i4s(_i4s, B_local_decode, N); +} + +template +__device__ void decode_i2u_to_i4s(T1 *_i4u, T2 *B_local_decode, const int N = 16) +{ + decode_i2b_to_i4s(_i4u, B_local_decode, N); +} + diff --git a/testing/cpp/lop3_type_conversion/lowprecision_to_int4.cu b/testing/cpp/lop3_type_conversion/lowprecision_to_int4.cu new file mode 100644 index 000000000..d39a85dcd --- /dev/null +++ b/testing/cpp/lop3_type_conversion/lowprecision_to_int4.cu @@ -0,0 +1,235 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. +#include +#include +#include +#include +#include "fast_decoding.hpp" + +#define cudaCheckLastError(ans) \ + { \ + gpuAssert((ans), __FILE__, __LINE__); \ + } +inline void gpuAssert(cudaError_t code, const char *file, int line, bool abort = true) +{ + if (code != cudaSuccess) + { + fprintf(stderr, "GPUassert: %s %s %d\n", cudaGetErrorString(code), file, line); + if (abort) + exit(code); + } +} + +#define REGISTER_GLOBAL_DEVICE_INVOKER(kernel, function) \ + template \ + __global__ void kernel(Args... args) \ + { \ + function(args...); \ + } + +REGISTER_GLOBAL_DEVICE_INVOKER(kernelWrapper_i2s_to_i4s, decode_i2s_to_i4s) +REGISTER_GLOBAL_DEVICE_INVOKER(kernelWrapper_i2u_to_i4s, decode_i2u_to_i4s) + +// TEST(DecodeTest, DecodeInt4ToINT8) +// { +// using target_dtype = int8_t; +// constexpr int nbits = 2; +// constexpr int N = 32 / nbits; +// constexpr int QN = N / 8 * nbits; +// constexpr bool isSigned = true; +// constexpr int zero_point = isSigned ? ((1 << (nbits - 1)) - 1) : 0; + +// // create four int8_t values +// int8_t in_data[N] = { +// 0, +// }; +// // breed seed +// srand(0); + +// // random initializations with nbits range +// for (int i = 0; i < N; i++) +// { +// in_data[i] = (rand() % (1 << nbits)) - zero_point; +// } + +// // print input data +// printf("in_data \n"); +// for (int i = 0; i < N; i++) +// { +// printf("i:%d %d %x \n", i, in_data[i], in_data[i]); +// } + +// int8_t *ins = new int8_t[QN]; +// for (int i = 0; i < QN; i++) +// { +// ins[i] = (in_data[i * 4] & 0x3) | ((in_data[i * 4 + 1] & 0x3) << 2) | ((in_data[i * 4 + 2] & 0x3) << 4) | ((in_data[i * 4 + 3] & 0x3) << 6); +// } +// // print input data +// printf("ins \n"); +// for (int i = 0; i < QN; i++) +// { +// printf("i:%d %d %x b: ", i, ins[i], ins[i]); +// for (int j = 7; j >= 0; j--) +// { +// printf("%d", (ins[i] >> j) & 1); +// } +// printf("\n"); +// } +// printf("\n"); +// int8_t *interleaved = new int8_t[QN]; +// general_interleave_int4(ins, interleaved, 2, QN * sizeof(int8_t), true); +// printf("interleaved \n"); +// for (int i = 0; i < QN; i++) +// { +// printf("i:%d %d %x b: ", i, interleaved[i], interleaved[i]); +// for (int j = 7; j >= 0; j--) +// { +// printf("%d", (interleaved[i] >> j) & 1); +// } +// printf("\n"); +// } +// target_dtype *decoded = new target_dtype[N]; +// int8_t *ins_gpu; +// target_dtype *decoded_gpu; + +// cudaCheckLastError(cudaMalloc((void **)&ins_gpu, QN * sizeof(int8_t))); +// cudaCheckLastError(cudaMalloc((void **)&decoded_gpu, N * sizeof(target_dtype))); +// cudaCheckLastError(cudaMemcpy(ins_gpu, interleaved, QN * sizeof(int8_t), cudaMemcpyHostToDevice)); +// cudaCheckLastError(cudaMemcpy(decoded_gpu, decoded, N * sizeof(target_dtype), cudaMemcpyHostToDevice)); +// cudaCheckLastError(cudaDeviceSynchronize()); + +// kernelWrapper_i2s_to_i4s<<>>(ins_gpu, decoded_gpu); +// cudaCheckLastError(cudaDeviceSynchronize()); +// cudaCheckLastError(cudaMemcpy(decoded, decoded_gpu, N * sizeof(target_dtype), cudaMemcpyDeviceToHost)); +// cudaCheckLastError(cudaFree(ins_gpu)); +// cudaCheckLastError(cudaFree(decoded_gpu)); +// printf("decoded \n"); +// for (int i = 0; i < (N / 2); i++) +// { +// printf("i %d %d %x \n", i, decoded[i], decoded[i]); +// } +// // output data int8 +// int8_t i8_out[N] = { +// 0, +// }; +// for (int i = 0; i < N; i++) +// { +// i8_out[i] = (decoded[i / 2] >> (4 * (i % 2)) ) & 0xf; +// } +// printf("i8_out \n"); +// for (int i = 0; i < N; i++) +// { +// printf("i %d in_data: %d %x decode_data: %d %x \n", i, in_data[i], in_data[i], i8_out[i], i8_out[i]); +// } +// for (int i = 0; i < (N / 2); i++) +// { +// EXPECT_EQ(in_data[i], int(i8_out[i])); +// } +// free(ins); +// free(interleaved); +// free(decoded); +// } + + +// int32 -> 16 int2 -> 4 int8 +// -> 16 int4 -> 8 int8 +TEST(DecodeTest, DecodeUInt4ToINT8) +{ + using target_dtype = int8_t; + constexpr int nbits = 2; + constexpr int N = 32 / nbits; + constexpr int QN = N / 8 * nbits; + constexpr bool isSigned = false; + constexpr int zero_point = isSigned ? ((1 << (nbits - 1)) - 1) : 0; + + // create four int8_t values + int8_t in_data[N] = { + 0, + }; + // breed seed + srand(0); + + // random initializations with nbits range + for (int i = 0; i < N; i++) + { + in_data[i] = (rand() % (1 << nbits)) - zero_point; + // in_data[i] = (i % 2); + // in_data[i] = 1; + } + + // print input data + for (int i = 0; i < N; i++) + { + printf("i:%d %d %x \n", i, in_data[i], in_data[i]); + } + + int8_t *ins = new int8_t[QN]; + for (int i = 0; i < QN; i++) + { + ins[i] = (in_data[i * 4] & 0x3) | ((in_data[i * 4 + 1] & 0x3) << 2) | ((in_data[i * 4 + 2] & 0x3) << 4) | ((in_data[i * 4 + 3] & 0x3) << 6); + } + // print input data + printf("ins \n"); + for (int i = 0; i < QN; i++) + { + printf("i:%d %d %x b: ", i, ins[i], ins[i]); + for (int j = 7; j >= 0; j--) + { + printf("%d", (ins[i] >> j) & 1); + } + printf("\n"); + } + printf("\n"); + int8_t *interleaved = new int8_t[QN]; + general_interleave_int4(ins, interleaved, 2, QN * sizeof(int8_t), true); + printf("interleaved \n"); + for (int i = 0; i < QN; i++) + { + printf("i:%d %d %x b: ", i, interleaved[i], interleaved[i]); + for (int j = 7; j >= 0; j--) + { + printf("%d", (interleaved[i] >> j) & 1); + } + printf("\n"); + } + target_dtype *decoded = new target_dtype[N]; + int8_t *ins_gpu; + target_dtype *decoded_gpu; + + cudaCheckLastError(cudaMalloc((void **)&ins_gpu, QN * sizeof(int8_t))); + cudaCheckLastError(cudaMalloc((void **)&decoded_gpu, N * sizeof(target_dtype))); + cudaCheckLastError(cudaMemcpy(ins_gpu, interleaved, QN * sizeof(int8_t), cudaMemcpyHostToDevice)); + cudaCheckLastError(cudaMemcpy(decoded_gpu, decoded, N * sizeof(target_dtype), cudaMemcpyHostToDevice)); + cudaCheckLastError(cudaDeviceSynchronize()); + + kernelWrapper_i2u_to_i4s<<>>(ins_gpu, decoded_gpu); + cudaCheckLastError(cudaDeviceSynchronize()); + cudaCheckLastError(cudaMemcpy(decoded, decoded_gpu, N * sizeof(target_dtype), cudaMemcpyDeviceToHost)); + cudaCheckLastError(cudaFree(ins_gpu)); + cudaCheckLastError(cudaFree(decoded_gpu)); + printf("decoded \n"); + for (int i = 0; i < (N / 2); i++) + { + printf("i %d %d %x \n", i, decoded[i], decoded[i]); + } + // output data int8 + int8_t i8_out[N] = { + 0, + }; + for (int i = 0; i < N; i++) + { + i8_out[i] = (decoded[i / 2] >> (4 * (i % 2)) ) & 0xf; + } + printf("i8_out \n"); + for (int i = 0; i < N; i++) + { + printf("i %d in_data: %d %x decode_data: %d %x \n", i, in_data[i], in_data[i], i8_out[i], i8_out[i]); + } + for (int i = 0; i < (N / 2); i++) + { + EXPECT_EQ(in_data[i], int(i8_out[i])); + } + free(ins); + free(interleaved); + free(decoded); +}