Skip to content

Commit

Permalink
Refactor dequantize scheduler and simplify pass
Browse files Browse the repository at this point in the history
  • Loading branch information
LeiWang1999 committed Oct 23, 2024
1 parent 20f8ce6 commit f56b1ed
Show file tree
Hide file tree
Showing 5 changed files with 297 additions and 134 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,6 @@ def general_dequant_matmul(
Zeros,
Qzeros,
local_size,
local_size_compressed,
bx,
tx,
k,
Expand Down Expand Up @@ -384,7 +383,6 @@ def _normal_dequant(
zeros_buffer: T.Buffer,
qzeros_buffer: T.Buffer,
local_size: int,
local_size_compressed: int,
pid_n: T.Var,
tx: T.Var,
k: T.Var,
Expand Down Expand Up @@ -413,9 +411,9 @@ def _normal_dequant_impl(
qzeros_buffer: T.Buffer,
):
for v in T.serial(0, local_size):
index = (i * threads * local_size_compressed + tx * local_size_compressed + v)
vi = index // (stride_k // num_elems_per_byte)
vj = index % (stride_k // num_elems_per_byte)
index = (i * threads * local_size + tx * local_size + v)
vi = index // stride_k
vj = index % stride_k
if not with_scaling:
dequant_weight_local[v] = self._decode_func(
num_bits,
Expand Down Expand Up @@ -486,12 +484,9 @@ def _normal_fast_dequant(
qzeros_buffer: T.Buffer,
func_name: str,
pid_n: T.Var,
tx: T.Var,
k: T.Var,
i: T.Var,
stride_n: int,
stride_k: int,
threads: int,
):
num_elems_per_byte = self.num_elems_per_byte
with_scaling = self.with_scaling
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -512,9 +512,9 @@ def _normal_dequant_impl(
qzeros_buffer: T.Buffer,
):
for v in T.serial(0, local_size):
index = (i * threads * local_size_compressed + tx * local_size_compressed + v)
vi = index // (stride_k // num_elems_per_byte)
vj = index % (stride_k // num_elems_per_byte)
index = (i * threads * local_size + tx * local_size + v)
vi = index // (stride_k)
vj = index % (stride_k)
if not with_scaling:
dequant_weight_local[v] = self._decode_func(
num_bits,
Expand Down Expand Up @@ -592,6 +592,7 @@ def _normal_fast_dequant(
stride_k: int,
threads: int,
):
# TODO(lei): un-used arguments should be removed
num_elems_per_byte = self.num_elems_per_byte
with_scaling = self.with_scaling
with_zeros = self.with_zeros
Expand Down
Loading

0 comments on commit f56b1ed

Please sign in to comment.