Skip to content

Commit

Permalink
A couple of extensions to rewriter (#1912)
Browse files Browse the repository at this point in the history
A couple of extensions to the rewriter, motivated by fusion optimization
experimentation with SmoLLM.

* Support list of constants in match-pattern.
* One multi-output scenario is easy to handle with the single-output
pattern-matcher (eg. defining a fusion rule for SkipNormalization):
namely when the extra outputs are intermediate values used in the
computation of the first value. Extend algorithm to handle this scenario
using the efficient single-output matching-algorithm.

An example for the second point is the following pattern:
```py
def skip_norm_pattern(op, input, skip, gamma, epsilon, stash_type):
    skip_sum = op.Add(input, skip)
    normalized = op.SimplifiedLayerNormalization(
        skip_sum,
        gamma,
        axis=-1,
        epsilon=epsilon,
        stash_type=stash_type,
        _domain="com.microsoft")
    return normalized, skip_sum
```
If we successfully find a match for `normalized` (which transitively
finds a match for all of the pattern subgraph that leads up to
`normalized`), we have also found a successful match for `skip_sum`, so
no need for a multi-output match.

(Will add test-cases later, as I work through the fusion optimizations I
am experimenting with.)
  • Loading branch information
gramalingam authored Oct 23, 2024
1 parent 3016daa commit f18dadc
Showing 1 changed file with 75 additions and 19 deletions.
94 changes: 75 additions & 19 deletions onnxscript/rewriter/pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,12 +282,11 @@ def _to_value_pattern(
return x
if isinstance(x, (int, float)):
return Constant(x)
# TODO(rama): support lists of int/float
# if isinstance(x, list):
# if all(isinstance(i, (int, float)) for i in x):
# return Constant(x)
# raise ValueError("Only lists of int/float can be used as a ValuePattern")
# TODO(titaiwang): Could this be wrapped Constant?
if isinstance(x, Sequence):
if all(isinstance(i, (int, float)) for i in x):
return Constant(x)
raise ValueError("Only lists of int/float can be used as a ValuePattern")

raise TypeError(f"Cannot convert {type(x)} to ValuePattern")


Expand Down Expand Up @@ -602,10 +601,13 @@ class Constant(ValuePattern):
"""Represents a pattern that matches against a scalar constant value."""

def __init__(
self, value: int | float, rel_tol: float = 1e-5, abs_tol: float = 1e-8
self,
value: int | float | Sequence[int] | Sequence[float],
rel_tol: float = 1e-5,
abs_tol: float = 1e-8,
) -> None:
super().__init__(None)
self._value = value
self._value = list(value) if isinstance(value, Sequence) else value
self._rel_tol = rel_tol
self._abs_tol = abs_tol

Expand All @@ -614,7 +616,7 @@ def clone(self, node_map: dict[NodePattern, NodePattern]) -> Constant:
return Constant(self._value, self._rel_tol, self._abs_tol)

@property
def value(self) -> int | float:
def value(self) -> int | float | list[int] | list[float]:
return self._value

def matches(self, value: ir.Value, match: MatchResult) -> MatchResult:
Expand All @@ -623,6 +625,24 @@ def matches(self, value: ir.Value, match: MatchResult) -> MatchResult:
return match.fail(f"Value is not a constant, expecting {self.value}.")

constant_value_numpy = constant_value.numpy()
if isinstance(self._value, list):
if constant_value_numpy.shape != (len(self._value),):
return match.fail(f"Value has mismatching shape, expecting ({self.value},).")
if not all(
math.isclose(
constant_value_numpy.item(i),
self._value[i],
rel_tol=self._rel_tol,
abs_tol=self._abs_tol,
)
for i in range(len(self._value))
):
return match.fail(
f"Value mismatch: expected {self._value}, got {constant_value_numpy}."
)
return match

# Scalar constant case:
# TODO (rama): allow users to specify shape requirement, if desired.
if constant_value_numpy.size != 1:
return match.fail(f"Value is not a scalar, expecting {self.value}.")
Expand Down Expand Up @@ -664,6 +684,20 @@ def visit(value_patterns: Sequence[ValuePattern | None]) -> None:
return node_patterns


def _add_backward_slice(node: NodePattern, backward_slice: set[NodePattern]) -> None:
"""Adds all nodes in the backward slice of given node to the set `backward_slice`.
The backward slice of a node is the set of all nodes that are reachable from the node
in a backward traversal from the given node.
"""
if node in backward_slice:
return
backward_slice.add(node)
for value_pattern in node.inputs:
if isinstance(value_pattern, NodeOutputPattern):
_add_backward_slice(value_pattern.producer(), backward_slice)


class GraphPattern:
"""Represents a pattern that can be matched against a subgraph."""

Expand All @@ -679,8 +713,10 @@ def __init__(
raise ValueError("GraphPattern must have at least one output")
self._nodes = nodes # _nodes_in_pattern(outputs)

# Check if all outputs are produced by the same node.
# Determine the output nodes of the pattern. These are a minimal set of nodes
# whose backward-slices cover the entire pattern.
output_nodes: set[NodePattern] = set()
covered: set[NodePattern] = set()
for value_pattern in outputs:
if not isinstance(value_pattern, ValuePattern):
raise TypeError(
Expand All @@ -691,7 +727,11 @@ def __init__(
"Constant values are not allowed as graph pattern outputs."
)
if isinstance(value_pattern, NodeOutputPattern):
output_nodes.add(value_pattern.producer())
candidate = value_pattern.producer()
if candidate not in covered:
output_nodes.add(candidate)
_add_backward_slice(candidate, covered)

self.output_nodes: list[NodePattern] = list(output_nodes)

@property
Expand Down Expand Up @@ -924,20 +964,41 @@ def _match_constant(self, pattern_constant: Constant, value: ir.Value) -> bool:
constant_value_numpy = constant_value.numpy()
except FileNotFoundError:
return self.fail(f"Constant value of {value.name} not available.")

pattern_constant_value = pattern_constant._value

if isinstance(pattern_constant_value, list):
expected_shape = (len(pattern_constant_value),)
if constant_value_numpy.shape != expected_shape:
return self.fail(f"Value has mismatching shape, expecting {expected_shape}.")
if not all(
math.isclose(
constant_value_numpy.item(i),
pattern_constant_value[i],
rel_tol=pattern_constant._rel_tol,
abs_tol=pattern_constant._abs_tol,
)
for i in range(len(pattern_constant_value))
):
return self.fail(
f"Value mismatch: expected {pattern_constant_value}, got {constant_value_numpy}."
)
return True

# TODO (rama): allow users to specify shape requirement, if desired.
if constant_value_numpy.size != 1:
return self.fail(
f"Value {value.name} is not a scalar, expecting {pattern_constant.value}.",
f"Value {value.name} is not a scalar, expecting {pattern_constant_value}.",
)

if not math.isclose(
constant_value_numpy.item(),
pattern_constant._value,
pattern_constant_value,
rel_tol=pattern_constant._rel_tol,
abs_tol=pattern_constant._abs_tol,
):
return self.fail(
f"Constant value mismatch: expected {pattern_constant._value}, got {constant_value_numpy.item()}.",
f"Constant value mismatch: expected {pattern_constant_value}, got {constant_value_numpy.item()}.",
)

return True
Expand Down Expand Up @@ -1079,11 +1140,6 @@ def _match_single_output_node(
if not _valid_to_replace(match.nodes, output_values):
return match.fail("Matched nodes have other uses preventing replacement.")

if len(node.outputs) != pattern.num_outputs:
return match.fail(
f"Number of node outputs mismatch: expected {pattern.num_outputs}, got {len(node.outputs)}."
)

match.outputs.extend(output_values)
return match

Expand Down

0 comments on commit f18dadc

Please sign in to comment.