diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index ef61a8e3b..ddff1b93b 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -959,6 +959,7 @@ def _match_node(self, pattern_node: NodePattern, node: ir.Node) -> bool: self._matched[pattern_node] = node + # Note: Need to revisit this to handle optional trailing inputs better. if len(node.inputs) != len(pattern_node.inputs): return self.fail("Input nums mismatch. {len(node.inputs)} vs {len(pattern_node.inputs)}")