Skip to content

Commit

Permalink
Address PR feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
gramalingam committed Oct 23, 2024
1 parent 9299b92 commit d8fc6a4
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions onnxscript/rewriter/pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ def _to_value_pattern(
return x
if isinstance(x, (int, float)):
return Constant(x)
if isinstance(x, list):
if isinstance(x, Sequence):
if all(isinstance(i, (int, float)) for i in x):
return Constant(x)
raise ValueError("Only lists of int/float can be used as a ValuePattern")

Check warning on line 288 in onnxscript/rewriter/pattern.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/pattern.py#L287-L288

Added lines #L287 - L288 were not covered by tests
Expand Down Expand Up @@ -602,12 +602,12 @@ class Constant(ValuePattern):

def __init__(
self,
value: int | float | list[int] | list[float],
value: int | float | Sequence[int] | Sequence[float],
rel_tol: float = 1e-5,
abs_tol: float = 1e-8,
) -> None:
super().__init__(None)
self._value = value
self._value = list(value) if isinstance(value, Sequence) else value
self._rel_tol = rel_tol
self._abs_tol = abs_tol

Expand Down Expand Up @@ -685,6 +685,11 @@ def visit(value_patterns: Sequence[ValuePattern | None]) -> None:


def _add_backward_slice(node: NodePattern, backward_slice: set[NodePattern]) -> None:
"""Adds all nodes in the backward slice of given node to the set `backward_slice`.
The backward slice of a node is the set of all nodes that are reachable from the node
in a backward traversal from the given node.
"""
if node in backward_slice:
return
backward_slice.add(node)
Expand Down

0 comments on commit d8fc6a4

Please sign in to comment.