From cae3d823cec4eb9ad781d9e589f1487e79c9286f Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Wed, 15 May 2024 19:17:31 -0400 Subject: [PATCH] Match torch.fake_quantize numerics in 8da4w QAT (#229) Summary: There are two subtle differences between the 8da4w quant primitives and `torch.fake_quantize_per_channel_affine` today: 1. 8da4w uses float32 zero points torch.fake_quantize uses int32 zero points 2. 8da4w uses input.div(scales) torch.fake_quantize uses input.mul(1.0 / scales) Of these two differences, the second one is smaller and only resulted in 0.1% elements mismatched in unit tests, but it is a source of numerical divergence nonetheless. This commit changes 8da4w QAT quant primitives to match the torch.fake_quantize behavior for both of these differences. In a future commit, we will change the 8da4w PTQ quant primitives as well so PTQ and QAT remain consistent. Note: This commit also has the side effect of reducing memory footprint significantly for bf16 inputs. We now cast them to fp32 before multiplying them with fp32 scales. This reduced memory usage presumably because bf16 * fp32 kernels are not as memory efficient. Test Plan: python test/quantization/test_qat.py -k test_qat_generic_fake_quantize Reviewers: jerryzh168, cpuhrsch Subscribers: jerryzh168, cpuhrsch, supriyar --- test/quantization/test_qat.py | 45 +++++++++++++---- torchao/quantization/prototype/qat.py | 64 ++++++++++++++++-------- torchao/quantization/quant_primitives.py | 2 +- 3 files changed, 78 insertions(+), 33 deletions(-) diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index fe2db8066..93323df0f 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -14,6 +14,7 @@ from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401 from torchao.quantization.prototype.qat import ( _choose_qparams_per_token_asymmetric, + _GenericFakeQuantize, fake_quantize_per_channel_group, fake_quantize_per_token, ) @@ -58,7 +59,7 @@ def _get_qmin_qmax(self, n_bit: int): qmax = 2 ** (n_bit - 1) - 1 return (qmin, qmax) - @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower") + @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower") def test_fake_quantize_per_channel_group(self): n_bit = 4 (qmin, qmax) = self._get_qmin_qmax(n_bit) @@ -67,6 +68,7 @@ def test_fake_quantize_per_channel_group(self): torch.manual_seed(self.SEED) x = torch.randn(100, 256).requires_grad_() (s, zp) = get_group_qparams_symmetric(x, n_bit, group_size) + zp = zp.to(torch.int32) x2 = copy.deepcopy(x) # fake quant op @@ -84,7 +86,7 @@ def test_fake_quantize_per_channel_group(self): ) torch.testing.assert_close(out, out_ptq, atol=0, rtol=0) - @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower") + @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower") def test_fake_quantize_per_token(self): (qmin, qmax) = self._get_qmin_qmax(8) @@ -92,10 +94,7 @@ def test_fake_quantize_per_token(self): x = torch.randn(100, 256).requires_grad_() x2 = copy.deepcopy(x) # TODO: use torch.ops.aten.quantized_decomposed version instead - (s, zp) = _choose_qparams_per_token_asymmetric( - x, - torch.int8, # not used - ) + (s, zp) = _choose_qparams_per_token_asymmetric(x, torch.float32, torch.int32) # fake quant op out = fake_quantize_per_token(x, s, zp, qmin, qmax) @@ -130,7 +129,7 @@ def _set_ptq_weight( ptq_linear.scales = s ptq_linear.zeros = zp - @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower") + @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower") def test_qat_8da4w_linear(self): from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATLinear from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear @@ -155,7 +154,7 @@ def test_qat_8da4w_linear(self): ptq_out = ptq_linear(x2) torch.testing.assert_close(ptq_out, qat_out, atol=0, rtol=0) - @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower") + @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower") def test_qat_8da4w_quantizer(self): from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer from torchao.quantization.GPTQ import Int8DynActInt4WeightQuantizer @@ -189,7 +188,7 @@ def test_qat_8da4w_quantizer(self): for k in ptq_state_dict.keys(): torch.testing.assert_close(ptq_state_dict[k], converted_state_dict[k], atol=0, rtol=0) - @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower") + @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower") def test_qat_8da4w_quantizer_meta_weights(self): from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer @@ -201,7 +200,7 @@ def test_qat_8da4w_quantizer_meta_weights(self): qat_model = qat_quantizer.prepare(m) self.assertTrue(all(v.is_meta for v in qat_model.state_dict().values())) - @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower") + @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower") def test_qat_8da4w_quantizer_disable_fake_quant(self): """ Test that 8da4w QAT with disabled fake quant matches nn.Linear in forward. @@ -254,7 +253,7 @@ def test_qat_8da4w_quantizer_disable_fake_quant(self): qat_out2 = qat_model2(*x2) torch.testing.assert_close(qat_out, qat_out2, atol=0, rtol=0) - @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower") + @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower") def test_qat_8da4w_quantizer_disable_fake_quant_backward(self): """ Test that 8da4w QAT with disabled fake quant matches nn.Linear in backward. @@ -299,6 +298,30 @@ def test_qat_8da4w_quantizer_disable_fake_quant_backward(self): torch.testing.assert_close(nn_model.linear2.weight, qat_model.linear2.weight, atol=0, rtol=0) torch.testing.assert_close(nn_model.sub.linear.weight, qat_model.sub.linear.weight, atol=0, rtol=0) + @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower") + def test_qat_generic_fake_quantize(self): + """ + Test that the generic fake quantize used in 8da4w QAT matches + the numerics of existing fake quantize ops in Pytorch in both + the forward and the backward passes. + """ + (qmin, qmax) = self._get_qmin_qmax(4) + py_input = torch.randn(16, 64).float().requires_grad_() + py_s = torch.randn(16).float() + py_zp = torch.randint(qmax, size=(16,), dtype=torch.int32) + py_out = torch.fake_quantize_per_channel_affine(py_input, py_s, py_zp, 0, qmin, qmax) + py_out.sum().backward() + + ao_input = copy.deepcopy(py_input) + ao_input.grad.data.zero_() + ao_s = copy.deepcopy(py_s).reshape(-1, 1) + ao_zp = copy.deepcopy(py_zp).reshape(-1, 1) + ao_out = _GenericFakeQuantize.apply(ao_input, ao_s, ao_zp, qmin, qmax) + ao_out.sum().backward() + + torch.testing.assert_close(py_out, ao_out, atol=0, rtol=0) + torch.testing.assert_close(py_input.grad, ao_input.grad, atol=0, rtol=0) + if __name__ == "__main__": unittest.main() diff --git a/torchao/quantization/prototype/qat.py b/torchao/quantization/prototype/qat.py index 314543bb8..6cda8eeee 100644 --- a/torchao/quantization/prototype/qat.py +++ b/torchao/quantization/prototype/qat.py @@ -4,18 +4,18 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -from typing import Any, Optional, Tuple +from typing import Any, Tuple import torch from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib from torch.library import impl -from torchao.quantization.utils import TORCH_VERSION_AFTER_2_3 +from torchao.quantization.utils import TORCH_VERSION_AFTER_2_4 from torchao.quantization.quant_primitives import get_group_qparams_symmetric from torchao.quantization.unified import TwoStepQuantizer -if TORCH_VERSION_AFTER_2_3: +if TORCH_VERSION_AFTER_2_4: from torchao.quantization.GPTQ import ( _replace_linear_8da4w, Int8DynActInt4WeightLinear, @@ -54,7 +54,7 @@ def prepare( self.precision, self.scales_precision, Int8DynActInt4WeightQATLinear, - copy_weights = True, + copy_weights=True, ) return model @@ -95,7 +95,7 @@ def _convert_qat_linear_8da4w(module: torch.nn.Module): quantized_linear.zeros = zp else: _convert_qat_linear_8da4w(child) - + class Int8DynActInt4WeightQATLinear(torch.nn.Linear): """ This module implements a linear layer with int8 dynamic per token fake @@ -131,6 +131,8 @@ def __init__( self.groupsize = groupsize self.precision = precision self.scales_precision = scales_precision + # TODO: make this configurable? + self.zero_points_precision = torch.int32 self._fake_quant_enabled = True def enable_fake_quant(self, enabled: bool = True): @@ -142,8 +144,8 @@ def disable_fake_quant(self): def forward(self, x: torch.Tensor) -> torch.Tensor: # activations: int8 dynamic asymmetric quant if self._fake_quant_enabled: - (act_scales, act_zp) =_choose_qparams_per_token_asymmetric( - x, torch.int8, # dtype not used + (act_scales, act_zp) = _choose_qparams_per_token_asymmetric( + x, self.scales_precision, self.zero_points_precision, ) (act_qmin, act_qmax) = self._get_qmin_qmax(8) x_fq = fake_quantize_per_token( @@ -157,6 +159,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: (weight_scales, weight_zp) = get_group_qparams_symmetric( self.weight, 4, self.groupsize, self.scales_precision, ) + # TODO: pass zp dtype to `get_group_qparams_symmetric` instead + weight_zp = weight_zp.to(self.zero_points_precision) (weight_qmin, weight_qmax) = self._get_qmin_qmax(4) w_fq = fake_quantize_per_channel_group( self.weight, @@ -190,6 +194,20 @@ def disable_8da4w_fake_quant(mod: torch.nn.Module): if isinstance(mod, Int8DynActInt4WeightQATLinear): mod.disable_fake_quant() +else: # not TORCH_VERSION_AFTER_2_4 + + class Int8DynActInt4WeightQATQuantizer: + def __init__(*args, **kwargs): + raise ValueError( + "Int8DynActInt4WeightQATQuantizer is only supported after PyTorch 2.4+" + ) + + class Int8DynActInt4WeightQATLinear: + def __init__(*args, **kwargs): + raise ValueError( + "Int8DynActInt4WeightQATLinear is only supported after PyTorch 2.4+" + ) + # ======================== # | QUANT PRIMITIVES | @@ -205,13 +223,14 @@ class _GenericFakeQuantize(torch.autograd.Function): @staticmethod def forward(ctx, input, scales, zero_points, quant_min, quant_max): - # Note: this diverges from `torch.fake_quantize_per_channel_affine`, - # which rounds first before adding the zero points. However, this - # is what `quantize_per_channel_group` and `quantize_per_token` - # do and here we try to match that behavior as closely as possible. - q = input.mul(1.0 / scales).add(zero_points).round() + # Note: for bf16 inputs, casting them to fp32 has the unexpected + # side effect of reducing memory footprint significantly, presumably + # because bf16 * fp32 kernels are not as memory efficient + assert input.dtype == torch.float32 + assert scales.dtype == torch.float32 + assert zero_points.dtype == torch.int32 + q = input.mul(1.0 / scales).round().add(zero_points) dq = q.clamp(quant_min, quant_max).sub(zero_points).mul(scales) - # TODO: do we need this mask? mask = torch.logical_and((q >= quant_min), (q <= quant_max)) ctx.save_for_backward(mask) return dq @@ -239,14 +258,13 @@ def fake_quantize_per_channel_group( assert group_size > 1 assert input.shape[-1] % group_size == 0 assert input.dim() == 2 - assert torch.isnan(input).sum() == 0 - grouped_input = input.reshape(-1, group_size) + grouped_input = input.reshape(-1, group_size).to(torch.float32) scales = scales.reshape(-1, 1) zero_points = zero_points.reshape(-1, 1) fq = _GenericFakeQuantize.apply( grouped_input, scales, zero_points, quant_min, quant_max, ) - return fq.reshape_as(input) + return fq.reshape_as(input).to(input.dtype) # TODO: move this to core quantized_decomposed_lib.define( @@ -266,9 +284,11 @@ def fake_quantize_per_token( from torch.ao.quantization.fx._decomposed import _per_token_quant_qparam_dim_check _per_token_quant_qparam_dim_check(input, scales, zero_points) - return _GenericFakeQuantize.apply( - input, scales, zero_points, quant_min, quant_max, + fq_input = input.to(torch.float32) + fq = _GenericFakeQuantize.apply( + fq_input, scales, zero_points, quant_min, quant_max, ) + return fq.reshape_as(input).to(input.dtype) # TODO: This is copied from torch/ao/quantization/fx/_decomposed.py. # The version in pytorch does not have backward support yet so we add @@ -276,7 +296,8 @@ def fake_quantize_per_token( # is landed. def _choose_qparams_per_token_asymmetric( input: torch.Tensor, - dtype: torch.dtype, + scales_precision: torch.dtype = torch.float32, + zero_points_precision: torch.dtype = torch.float32, ) -> Tuple[torch.Tensor, torch.Tensor]: """Choose quantization parameters for per token quantization. This means for a N dimension Tensor (M1, M2, ...Mn, N), we calculate scales/zero_points for each N elements and quantize @@ -285,7 +306,8 @@ def _choose_qparams_per_token_asymmetric( Args: input (torch.Tensor): original float32/float16 Tensor - dtype (torch.dtype): dtype (e.g. torch.uint8) for input Tensor + scales_precision (torch.dtype): precision of returned scales + zero_points_precision (torch.dtype): precision of returned zero points Returns: scales and zero_points, both float32 Tensors @@ -314,4 +336,4 @@ def _choose_qparams_per_token_asymmetric( ) zero_point = torch.clamp(zero_point, qmin, qmax).round() - return scale.to(torch.float32), zero_point.to(torch.float32) + return scale.to(scales_precision), zero_point.to(zero_points_precision) diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index 30c685448..e1de871e2 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -764,7 +764,7 @@ def groupwise_affine_dequantize_tensor( ) -# TODO: replace this with torch.ao.quantization.PerChannelMinMaxObserver +# TODO: separate scale and zero point precision def get_group_qparams_symmetric(w, n_bit=4, groupsize=128, precision=torch.float32): # needed for GPTQ with padding if groupsize > w.shape[-1]: