From 282311e440f934b9154ef17200d8335c94eaa8c1 Mon Sep 17 00:00:00 2001 From: WeiweiZhang1 Date: Wed, 7 Aug 2024 09:19:23 +0800 Subject: [PATCH] Limit the scale minimum value not to 0 (#209) Signed-off-by: Zhang, Weiwei1 --- auto_round/data_type/int.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/auto_round/data_type/int.py b/auto_round/data_type/int.py index 726a2b9c..3852f4e7 100644 --- a/auto_round/data_type/int.py +++ b/auto_round/data_type/int.py @@ -19,7 +19,7 @@ @register_dtype("int_asym") def quant_tensor_asym(weight, num_bits=4, v=0, min_scale=1.0, max_scale=1.0, scale_dtype=torch.float16, - weight_min=None, weight_max=None, q_scale_thresh=0.0,**kwargs): + weight_min=None, weight_max=None, q_scale_thresh=0.0 ,**kwargs): """Quantizes and dequantizes weight asymmetrically. Args: @@ -59,6 +59,8 @@ def quant_tensor_asym(weight, num_bits=4, v=0, min_scale=1.0, max_scale=1.0, sca wmax[tmp] = +1 scale = ((wmax - wmin) / maxq).to(scale_dtype) scale = torch.clamp(scale, min=q_scale_thresh) + if (scale == 0.).any(): + scale = torch.clamp(scale, min=1e-5) zp = round_ste(-wmin / scale) # pylint: disable=E1130 scale = scale.unsqueeze(dim=-1) zp = zp.unsqueeze(dim=-1) @@ -68,7 +70,7 @@ def quant_tensor_asym(weight, num_bits=4, v=0, min_scale=1.0, max_scale=1.0, sca @register_dtype("int_sym") def quant_tensor_sym(weight, num_bits=4, v=0, min_scale=1.0, max_scale=1.0, scale_dtype=torch.float16, weight_min=None, - weight_max=None, q_scale_thresh=0.0,**kargs): + weight_max=None, q_scale_thresh=0.0, **kargs): """Quantizes and dequantizes weight symmetrically. Args: @@ -114,9 +116,12 @@ def quant_tensor_sym(weight, num_bits=4, v=0, min_scale=1.0, max_scale=1.0, scal wmax_new[tmp] = +1 scale = ((wmax_new - wmin_new) / maxq).to(scale_dtype) scale = torch.clamp(scale, min=q_scale_thresh) + if (scale == 0.).any(): + scale = torch.clamp(scale, min=1e-5) scale = scale.unsqueeze(dim=-1) zp = torch.full_like(scale, (maxq + 1) / 2) int_w = round_ste(weight / scale + v) q = torch.clamp(int_w + zp, 0, maxq) return scale * (q - zp), scale, zp +