diff --git a/examples/pattern_rewriting.py b/examples/pattern_rewriting.py index 7ebe10157..c9dc2394f 100644 --- a/examples/pattern_rewriting.py +++ b/examples/pattern_rewriting.py @@ -14,7 +14,7 @@ import onnx.numpy_helper as onh from onnxscript import ir -from onnxscript.rewriter import generic_pattern +from onnxscript.rewriter import pattern def get_rotary_model(bad_model=False): @@ -99,9 +99,7 @@ def rotary_apply_pattern(op, x, pos_ids, axis): # # The rule is easy to create. -rule = generic_pattern.make_pattern_rule( - rotary_match_pattern, rotary_apply_pattern, verbose=10 -) +rule = pattern.RewriteRule(rotary_match_pattern, rotary_apply_pattern, verbose=10) ########################## # Let's apply it. @@ -136,9 +134,7 @@ def rotary_apply_pattern(op, x, pos_ids, axis): # The match did not happen. # Let's increase the verbosity. -rule = generic_pattern.make_pattern_rule( - rotary_match_pattern, rotary_apply_pattern, verbose=10 -) +rule = pattern.RewriteRule(rotary_match_pattern, rotary_apply_pattern, verbose=10) rule.apply_to_model(ir_model) diff --git a/onnxscript/rewriter/generic_pattern.py b/onnxscript/rewriter/generic_pattern.py index a27952a0c..1fad112bd 100644 --- a/onnxscript/rewriter/generic_pattern.py +++ b/onnxscript/rewriter/generic_pattern.py @@ -4,6 +4,7 @@ import inspect import os import textwrap +import warnings from typing import Any, Callable, Iterator, Sequence import onnxscript.rewriter.pattern as orp @@ -79,7 +80,7 @@ def _to_match_result(pmr: PatternMatchResult) -> orp.MatchResult: TODO: This is a temporary hack until MatchResult and PatternMatchResult are unified. """ - result = orp.MatchResult(success=True) + result = orp.MatchResult() result.nodes.extend(pmr.model_nodes) for var, val in pmr.matched_pattern_to_model_value.items(): if var.name is not None: @@ -633,6 +634,11 @@ def make_pattern_rule( the rewriting rule """ + warnings.warn( + "make_pattern_rule(...) is deprecated, use pattern.RewriteRule(...) instead", + FutureWarning, + stacklevel=2, + ) pattern = orp._to_graph_pattern(match_pattern_function) matcher = GenericPatternMatcher(pattern) return orp.RewriteRule( diff --git a/onnxscript/rewriter/generic_pattern_test.py b/onnxscript/rewriter/generic_pattern_test.py index b45c49455..c96aa37d9 100644 --- a/onnxscript/rewriter/generic_pattern_test.py +++ b/onnxscript/rewriter/generic_pattern_test.py @@ -11,7 +11,7 @@ import onnxruntime as ort from onnxscript import ir -from onnxscript.rewriter import generic_pattern +from onnxscript.rewriter import generic_pattern, pattern FLOAT = onnx.TensorProto.FLOAT @@ -41,8 +41,11 @@ def validate_mapping(context, x, y, z, **_) -> bool: del context return True - rule = generic_pattern.make_pattern_rule( - match_pattern, apply_pattern, validate_mapping + rule = pattern.RewriteRule( + match_pattern, + apply_pattern, + validate_mapping, + generic_pattern.GenericPatternMatcher, ) class AddAdd(onnx.reference.op_run.OpRun): @@ -118,8 +121,12 @@ def apply_pattern(op, x, y, w, z, **_): def validate_mapping(context, **_) -> bool: return True - rule = generic_pattern.make_pattern_rule( - match_pattern, apply_pattern, validate_mapping, verbose=10 + rule = pattern.RewriteRule( + match_pattern, + apply_pattern, + validate_mapping, + generic_pattern.GenericPatternMatcher, + verbose=10, ) class AddAddAddAdd(onnx.reference.op_run.OpRun): @@ -284,8 +291,12 @@ def apply_pattern(op, x, pos_ids, axis, **_): outputs=2, ) - rule = generic_pattern.make_pattern_rule( - match_pattern, apply_pattern, validate_mapping, verbose=10 + rule = pattern.RewriteRule( + match_pattern, + apply_pattern, + validate_mapping, + generic_pattern.GenericPatternMatcher, + verbose=10, ) model = self.get_rotary_model() @@ -345,10 +356,11 @@ def rotary_apply_pattern(op, x, pos_ids, axis, **_): ) return part1, part2 - rule = generic_pattern.make_pattern_rule( + rule = pattern.RewriteRule( rotary_match_pattern, rotary_apply_pattern, validate_rotary_mapping, + generic_pattern.GenericPatternMatcher, verbose=10, ) @@ -416,10 +428,11 @@ def rotary_apply_pattern(op, x, pos_ids, axis): model = onnx.shape_inference.infer_shapes(model) ir_model = ir.serde.deserialize_model(model) - rule = generic_pattern.make_pattern_rule( + rule = pattern.RewriteRule( rotary_match_pattern, rotary_apply_pattern, validate_rotary_mapping, + generic_pattern.GenericPatternMatcher, verbose=10, ) @@ -472,10 +485,11 @@ def transpose_transpose_apply_pattern(op, X, XT: ir.Value, Y, **_): composed_perm = transpose_transpose_mapping(perm0, perm1) return op.Transpose(X, perm=composed_perm) - rule = generic_pattern.make_pattern_rule( + rule = pattern.RewriteRule( transpose_transpose_pattern, transpose_transpose_apply_pattern, transpose_transpose_check, + generic_pattern.GenericPatternMatcher, verbose=0, ) diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index d17f93c78..504cfdeea 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -284,8 +284,8 @@ def pattern(x, shape1, shape2): contain the values that are bound to the variables `x`, `shape1`, and `shape2`. """ - def __init__(self, success: bool) -> None: - self._success: bool = success + def __init__(self) -> None: + self._success: bool = True # For a successful match, _matched_nodes is a list of values that matched the pattern. # These include the internal nodes of the pattern that were matched, but not # the leaves (sub-trees) that match against the variables in the pattern. @@ -295,13 +295,20 @@ def __init__(self, success: bool) -> None: # to values. self.bindings: dict[str, Any] = {} self.outputs: list[ir.Value] = [] + # For a failed match, _reason is a string that describes the reason for the failure. + self._reason: str = "" def __bool__(self): return self._success - @classmethod - def FAIL(cls): - return cls(False) + def fail(self, reason: str = "") -> MatchResult: + self._success = False + self._reason = reason + return self + + @property + def reason(self) -> str: + return self._reason @property def nodes(self) -> MutableSequence[ir.Node]: @@ -369,12 +376,6 @@ def append_use(self, node: NodePattern, index: int): def __repr__(self) -> str: return f"ValuePattern({self._name!r})" - def matches(self, value: ir.Value): - result = MatchResult(success=True) - if self._name is not None: - result.bind(self._name, value) - return result - def commute(self) -> Sequence[ValuePattern]: """Return a list of commuted patterns. @@ -470,61 +471,35 @@ def op_identifier(self) -> Tuple[str, str, str] | None: def op_type(self) -> str: return str(self.op) - def matches(self, node: ir.Node) -> bool: + def matches(self, node: ir.Node, match: MatchResult) -> MatchResult: """Matches the pattern represented by self against a node. This is purely a local node-level match, and does not consider the subgraph rooted at the node. We check the domain, op_type, and attributes of the node, but not the inputs. """ - if not self.op.matches(node.op_type): - return False # TODO(rama): Ensure we handle "" and "onnx.ai" correctly. if not self.domain.matches(node.domain): - return False - - # for name, attr_pattern in self.attributes.items(): - # attr_value = node.attributes.get(name) - # if attr_value is None: - # return False - # if not attr_pattern.matches(attr_value): - # return False - return True - - def matches_subgraph(self, node: ir.Node) -> MatchResult: - """Matches the pattern subgraph represented by self against subgraph rooted at node.""" - if not self.domain.matches(node.domain): - return MatchResult.FAIL() + return match.fail(f"Domain mismatch: expected {self.domain}, got {node.domain}.") if not self.op.matches(node.op_type): - return MatchResult.FAIL() - match = MatchResult(success=True) - # TODO: We should add filtered logging starting from here to emit why - # matching failed. This should cut a lot of noises compared to logging everything, - # because at least the starting node op_type is already matched. - for arg_value, previous_node_output_pattern in zip(node.inputs, self.inputs): - # previous_node_output_pattern could be a Var, if it's the original arg. - if arg_value is None and previous_node_output_pattern is None: - continue - if arg_value is None or previous_node_output_pattern is None: - return MatchResult.FAIL() - sub_match = previous_node_output_pattern.matches(arg_value) - match.extend(sub_match) - if not match: # If sub-match failed, - return match - # Sub-graphs not handled yet. + return match.fail(f"OpType mismatch: expected {self.op}, got {node.op_type}.") + for name, attr_pattern in self.attributes.items(): attr_value = node.attributes.get(name) if attr_value is None: - return MatchResult.FAIL() + return match.fail(f"Attribute {name} not found in node.") if not attr_pattern.matches(attr_value): - return MatchResult.FAIL() + return match.fail( + f"Attribute {name} mismatch: expected {attr_pattern}, got {attr_value}." + ) if attr_pattern.name is not None: if not match.bind(attr_pattern.name, attr_value): return match + for name in node.attributes: # TODO: Support matching default nodes for attributes. if name not in self.attributes: - return MatchResult.FAIL() - match.nodes.append(node) + return match.fail(f"Attribute {name} not expected in node.") + return match def commute(self) -> Sequence[NodePattern]: @@ -570,15 +545,6 @@ def __init__( def output_index(self) -> int: return self._output_index - def matches(self, value: ir.Value): - """Match the StaticValueInfo from IR with the `matches_subgraph()` in node pattern.""" - node = value.producer() - if node is None: - return MatchResult.FAIL() - if value.index() != self._output_index: - return MatchResult.FAIL() - return self._producer.matches_subgraph(node) - def commute(self) -> Sequence[ValuePattern]: # TODO return [ @@ -604,27 +570,30 @@ def __init__( self._rel_tol = rel_tol self._abs_tol = abs_tol - def match_scalar(self, scalar_value): - status = math.isclose( - scalar_value, self._value, rel_tol=self._rel_tol, abs_tol=self._abs_tol - ) - # Note: If the value is produced by a Constant node, we could include - # the Constant node in the return_value list. However, we don't do that. - # Instead, we will rely on DCE to remove the constant node if it is not - # used elsewhere. - return MatchResult(success=status) + @property + def value(self) -> int | float: + return self._value - def matches(self, value: ir.Value): + def matches(self, value: ir.Value, match: MatchResult) -> MatchResult: value = _ir_utils.propagate_const_value(value) constant_value = _ir_utils.get_numpy_from_ir_value(value) if constant_value is None: - return MatchResult.FAIL() + return match.fail(f"Value is not a constant, expecting {self.value}.") # TODO (rama): allow users to specify shape requirement, if desired. if constant_value.size != 1: - return MatchResult.FAIL() + return match.fail(f"Value is not a scalar, expecting {self.value}.") - return self.match_scalar(constant_value.item()) + if not math.isclose( + constant_value.item(), self._value, rel_tol=self._rel_tol, abs_tol=self._abs_tol + ): + match.fail(f"Value mismatch: expected {self._value}, got {constant_value.item()}.") + + # Note: If the value is produced by a Constant node, we could include + # the Constant node in the return_value list. However, we don't do that. + # Instead, we will rely on DCE to remove the constant node if it is not + # used elsewhere. + return match def commute(self) -> list[ValuePattern]: return [self] @@ -707,11 +676,6 @@ def has_single_output_node(self) -> bool: def num_outputs(self) -> int: return len(self._outputs) - def matches_subgraph(self, node: ir.Node) -> MatchResult: - if self._output_node is None: - return MatchResult.FAIL() - return self._output_node.matches_subgraph(node) - def commute(self) -> Sequence[GraphPattern]: if self._output_node is None: raise NotImplementedError( @@ -912,6 +876,112 @@ def __init__(self, pattern: GraphPattern) -> None: ), "SimplePatternMatcher only supports patterns with a single output node." super().__init__(pattern) + def fail(self, reason: str) -> bool: + if self._verbose: + if self._matched: # Print only if at least one node successfully matched. + count = len(self._matched) + print(f"Match failed after {count} nodes: {reason}") + self._match.fail(reason) + return False + + def _match_constant(self, pattern_constant: Constant, value: ir.Value) -> bool: + """Match a Constant pattern against a value. + + If the constant value is produced by a Constant node, we do not include + the constant node as part of the matched graph. Thus, it will not be deleted, + if subgraph replacement happens. But subsequent DCE will remove the constant + node if it is not used elsewhere. + """ + value = _ir_utils.propagate_const_value(value) + constant_value = _ir_utils.get_numpy_from_ir_value(value) + if constant_value is None: + return self.fail( + f"Value {value.name} is not a constant, expecting {pattern_constant.value}.", + ) + + # TODO (rama): allow users to specify shape requirement, if desired. + if constant_value.size != 1: + return self.fail( + f"Value {value.name} is not a scalar, expecting {pattern_constant.value}.", + ) + + if not math.isclose( + constant_value.item(), + pattern_constant._value, + rel_tol=pattern_constant._rel_tol, + abs_tol=pattern_constant._abs_tol, + ): + return self.fail( + f"Constant value mismatch: expected {pattern_constant._value}, got {constant_value.item()}.", + ) + + return True + + def _match_node(self, pattern_node: NodePattern, node: ir.Node) -> bool: + """Matches a pattern subgraph against subgraph rooted at node.""" + + # Graph-matching: we do not allow the same pattern node to be matched against + # different graph nodes. + if pattern_node in self._matched: + if self._matched[pattern_node] is not node: + return self.fail("Same pattern node is matched against different graph nodes.") + return True + match = self._match + if not pattern_node.matches(node, match): + return self.fail(match.reason) + + if self._verbose: + print(f"Matched: {node.op_type}") + + self._matched[pattern_node] = node + + for arg_value, previous_node_output_pattern in zip(node.inputs, pattern_node.inputs): + # previous_node_output_pattern could be a Var, if it's the original arg. + if arg_value is None and previous_node_output_pattern is None: + continue + if arg_value is None or previous_node_output_pattern is None: + msg = ( + "Input not expected to be None" + if arg_value is None + else "Input expected to be None" + ) + return self.fail(msg) + if not self._match_value(previous_node_output_pattern, arg_value): + return False + + match.nodes.append(node) + return True + + def _match_value(self, pattern_value: ValuePattern, value: ir.Value) -> bool: + """Match an IR value against a ValuePattern instance.""" + if pattern_value.name is not None: + match = self._match + if pattern_value.name in match.bindings: + # TODO(rama): Use appropriate equality-check here: future extension possibility. + if match.bindings[pattern_value.name] == value: + return True + return self.fail(f"Variable {pattern_value.name} is bound to multiple values.") + match.bindings[pattern_value.name] = value + + if isinstance(pattern_value, NodeOutputPattern): + return self._match_node_output(pattern_value, value) + if isinstance(pattern_value, Constant): + return self._match_constant(pattern_value, value) + return True + + def _match_node_output(self, pattern_value: NodeOutputPattern, value: ir.Value) -> bool: + """Match an IR value against a NodeOutputPattern instance.""" + node = value.producer() + if node is None: + return self.fail( + "Mismatch: Computed node pattern does not match uncomputed IR value." + ) + if value.index() != pattern_value.output_index: + return self.fail( + f"Node output index mismatch: expected {pattern_value._output_index}, got {value.index()}." + ) + return self._match_node(pattern_value.producer(), node) + def match( self, model: ir.Model, @@ -919,16 +989,27 @@ def match( node: ir.Node, verbose: int = 0, ) -> MatchResult: - # TODO(rama): support verbose del model del graph_or_function - if len(node.outputs) != self.pattern.num_outputs: - return MatchResult.FAIL() - match = self.pattern.matches_subgraph(node) - if not match: - return MatchResult.FAIL() - if not _valid_to_replace(match.nodes): - return MatchResult.FAIL() + self._verbose = verbose + self._matched: dict[NodePattern, ir.Node] = {} + self._match: MatchResult = MatchResult() + + pattern = self.pattern + match = self._match + if len(node.outputs) != pattern.num_outputs: + return match.fail( + f"Number of node outputs mismatch: expected {pattern.num_outputs}, got {len(node.outputs)}." + ) + if pattern._output_node is None: + return match.fail( + "Internal Error: SimplePatternMatcher should not be used for patterns with multiple output nodes." + ) + + if self._match_node(pattern._output_node, node): + if not _valid_to_replace(match.nodes): + return match.fail("Matched nodes have other uses preventing replacement.") + match.outputs.extend(node.outputs) return match @@ -939,19 +1020,19 @@ def __init__( target_pattern: GraphPattern | Callable, replacement_pattern: ReplacementPatternFunction | Callable, condition_function: Callable | None = None, - matcher: PatternMatcher | None = None, + matcher: PatternMatcher | Callable[[GraphPattern], PatternMatcher] | None = None, verbose: int = 0, ) -> None: """Create a rewrite rule. Args: - target_pattern: The pattern function that will be - matched against the IR. - replacement_pattern: The replacement function that - will be used to replace the matched pattern. - condition_function: The condition function that - will be used to check if the pattern matches the IR with ir.Values - constraints in consideration. + target_pattern: The GraphPattern that will be matched against the IR. + If a callable is provided, it will be converted to a GraphPattern. + replacement_pattern: The ReplacementPatternFunction that will be used to + replace the matched pattern. If a callable is provided, it will be + converted to a ReplacementPatternFunction. + condition_function: The condition function that will be used to check if + the pattern match found should be rewritten. matcher: The pattern matcher that will be used to match the pattern. If not provided, a default matcher will be used. verbose: The verbosity level of the rule. @@ -965,21 +1046,29 @@ def __init__( replacement_pattern = ReplacementPatternFunction(replacement_pattern) self._replacement_pattern = replacement_pattern self._condition_function = condition_function or always_true - if matcher is None: + if isinstance(matcher, PatternMatcher): + self._matcher = matcher + elif matcher is None: if target_pattern.has_single_output_node: - matcher = SimplePatternMatcher(self._target_pattern) + self._matcher = SimplePatternMatcher(self._target_pattern) else: import onnxscript.rewriter.generic_pattern as generic_pattern - matcher = generic_pattern.GenericPatternMatcher(self._target_pattern) - self._matcher = matcher + self._matcher = generic_pattern.GenericPatternMatcher(self._target_pattern) + else: + self._matcher = matcher(self._target_pattern) self._verbose = verbose def try_rewrite( - self, model: ir.Model, graph_or_function: ir.Graph | ir.Function, node: ir.Node + self, + model: ir.Model, + graph_or_function: ir.Graph | ir.Function, + node: ir.Node, + verbose: int | None = None, ) -> ReplacementSubgraph | None: """If the node matches the pattern, then replace the node with the replacement pattern.""" - match = self._matcher.match(model, graph_or_function, node, verbose=self._verbose) + verbose = verbose if verbose is not None else self._verbose + match = self._matcher.match(model, graph_or_function, node, verbose=verbose) if match: context = None # TODO(rama) if not self._condition_function(context, **match.bindings): @@ -998,9 +1087,12 @@ def try_rewrite( return replacement_subgraph return None - def apply_to_model(self, model: ir.Model, *, commute: bool = False): - # TODO(titaiwang): Why do we need RewriteRuleSet? - return RewriteRuleSet([self], commute=commute).apply_to_model(model) + def apply_to_model( + self, model: ir.Model, *, commute: bool = False, verbose: int | None = None + ): + # A convenience method to apply the rule to a model. We use a RewriteRuleSet to + # handle commutative rules. + return RewriteRuleSet([self], commute=commute).apply_to_model(model, verbose=verbose) def commute(self) -> Sequence[RewriteRule]: def replace_pattern(new_pattern): @@ -1080,6 +1172,7 @@ def _apply_to_graph_or_function( self, model: ir.Model, graph_or_function: ir.Graph | ir.Function, + verbose: int | None, ) -> int: count = 0 @@ -1087,7 +1180,7 @@ def _apply_to_graph_or_function( # And the graph is applied in order. for rule in self.rules: for node in graph_or_function: - delta = rule.try_rewrite(model, graph_or_function, node) + delta = rule.try_rewrite(model, graph_or_function, node, verbose=verbose) if delta is None: continue _apply_delta(graph_or_function, node, delta) @@ -1095,9 +1188,9 @@ def _apply_to_graph_or_function( return count - def apply_to_model(self, model: ir.Model) -> int: + def apply_to_model(self, model: ir.Model, verbose: int | None = None) -> int: assert isinstance(model, ir.Model) - count = self._apply_to_graph_or_function(model, model.graph) + count = self._apply_to_graph_or_function(model, model.graph, verbose=verbose) for function in model.functions.values(): - count += self._apply_to_graph_or_function(model, function) + count += self._apply_to_graph_or_function(model, function, verbose=verbose) return count diff --git a/onnxscript/rewriter/pattern_test.py b/onnxscript/rewriter/pattern_test.py index 7296f7610..1ccddcc31 100644 --- a/onnxscript/rewriter/pattern_test.py +++ b/onnxscript/rewriter/pattern_test.py @@ -1,3 +1,5 @@ +import contextlib +import io import logging import unittest @@ -57,6 +59,16 @@ def test_failed_match(self): self.assertEqual(count, 0) self.assertEqual(len(model.graph), 4) + # Test verbose output produces something: + # TODO(rama): Need a better way to test this. + # Well-defined error-codes and messages would be helpful. + + buffer = io.StringIO() + with contextlib.redirect_stdout(buffer): + self.rule().apply_to_model(model, verbose=5) + out = buffer.getvalue() + self.assertIn("Match failed", out) + def test_multiple_matches(self): model_proto = onnx.parser.parse_model( """