diff --git a/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py b/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py index eea256fd9..13658aab4 100644 --- a/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py +++ b/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py @@ -424,7 +424,9 @@ def apply_config( threads = warp_size * (block_row_warps * block_col_warps) # Calculate local fragment sizes for tensor core - local_size = (micro_size_x * micro_size_y) // warp_size + local_size_a = (micro_size_x * micro_size_k) // warp_size + local_size_b = (micro_size_y * micro_size_k) // warp_size + local_size_c = (micro_size_x * micro_size_y) // warp_size warp_rows = warp_row_tiles // micro_size_x warp_cols = warp_col_tiles // micro_size_y @@ -459,9 +461,9 @@ def main( A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) - A_local = T.alloc_local((warp_rows * local_size), in_dtype) - B_local = T.alloc_local((warp_cols * local_size), in_dtype) - C_local = T.alloc_local((warp_rows * warp_cols * local_size), accum_dtype) + A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) + B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) + C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) # Thread-level parallelism for Tensor Cores thread_bindings = T.thread_binding(0, threads, "threadIdx.x") diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/finegrained_primitive_tensorcore.py b/bitblas/ops/general_matmul/tilelang/dequantize/finegrained_primitive_tensorcore.py index d755ba2f8..d57951455 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/finegrained_primitive_tensorcore.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/finegrained_primitive_tensorcore.py @@ -231,7 +231,9 @@ def apply_config( block_K = chunk threads = warp_size * (block_row_warps * block_col_warps) - fragement_size = (micro_size_x * micro_size_y) // warp_size + fragement_size_a = (micro_size_x * micro_size_k) // warp_size + fragement_size_b = (micro_size_y * micro_size_k) // warp_size + fragement_size_c = (micro_size_x * micro_size_y) // warp_size warp_rows = warp_row_tiles // micro_size_x warp_cols = warp_col_tiles // micro_size_y @@ -318,9 +320,9 @@ def general_dequant_matmul( B_dequantize_shared = T.alloc_shared(B_dequantize_shared_shape, in_dtype) C_shared = T.alloc_shared(C_shared_shape, out_dtype) - A_frag = T.alloc_local((warp_rows * fragement_size), in_dtype) - B_frag = T.alloc_local((warp_cols * fragement_size), in_dtype) - C_frag = T.alloc_local((warp_rows * warp_cols * fragement_size), accum_dtype) + A_frag = T.alloc_local((warp_rows * fragement_size_a), in_dtype) + B_frag = T.alloc_local((warp_cols * fragement_size_b), in_dtype) + C_frag = T.alloc_local((warp_rows * warp_cols * fragement_size_c), accum_dtype) B_local = T.alloc_local([local_size_compressed], storage_dtype) B_dequantize_local = T.alloc_local([local_size], in_dtype) diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/ladder_weight_transform_tensorcore.py b/bitblas/ops/general_matmul/tilelang/dequantize/ladder_weight_transform_tensorcore.py index bb463e59a..7f8920575 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/ladder_weight_transform_tensorcore.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/ladder_weight_transform_tensorcore.py @@ -71,7 +71,9 @@ def apply_config( block_K = chunk threads = warp_size * (block_row_warps * block_col_warps) - fragement_size = (micro_size_x * micro_size_y) // warp_size + fragement_size_a = (micro_size_x * micro_size_k) // warp_size + fragement_size_b = (micro_size_y * micro_size_k) // warp_size + fragement_size_c = (micro_size_x * micro_size_y) // warp_size warp_rows = warp_row_tiles // micro_size_x warp_cols = warp_col_tiles // micro_size_y @@ -173,11 +175,11 @@ def general_dequant_matmul( B_shared = T.alloc_shared(B_shared_shape, storage_dtype) C_shared = T.alloc_shared(C_shared_shape, out_dtype) - A_frag = T.alloc_local((warp_rows * fragement_size), in_dtype) - B_frag = T.alloc_local((warp_cols * fragement_size // num_elems_per_byte), + A_frag = T.alloc_local((warp_rows * fragement_size_a), in_dtype) + B_frag = T.alloc_local((warp_cols * fragement_size_b // num_elems_per_byte), storage_dtype) - B_dequantize_frag = T.alloc_local((warp_cols * fragement_size), in_dtype) - C_frag = T.alloc_local((warp_rows * warp_cols * fragement_size), accum_dtype) + B_dequantize_frag = T.alloc_local((warp_cols * fragement_size_b), in_dtype) + C_frag = T.alloc_local((warp_rows * warp_cols * fragement_size_c), accum_dtype) tx = T.thread_binding(0, threads, thread="threadIdx.x")