From 69ae7f421e2f1931e353949fa5e3a0fb23dbe622 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 20 May 2024 00:08:08 -0700 Subject: [PATCH] [IR] Fix broadcast_to_matmul (#1542) - Check whether the shape tensor is constant before using it in the logic. Exiting early if needed. - Handle cases when the input is 1d or 0d Thanks @borisfom for the proposed fix! Fix #1541 --- .../rewriter/examples/broadcast_matmul.py | 71 +++++++------ onnxscript/rewriter/_ir_utils.py | 47 +++++---- onnxscript/rewriter/broadcast_to_matmul.py | 99 ++++++++++--------- .../rewriter/broadcast_to_matmul_test.py | 83 ++++++++++++++++ 4 files changed, 210 insertions(+), 90 deletions(-) diff --git a/docs/tutorial/rewriter/examples/broadcast_matmul.py b/docs/tutorial/rewriter/examples/broadcast_matmul.py index 84d16c6bf..ad48842a9 100644 --- a/docs/tutorial/rewriter/examples/broadcast_matmul.py +++ b/docs/tutorial/rewriter/examples/broadcast_matmul.py @@ -9,7 +9,6 @@ import logging -import numpy as np import onnx import onnxscript @@ -65,71 +64,81 @@ def matmul_pattern(op, input_a: ir.Value, input_b: ir.Value, **_): def check_if_not_need_reshape( context, input_a: ir.Value, input_b: ir.Value, shape_c: ir.Value, **_ ) -> bool: - """If matmul broadcasting is enough, then we don't need the reshapes. + """Condition to check if we need to replace the pattern. + + If matmul broadcasting is enough, then we don't need the reshapes. To validate this, we need to check the following: 1. Input shapes check: input_a and input_b should be broadcastable 2. Output shape check: shape_c should be the same as the output shape from the matmul(input_a, input_b) If the above are true, then we don't need the reshapes. + + Returns: + True if we need to replace the pattern, False otherwise. """ del context # Reserved for future extensions + input_a_shape = input_a.shape input_b_shape = input_b.shape # TODO: Get a helper func to get const_value - shape_c_value = _ir_utils.propagate_const_value(shape_c) - shape_c = shape_c_value.const_value.numpy() # type: ignore[union-attr] - if shape_c is None: - return False - if not isinstance(shape_c, np.ndarray): - logger.info("Unexpected shape_c value. Expected np.ndarray, got %s", type(shape_c)) + _ir_utils.propagate_const_value(shape_c) + shape_c_tensor = shape_c.const_value + if shape_c_tensor is None: + logger.info("The value 'shape_c' is not statically known.") return False - if len(shape_c.shape) != 1: + + if len(shape_c_tensor.shape) != 1: logger.info( "Unexpected final shape. The shape of 'shape' value is %s", - shape_c.shape, + shape_c_tensor.shape, ) return False - shape_c_list = shape_c.tolist() # NOTE: When there is a subset match with a pattern. The MatchResult won't have the shape # information. So, we need to check if the shape is None and return False. - if input_a_shape is None or input_b_shape is None or shape_c is None: + if input_a_shape is None or input_b_shape is None: logger.info("Shape information is not available for the inputs and outputs.") return False - input_a_shape = list(input_a_shape) - input_b_shape = list(input_b_shape) + input_a_shape = input_a_shape.numpy() + input_b_shape = input_b_shape.numpy() + shape_c = shape_c_tensor.numpy().tolist() + + a_rank = len(input_a_shape) + b_rank = len(input_b_shape) - dim_a = len(input_a_shape) - dim_b = len(input_b_shape) + # TODO(justinchuby): Check shape size # 1. Check if input shapes are broadcastable # 1.a. If the first input is 1-D, check whether # the dim matches the last second dim of the second input. mimic_matmul_broadcast_behavior = False - if dim_a < 2: + if a_rank < 2: + if b_rank < 2: + logger.info("Optimization of dot product is not supported yet.") + return False if input_a_shape[-1] != input_b_shape[-2]: logger.info("Original shape is not MatMul compatible.") return False else: input_a_shape = [1, *input_a_shape] - dim_a = len(input_a_shape) + a_rank = len(input_a_shape) mimic_matmul_broadcast_behavior = True # 1.b. If the second input is 1-D, check whether # the dim matches the last dim of the first input. - if dim_b < 2: + if b_rank < 2: if input_b_shape[-1] != input_a_shape[-1]: logger.info("Original shape is not MatMul compatible.") return False else: input_b_shape = [*input_b_shape, 1] - dim_b = len(input_b_shape) + b_rank = len(input_b_shape) mimic_matmul_broadcast_behavior = True # 1.c. If both inputs are at least 2-D, check whether # the last dimension of the first input matches the second # last dimension of the second input, and shape[:-2] are # broadcastable. - input_a_shape_except_second_last_dim = input_a_shape[:-2] + [input_a_shape[-1]] + input_a_shape_except_second_last_dim = [*input_a_shape[:-2], *[input_a_shape[-1]]] input_b_shape_except_last_dim = input_b_shape[:-1] broadcast_matmul_output_shape = [input_a_shape[-2], input_b_shape[-1]] for idx, (dim_from_a, dim_from_b) in enumerate( @@ -149,23 +158,27 @@ def check_if_not_need_reshape( # 2. Check if output shape is the same as the output shape from the matmul(input_a, input_b) # Prepend the broadcast_matmul_output_shape with the longer shape of input - if dim_a > dim_b: + if a_rank > b_rank: longer_shape = input_a_shape shorter_shape = input_b_shape else: longer_shape = input_b_shape shorter_shape = input_a_shape - broadcast_matmul_output_shape = ( - longer_shape[: -len(shorter_shape)] + broadcast_matmul_output_shape - ) - if mimic_matmul_broadcast_behavior and dim_b == 2: + broadcast_matmul_output_shape = [ + *longer_shape[: -len(shorter_shape)], + *broadcast_matmul_output_shape, + ] + if mimic_matmul_broadcast_behavior and b_rank == 2 and input_b_shape[-1] == 1: + # If input_b is expanded to 2-D, then we need to remove the last dimension broadcast_matmul_output_shape = broadcast_matmul_output_shape[:-1] - if mimic_matmul_broadcast_behavior and dim_a == 2: + if mimic_matmul_broadcast_behavior and a_rank == 2 and input_a_shape[0] == 1: + # If input_a is expanded to 2-D, then we need to remove the first dimension + # of input_a, which would be the -2nd dimension of the output shape. broadcast_matmul_output_shape.pop(-2) - if shape_c_list != broadcast_matmul_output_shape: + if shape_c != broadcast_matmul_output_shape: logger.info( "Final output shape is not the same. Expected %s vs actual %s", - shape_c_list, + shape_c, broadcast_matmul_output_shape, ) return False diff --git a/onnxscript/rewriter/_ir_utils.py b/onnxscript/rewriter/_ir_utils.py index b8dd5f45f..9bfc4ac5a 100644 --- a/onnxscript/rewriter/_ir_utils.py +++ b/onnxscript/rewriter/_ir_utils.py @@ -2,6 +2,8 @@ from __future__ import annotations +import typing + import numpy as np from onnxscript import ir @@ -10,24 +12,35 @@ def propagate_const_value(ir_value: ir.Value) -> ir.Value: + """Temporary method to propagate a constant value to the IR value.""" node = ir_value.producer() - if ir_value.const_value is None and node is not None and node.op_type == "Constant": - attr_names = [ - "value_float", - "value_int", - "value_string", - "value", - "value_floats", - "value_ints", - "value_strings", - ] - for attr_name in attr_names: - attr_value = node.attributes.get(attr_name) - if attr_value is not None: - # TODO: RefAttr should be also supported? - if isinstance(attr_value, ir.Attr): - ir_value.const_value = attr_value.value # type: ignore[union-attr] - break + if node is None: + return ir_value + if node.op_type != "Constant": + return ir_value + attr_name, attr_value = next(iter(node.attributes.items())) + if attr_value is None or not isinstance(attr_value, ir.Attr): + return ir_value + + const_value: ir.TensorProtocol + if attr_name in {"value_float", "value_floats"}: + const_value = ir.Tensor( + np.array(attr_value.value, dtype=np.float32), name=ir_value.name + ) + elif attr_name in {"value_int", "value_ints"}: + const_value = ir.Tensor(np.array(attr_value.value, dtype=np.int64), name=ir_value.name) + elif attr_name in {"value_string", "value_strings"}: + const_value = ir.StringTensor( + np.array(attr_value.value, dtype=np.bytes_), name=ir_value.name + ) + elif attr_name == "value": + const_value = typing.cast(ir.TensorProtocol, attr_value.value) + else: + return ir_value + + ir_value.const_value = const_value + ir_value.shape = const_value.shape # type: ignore + ir_value.dtype = const_value.dtype return ir_value diff --git a/onnxscript/rewriter/broadcast_to_matmul.py b/onnxscript/rewriter/broadcast_to_matmul.py index b9ba56585..ead1bbada 100644 --- a/onnxscript/rewriter/broadcast_to_matmul.py +++ b/onnxscript/rewriter/broadcast_to_matmul.py @@ -2,17 +2,18 @@ import logging -import numpy as np - +from onnxscript import ir from onnxscript.rewriter import _ir_utils, pattern -op = pattern.onnxop logger = logging.getLogger(__name__) -# condition to check if we need to replace the pattern -def check_if_not_need_reshape(context, input_a, input_b, shape_c, **_) -> bool: - """If matmul broadcasting is enough, then we don't need the reshapes. +def check_if_not_need_reshape( + context, input_a: ir.Value, input_b: ir.Value, shape_c: ir.Value, **_ +) -> bool: + """Condition to check if we need to replace the pattern. + + If matmul broadcasting is enough, then we don't need the reshapes. To validate this, we need to check the following: 1. Input shapes check: input_a and input_b should be broadcastable @@ -21,65 +22,74 @@ def check_if_not_need_reshape(context, input_a, input_b, shape_c, **_) -> bool: If the above are true, then we don't need the reshapes. Returns: - bool: True if we need to replace the pattern, False otherwise. - + True if we need to replace the pattern, False otherwise. """ + del context # Reserved for future extensions + input_a_shape = input_a.shape input_b_shape = input_b.shape # TODO: Get a helper func to get const_value - shape_c_value = _ir_utils.propagate_const_value(shape_c) - shape_c = shape_c_value.const_value.numpy() # type: ignore[union-attr] - if shape_c is None: - return False - if not isinstance(shape_c, np.ndarray): - logger.info("Unexpected shape_c value. Expected np.ndarray, got %s", type(shape_c)) + _ir_utils.propagate_const_value(shape_c) + shape_c_tensor = shape_c.const_value + if shape_c_tensor is None: + logger.info("The value 'shape_c' is not statically known.") return False - if len(shape_c.shape) != 1: + + if len(shape_c_tensor.shape) != 1: logger.info( "Unexpected final shape. The shape of 'shape' value is %s", - shape_c.shape, + shape_c_tensor.shape, ) return False - shape_c = shape_c.tolist() # NOTE: When there is a subset match with a pattern. The MatchResult won't have the shape # information. So, we need to check if the shape is None and return False. - if input_a_shape is None or input_b_shape is None or shape_c is None: + if input_a_shape is None or input_b_shape is None: logger.info("Shape information is not available for the inputs and outputs.") return False - input_a_shape = list(input_a_shape) - input_b_shape = list(input_b_shape) + if any(isinstance(dim, ir.SymbolicDim) for dim in input_a_shape): + logger.info("Symbolic dimensions are not yet supported.") + return False + if any(isinstance(dim, ir.SymbolicDim) for dim in input_b_shape): + logger.info("Symbolic dimensions are not yet supported.") + return False + input_a_shape = input_a_shape.numpy() # type: ignore[assignment] + input_b_shape = input_b_shape.numpy() # type: ignore[assignment] + shape_c = shape_c_tensor.numpy().tolist() - dim_a = len(input_a_shape) - dim_b = len(input_b_shape) + a_rank = len(input_a_shape) + b_rank = len(input_b_shape) # 1. Check if input shapes are broadcastable # 1.a. If the first input is 1-D, check whether # the dim matches the last second dim of the second input. mimic_matmul_broadcast_behavior = False - if dim_a < 2: + if a_rank < 2: + if b_rank < 2: + logger.info("Optimization of dot product is not supported yet.") + return False if input_a_shape[-1] != input_b_shape[-2]: logger.info("Original shape is not MatMul compatible.") return False else: - input_a_shape = [1, *input_a_shape] - dim_a = len(input_a_shape) + input_a_shape = [1, *input_a_shape] # type: ignore[assignment] + a_rank = len(input_a_shape) mimic_matmul_broadcast_behavior = True # 1.b. If the second input is 1-D, check whether # the dim matches the last dim of the first input. - if dim_b < 2: + if b_rank < 2: if input_b_shape[-1] != input_a_shape[-1]: logger.info("Original shape is not MatMul compatible.") return False else: - input_b_shape = [*input_b_shape, 1] - dim_b = len(input_b_shape) + input_b_shape = [*input_b_shape, 1] # type: ignore[assignment] + b_rank = len(input_b_shape) mimic_matmul_broadcast_behavior = True # 1.c. If both inputs are at least 2-D, check whether # the last dimension of the first input matches the second # last dimension of the second input, and shape[:-2] are # broadcastable. - input_a_shape_except_second_last_dim = input_a_shape[:-2] + [input_a_shape[-1]] + input_a_shape_except_second_last_dim = [*input_a_shape[:-2], *[input_a_shape[-1]]] input_b_shape_except_last_dim = input_b_shape[:-1] broadcast_matmul_output_shape = [input_a_shape[-2], input_b_shape[-1]] for idx, (dim_from_a, dim_from_b) in enumerate( @@ -93,25 +103,26 @@ def check_if_not_need_reshape(context, input_a, input_b, shape_c, **_) -> bool: return False elif idx > 0: broadcast_matmul_output_shape = [ - max(dim_from_a, dim_from_b), + max(dim_from_a, dim_from_b), # type: ignore[type-var] *broadcast_matmul_output_shape, ] # 2. Check if output shape is the same as the output shape from the matmul(input_a, input_b) # Prepend the broadcast_matmul_output_shape with the longer shape of input - if dim_a > dim_b: + if a_rank > b_rank: longer_shape = input_a_shape shorter_shape = input_b_shape else: longer_shape = input_b_shape shorter_shape = input_a_shape - broadcast_matmul_output_shape = ( - longer_shape[: -len(shorter_shape)] + broadcast_matmul_output_shape - ) - if mimic_matmul_broadcast_behavior and dim_b == 2 and input_b_shape[-1] == 1: + broadcast_matmul_output_shape = [ + *longer_shape[: -len(shorter_shape)], + *broadcast_matmul_output_shape, + ] + if mimic_matmul_broadcast_behavior and b_rank == 2 and input_b_shape[-1] == 1: # If input_b is expanded to 2-D, then we need to remove the last dimension broadcast_matmul_output_shape = broadcast_matmul_output_shape[:-1] - if mimic_matmul_broadcast_behavior and dim_a == 2 and input_a_shape[0] == 1: + if mimic_matmul_broadcast_behavior and a_rank == 2 and input_a_shape[0] == 1: # If input_a is expanded to 2-D, then we need to remove the first dimension # of input_a, which would be the -2nd dimension of the output shape. broadcast_matmul_output_shape.pop(-2) @@ -126,7 +137,7 @@ def check_if_not_need_reshape(context, input_a, input_b, shape_c, **_) -> bool: return True -def two_reshapes_matmul_reshape_pattern(op, input_a, input_b, shape_a, shape_b, shape_c): +def _two_reshapes_matmul_reshape_pattern(op, input_a, input_b, shape_a, shape_b, shape_c): # TODO: Modified from `value_ints` to `value` to match pattern in benchmark models. # This implementation misses pattern of Constants with `value_ints` attribute. # See more at https://github.com/microsoft/onnx-rewriter/issues/191. @@ -138,11 +149,11 @@ def two_reshapes_matmul_reshape_pattern(op, input_a, input_b, shape_a, shape_b, return op.Reshape(matmul, shape_c) -def matmul(op, input_a, input_b, **_): +def _matmul(op, input_a, input_b, **_): return op.MatMul(input_a, input_b) -def one_reshape_matmul_reshape_pattern(op, input_a, input_b, shape_a, shape_c): +def _one_reshape_matmul_reshape_pattern(op, input_a, input_b, shape_a, shape_c): reshape_a = op.Reshape(input_a, shape_a) matmul = op.MatMul(reshape_a, input_b) return op.Reshape(matmul, shape_c) @@ -150,15 +161,15 @@ def one_reshape_matmul_reshape_pattern(op, input_a, input_b, shape_a, shape_c): # Register the rewrite rules two_reshapes_matmul_reshape_rule = pattern.RewriteRule( - two_reshapes_matmul_reshape_pattern, - matmul, + _two_reshapes_matmul_reshape_pattern, + _matmul, check_if_not_need_reshape, ) one_reshape_matmul_reshape_rule = pattern.RewriteRule( - one_reshape_matmul_reshape_pattern, - matmul, + _one_reshape_matmul_reshape_pattern, + _matmul, # We can use the same check_if_not_need_reshape function for both the rules, - # as one_reshape_matmul_reshape_pattern is a subset of two_reshapes_matmul_reshape_pattern. + # as one_reshape_matmul_reshape_pattern is a subset of _two_reshapes_matmul_reshape_pattern. check_if_not_need_reshape, ) diff --git a/onnxscript/rewriter/broadcast_to_matmul_test.py b/onnxscript/rewriter/broadcast_to_matmul_test.py index a654a5734..cc390d7a3 100644 --- a/onnxscript/rewriter/broadcast_to_matmul_test.py +++ b/onnxscript/rewriter/broadcast_to_matmul_test.py @@ -1,12 +1,23 @@ +from __future__ import annotations + import unittest import onnx.parser import onnx.shape_inference +import parameterized from onnxscript import ir from onnxscript.rewriter import broadcast_to_matmul +def _infer_shapes(model: ir.Model) -> ir.Model: + """Run shape inference on the IR model.""" + # TODO: Update when shape inference is supported on the IR + return ir.serde.deserialize_model( + onnx.shape_inference.infer_shapes(ir.serde.serialize_model(model)) + ) + + class TwoReshapesMatMulReshapeTest(unittest.TestCase): def test_reshape_matmul_reshape_replace_when_nd_inputs_are_broadcastable(self): model_proto = onnx.parser.parse_model( @@ -29,6 +40,78 @@ def test_reshape_matmul_reshape_replace_when_nd_inputs_are_broadcastable(self): self.assertEqual(count, 1) self.assertEqual(len(model.graph), 4) + @parameterized.parameterized.expand( + [ + ( + "0d", + [], + [1, 1], + [], + [1, 1], + [1, 1], + [1, 1], + ), + ( + "x_1d", + [4], + [1, 4], + [4, 2], + [4, 2], + [1, 2], + [1, 2], + ), + ( + "y_1d", + [1, 4], + [1, 4], + [2], + [4, 2], + [1, 2], + [1, 2], + ), + ( + "both_1d", + [2], + [1, 2], + [2], + [2, 1], + [], + [], + ), + ] + ) + def test_reshape_matmul_reshape_does_not_replace_when_output_sizes_do_not_match( + self, + _: str, + input_x_shape: list[int], + shape_a: list[int], + input_y_shape: list[int], + shape_b: list[int], + output_shape: list[int], + shape_c: list[int], + ): + model_proto = onnx.parser.parse_model( + f""" + + agraph (float{input_x_shape} input_x, float{input_y_shape} input_y) => (float{output_shape} output) + {{ + shape_a = Constant() + reshape_x = Reshape (input_x, shape_a) + shape_b = Constant() + reshape_y = Reshape (input_y, shape_b) + matmul = MatMul (reshape_x, reshape_y) + shape_c = Constant() + output = Reshape (matmul, shape_c) + }} + """ + ) + model = ir.serde.deserialize_model(model_proto) + count = broadcast_to_matmul.rules.apply_to_model(model) + self.assertEqual(count, 0) + self.assertEqual(len(model.graph), 7) + model = _infer_shapes(model) + self.assertEqual(model.graph.outputs[0].shape, output_shape) + def test_reshape_matmul_reshape_replace_when_nd_inputs_are_broadcastable_in_nested_function( self, ):