Skip to content

Commit

Permalink
Limit the scale minimum value not to 0 (#209)
Browse files Browse the repository at this point in the history
Signed-off-by: Zhang, Weiwei1 <[email protected]>
  • Loading branch information
WeiweiZhang1 committed Aug 7, 2024
1 parent a988940 commit 282311e
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions auto_round/data_type/int.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

@register_dtype("int_asym")
def quant_tensor_asym(weight, num_bits=4, v=0, min_scale=1.0, max_scale=1.0, scale_dtype=torch.float16,
weight_min=None, weight_max=None, q_scale_thresh=0.0,**kwargs):
weight_min=None, weight_max=None, q_scale_thresh=0.0 ,**kwargs):
"""Quantizes and dequantizes weight asymmetrically.
Args:
Expand Down Expand Up @@ -59,6 +59,8 @@ def quant_tensor_asym(weight, num_bits=4, v=0, min_scale=1.0, max_scale=1.0, sca
wmax[tmp] = +1
scale = ((wmax - wmin) / maxq).to(scale_dtype)
scale = torch.clamp(scale, min=q_scale_thresh)
if (scale == 0.).any():
scale = torch.clamp(scale, min=1e-5)
zp = round_ste(-wmin / scale) # pylint: disable=E1130
scale = scale.unsqueeze(dim=-1)
zp = zp.unsqueeze(dim=-1)
Expand All @@ -68,7 +70,7 @@ def quant_tensor_asym(weight, num_bits=4, v=0, min_scale=1.0, max_scale=1.0, sca

@register_dtype("int_sym")
def quant_tensor_sym(weight, num_bits=4, v=0, min_scale=1.0, max_scale=1.0, scale_dtype=torch.float16, weight_min=None,
weight_max=None, q_scale_thresh=0.0,**kargs):
weight_max=None, q_scale_thresh=0.0, **kargs):
"""Quantizes and dequantizes weight symmetrically.
Args:
Expand Down Expand Up @@ -114,9 +116,12 @@ def quant_tensor_sym(weight, num_bits=4, v=0, min_scale=1.0, max_scale=1.0, scal
wmax_new[tmp] = +1
scale = ((wmax_new - wmin_new) / maxq).to(scale_dtype)
scale = torch.clamp(scale, min=q_scale_thresh)
if (scale == 0.).any():
scale = torch.clamp(scale, min=1e-5)
scale = scale.unsqueeze(dim=-1)
zp = torch.full_like(scale, (maxq + 1) / 2)

int_w = round_ste(weight / scale + v)
q = torch.clamp(int_w + zp, 0, maxq)
return scale * (q - zp), scale, zp

0 comments on commit 282311e

Please sign in to comment.