Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[IR] Implement pass to remove unused nodes in graph #1841

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

yichen-li-ucla
Copy link
Contributor

@yichen-li-ucla yichen-li-ucla commented Sep 2, 2024

@yichen-li-ucla yichen-li-ucla changed the title Remove unused nodes. [IR] Remove unused nodes in graph. Sep 2, 2024
@justinchuby justinchuby changed the title [IR] Remove unused nodes in graph. [IR] Implement pass to remove unused nodes in graph Sep 2, 2024
Copy link

codecov bot commented Sep 2, 2024

Codecov Report

Attention: Patch coverage is 91.66667% with 9 lines in your changes missing coverage. Please review.

Project coverage is 75.28%. Comparing base (d7a6411) to head (b69dd83).
Report is 18 commits behind head on main.

Files with missing lines Patch % Lines
onnxscript/ir/passes/_remove_unused.py 84.09% 3 Missing and 4 partials ⚠️
onnxscript/ir/passes/_remove_unused_test.py 96.87% 1 Missing and 1 partial ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1841      +/-   ##
==========================================
+ Coverage   75.20%   75.28%   +0.07%     
==========================================
  Files         251      253       +2     
  Lines       27429    27537     +108     
  Branches     5032     5047      +15     
==========================================
+ Hits        20629    20730     +101     
- Misses       5828     5830       +2     
- Partials      972      977       +5     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

onnxscript/ir/passes/_remove_unused.py Fixed Show resolved Hide resolved
# Remove
for node in all_nodes:
if node not in visited_nodes:
node.graph.remove(node)

Check failure

Code scanning / lintrunner

MYPY/union-attr Error

Item "None" of "Graph | None" has no attribute "remove" To disable, use # type: ignore[union-attr]
Copy link
Contributor Author

@yichen-li-ucla yichen-li-ucla Sep 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Acknowledged.

onnxscript/ir/passes/_remove_unused_test.py Fixed Show resolved Hide resolved
onnxscript/ir/passes/_remove_unused_test.py Fixed Show resolved Hide resolved
onnxscript/ir/passes/_remove_unused_test.py Fixed Show resolved Hide resolved
@yichen-li-ucla
Copy link
Contributor Author

@justinchuby

@yichen-li-ucla
Copy link
Contributor Author

Just updated with upstream. No need for rerun.

visited_nodes: set[Node] = set()

# BFS Traversal
value_queue: deque[Value] = deque(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be better if a node's subgraphs are processed only after the node is itself determined to be useful (that is, added to visited_nodes. This will handle examples such as the one below better:

   x = ...
   y = If ( cond, ... x ..., ...)

Here, if y is not used, then we may not need x either. But the current logic will, I believe, mark x as visited since it is used to compute the output of the If's then subgraph's output.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree. Do we really need all subgraphs' outputs? @justinchuby

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I adopted @gramalingam 's idea, modified the code and added a testcase.



class RemoveUnused:
def __init__(self, graph_like: Graph):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ir.Graph?

if not isinstance(attr, Attr):
continue
if attr.type == _enums.AttributeType.GRAPH:
add_graph_output_values_to_queue(attr.value)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some subgraphs use intermediate results declared in the main graph. You need to loop over nodes inside subgraphs as well. You'll have to handle inputs/outputs with the same name.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The RecursiveGraphIterator will loop over all nodes in subgraphs. So all_nodes includes nodes from the subgraph already

Copy link
Collaborator

@gramalingam gramalingam Sep 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the code handles Xavier's case, but not because of the recursive graph iterator. (That seems to be used only in the later loop below to remove nodes). The code above goes from value to the producer of the value: this should go from a use inside a subgraph to a producer outside the subgraph (as long as the IR is constructed correctly.)


add_graph_output_values_to_queue(self._graph)

while queue:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be possible to avoid the queue by looping over all nodes in the backward order (assuming the ir preserve a consistent order on the nodes +@justinchuby).

from collections import deque

import onnxscript.ir as ir
from onnxscript.ir import Attr, Graph, Node, Value, _enums
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Import modules only

@justinchuby
Copy link
Collaborator

I will come back to this later this week. Thanks for your patience!

@@ -0,0 +1,82 @@
# Copyright (c) Microsoft Corporation.

Check warning

Code scanning / lintrunner

RUFF-FORMAT/format Warning

Run lintrunner -a to apply this patch.

# Remove
for node in all_nodes:
if node not in visited_nodes: # type: ignore[union-attr]`

Check failure

Code scanning / lintrunner

MYPY/syntax Error

Invalid "type: ignore" comment To disable, use # type: ignore[syntax]
# Remove all nodes that have not been marked as visited during the BFS traversal.

# Initialize
all_nodes: list[Node] = list(ir.traversal.RecursiveGraphIterator(self._graph))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: now this can be moved down to line 79, which is where it is used, I think ...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Development

Successfully merging this pull request may close these issues.

[IR] Dead code elimination pass
4 participants