Skip to content

Commit

Permalink
Update rewriter to allow matching variable against None
Browse files Browse the repository at this point in the history
Signed-off-by: Ganesan Ramalingam <[email protected]>
  • Loading branch information
gramalingam committed Oct 2, 2024
1 parent c8a299a commit dc9a12d
Showing 1 changed file with 17 additions and 14 deletions.
31 changes: 17 additions & 14 deletions onnxscript/rewriter/pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,9 @@ def producer(self) -> NodePattern:

Var = ValuePattern

def _is_pattern_variable(x: Any) -> bool:
# The derived classes of ValuePattern represent constant patterns and node-output patterns.
return type(x) is ValuePattern

class Constant(ValuePattern):
"""Represents a pattern that matches against a scalar constant value."""
Expand Down Expand Up @@ -954,18 +957,14 @@ def _match_node(self, pattern_node: NodePattern, node: ir.Node) -> bool:

self._matched[pattern_node] = node

for arg_value, previous_node_output_pattern in zip(node.inputs, pattern_node.inputs):
# previous_node_output_pattern could be a Var, if it's the original arg.
if arg_value is None and previous_node_output_pattern is None:
continue
if arg_value is None or previous_node_output_pattern is None:
msg = (
"Input not expected to be None"
if arg_value is None
else "Input expected to be None"
)
return self.fail(msg)
if not self._match_value(previous_node_output_pattern, arg_value):
for arg_value, arg_pattern in zip(node.inputs, pattern_node.inputs):
# arg_pattern could be a Var, if it's the original arg.
if arg_pattern is None:
if arg_value is None:
continue
else:
return self.fail("(Optional) input is expected to be None but is not.")
if not self._match_value(arg_pattern, arg_value):
return False

for i, output_value_pattern in enumerate(pattern_node.outputs):
Expand All @@ -975,7 +974,7 @@ def _match_node(self, pattern_node: NodePattern, node: ir.Node) -> bool:
match.nodes.append(node)
return True

def _bind_value(self, pattern_value: ValuePattern, value: ir.Value) -> bool:
def _bind_value(self, pattern_value: ValuePattern, value: ir.Value | None) -> bool:
"""Bind a ValuePattern var to ir Value."""
if pattern_value.name is not None:
match = self._match
Expand All @@ -987,8 +986,12 @@ def _bind_value(self, pattern_value: ValuePattern, value: ir.Value) -> bool:
match.bindings[pattern_value.name] = value
return True

def _match_value(self, pattern_value: ValuePattern, value: ir.Value) -> bool:
def _match_value(self, pattern_value: ValuePattern, value: ir.Value | None) -> bool:
"""Match an IR value against a ValuePattern instance."""
if value is None:
if not _is_pattern_variable(pattern_value):
return self.fail("Mismatch: input value is None, but pattern value is not a variable.")

if not self._bind_value(pattern_value, value):
return False

Expand Down

0 comments on commit dc9a12d

Please sign in to comment.