From 35fdcf57a7119a7f94e93696ab209ee740c48dc8 Mon Sep 17 00:00:00 2001 From: "G. Ramalingam" Date: Wed, 2 Oct 2024 15:17:20 -0700 Subject: [PATCH] Add test cases for pattern matching against optional inputs (#1890) I updated the pattern matcher to support matching against optional inputs. The change was accidentally pushed into the main branch (not a sub-branch as I thought) ... guess the branch protections were not good enough, changed it now. Adding test-cases now to test it in this PR. --------- Signed-off-by: Ganesan Ramalingam --- onnxscript/rewriter/pattern.py | 10 +++-- onnxscript/rewriter/pattern_test.py | 59 ++++++++++++++++++++++++++++- 2 files changed, 64 insertions(+), 5 deletions(-) diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index 840a54e99..be265963c 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -592,10 +592,12 @@ def producer(self) -> NodePattern: Var = ValuePattern + def _is_pattern_variable(x: Any) -> bool: # The derived classes of ValuePattern represent constant patterns and node-output patterns. return type(x) is ValuePattern + class Constant(ValuePattern): """Represents a pattern that matches against a scalar constant value.""" @@ -988,16 +990,16 @@ def _bind_value(self, pattern_value: ValuePattern, value: ir.Value | None) -> bo def _match_value(self, pattern_value: ValuePattern, value: ir.Value | None) -> bool: """Match an IR value against a ValuePattern instance.""" - if value is None: - if not _is_pattern_variable(pattern_value): - return self.fail("Mismatch: input value is None, but pattern value is not a variable.") - if not self._bind_value(pattern_value, value): return False if isinstance(pattern_value, NodeOutputPattern): + if value is None: + return self.fail("Mismatch: Computed node pattern does not match None.") return self._match_node_output(pattern_value, value) if isinstance(pattern_value, Constant): + if value is None: + return self.fail("Mismatch: Constant pattern does not match None.") return self._match_constant(pattern_value, value) return True diff --git a/onnxscript/rewriter/pattern_test.py b/onnxscript/rewriter/pattern_test.py index 5385a5233..6c9497d7a 100644 --- a/onnxscript/rewriter/pattern_test.py +++ b/onnxscript/rewriter/pattern_test.py @@ -8,7 +8,8 @@ import onnx.checker import onnx.parser -from onnxscript import ir +from onnxscript import FLOAT, ir, script +from onnxscript import opset17 as op from onnxscript.rewriter import _ir_utils, cast_constant_of_shape, pattern logger = logging.getLogger(__name__) @@ -420,6 +421,62 @@ def concat(op, x, y, result: ir.Value): self.assertEqual(model.graph[0].op_type, "Concat") self.assertNotIn("axis", model.graph[0].attributes) + def test_match_none_input(self): + def none_pattern(op, x): + # match against a call to Original where the first input is None + return op.Original(None, x) + + def replacement(op, x): + return op.Replaced(x) + + rule = pattern.RewriteRule(none_pattern, replacement) + + @script() + def test_model(x: FLOAT[1024]) -> FLOAT[1024]: + # Pattern should match following call + t1 = op.Original(None, x) + # Pattern should not match following call + z = op.Original(t1, x) + return z + + model_proto = test_model.to_model_proto() + model = ir.serde.deserialize_model(model_proto) + + count = rule.apply_to_model(model) + self.assertEqual(count, 1) + self.assertEqual(len(model.graph), 2) + self.assertEqual(model.graph.node(0).op_type, "Replaced") + self.assertEqual(model.graph.node(1).op_type, "Original") + + def test_match_optional_input(self): + def none_pattern(op, optional_input, x): + # match against a call to Original where the first input may or may not be None + return op.Original(optional_input, x) + + def replacement(op, optional_input, x): + if optional_input is None: + return op.ReplacedNone(x) + return op.ReplacedNotNone(x) + + rule = pattern.RewriteRule(none_pattern, replacement) + + @script() + def test_model(x: FLOAT[1024]) -> FLOAT[1024]: + # Pattern should match following call + t1 = op.Original(None, x) + # as well as this one + z = op.Original(t1, x) + return z + + model_proto = test_model.to_model_proto() + model = ir.serde.deserialize_model(model_proto) + + count = rule.apply_to_model(model) + self.assertEqual(count, 2) + self.assertEqual(len(model.graph), 2) + self.assertEqual(model.graph.node(0).op_type, "ReplacedNone") + self.assertEqual(model.graph.node(1).op_type, "ReplacedNotNone") + class PatternBuilderTest(unittest.TestCase): def test_pattern_builder_context(self):