Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[IR] Implement pass to remove unused nodes in graph #1841

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 82 additions & 0 deletions onnxscript/ir/passes/_remove_unused.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Copyright (c) Microsoft Corporation.

Check warning

Code scanning / lintrunner

RUFF-FORMAT/format Warning

Run lintrunner -a to apply this patch.
# Licensed under the MIT License.
"""Utilities for removing unused nodes the IR graph."""

from __future__ import annotations

from collections import deque

import onnxscript.ir as ir
from onnxscript.ir import Attr, Graph, Node, Value, _enums
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Import modules only



class RemoveUnused:
def __init__(self, graph_like: Graph):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ir.Graph?

self._graph = graph_like

def purge(self) -> None:
"""Remove unused nodes in this graph (and all subgraphs) that do not contribute to main graph outputs."""
# 1. Initialize:
# Gather all nodes from the graph and its subgraphs.
# Initialize sets to keep track of visited graphs, values, and nodes.
# 2. BFS traversal:
# Create a queue initialized with all output values of the main graph.
# While there are values in the queue:
# - Dequeue a value and retrieve its producer node.
# - Mark the producer node as visited, if it hasn't been visited.
# - Enqueue all output values of the attribute subgraphs of the producer node,
# if they haven't been visited.
# - Enqueue all input values of the producer node, if they haven't been visited.
# 3. Remove:
# Remove all nodes that have not been marked as visited during the BFS traversal.

# Initialize
all_nodes: list[Node] = list(ir.traversal.RecursiveGraphIterator(self._graph))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: now this can be moved down to line 79, which is where it is used, I think ...

visited_graphs: set[Graph] = set()
visited_values: set[Value] = set()
visited_nodes: set[Node] = set()

# BFS Traversal
queue: deque[Value] = deque()

def add_graph_output_values_to_queue(graph: Graph | None) -> None:
"""Helper function to add all output values of a graph to the queue."""
if not graph or graph in visited_graphs:
return
visited_graphs.add(graph)
for output in graph.outputs:
if not output:
continue

Check warning on line 49 in onnxscript/ir/passes/_remove_unused.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/passes/_remove_unused.py#L49

Added line #L49 was not covered by tests
queue.append(output)
visited_values.add(output)

add_graph_output_values_to_queue(self._graph)

while queue:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be possible to avoid the queue by looping over all nodes in the backward order (assuming the ir preserve a consistent order on the nodes +@justinchuby).

# Dequeue a value and retrieve its producer_node
# Add producer_node to visited_nodes
current_value = queue.popleft()
producer_node = current_value.producer()
if not producer_node or producer_node in visited_nodes:
continue
visited_nodes.add(producer_node)
# Add producer_node's subgraphs to visited_graphs
# Add subgraphs' output values to queue
for attr in producer_node.attributes.values():
if not isinstance(attr, Attr):
continue

Check warning on line 67 in onnxscript/ir/passes/_remove_unused.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/passes/_remove_unused.py#L67

Added line #L67 was not covered by tests
if attr.type == _enums.AttributeType.GRAPH:
add_graph_output_values_to_queue(attr.value)

Check warning on line 69 in onnxscript/ir/passes/_remove_unused.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/passes/_remove_unused.py#L69

Added line #L69 was not covered by tests
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some subgraphs use intermediate results declared in the main graph. You need to loop over nodes inside subgraphs as well. You'll have to handle inputs/outputs with the same name.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The RecursiveGraphIterator will loop over all nodes in subgraphs. So all_nodes includes nodes from the subgraph already

Copy link
Collaborator

@gramalingam gramalingam Sep 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the code handles Xavier's case, but not because of the recursive graph iterator. (That seems to be used only in the later loop below to remove nodes). The code above goes from value to the producer of the value: this should go from a use inside a subgraph to a producer outside the subgraph (as long as the IR is constructed correctly.)

elif attr.type == _enums.AttributeType.GRAPHS:
for subgraph in attr.value:
add_graph_output_values_to_queue(subgraph)
# Add producer_node's input values to queue
for input_value in producer_node.inputs:
if input_value and input_value not in visited_values:
visited_values.add(input_value)
queue.append(input_value)

# Remove
for node in all_nodes:
if node not in visited_nodes: # type: ignore[union-attr]`

Check failure

Code scanning / lintrunner

MYPY/syntax Error

Invalid "type: ignore" comment To disable, use # type: ignore[syntax]
node.graph.remove(node)

Check failure

Code scanning / lintrunner

MYPY/union-attr Error

Item "None" of "Graph | None" has no attribute "remove" To disable, use # type: ignore[union-attr]
Copy link
Contributor Author

@yichen-li-ucla yichen-li-ucla Sep 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Acknowledged.

133 changes: 133 additions & 0 deletions onnxscript/ir/passes/_remove_unused_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import unittest
github-advanced-security[bot] marked this conversation as resolved.
Fixed
Show resolved Hide resolved

from onnxscript import ir
from onnxscript.ir.passes._remove_unused import RemoveUnused


class RemoveUnusedTest(unittest.TestCase):
def test_purge_empty(self):
graph = ir.Graph(
inputs=(),
outputs=(),
nodes=(),
opset_imports={"": 1},
)
remove_unused = RemoveUnused(graph)
remove_unused.purge()
self.assertEqual(tuple(graph), ())

def test_purge_a_single_node(self):
v0 = ir.Value(name="v0")
node0 = ir.Node("", "Node0", inputs=(v0,), num_outputs=1)
node1 = ir.Node("", "Node1", inputs=(v0,), num_outputs=1)
node2 = ir.Node("", "Node2", inputs=(v0,), num_outputs=0)
node3 = ir.Node("", "Node3", inputs=(), num_outputs=1)
node4 = ir.Node("", "Node4", inputs=(None,), num_outputs=1)
graph = ir.Graph(
(v0,),
(node0.outputs[0], node3.outputs[0], node4.outputs[0]),
nodes=(node0, node1, node2, node3, node4),
opset_imports={"": 1},
)
remove_unused = RemoveUnused(graph)
remove_unused.purge()
self.assertEqual(tuple(graph), (node0, node3, node4))

def test_purge_a_tree(self):
v0 = ir.Value(name="v0")
node0 = ir.Node("", "Node0", inputs=(v0,), num_outputs=1)
node1 = ir.Node("", "Node1", inputs=(node0.outputs[0],), num_outputs=1)
node2 = ir.Node("", "Node2", inputs=(node0.outputs[0],), num_outputs=1)
graph = ir.Graph(
(v0,),
(),
nodes=(node0, node1, node2),
opset_imports={"": 1},
)
remove_unused = RemoveUnused(graph)
remove_unused.purge()
self.assertEqual(tuple(graph), ())

def test_purge_subgraph_partial(self):
v0 = ir.Value(name="va")
v1 = ir.Value(name="vb")
v2 = ir.Value(name="vc")
v3 = ir.Value(name="vd")
node0 = ir.Node("", "a", inputs=(v0,), num_outputs=1)
node1 = ir.Node("", "b", inputs=(v1,), num_outputs=1)
node2 = ir.Node("", "c", inputs=(v2,), num_outputs=1)
node3 = ir.Node("", "d", inputs=(v3,), num_outputs=1)
node4 = ir.Node("", "sub", inputs=(node2.outputs[0], node3.outputs[0]), num_outputs=1)
node5 = ir.Node("", "add", inputs=(node2.outputs[0], node3.outputs[0]), num_outputs=1)
node6 = ir.Node("", ">", inputs=(node0.outputs[0], node1.outputs[0]), num_outputs=1)
then_graph = ir.Graph(
inputs=(node2.outputs[0], node3.outputs[0]),
outputs=(node4.outputs[0],),
nodes=(node4,),
name="then_graph",
)
else_graph = ir.Graph(
inputs=(node2.outputs[0], node3.outputs[0]),
outputs=(),
nodes=(node5,),
name="else_graph",
)

node7 = ir.Node(
"",
"if",
inputs=(node6.outputs[0],),
num_outputs=1,
attributes=[
ir.AttrGraphs("subgraphs", [then_graph, else_graph]),
],
)
main_graph = ir.Graph(
inputs=(v0, v1, v2, v3),
outputs=(node7.outputs[0],),
nodes=(node0, node1, node2, node3, node6, node7),
name="main_graph",
opset_imports={"": 1},
)
remove_unused = RemoveUnused(main_graph)
remove_unused.purge()
self.assertEqual(tuple(main_graph), (node0, node1, node2, node3, node6, node7))
self.assertEqual(tuple(then_graph), (node4,))
self.assertEqual(tuple(else_graph), ())

def test_purge_subgraph_all(self):
v0 = ir.Value(name="v0")
node0 = ir.Node("", "c", inputs=(v0,), num_outputs=1)
node1 = ir.Node("", "sub", inputs=(node0.outputs[0],), num_outputs=1)
node2 = ir.Node("", ">", inputs=(v0,), num_outputs=1)
then_graph = ir.Graph(
inputs=(node0.outputs[0],),
outputs=(node1.outputs[0],),
nodes=(node1,),
name="then_graph",
)
node4 = ir.Node(
"",
"if",
inputs=(node2.outputs[0],),
num_outputs=1,
attributes=[
ir.AttrGraph("then_graph", then_graph),
],
)
main_graph = ir.Graph(
inputs=(v0,),
outputs=(),
nodes=(node0, node2, node4),
name="main_graph",
)
remove_unused = RemoveUnused(main_graph)
remove_unused.purge()
self.assertEqual(tuple(main_graph), ())
self.assertEqual(tuple(then_graph), ())


if __name__ == "__main__":
unittest.main()

Check warning on line 133 in onnxscript/ir/passes/_remove_unused_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/passes/_remove_unused_test.py#L133

Added line #L133 was not covered by tests
Loading