Skip to content

Commit

Permalink
lint fix
Browse files Browse the repository at this point in the history
  • Loading branch information
LeiWang1999 committed Nov 1, 2024
1 parent 811e5c7 commit 4a0afc9
Show file tree
Hide file tree
Showing 8 changed files with 109 additions and 117 deletions.
10 changes: 10 additions & 0 deletions bitblas/ops/base_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,13 @@ def common_header(self):
# TODO(lei): For HIP Backend it should be different
common_header = "#include <tl_templates/cuda/common.h>\n"
return common_header


# Decorator to simplify the output of a function
def simplify_prim_func(func: Callable):

def wrapper(*args, **kwargs):
stmt: Union[PrimFunc, IRModule] = (func)(*args, **kwargs)
return BaseScheduler.Simplify(stmt)

return wrapper
125 changes: 42 additions & 83 deletions bitblas/tl/macro_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,27 +61,21 @@ def __init__(
self.chunk = chunk
self._initialize_k_dim(a_dtype)
self._initialize_abbrev(a_dtype, b_dtype, accum_dtype)
self._initialize_local_size(
self.M_DIM, self.N_DIM, self.k_dim, self.WARP_SIZE
)
self._initialize_local_size(self.M_DIM, self.N_DIM, self.k_dim, self.WARP_SIZE)
self._initialize_mma_prefix(self.k_dim)
self._initialize_micro_size(self.M_DIM, self.N_DIM, self.k_dim)
self.warp_rows = warp_row_tiles // self.micro_size_x
self.warp_cols = warp_col_tiles // self.micro_size_y
self.reduce_k = reduce_k
self.threads = (
self.WARP_SIZE * (block_row_warps * block_col_warps) * reduce_k
)
self.threads = (self.WARP_SIZE * (block_row_warps * block_col_warps) * reduce_k)
self.num_elems_per_byte = num_elems_per_byte

def _initialize_k_dim(self, a_dtype="float16"):
if isinstance(a_dtype, str):
a_dtype = DataType(a_dtype)
self.k_dim = 256 // a_dtype.bits

def _initialize_local_size(
self, m_dim=16, n_dim=16, k_dim=16, warp_size=32
):
def _initialize_local_size(self, m_dim=16, n_dim=16, k_dim=16, warp_size=32):
self.local_size_a = (m_dim * k_dim) // warp_size
self.local_size_b = (n_dim * k_dim) // warp_size
self.local_size_out = (m_dim * n_dim) // warp_size
Expand Down Expand Up @@ -136,20 +130,14 @@ def _warp_ldmatrix_a(
".b16",
A_local_buf.data,
i * local_size_a,
T.address_of(
A_shared_buf[
ty * warp_row_tiles + i * micro_size_x,
rk * chunk + ki * micro_size_k,
]
),
get_ldmatrix_offset(
"A", tx, 0, stride, a_dtype, a_transposed
),
T.address_of(A_shared_buf[
ty * warp_row_tiles + i * micro_size_x,
rk * chunk + ki * micro_size_k,
]),
get_ldmatrix_offset("A", tx, 0, stride, a_dtype, a_transposed),
)

return _warp_ldmatrix_a(
A_local_buf, A_shared_buf, ki, thread_bindings, rk
)
return _warp_ldmatrix_a(A_local_buf, A_shared_buf, ki, thread_bindings, rk)

def ldmatrix_b(self, B_local_buf, B_shared_buf, ki, thread_bindings, rk=0):

Expand All @@ -175,9 +163,7 @@ def _warp_ldmatrix_b(
):
stride = B_shared_buf.shape[-1]
tx = thread_bindings % WARP_SIZE
tz = (
thread_bindings // (WARP_SIZE * block_row_warps)
) % block_col_warps
tz = (thread_bindings // (WARP_SIZE * block_row_warps)) % block_col_warps

for j in T.serial(warp_cols):
# Assign B_shared_elem
Expand All @@ -195,14 +181,10 @@ def _warp_ldmatrix_b(
B_local_buf.data,
j * local_size_b,
T.address_of(B_shared_elem),
get_ldmatrix_offset(
"B", tx, 0, stride, b_dtype, b_transposed
),
get_ldmatrix_offset("B", tx, 0, stride, b_dtype, b_transposed),
)

return _warp_ldmatrix_b(
B_local_buf, B_shared_buf, ki, thread_bindings, rk
)
return _warp_ldmatrix_b(B_local_buf, B_shared_buf, ki, thread_bindings, rk)

def mma(self, A_local_buf, B_local_buf, C_local_buf):
warp_rows = self.warp_rows
Expand Down Expand Up @@ -249,9 +231,7 @@ def _warp_mma(A_local_buf, B_local_buf, C_local_buf):
B_local_buf.data,
j * local_size_b + lift(local_size_b) // 2,
C_local_buf.data,
i * warp_cols * local_size_out
+ j * local_size_out
+ lift(local_size_out) // 2,
i * warp_cols * local_size_out + j * local_size_out + lift(local_size_out) // 2,
T.bool(False),
)

Expand All @@ -273,21 +253,15 @@ def stmatrix(self, C_local_buf, C_shared_buf, thread_bindings):
def _warp_stmatrix(C_local_buf, C_shared_buf, thread_bindings):
tx = thread_bindings % WARP_SIZE
ty = (thread_bindings // WARP_SIZE) % block_row_warps
tz = (
thread_bindings // (WARP_SIZE * block_row_warps)
) % block_col_warps
tz = (thread_bindings // (WARP_SIZE * block_row_warps)) % block_col_warps
for i, j in T.grid(warp_rows, warp_cols):
for local_id_o in T.serial(local_size_out // 2):
for local_id_i in T.vectorized(2):
local_id = local_id_o * 2 + local_id_i
row, col = T.meta_var(mma_store_index_map(tx, local_id))
C_shared_buf[
ty * warp_rows + i, tz * warp_cols + j, row, col
] = C_local_buf[
i * (warp_cols * local_size_out)
+ j * local_size_out
+ local_id
]
C_shared_buf[ty * warp_rows + i, tz * warp_cols + j, row,
col] = C_local_buf[i * (warp_cols * local_size_out) +
j * local_size_out + local_id]

return _warp_stmatrix(C_local_buf, C_shared_buf, thread_bindings)

Expand Down Expand Up @@ -334,9 +308,7 @@ def __init__(
def _initialize_k_dim(self, a_dtype="float16"):
self.k_dim = 256 // DataType(a_dtype).bits

def _initialize_local_size(
self, m_dim=16, n_dim=16, k_dim=16, warp_size=32
):
def _initialize_local_size(self, m_dim=16, n_dim=16, k_dim=16, warp_size=32):
self.local_size_a = (m_dim * k_dim) // warp_size
self.local_size_b = (n_dim * k_dim) // warp_size
self.local_size_out = (m_dim * n_dim) // warp_size
Expand Down Expand Up @@ -380,34 +352,31 @@ def _initialize_transform_kind(self, transform_kind_a, transform_kind_b):
assert transform_kind_b in [0, 3], "Currently only support 0 and 3"

def ldmatrix_b(self, B_local_buf, B_shared_buf, ki, thread_bindings, rk=0):
WARP_SIZE = self.WARP_SIZE
block_row_warps = self.block_row_warps
block_col_warps = self.block_col_warps
warp_col_tiles = self.warp_col_tiles
warp_cols = self.warp_cols
chunk = self.chunk
micro_size_y = self.micro_size_y
micro_size_k = self.micro_size_k
local_size_b = self.local_size_b
b_dtype = self.b_dtype
transform_kind_b = self.transform_kind_b
b_transposed = self.b_transposed
num_elems_per_byte = self.num_elems_per_byte

@T.macro
def _warp_ldmatrix_b(
inst,
B_local_buf,
B_shared_buf,
ki,
thread_bindings,
rk=0,
):
WARP_SIZE = inst.WARP_SIZE
block_row_warps = inst.block_row_warps
block_col_warps = inst.block_col_warps
warp_col_tiles = inst.warp_col_tiles
warp_cols = inst.warp_cols
chunk = inst.chunk
micro_size_y = inst.micro_size_y
micro_size_k = inst.micro_size_k
local_size_b = inst.local_size_b
b_dtype = inst.b_dtype
transform_kind_b = inst.transform_kind_b
b_transposed = inst.b_transposed
num_elems_per_byte = inst.num_elems_per_byte

stride = B_shared_buf.shape[-1]
tx = thread_bindings % WARP_SIZE
tz = (
thread_bindings // (WARP_SIZE * block_row_warps)
) % block_col_warps
tz = (thread_bindings // (WARP_SIZE * block_row_warps)) % block_col_warps

if transform_kind_b < TransformKind.LDMatrixTransform:
for j in T.serial(warp_cols):
Expand All @@ -422,9 +391,7 @@ def _warp_ldmatrix_b(
(ri) % micro_size_y,
(rj) % micro_size_k,
)
args = (
(ni, nj, nii, njj) if transform_kind_b > 0 else (ri, rj)
)
args = ((ni, nj, nii, njj) if transform_kind_b > 0 else (ri, rj))
B_shared_elem = B_shared_buf[args]

T.ptx_ldmatrix(
Expand All @@ -435,9 +402,7 @@ def _warp_ldmatrix_b(
B_local_buf.data,
j * local_size_b,
T.address_of(B_shared_elem),
get_ldmatrix_offset(
"B", tx, 0, stride, b_dtype, b_transposed
),
get_ldmatrix_offset("B", tx, 0, stride, b_dtype, b_transposed),
)
else:
local_size_dequantize = local_size_b // num_elems_per_byte
Expand All @@ -448,18 +413,14 @@ def _warp_ldmatrix_b(
tz * warp_cols + j,
rk * (chunk // micro_size_k) + ki,
)
rii, rjj = (tx * local_size_dequantize + local_id) // (
micro_size_k // num_elems_per_byte
), (tx * local_size_dequantize + local_id) % (
micro_size_k // num_elems_per_byte
)
rii, rjj = (tx * local_size_dequantize +
local_id) // (micro_size_k // num_elems_per_byte), (
tx * local_size_dequantize + local_id) % (
micro_size_k // num_elems_per_byte)
B_local_buf[j * local_size_dequantize + local_id] = (
B_shared_buf[ri, rj, rii, rjj]
)
B_shared_buf[ri, rj, rii, rjj])

return _warp_ldmatrix_b(
B_local_buf, B_shared_buf, ki, thread_bindings, rk
)
return _warp_ldmatrix_b(B_local_buf, B_shared_buf, ki, thread_bindings, rk)

def mma(self, A_local_buf, B_local_buf, C_local_buf):
warp_rows = self.warp_rows
Expand Down Expand Up @@ -506,9 +467,7 @@ def _warp_mma(A_local_buf, B_local_buf, C_local_buf):
B_local_buf.data,
j * local_size_b + lift(local_size_b) // 2,
C_local_buf.data,
i * warp_cols * local_size_out
+ j * local_size_out
+ lift(local_size_out) // 2,
i * warp_cols * local_size_out + j * local_size_out + lift(local_size_out) // 2,
T.bool(False),
)

Expand Down
10 changes: 5 additions & 5 deletions bitblas/tl/mma_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@ def ldmatrix_trans_32x8_to_shared_16x16_layout(thread_id, local_id):
return row, col


def ldmatrix_32x16_to_shared_16x32_layout_a(thread_id, local_id):
def ldmatrix_16x32_to_shared_16x32_layout_a(thread_id, local_id):
row = thread_id % 16
col = local_id + (thread_id // 16) * 16
col = 16 * (thread_id // 16) + local_id % 16
return row, col


def ldmatrix_32x16_to_shared_16x32_layout_b(thread_id, local_id):
row = (thread_id // 16) * 8 + (thread_id % 8)
col = local_id + 16 * ((thread_id % 16) // 8)
def ldmatrix_16x32_to_shared_16x32_layout_b(thread_id, local_id):
row = 8 * (thread_id // 16) + (thread_id % 8)
col = 16 * ((thread_id % 16) // 8) + local_id % 16
return row, col


Expand Down
42 changes: 27 additions & 15 deletions bitblas/tl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from .mma_layout import (
ldmatrix_32x8_to_shared_16x16_layout,
ldmatrix_trans_32x8_to_shared_16x16_layout,
ldmatrix_32x16_to_shared_16x32_layout_a,
ldmatrix_32x16_to_shared_16x32_layout_b,
ldmatrix_16x32_to_shared_16x32_layout_a,
ldmatrix_16x32_to_shared_16x32_layout_b,
mma_store_32x8_to_shared_16x16_layout,
)

Expand Down Expand Up @@ -70,28 +70,40 @@ def get_swizzle_layout(row_idx, col_idx, row_size, dtype: Union[DataType, str]):
return row_idx, ana.simplify(new_col_idx_outer * bank_elems + col_idx_inner)


# the original implementation and insight is from the following code snippet
# 3rdparty/tvm/python/tvm/tir/tensor_intrin/cuda.py#get_ldmatrix_intrin
def get_ldmatrix_offset(
matrix: Literal["A", "B"],
row_idx,
col_idx,
stride,
dtype: Literal["float16", "int8"] = "float16",
transpose: bool = False,
transposed: bool = False,
):
assert matrix in ["A", "B"], "matrix should be either A or B"
transform_func = (
ldmatrix_32x8_to_shared_16x16_layout
if dtype in ["float16", "bfloat16"] else ldmatrix_32x16_to_shared_16x32_layout_b)
transform_func_trans = (
ldmatrix_trans_32x8_to_shared_16x16_layout
if dtype in ["float16", "bfloat16"] else ldmatrix_32x16_to_shared_16x32_layout_a)
if matrix == "A":
assert not transpose, "A matrix should not be transposed"
new_row_idx, new_col_idx = transform_func(row_idx, col_idx)
return new_row_idx * stride + new_col_idx
dtype_bits = DataType(dtype).bits
if dtype_bits == 16:
transform_func = ldmatrix_32x8_to_shared_16x16_layout
transform_func_trans = ldmatrix_trans_32x8_to_shared_16x16_layout
if transposed:
new_row_idx, new_col_idx = transform_func_trans(row_idx, col_idx)
return new_row_idx * stride + new_col_idx
else:
new_row_idx, new_col_idx = transform_func(row_idx, col_idx)
return new_row_idx * stride + new_col_idx
elif dtype_bits == 8:
if matrix == "B" and transposed:
transform_func = ldmatrix_16x32_to_shared_16x32_layout_b
new_row_idx, new_col_idx = transform_func(row_idx, col_idx)
return new_row_idx * stride + new_col_idx
elif matrix == "A" and not transposed:
transform_func = ldmatrix_16x32_to_shared_16x32_layout_a
new_row_idx, new_col_idx = transform_func(row_idx, col_idx)
return new_row_idx * stride + new_col_idx
else:
raise ValueError("ldmatrix only supports B transposed and A non-transposed for int8")
else:
new_row_idx, new_col_idx = transform_func_trans(row_idx, col_idx)
return new_row_idx * stride + new_col_idx
raise ValueError(f"Unsupported dtype {dtype}")


def mma_store_index_map(*args, **kwargs):
Expand Down
4 changes: 2 additions & 2 deletions format.sh
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ echo 'bitblas codespell: Done'
echo 'bitblas ruff: Check Start'
# Lint specified files
lint() {
ruff "$@"
ruff check "$@"
}

# Lint files that differ from main branch. Ignores dirs that are not slated
Expand All @@ -170,7 +170,7 @@ lint_changed() {

if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.py' '*.pyi' &>/dev/null; then
git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' '*.pyi' | xargs \
ruff
ruff check
fi

}
Expand Down
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
yapf==0.40.2
toml==0.10.2
tomli==2.0.1
ruff==0.1.5
ruff==0.6.5
codespell==2.3.0

cffi
Expand Down
2 changes: 1 addition & 1 deletion requirements-test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
yapf==0.40.2
toml==0.10.2
tomli==2.0.1
ruff==0.1.5
ruff==0.6.5
codespell==2.3.0

cffi
Expand Down
Loading

0 comments on commit 4a0afc9

Please sign in to comment.