From d71b74fb0b194e718a1fa78eddef7d89b57cf4a1 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 16 May 2024 11:03:04 -0700 Subject: [PATCH] [IR] Create pass infra (#1528) Create PassBase, PassResult, PassManager, NodeTransformer for creating passes with the IR. - Implement the `remove_unused_functions` pass using this infrastructure. - Remove the `_invariance` module because it is unused. Future PRs: - Update rewriter to make it compatible with the `PassManager` ## TODO - Better docs for PassManager - Test PassManager Fix #1524 --- onnxscript/ir/__init__.py | 4 +- onnxscript/ir/_invariants.py | 60 ---- onnxscript/ir/passes/__init__.py | 27 ++ onnxscript/ir/passes/_pass_infra.py | 256 ++++++++++++++++++ onnxscript/optimizer/__init__.py | 2 +- .../optimizer/remove_unused_function.py | 95 +++---- .../optimizer/simple_function_folding_test.py | 12 +- onnxscript/rewriter/__init__.py | 2 +- onnxscript/rewriter/onnxruntime/__init__.py | 2 +- 9 files changed, 344 insertions(+), 116 deletions(-) delete mode 100644 onnxscript/ir/_invariants.py create mode 100644 onnxscript/ir/passes/__init__.py create mode 100644 onnxscript/ir/passes/_pass_infra.py diff --git a/onnxscript/ir/__init__.py b/onnxscript/ir/__init__.py index 9d0678656..f8d5793ef 100644 --- a/onnxscript/ir/__init__.py +++ b/onnxscript/ir/__init__.py @@ -68,9 +68,11 @@ # Conversion functions "from_proto", "to_proto", + # Pass infrastructure + "passes", ] -from onnxscript.ir import serde +from onnxscript.ir import passes, serde from onnxscript.ir._core import ( Attr, AttrFloat32, diff --git a/onnxscript/ir/_invariants.py b/onnxscript/ir/_invariants.py deleted file mode 100644 index 8d009c3cc..000000000 --- a/onnxscript/ir/_invariants.py +++ /dev/null @@ -1,60 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- -"""Utilities to enforce invariants on the IR.""" - -from __future__ import annotations - -import functools -from typing import Any, Callable - - -class InvariantError(Exception): - """Raised when an invariant is violated.""" - - -class PreconditionError(InvariantError): - """Raised when a precondition is violated.""" - - -class PostconditionError(InvariantError): - """Raised when a postcondition is violated.""" - - -def requires( - preconditions: Callable[..., str | None], -) -> Callable[..., Callable[..., Any]]: - """Decorator to enforce preconditions on a function.""" - # TODO(justinchuby): Preserve python function signature with this decorator - - def decorator(func: Callable[..., None]) -> Callable[..., None]: - @functools.wraps(func) - def wrapper(*args: Any, **kwargs: Any) -> None: - message = preconditions(*args, **kwargs) - if message is not None: - raise PreconditionError(message) - return func(*args, **kwargs) - - return wrapper - - return decorator - - -def ensures( - postconditions: Callable[..., str | None], -) -> Callable[..., Callable[..., Any]]: - """Decorator to enforce postconditions on a function.""" - - def decorator(func: Callable[..., None]) -> Callable[..., None]: - @functools.wraps(func) - def wrapper(*args: Any, **kwargs: Any) -> None: - result = func(*args, **kwargs) - message = postconditions(*args, **kwargs) - if message is not None: - raise PostconditionError(message) - return result - - return wrapper - - return decorator diff --git a/onnxscript/ir/passes/__init__.py b/onnxscript/ir/passes/__init__.py new file mode 100644 index 000000000..b594918ee --- /dev/null +++ b/onnxscript/ir/passes/__init__.py @@ -0,0 +1,27 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +__all__ = [ + "PassBase", + "PassResult", + "PassManager", + "NodeTransformer", + # Errors + "InvariantError", + "PreconditionError", + "PostconditionError", + "PassError", +] + +from onnxscript.ir.passes._pass_infra import ( + InvariantError, + NodeTransformer, + PassBase, + PassError, + PassManager, + PassResult, + PostconditionError, + PreconditionError, +) diff --git a/onnxscript/ir/passes/_pass_infra.py b/onnxscript/ir/passes/_pass_infra.py new file mode 100644 index 000000000..ed826b3ad --- /dev/null +++ b/onnxscript/ir/passes/_pass_infra.py @@ -0,0 +1,256 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +# +# This module implements some APIs described in +# https://pytorch.org/executorch/stable/compiler-custom-compiler-passes.html +# for the ONNX IR. +# The classes {PassResult and PassManager} are derived from +# https://github.com/pytorch/pytorch/blob/1e47c7b11b312b47a621efd547f5c90081f0d9cb/torch/fx/passes/infra/pass_base.py#L12 +# and +# https://github.com/pytorch/pytorch/blob/1e47c7b11b312b47a621efd547f5c90081f0d9cb/torch/fx/passes/infra/pass_manager.py#L147 +# The original code is licensed under the PyTorch License https://github.com/pytorch/pytorch/blob/main/LICENSE + +"""Passes infrastructure for the IR.""" + +from __future__ import annotations + +import dataclasses +import logging +from typing import Sequence + +__all__ = [ + "NodeTransformer", + "PassBase", + "PassManager", + "PassResult", + # Errors + "InvariantError", + "PreconditionError", + "PostconditionError", + "PassError", +] + +import abc + +from onnxscript import ir + +logger = logging.getLogger(__name__) + + +class InvariantError(Exception): + """Raised when an invariant is violated.""" + + +class PreconditionError(InvariantError): + """Raised when a precondition is violated.""" + + +class PostconditionError(InvariantError): + """Raised when a postcondition is violated.""" + + +class PassError(RuntimeError): + """Raised when an error occurs during a pass.""" + + +@dataclasses.dataclass +class PassResult: + """Result of a pass. + + Attributes: + model: The transformed model. + modified: Whether the model was modified. + """ + + model: ir.Model + modified: bool + + +class PassBase(abc.ABC): + """Base class for all passes. + + Class attributes: + in_place: Whether the pass modifies the model in place. + """ + + in_place: bool = True + + def __call__(self, model: ir.Model) -> PassResult: + return self.call(model) + + @abc.abstractmethod + def call(self, model: ir.Model) -> PassResult: + """The main entry point for the pass.""" + ... + + def requires(self, model: ir.Model) -> None: + """Pre-conditions for the pass. + + This is optional to implement, will be called before call() if run by a pass manager. + """ + del model # Unused + + def ensures(self, model: ir.Model) -> None: + """Post-conditions for the pass. + + This is optional to implement, will be called after call() if run by a pass manager. + """ + del model # Unused + + +class NodeTransformer(PassBase): + """NodeTransformer for the ONNX IR. + + An NodeTransformer is a pass that traverses the IR and performs some + operation on the nodes. The operation can be anything, such as + checking invariants, transforming the IR, or generating code. + + By default, the NodeTransformer updates the model in place. + + .. warning:: + Users should not depend on this class before the warning is removed, because it is not stable. + + Attributes: + model: ir.Model: The model being interpreted. + scope (list[ir.Graph]): The current graph the NodeTransformer is running on. + reversed (bool): Whether to traverse the graph in reverse order. + modified (bool): Whether the model was modified. + """ + + def __init__(self, reversed: bool = False): + self._model: ir.Model | None = None + self.scope: list[ir.Graph] = [] + self.reversed = reversed + self.modified: bool | None = None + + @property + def model(self) -> ir.Model: + """Return the model being interpreted.""" + if self._model is None: + raise ValueError("Model is not set. The model is set during the pass execution.") + return self._model + + def call(self, model: ir.Model) -> PassResult: + self._model = model + self.enter_pass() + self._call_graph(self._model.graph) + self.exit_pass() + if self.modified is None: + raise PassError("The modified attribute was not set. Please set it in the pass.") + return PassResult(self._model, self.modified) + + def _call_graph(self, graph: ir.Graph): + self.enter_graph(graph) + self.scope.append(graph) + iterable = reversed(graph) if self.reversed else graph + for node in iterable: + self.call_node_recursive(node) + self.exit_graph(graph) + self.scope.pop() + + def call_node_recursive(self, node: ir.Node): + self.call_node(node) + for attr in node.attributes.values(): + if not isinstance(attr, ir.Attr): + continue + if attr.type == ir.AttributeType.GRAPH: + self._call_graph(attr.value) + elif attr.type == ir.AttributeType.GRAPHS: + for graph in attr.value: + self._call_graph(graph) + + def enter_pass(self): + """Called when entering the pass. Optional to implement.""" + + def exit_pass(self): + """Called when exiting the pass. Optional to implement.""" + + def enter_graph(self, graph: ir.Graph): + """Called when entering a graph. Optional to implement.""" + del graph # Unused + + def exit_graph(self, graph: ir.Graph): + """Called when exiting a graph. Optional to implement.""" + del graph # Unused + + @abc.abstractmethod + def call_node(self, node: ir.Node): + """Called when visiting a node.""" + ... + + +class PassManager: + """Pass manager for the IR. + + The PassManager is a callable that runs a sequence of passes on a model. + + Attributes: + passes: The passes to run. + check_invariants: Whether to check invariants before and after each pass. + steps: The number of times to run the passes. + """ + + def __init__( + self, + passes: Sequence[PassBase], + check_invariants: bool = False, + steps: int = 1, + ): + # TODO(justinchuby): Implement constraints + self.passes = list(passes) + self.check_invariants = check_invariants + self.steps = steps + + def __call__(self, model: ir.Model) -> PassResult: + """Run the set of passes `steps` number of times or until the graph stops changing.""" + overall_modified = False + for step in range(self.steps): + step_result = self._run_one_step(model, step) + model = step_result.model + modified = step_result.modified + overall_modified = overall_modified or modified + # If the graph no longer changes, then we can stop running these passes + if not modified: + logger.info("PassManager: No more graph changes detected after step %s", step) + break + return PassResult(model, overall_modified) + + def _run_one_step(self, model: ir.Model, step: int) -> PassResult: + modified = False + for i, pass_ in enumerate(self.passes): + logger.debug("Running the %s-th pass '%s', (step %s)", i, pass_, step) + + # 1. Check preconditions + if self.check_invariants: + try: + pass_.requires(model) + except Exception as e: + raise PreconditionError(f"Pre-condition failed for {pass_}") from e + + # 2. Run the pass + try: + pass_result = pass_(model) + except Exception as e: + prev_pass_names = [str(p) for p in self.passes[:i]] + raise PassError( + f"An error occurred when running the '{pass_}' pass after the " + f"following passes: {prev_pass_names} during step {step}" + ) from e + if not isinstance(pass_result, PassResult): + raise TypeError( + f"The result of the pass {pass_} should be type PassResult." + "Please create one with ir.passes.PassResult()." + ) + + model = pass_result.model + modified = modified or pass_result.modified + + # 3. Check postconditions + if self.check_invariants: + try: + pass_.ensures(model) + except Exception as e: + raise PostconditionError(f"Post-condition failed for {pass_}") from e + return PassResult(model, modified) diff --git a/onnxscript/optimizer/__init__.py b/onnxscript/optimizer/__init__.py index 03c1e748e..0931e45c3 100644 --- a/onnxscript/optimizer/__init__.py +++ b/onnxscript/optimizer/__init__.py @@ -74,7 +74,7 @@ def optimize( remove_unused_nodes(model) inline_simple_functions(model) - remove_unused_functions(model) + model = remove_unused_functions(model) inline_functions_with_unused_outputs(model) # NOTE: This is general rewrite rules model = rewriter.rewrite( diff --git a/onnxscript/optimizer/remove_unused_function.py b/onnxscript/optimizer/remove_unused_function.py index 573dfaa8b..55756c062 100644 --- a/onnxscript/optimizer/remove_unused_function.py +++ b/onnxscript/optimizer/remove_unused_function.py @@ -1,56 +1,59 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- from __future__ import annotations import logging import onnx -from google.protobuf.internal.containers import ( # type: ignore - RepeatedCompositeFieldContainer, -) + +from onnxscript import ir logger = logging.getLogger(__name__) -class UnusedFunctionRemover: - def compute_used_in_node(self, n: onnx.NodeProto) -> set[tuple[str, str]]: - used = {(n.domain, n.op_type)} - for attr in n.attribute: - if attr.HasField("g"): - used |= self.process_graph(attr.g) - elif len(attr.graphs) > 0: - for graph in attr.graphs: - used |= self.process_graph(graph) - if (n.domain, n.op_type) in self._functions: - function = self._functions[(n.domain, n.op_type)] - used |= self.process_function(function) - return used - - def process_nodes( - self, nodes: RepeatedCompositeFieldContainer[onnx.NodeProto] - ) -> set[tuple[str, str]]: - used = set() - for node in nodes: - used |= self.compute_used_in_node(node) - return used - - def process_graph(self, graph: onnx.GraphProto) -> set[tuple[str, str]]: - return self.process_nodes(graph.node) - - def process_function(self, function: onnx.FunctionProto) -> set[tuple[str, str]]: - return self.process_nodes(function.node) - - def process_model(self, model: onnx.ModelProto) -> None: - self._functions = {(f.domain, f.name): f for f in model.functions} - used = self.process_graph(model.graph) - count = 0 - logger.debug("Used function protos: %s", used) - for i in range(len(model.functions) - 1, -1, -1): - if (model.functions[i].domain, model.functions[i].name) not in used: - del model.functions[i] - count += 1 - logger.info("Removed %s unused function protos", count) - logger.debug("Function protos left: %s", [f.name for f in model.functions]) - - -def remove_unused_functions(model: onnx.ModelProto) -> None: +class UnusedFunctionRemover(ir.passes.NodeTransformer): + def __init__(self): + super().__init__() + self.used: set[ir.OperatorIdentifier] = set() + + def _call_function(self, function: ir.Function) -> None: + if function.identifier() in self.used: + # The function and its nodes are already recorded as used + return + self.used.add(function.identifier()) + for node in function: + self.call_node_recursive(node) + + def call_node(self, node: ir.Node) -> None: + op_identifier = node.op_identifier() + if op_identifier in self.model.functions: + self._call_function(self.model.functions[op_identifier]) + else: + self.used.add(op_identifier) + + def exit_pass(self) -> None: + # Update the model to remove unused functions + unused = set(self.model.functions) - self.used + if not unused: + logger.info("No unused functions to remove") + self.modified = False + return + for op_identifier in unused: + if op_identifier not in self.used: + del self.model.functions[op_identifier] + self.modified = True + logger.info("Removed %s unused functions", len(unused)) + logger.debug("Functions left: %s", list(self.model.functions)) + logger.debug("Functions removed: %s", unused) + + +def remove_unused_functions(model_proto: onnx.ModelProto) -> onnx.ModelProto: """Removes unused function protos from the model.""" - UnusedFunctionRemover().process_model(model) + # TODO(justinchuby): Update this to accept an ir.Model + model = ir.serde.deserialize_model(model_proto) + UnusedFunctionRemover()(model) + model_proto = ir.serde.serialize_model(model) + + return model_proto diff --git a/onnxscript/optimizer/simple_function_folding_test.py b/onnxscript/optimizer/simple_function_folding_test.py index df7feaec2..34a9e613b 100644 --- a/onnxscript/optimizer/simple_function_folding_test.py +++ b/onnxscript/optimizer/simple_function_folding_test.py @@ -31,7 +31,7 @@ def test_fold_single_node_function(self): ) simple_function_folding.inline_simple_functions(model) - remove_unused_function.remove_unused_functions(model) + model = remove_unused_function.remove_unused_functions(model) self.assertEqual(len(model.functions), 0) @@ -58,7 +58,7 @@ def test_fold_single_node_function_ref_attr(self): ) simple_function_folding.inline_simple_functions(model) - remove_unused_function.remove_unused_functions(model) + model = remove_unused_function.remove_unused_functions(model) self.assertEqual(len(model.functions), 0) self.assertFalse(model.graph.node[0].attribute[0].ref_attr_name) @@ -97,7 +97,7 @@ def test_fold_single_node_function_nested(self): ) simple_function_folding.inline_simple_functions(model) - remove_unused_function.remove_unused_functions(model) + model = remove_unused_function.remove_unused_functions(model) self.assertEqual(len(model.functions), 1) self.assertEqual(model.functions[0].node[0].op_type, "Concat") @@ -126,7 +126,7 @@ def test_fold_single_node_function_create_new_nodes_with_correct_attributes(self """ ) simple_function_folding.inline_simple_functions(model) - remove_unused_function.remove_unused_functions(model) + model = remove_unused_function.remove_unused_functions(model) self.assertEqual(len(model.functions), 0) self.assertEqual(len(model.graph.node), 3) self.assertEqual(model.graph.node[0].attribute[0].i, 10) @@ -169,7 +169,7 @@ def test_fold_nested_if_function_succeeds(self): ) simple_function_folding.inline_simple_functions(model) - remove_unused_function.remove_unused_functions(model) + model = remove_unused_function.remove_unused_functions(model) self.assertEqual(len(model.functions), 0) self.assertEqual(len(model.graph.node), 2) @@ -210,7 +210,7 @@ def test_fold_function_with_unused_output(self): ) simple_function_folding.inline_functions_with_unused_outputs(model) - remove_unused_function.remove_unused_functions(model) + model = remove_unused_function.remove_unused_functions(model) self.assertEqual(len(model.functions), 1) diff --git a/onnxscript/rewriter/__init__.py b/onnxscript/rewriter/__init__.py index 7dc784650..e3add1ac1 100644 --- a/onnxscript/rewriter/__init__.py +++ b/onnxscript/rewriter/__init__.py @@ -39,5 +39,5 @@ def rewrite( print(f"Applied {count} of general pattern rewrite rules.") model = ir.serde.serialize_model(model_ir) remove_unused.remove_unused_nodes(model) - remove_unused_function.remove_unused_functions(model) + model = remove_unused_function.remove_unused_functions(model) return model diff --git a/onnxscript/rewriter/onnxruntime/__init__.py b/onnxscript/rewriter/onnxruntime/__init__.py index 4e9007e36..4a8ffa61b 100644 --- a/onnxscript/rewriter/onnxruntime/__init__.py +++ b/onnxscript/rewriter/onnxruntime/__init__.py @@ -54,5 +54,5 @@ def rewrite( model_proto = ir.serde.serialize_model(model) remove_unused.remove_unused_nodes(model_proto) - remove_unused_function.remove_unused_functions(model_proto) + model_proto = remove_unused_function.remove_unused_functions(model_proto) return model_proto