-
Notifications
You must be signed in to change notification settings - Fork 0
/
graph_compiler.py
302 lines (251 loc) · 11 KB
/
graph_compiler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
from contextlib import contextmanager, nullcontext
from copy import copy
from dataclasses import dataclass
from functools import partial, wraps
from typing import Any, Callable, Dict, List, Optional, Union
import torch
import torch.distributed as dist
# We need to import _functional_collectives to trigger op registration
import torch.distributed._functional_collectives
import torch.nn as nn
import torch.optim as optim
import torch.utils._pytree as pytree
from graph_compiler_utils import SPMD_DECOMP_TABLE
from graph_profiler import GraphProfiler, ProfilerEngine
from torch import fx
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.distributed._functional_collectives import all_reduce
from torch.fx.experimental.proxy_tensor import make_fx
from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo, CodeGen
from torch.nn.utils import stateless
from torch.utils.hooks import RemovableHandle
class _PyTreeCodeGenOutputsOnly(_PyTreeCodeGen):
# pyre-ignore[3]
def process_inputs(self, *args: Any) -> Any:
return args
# pyre-ignore[2, 3]
def gen_fn_def(self, free_vars, maybe_return_annotation):
return CodeGen.gen_fn_def(self, free_vars, maybe_return_annotation)
def _to_caller_flattened_graph_module(gm: fx.GraphModule) -> fx.GraphModule:
"""Move the responsibility of flattening the input arguments from the
graph module to the caller.
Example:
output = gm(my_struct)
gm = gm(to_caller_flattened_graph_module)
output = gm(*pytree.flatten(my_struct)[0])
"""
# pyre-ignore[16]
gm._graph._codegen = _PyTreeCodeGenOutputsOnly(
pytree_info=_PyTreeInfo(
# pyre-ignore[6]
orig_args=None, # type: ignore[arg-type]
# pyre-ignore[6]
in_spec=None, # type: ignore[arg-type]
# pyre-ignore[16]
out_spec=gm._graph._codegen.pytree_info.out_spec,
)
)
gm.graph.eliminate_dead_code()
gm.recompile()
return gm
@contextmanager
def gradients_tagging(params: Dict[str, nn.Parameter]):
"""
This is a helper function that tags the gradient of the parameters
with a special tag, so that we can identify them during SPMD expansion.
It's safe to trace those hooks and we would remove those nodes later.
"""
# tagging_hooks: List[RemovableHandle] = []
all_red_hooks: List[RemovableHandle] = []
try:
for p in params.values():
# h = p.register_hook(lambda grad: torch.ops.dummy.tag_grad(grad))
h2 = p.register_hook(
lambda grad: all_reduce(grad, reduceOp="avg", group=dist.group.WORLD)
)
# tagging_hooks.append(h)
all_red_hooks.append(h2)
yield
finally:
# remove those hooks after tracing
# for h in tagging_hooks:
# h.remove()
for h in all_red_hooks:
h.remove()
@contextmanager
def _rematerialize_optimizer(
opt: optim.Optimizer,
named_states: Dict[str, Any],
params: Dict[str, nn.Parameter],
):
assert opt is not None
# update opt.state with proxy tensors
orig_states = copy(opt.state)
for n in named_states:
# opt.state's key type is string, but optimizer uses Parameter as keys
opt.state[params[n]] = named_states[n] # type: ignore[index]
# FIXME: support multiple parameter groups
param_group = opt.param_groups[0]
orig_params = param_group["params"]
param_group["params"] = params.values()
try:
yield
finally:
param_group["params"] = orig_params
opt.state = orig_states
@contextmanager
def _enable_compile():
# The return value of torch._utils.is_compiling changes optimizer behavior.
# We need that function to return True to include optimizer in the graph.
# See: https://github.com/pytorch/pytorch/blob/a524123c91ab399c9dd6882c1189596dd77e7734/torch/optim/optimizer.py#L41
def f_true():
return True
orig_is_compiling_code = torch._utils.is_compiling.__code__
torch._utils.is_compiling.__code__ = f_true.__code__
try:
yield
finally:
torch._utils.is_compiling.__code__ = orig_is_compiling_code
@dataclass
class _CompiledResult:
gm: fx.GraphModule
mod: nn.Module
opt: Optional[torch.optim.Optimizer]
flat_state: List[torch.Tensor]
def _compile(func: Callable, *args: Any, **kwargs: Any):
# 1. Extract nn.Module and Optimizer from args and kwargs
mod, opt = None, None
for arg in pytree.tree_flatten(list(args) + list(kwargs.values()))[0]:
if isinstance(arg, nn.Module):
assert mod is None, "Only support single nn.Module for now"
mod = arg
if isinstance(arg, optim.Optimizer):
assert opt is None, "Only support single Optimizer for now"
opt = arg
assert mod is not None, "Couldn't find nn.Module instances from the arguments."
# 2. Trace the stateless version of the train_step
params = dict(mod.named_parameters(remove_duplicate=False))
buffers = dict(mod.named_buffers(remove_duplicate=False))
named_states: Dict[str, nn.Parameter] = {}
# Pass named_states instead of opt.state to stateless_func, because
# the later uses nn.Parameter as key. During tracing, we need to
# make sure optimizers can find the states using proxy tensors.
for n, p in params.items():
if p in opt.state:
# opt.state's key type is string, but optimizer uses
# Parameter as keys
named_states[n] = opt.state[p]
# Lift states and parameters as function arguments so that make_fx
# can trace operations applied to them
def stateless_func(
func: Callable,
params: Dict[str, nn.Parameter],
buffers: Dict[str, torch.Tensor],
named_states: Dict[str, nn.Parameter],
args: Any,
kwargs: Any,
):
with stateless._reparametrize_module(
mod, {**params, **buffers}
), _rematerialize_optimizer(
opt, named_states, params
) if opt else nullcontext():
# Installing hooks onto gradients to identify the gradients.
with gradients_tagging(params):
ret = func(*args, **kwargs)
# the return value of the function must be the original return value
# updated paramaters and updated optimizer states
return ret, list(mod.parameters()), list(named_states.values())
tracing_mode = "fake"
fake_mode = FakeTensorMode()
def _get_fake_args(arg: torch.Tensor) -> torch.Tensor:
fake_arg = fake_mode.from_tensor(arg)
return fake_arg
args = pytree.tree_map_only(torch.Tensor, _get_fake_args, args)
kwargs = pytree.tree_map_only(torch.Tensor, _get_fake_args, kwargs)
with _enable_compile(), torch.autograd.detect_anomaly(check_nan=False):
gm = make_fx(
partial(stateless_func, func),
tracing_mode=tracing_mode,
decomposition_table=SPMD_DECOMP_TABLE,
_allow_non_fake_inputs=False,
)(params, buffers, named_states, args, kwargs)
params_and_buffers: Dict[str, Union[torch.Tensor, nn.Parameter]] = {
**params,
**buffers,
}
flat_state, _ = pytree.tree_flatten([params_and_buffers, named_states])
for node in gm.graph.nodes:
if node.target == torch.ops.aten.detach.default:
input_node = node.all_input_nodes[0]
node.replace_all_uses_with(input_node)
if len(node.users) == 0:
gm.graph.erase_node(node)
if node.target == torch.ops.c10d_functional.wait_tensor.default:
all_red_node = node.all_input_nodes[0]
grad_node = all_red_node.all_input_nodes[0]
while grad_node.target == torch.ops.c10d_functional.wait_tensor.default:
node.replace_all_uses_with(grad_node)
if len(node.users) == 0:
gm.graph.erase_node(node)
all_red_node = grad_node.all_input_nodes[0]
grad_node = all_red_node.all_input_nodes[0]
gm = _to_caller_flattened_graph_module(gm)
return _CompiledResult(gm, mod, opt, flat_state)
# Note that the Python convention of __dict__ requires the key to be str.
# TODO: ensure the key is unique.
COMPILED_OBJECT_KEY = "_compiled_obj"
def compile(
gm_transformations: Optional[Callable[[fx.GraphModule], fx.GraphModule]] = None,
):
r"""
Compile and optimize a callable, which can be a train step within a training
loop. This method will extract :class:`nn.Module` and :class:`torch.optim.Optimizer`
instances from the input arguments and trace operations applied to their
parameters and states.
Args:
gm_transformation (Optional[Callable[fx.GraphModule, fx.GraphModule]]):
a callback that will be called after the original callable is
compiled (usually after the first iteration) to
transform the compiled GraphModule into a new optimized one.
"""
def compile_inner(func: Callable):
@wraps(func)
def wrapper(*args, **kwargs):
last_train_step = kwargs.pop("last_train_step", False) if kwargs else False
first_iter = False
# Put the COMPILED_OBJECT_KEY in ``wrapper`` instead of ``func`` as
# ``wrapper`` is the one that users will get.
compiled_obj = wrapper.__dict__.get(COMPILED_OBJECT_KEY, None)
if compiled_obj is None:
first_iter = True
compiled_obj = _compile(func, *args, **kwargs)
wrapper.__dict__[COMPILED_OBJECT_KEY] = compiled_obj
print(compiled_obj.gm.graph)
flat_inps = compiled_obj.flat_state + pytree.tree_flatten([args, kwargs])[0]
# profiler_engine = ProfilerEngine(gm = compiled_obj.gm, profile_mode="default")
# profiler_engine.run(*flat_inps)
# profiler_engine.summarize(to_aggregate=True, to_print=True)
with torch.no_grad():
# N.B.: we don't need autograd as backward has already been
# captured in the graph.
if first_iter and gm_transformations:
# print(compiled_obj.gm.graph)
compiled_obj.gm = gm_transformations(compiled_obj.gm)
if not last_train_step:
output = compiled_obj.gm(*flat_inps)[0]
else:
# This is the last train step. Call IterGraphModule.forward()
# with the `last_iter` argument and catch the exception in
# case the compiled_obj is not wrapped with IterGraphModule.
try:
output = compiled_obj.gm(*flat_inps, last_iter=last_train_step)[
0
]
except TypeError as e:
if "last_iter" not in str(e):
raise e
output = compiled_obj.gm(*flat_inps)[0]
return output
return wrapper
return compile_inner