Skip to content

Commit

Permalink
Refactor tensor core memory allocation in MatmulFineGrainScheduler
Browse files Browse the repository at this point in the history
- Adjusted the local fragment sizes for tensor core memory allocation in the MatmulFineGrainScheduler class.
- Updated the allocation sizes for A_local, B_local, and C_local variables based on the new fragment sizes.
- The changes ensure efficient memory utilization and improve performance.

Refactor tensor core memory allocation in MatmulDequantizeFineGrainedScheduler

- Modified the fragment sizes for tensor core memory allocation in the MatmulDequantizeFineGrainedScheduler class.
- Updated the allocation sizes for A_frag, B_frag, and C_frag variables based on the new fragment sizes.
- The changes optimize memory usage and enhance the efficiency of the dequantization process.

Refactor tensor core memory allocation in MatmulDequantizeWeightPropagationScheduler

- Adjusted the fragment sizes for tensor core memory allocation in the MatmulDequantizeWeightPropagationScheduler class.
- Updated the allocation sizes for A_frag, B_frag, B_dequantize_frag, and C_frag variables based on the new fragment sizes.
- The changes improve memory utilization and optimize the weight propagation process.
  • Loading branch information
LeiWang1999 committed Nov 1, 2024
1 parent 4a0afc9 commit d2f7fcb
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 13 deletions.
10 changes: 6 additions & 4 deletions bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,9 @@ def apply_config(
threads = warp_size * (block_row_warps * block_col_warps)

# Calculate local fragment sizes for tensor core
local_size = (micro_size_x * micro_size_y) // warp_size
local_size_a = (micro_size_x * micro_size_k) // warp_size
local_size_b = (micro_size_y * micro_size_k) // warp_size
local_size_c = (micro_size_x * micro_size_y) // warp_size
warp_rows = warp_row_tiles // micro_size_x
warp_cols = warp_col_tiles // micro_size_y

Expand Down Expand Up @@ -459,9 +461,9 @@ def main(
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope)
C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope)
A_local = T.alloc_local((warp_rows * local_size), in_dtype)
B_local = T.alloc_local((warp_cols * local_size), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size), accum_dtype)
A_local = T.alloc_local((warp_rows * local_size_a), in_dtype)
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype)

# Thread-level parallelism for Tensor Cores
thread_bindings = T.thread_binding(0, threads, "threadIdx.x")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,9 @@ def apply_config(
block_K = chunk
threads = warp_size * (block_row_warps * block_col_warps)

fragement_size = (micro_size_x * micro_size_y) // warp_size
fragement_size_a = (micro_size_x * micro_size_k) // warp_size
fragement_size_b = (micro_size_y * micro_size_k) // warp_size
fragement_size_c = (micro_size_x * micro_size_y) // warp_size
warp_rows = warp_row_tiles // micro_size_x
warp_cols = warp_col_tiles // micro_size_y

Expand Down Expand Up @@ -318,9 +320,9 @@ def general_dequant_matmul(
B_dequantize_shared = T.alloc_shared(B_dequantize_shared_shape, in_dtype)
C_shared = T.alloc_shared(C_shared_shape, out_dtype)

A_frag = T.alloc_local((warp_rows * fragement_size), in_dtype)
B_frag = T.alloc_local((warp_cols * fragement_size), in_dtype)
C_frag = T.alloc_local((warp_rows * warp_cols * fragement_size), accum_dtype)
A_frag = T.alloc_local((warp_rows * fragement_size_a), in_dtype)
B_frag = T.alloc_local((warp_cols * fragement_size_b), in_dtype)
C_frag = T.alloc_local((warp_rows * warp_cols * fragement_size_c), accum_dtype)

B_local = T.alloc_local([local_size_compressed], storage_dtype)
B_dequantize_local = T.alloc_local([local_size], in_dtype)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,9 @@ def apply_config(
block_K = chunk
threads = warp_size * (block_row_warps * block_col_warps)

fragement_size = (micro_size_x * micro_size_y) // warp_size
fragement_size_a = (micro_size_x * micro_size_k) // warp_size
fragement_size_b = (micro_size_y * micro_size_k) // warp_size
fragement_size_c = (micro_size_x * micro_size_y) // warp_size
warp_rows = warp_row_tiles // micro_size_x
warp_cols = warp_col_tiles // micro_size_y

Expand Down Expand Up @@ -173,11 +175,11 @@ def general_dequant_matmul(
B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
C_shared = T.alloc_shared(C_shared_shape, out_dtype)

A_frag = T.alloc_local((warp_rows * fragement_size), in_dtype)
B_frag = T.alloc_local((warp_cols * fragement_size // num_elems_per_byte),
A_frag = T.alloc_local((warp_rows * fragement_size_a), in_dtype)
B_frag = T.alloc_local((warp_cols * fragement_size_b // num_elems_per_byte),
storage_dtype)
B_dequantize_frag = T.alloc_local((warp_cols * fragement_size), in_dtype)
C_frag = T.alloc_local((warp_rows * warp_cols * fragement_size), accum_dtype)
B_dequantize_frag = T.alloc_local((warp_cols * fragement_size_b), in_dtype)
C_frag = T.alloc_local((warp_rows * warp_cols * fragement_size_c), accum_dtype)

tx = T.thread_binding(0, threads, thread="threadIdx.x")

Expand Down

0 comments on commit d2f7fcb

Please sign in to comment.