-
Notifications
You must be signed in to change notification settings - Fork 0
/
graph_utils.py
159 lines (131 loc) · 5.1 KB
/
graph_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import logging
import os
import tempfile
from enum import Enum
from typing import Callable, cast, Dict, Iterable, List, Set
import torch.fx as fx
from torch.fx.passes.shape_prop import TensorMetadata
from torch.utils._pytree import tree_flatten, tree_unflatten
logger: logging.Logger = logging.getLogger("graph_utils")
class OP(str, Enum):
CALL_FUNCTION = "call_function"
CALL_MODULE = "call_module"
CALL_METHOD = "call_method"
GET_ATTR = "get_attr"
OUTPUT = "output"
PLACEHOLDER = "placeholder"
class CommType(str, Enum):
ALLREDUCE = "allreduce_"
ALLGATHER = "allgather_"
BROADCAST = "broadcast_"
REDUCESCATTER = "reduce_scatter_"
SCATTER = "scatter_"
def get_node_tensor_metadata(node: fx.Node, is_required: bool = True) -> TensorMetadata:
metadata = node.meta.get("tensor_meta", None)
if is_required and metadata is None:
raise RuntimeError(
f"Callsite expects that ``tensor_meta`` exists in ``{node.name}``, "
f"but got None instead. Node: {node.op} {node.name} {node.target}"
)
return metadata
def get_output(graph: fx.Graph) -> fx.Node:
"""
Take a graphmodule and returns the graph output node. We traverse in reverse
to expedite it, with the idea that last node should be output
"""
for node in reversed(graph.nodes):
if node.op == OP.OUTPUT:
return node
raise RuntimeError(f"Cannot find the output node in {graph}")
def find_node(
graph: fx.Graph, predicate: Callable, reverse_order: bool = False
) -> List[fx.Node]:
"""
Take a predicate and return all the nodes in the `graph` where the predicate
holds.
"""
nodes = cast(Iterable[fx.Node], graph.nodes)
if reverse_order:
nodes = cast(Iterable[fx.Node], iter(reversed(nodes))) # type: ignore[call-overload]
return [node for node in nodes if predicate(node)]
def is_leaf_subgraph(graph: fx.Graph, subgraph: List[fx.Node]) -> bool:
"""
This function ensures nodes in ``subgraph`` satisfy one of the rules:
1. The user of the node is in ``subgraph``.
2. The user of the node is output.
3. There are no users -- the node is a side-effect node.
"""
all_nodes: Set[fx.Node] = set(subgraph)
output = get_output(graph)
for node in subgraph:
for user in node.users:
if not isinstance(user, fx.Node):
continue
if user not in all_nodes and user != output:
return False
return True
def clone_subgraph(
graph: fx.Graph, subgraph: List[fx.Node], target: fx.Node
) -> List[fx.Node]:
"""
Clone the given subgraph and insert it before ``target``.
This API currently does not support inserting after ``target``.
"""
all_nodes = set(subgraph)
mapping: Dict[fx.Node, fx.Node] = dict()
cloned_subgraph = []
with graph.inserting_before(target):
for node in subgraph:
cloned_node = graph.call_function(
node.target, node.args, node.kwargs, node.type
)
# TODO: there are many flatten/unflatten in IterGraph that
# can be simplified with tree_map. Will simplify this in
# a follow-up PR.
original_input, _ = tree_flatten((node.args, node.kwargs))
cloned_input, spec = tree_flatten((cloned_node.args, cloned_node.kwargs))
mapped_cloned_input = []
for original_input_node, cloned_input_node in zip(
original_input, cloned_input
):
if (
isinstance(original_input_node, fx.Node)
and original_input_node in all_nodes
):
assert original_input_node in mapping
mapped_cloned_input.append(mapping[original_input_node])
else:
mapped_cloned_input.append(cloned_input_node)
cloned_node.args, cloned_node.kwargs = tree_unflatten(
mapped_cloned_input, spec
)
mapping[node] = cloned_node
cloned_subgraph.append(cloned_node)
return cloned_subgraph
def rebuild_graph(gm: fx.GraphModule, remove_dead_code: bool = True) -> None:
"""
Runs the required steps to ensure production-ready graph.
note - per the fx docs, eliminate dead code is not very precise.
Hence, the flag to make this step optional.
"""
gm.graph.lint()
if remove_dead_code:
gm.graph.eliminate_dead_code()
gm.recompile()
def dump_graphs_to_files(graphs: Dict[str, fx.GraphModule], folder: str = "") -> str:
if not folder:
folder = tempfile.mkdtemp()
for prefix, gm in graphs.items():
with open(os.path.join(folder, f"{prefix}.graph"), "w") as fp:
fp.write(str(gm))
logger.warning("Dump graphs to %s", folder)
return folder
def replace_subsequent_uses_of(
graph: fx.Graph, old_node: fx.Node, new_node: fx.Node
) -> None:
old_node_users = old_node.users
for node in reversed(graph.nodes):
if node == new_node:
break
if node in old_node_users:
node.replace_input_with(old_node, new_node)