diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index d49e503f1..059895ea8 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -282,12 +282,11 @@ def _to_value_pattern( return x if isinstance(x, (int, float)): return Constant(x) - # TODO(rama): support lists of int/float - # if isinstance(x, list): - # if all(isinstance(i, (int, float)) for i in x): - # return Constant(x) - # raise ValueError("Only lists of int/float can be used as a ValuePattern") - # TODO(titaiwang): Could this be wrapped Constant? + if isinstance(x, Sequence): + if all(isinstance(i, (int, float)) for i in x): + return Constant(x) + raise ValueError("Only lists of int/float can be used as a ValuePattern") + raise TypeError(f"Cannot convert {type(x)} to ValuePattern") @@ -602,10 +601,13 @@ class Constant(ValuePattern): """Represents a pattern that matches against a scalar constant value.""" def __init__( - self, value: int | float, rel_tol: float = 1e-5, abs_tol: float = 1e-8 + self, + value: int | float | Sequence[int] | Sequence[float], + rel_tol: float = 1e-5, + abs_tol: float = 1e-8, ) -> None: super().__init__(None) - self._value = value + self._value = list(value) if isinstance(value, Sequence) else value self._rel_tol = rel_tol self._abs_tol = abs_tol @@ -614,7 +616,7 @@ def clone(self, node_map: dict[NodePattern, NodePattern]) -> Constant: return Constant(self._value, self._rel_tol, self._abs_tol) @property - def value(self) -> int | float: + def value(self) -> int | float | list[int] | list[float]: return self._value def matches(self, value: ir.Value, match: MatchResult) -> MatchResult: @@ -623,6 +625,24 @@ def matches(self, value: ir.Value, match: MatchResult) -> MatchResult: return match.fail(f"Value is not a constant, expecting {self.value}.") constant_value_numpy = constant_value.numpy() + if isinstance(self._value, list): + if constant_value_numpy.shape != (len(self._value),): + return match.fail(f"Value has mismatching shape, expecting ({self.value},).") + if not all( + math.isclose( + constant_value_numpy.item(i), + self._value[i], + rel_tol=self._rel_tol, + abs_tol=self._abs_tol, + ) + for i in range(len(self._value)) + ): + return match.fail( + f"Value mismatch: expected {self._value}, got {constant_value_numpy}." + ) + return match + + # Scalar constant case: # TODO (rama): allow users to specify shape requirement, if desired. if constant_value_numpy.size != 1: return match.fail(f"Value is not a scalar, expecting {self.value}.") @@ -664,6 +684,20 @@ def visit(value_patterns: Sequence[ValuePattern | None]) -> None: return node_patterns +def _add_backward_slice(node: NodePattern, backward_slice: set[NodePattern]) -> None: + """Adds all nodes in the backward slice of given node to the set `backward_slice`. + + The backward slice of a node is the set of all nodes that are reachable from the node + in a backward traversal from the given node. + """ + if node in backward_slice: + return + backward_slice.add(node) + for value_pattern in node.inputs: + if isinstance(value_pattern, NodeOutputPattern): + _add_backward_slice(value_pattern.producer(), backward_slice) + + class GraphPattern: """Represents a pattern that can be matched against a subgraph.""" @@ -679,8 +713,10 @@ def __init__( raise ValueError("GraphPattern must have at least one output") self._nodes = nodes # _nodes_in_pattern(outputs) - # Check if all outputs are produced by the same node. + # Determine the output nodes of the pattern. These are a minimal set of nodes + # whose backward-slices cover the entire pattern. output_nodes: set[NodePattern] = set() + covered: set[NodePattern] = set() for value_pattern in outputs: if not isinstance(value_pattern, ValuePattern): raise TypeError( @@ -691,7 +727,11 @@ def __init__( "Constant values are not allowed as graph pattern outputs." ) if isinstance(value_pattern, NodeOutputPattern): - output_nodes.add(value_pattern.producer()) + candidate = value_pattern.producer() + if candidate not in covered: + output_nodes.add(candidate) + _add_backward_slice(candidate, covered) + self.output_nodes: list[NodePattern] = list(output_nodes) @property @@ -924,20 +964,41 @@ def _match_constant(self, pattern_constant: Constant, value: ir.Value) -> bool: constant_value_numpy = constant_value.numpy() except FileNotFoundError: return self.fail(f"Constant value of {value.name} not available.") + + pattern_constant_value = pattern_constant._value + + if isinstance(pattern_constant_value, list): + expected_shape = (len(pattern_constant_value),) + if constant_value_numpy.shape != expected_shape: + return self.fail(f"Value has mismatching shape, expecting {expected_shape}.") + if not all( + math.isclose( + constant_value_numpy.item(i), + pattern_constant_value[i], + rel_tol=pattern_constant._rel_tol, + abs_tol=pattern_constant._abs_tol, + ) + for i in range(len(pattern_constant_value)) + ): + return self.fail( + f"Value mismatch: expected {pattern_constant_value}, got {constant_value_numpy}." + ) + return True + # TODO (rama): allow users to specify shape requirement, if desired. if constant_value_numpy.size != 1: return self.fail( - f"Value {value.name} is not a scalar, expecting {pattern_constant.value}.", + f"Value {value.name} is not a scalar, expecting {pattern_constant_value}.", ) if not math.isclose( constant_value_numpy.item(), - pattern_constant._value, + pattern_constant_value, rel_tol=pattern_constant._rel_tol, abs_tol=pattern_constant._abs_tol, ): return self.fail( - f"Constant value mismatch: expected {pattern_constant._value}, got {constant_value_numpy.item()}.", + f"Constant value mismatch: expected {pattern_constant_value}, got {constant_value_numpy.item()}.", ) return True @@ -1079,11 +1140,6 @@ def _match_single_output_node( if not _valid_to_replace(match.nodes, output_values): return match.fail("Matched nodes have other uses preventing replacement.") - if len(node.outputs) != pattern.num_outputs: - return match.fail( - f"Number of node outputs mismatch: expected {pattern.num_outputs}, got {len(node.outputs)}." - ) - match.outputs.extend(output_values) return match