From d8fc6a4d57a12e2d83d09b365eba4af6d5c03f98 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Tue, 22 Oct 2024 20:16:27 -0700 Subject: [PATCH] Address PR feedback --- onnxscript/rewriter/pattern.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index d269b6f5f..059895ea8 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -282,7 +282,7 @@ def _to_value_pattern( return x if isinstance(x, (int, float)): return Constant(x) - if isinstance(x, list): + if isinstance(x, Sequence): if all(isinstance(i, (int, float)) for i in x): return Constant(x) raise ValueError("Only lists of int/float can be used as a ValuePattern") @@ -602,12 +602,12 @@ class Constant(ValuePattern): def __init__( self, - value: int | float | list[int] | list[float], + value: int | float | Sequence[int] | Sequence[float], rel_tol: float = 1e-5, abs_tol: float = 1e-8, ) -> None: super().__init__(None) - self._value = value + self._value = list(value) if isinstance(value, Sequence) else value self._rel_tol = rel_tol self._abs_tol = abs_tol @@ -685,6 +685,11 @@ def visit(value_patterns: Sequence[ValuePattern | None]) -> None: def _add_backward_slice(node: NodePattern, backward_slice: set[NodePattern]) -> None: + """Adds all nodes in the backward slice of given node to the set `backward_slice`. + + The backward slice of a node is the set of all nodes that are reachable from the node + in a backward traversal from the given node. + """ if node in backward_slice: return backward_slice.add(node)