Skip to content

Commit

Permalink
[Quant tool] Improve performance of int4 weight quantization (#20935)
Browse files Browse the repository at this point in the history
### Description
- Uses our own quantization functions instead of the ONNX reference
implementation of QuantizeLinear when quantizing weights to int4.
- Uses a custom function that packs bytes into 4-bit elements.



### Motivation and Context
Running the quantization tool to create QDQ models with int4 weights
could take up to 7x longer. This PR uses our own quantization and byte
packing utilities to improve performance.

#### Measurements
Model with ~5M parameters to quantize to int4.

- Current implementation: **84.5s**
- Only replace ONNX QuantizeLinear implementation: **50.3s** (1.68x
speedup)
- This PR (replace onnx Q impl, custom packing func): **13.5s** (6.26x
speedup)

---------

Signed-off-by: adrianlizarraga <[email protected]>
  • Loading branch information
adrianlizarraga authored Jun 5, 2024
1 parent 4cb23b0 commit df28c7d
Show file tree
Hide file tree
Showing 3 changed files with 137 additions and 49 deletions.
39 changes: 23 additions & 16 deletions onnxruntime/python/tools/quantization/base_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
find_by_name,
model_has_infer_metadata,
normalize_axis,
pack_bytes_to_4bit,
quantize_data,
quantize_nparray,
save_and_reload_model_with_shape_infer,
Expand Down Expand Up @@ -340,13 +341,17 @@ def quantize_initializer_impl(self, weight, qType, reduce_range=False, keep_floa
f"\nraw={str(q_weight_initializer)[:200]}."
)
elif qType in (onnx.TensorProto.INT4, onnx.TensorProto.UINT4):
# TODO: Use simpler make_tensor call when ONNX bug that does not store negative weights packed
# within int32_data is fixed.
# q_weight_initializer = onnx.helper.make_tensor(q_weight_name, qType, weight.dims, q_weight_data)
packed_data = onnx.helper.pack_float32_to_4bit(q_weight_data.flatten(), qType == onnx.TensorProto.INT4)
q_weight_initializer = onnx.helper.make_tensor(
q_weight_name, qType, weight.dims, packed_data.tobytes(), raw=True
)
if q_weight_data.dtype not in (np.int8, np.uint8):
raise RuntimeError(
f"Quantized weights for {q_weight_name} must be 8-bit before packing as 4-bit values."
)

# We do not use onnx.helper.pack_float32_to_4bit() due to performance.
# This can be the difference between a large model taking 30 minutes to quantize vs 5 minutes.
packed_data = bytes(pack_bytes_to_4bit(q_weight_data.tobytes()))

# We only use onnx.helper.make_tensor with raw data due to bug: https://github.com/onnx/onnx/pull/6161
q_weight_initializer = onnx.helper.make_tensor(q_weight_name, qType, weight.dims, packed_data, raw=True)
else:
q_weight_data = np.asarray(q_weight_data, dtype=onnx.helper.tensor_dtype_to_np_dtype(qType)).reshape(
weight.dims
Expand Down Expand Up @@ -483,16 +488,18 @@ def quantize_weight_per_channel_impl(

if not keep_float_weight:
if weight_qType in (onnx.TensorProto.INT4, onnx.TensorProto.UINT4):
# TODO: Use simpler make_tensor call when ONNX bug that does not store negative weights packed
# within int32_data is fixed.
# q_weight_initializer = onnx.helper.make_tensor(
# q_weight_name, weight_qType, weights_shape, quantized_weights
# )
packed_data = onnx.helper.pack_float32_to_4bit(
quantized_weights.flatten(), weight_qType == onnx.TensorProto.INT4
)
if quantized_weights.dtype not in (np.int8, np.uint8):
raise RuntimeError(
f"Quantized weights for {q_weight_name} must be 8-bit before packing as 4-bit values."
)

# We do not use onnx.helper.pack_float32_to_4bit() due to performance.
# This can be the difference between a large model taking 30 minutes to quantize vs 5 minutes.
packed_data = bytes(pack_bytes_to_4bit(quantized_weights.tobytes()))

# We only use onnx.helper.make_tensor with raw data due to bug: https://github.com/onnx/onnx/pull/6161
q_weight_initializer = onnx.helper.make_tensor(
q_weight_name, weight_qType, weights_shape, packed_data.tobytes(), raw=True
q_weight_name, weight_qType, weights_shape, packed_data, raw=True
)
self.model.initializer_extend([q_weight_initializer])
else:
Expand Down
78 changes: 46 additions & 32 deletions onnxruntime/python/tools/quantization/quant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,18 @@
from onnxruntime import GraphOptimizationLevel, InferenceSession, SessionOptions

try:
from onnx.reference.custom_element_types import float8e4m3fn, int4, uint4
from onnx.reference.custom_element_types import float8e4m3fn
except ImportError:
float8e4m3fn = None

# INT4 np.dtypes added in ONNX 1.16. These map to np.int8/np.uint8 because numpy
# does not support sub-byte types.
try:
from onnx.reference.custom_element_types import int4, uint4
except ImportError:
int4 = None
uint4 = None


__producer__ = "onnx.quantize"
__version__ = "0.1.0"
Expand Down Expand Up @@ -134,8 +142,8 @@ def from_string(format):
onnx_proto.TensorProto.INT16: numpy.dtype("int16"),
onnx_proto.TensorProto.UINT16: numpy.dtype("uint16"),
onnx_proto.TensorProto.FLOAT8E4M3FN: float8e4m3fn,
onnx_proto.TensorProto.INT4: int4,
onnx_proto.TensorProto.UINT4: uint4,
onnx_proto.TensorProto.INT4: int4, # base_dtype is np.int8
onnx_proto.TensorProto.UINT4: uint4, # base_dtype is np.uint8
}

ONNX_INT_TYPE_RANGE = {
Expand Down Expand Up @@ -212,36 +220,12 @@ def quantize_nparray(qType, arr, scale, zero_point, low=None, high=None):
)
ref = ReferenceEvaluator(onnx_model)
return _check_type(ref.run(None, {"X": arr, "scale": scale})[0])
elif qType in (
onnx_proto.TensorProto.INT4,
onnx_proto.TensorProto.UINT4,
):
if arr.dtype == numpy.float32:
onnx_type = TensorProto.FLOAT
elif arr.dtype == numpy.float16:
onnx_type = TensorProto.FLOAT16
else:
raise ValueError(f"Unexpected dtype {arr.dtype}.")
onnx_model = make_model(
make_graph(
[
make_node("QuantizeLinear", ["X", "scale", "zero_point"], ["Y"]),
],
"qu",
[
make_tensor_value_info("X", onnx_type, None),
make_tensor_value_info("scale", onnx_type, None),
make_tensor_value_info("zero_point", qType, None),
],
[make_tensor_value_info("Y", qType, None)],
)
)
# The reference ONNX implementation of QuantizeLinear<int4> returns "unpacked" int8 numpy values
# because numpy cannot represent 4bit values (although ONNX TensorProto has no problem with this).
# These "unpacked" int8 values are correctly re-packed when passed to onnx.make_tensor().
ref = ReferenceEvaluator(onnx_model)
return _check_type(ref.run(None, {"X": arr, "scale": scale, "zero_point": zero_point})[0])
else:
# Quantizes data for all integer types.
#
# For int4 types, the quantized data is returned as either np.int8 or np.uint8,
# which matches the python reference ONNX implementation of QuantizeLinear.
# This data can be packed into 4-bit elements by using pack_bytes_to_4bit().
dtype = ONNX_TYPE_TO_NP_TYPE[qType]
(qmin, qmax) = get_qmin_qmax_for_qType(qType, reduce_range=False, symmetric=True)

Expand Down Expand Up @@ -482,6 +466,36 @@ def normalize_axis(axis: int, rank: int) -> tuple[bool, int]:
return is_valid, axis_norm


def pack_bytes_to_4bit(src_8bit: bytes) -> bytearray:
"""
Copies a source array of 8-bit values into a destination bytearray of packed 4-bit values.
Assumes that the source values are already in the appropriate int4 range.
:parameter src_8bit: The 8-bit element values to pack.
:return A bytearray with every two 8-bit src elements packed into a single byte.
"""
num_elems = len(src_8bit)
if num_elems == 0:
return bytearray()

dst_size = (num_elems + 1) // 2 # Ex: 5 8-bit elems packed into 3 bytes
dst = bytearray(dst_size)

src_i: int = 0
dst_i: int = 0

# Pack two 8-bit elements into a single byte in each iteration.
while src_i < num_elems - 1:
dst[dst_i] = ((src_8bit[src_i + 1] & 0xF) << 4) | (src_8bit[src_i] & 0xF)
dst_i += 1
src_i += 2

if src_i < num_elems:
# Odd number of elements.
dst[dst_i] = src_8bit[src_i] & 0xF

return dst


class QuantizedInitializer:
"""
Represents a linearly quantized weight input from ONNX operators
Expand Down
69 changes: 68 additions & 1 deletion onnxruntime/test/python/quantization/test_quant_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,13 @@
import onnx
from onnx import TensorProto, helper, numpy_helper

from onnxruntime.quantization.quant_utils import compute_scale_zp, load_model_with_shape_infer, model_has_infer_metadata
from onnxruntime.quantization.quant_utils import (
compute_scale_zp,
load_model_with_shape_infer,
model_has_infer_metadata,
pack_bytes_to_4bit,
quantize_data,
)


class TestQuantUtil(unittest.TestCase):
Expand Down Expand Up @@ -101,6 +107,67 @@ def test_load_external_model(self):
model_reloaded = load_model_with_shape_infer(Path(model_file_path))
self.assertTrue(model_has_infer_metadata(model_reloaded))

def test_pack_bytes_to_4bit(self):
"""
Tests the pack_bytes_to_4bit() utility.
"""
subtest_configs = [
(-8, 6, True), # Odd num elems, signed
(-8, 7, True), # Even num elems, signed
(0, 14, False), # Odd num elems, unsigned
(0, 15, False), # Even num elems, unsigned
]
for min_val, max_val, signed in subtest_configs:
with self.subTest(min_val=min_val, max_val=max_val, signed=signed):
src_float = numpy.arange(min_val, max_val + 1).astype(numpy.float32)
src_int = src_float.astype(numpy.int8 if signed else numpy.uint8)

actual_packed_vals = bytes(pack_bytes_to_4bit(src_int.tobytes()))
expected_packed_vals = onnx.helper.pack_float32_to_4bit(src_float, signed).tobytes()
self.assertEqual(actual_packed_vals, expected_packed_vals)

def test_quantize_data_4bit(self):
"""
Test that calling quantize_data for int4 quantization returns data of the correct type and range.
"""
data_float = numpy.arange(-20, 17).astype(numpy.float32)

subtest_configs = [
(onnx.TensorProto.INT4, True), # int4, symmetric quant
(onnx.TensorProto.INT4, False), # int4, symmetric quant
(onnx.TensorProto.UINT4, True), # uint4, symmetric quant
(onnx.TensorProto.UINT4, False), # uint4, symmetric quant
]

for onnx_type, symmetric in subtest_configs:
with self.subTest(onnx_type=onnx_type, symmetric=symmetric):
_, _, zero_point, scale, data_quant = quantize_data(data_float, onnx_type, symmetric)
is_signed = onnx_type == onnx.TensorProto.INT4
np_int_type = numpy.int8 if is_signed else numpy.uint8
qmin = numpy.array(-8 if is_signed else 0, dtype=np_int_type)
qmax = numpy.array(7 if is_signed else 15, dtype=np_int_type)

self.assertEqual(zero_point.dtype, np_int_type)
self.assertEqual(scale.dtype, data_float.dtype)

expected_zp, expected_scale = compute_scale_zp(
data_float.min(), data_float.max(), qmin, qmax, symmetric=symmetric
)
self.assertEqual(zero_point, expected_zp)
self.assertEqual(scale, expected_scale)

# Even int4 quantization generates 8-bit numpy values.
self.assertEqual(data_quant.dtype, np_int_type)
for index, actual_quant_val in enumerate(data_quant.flatten()):
self.assertTrue(actual_quant_val >= qmin and actual_quant_val <= qmax)

expected_quant_val = numpy.asarray((data_float[index] / scale).round() + zero_point).astype(
np_int_type
)
numpy.clip(expected_quant_val, qmin, qmax, out=expected_quant_val)

self.assertEqual(numpy.array(actual_quant_val), expected_quant_val)


if __name__ == "__main__":
unittest.main()

0 comments on commit df28c7d

Please sign in to comment.