-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_transformation.py
358 lines (297 loc) · 10.8 KB
/
test_transformation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
# Owner(s): ["oncall: distributed"]
import unittest
from copy import deepcopy
from functools import wraps
from typing import Any
from unittest.mock import MagicMock
import torch
import torch.nn as nn
from torch._inductor.utils import has_triton
from torch.distributed._spmd.api import compile
from torch.distributed._spmd.gm_transformation import GraphModuleTransformation
from torch.distributed._spmd.graph_optimization import (
_optimized_func,
comm_fusion_with_concat,
find_all_descendants,
get_all_fused_optimizer_blocks,
graph_optimization_pass,
iter_move_grads_and_optimizers,
remove_copy_from_optimizer,
schedule_comm_wait,
split_fused_optimizer,
)
from torch.distributed._spmd.graph_utils import find_node
from torch.distributed._spmd.iter_graph_module import IterGraphModule
from torch.distributed._tensor.op_schema import OpSchema, OutputSharding
from torch.distributed._tensor.ops.utils import register_prop_rule
from torch.distributed._tensor.placement_types import DTensorSpec
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_utils import run_tests, skipIfRocm
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
with_comms as base_with_comms,
)
def with_comms(func):
@base_with_comms
@wraps(func)
def wrapper(self, *args, **kwargs):
# make sure we set different random seeds for each rank
# otherwise we dont need DDP / SPMD
# (we would have the same parameters and inputs everywhere)
torch.manual_seed(self.rank)
return func(self, *args, **kwargs)
return wrapper
def sepm(x: torch.Tensor) -> torch.Tensor:
return x
def sepm_backward(grad: torch.Tensor) -> torch.Tensor:
return grad
separator_lib = torch.library.Library("separator", "DEF")
separator_lib.define("sepm(Tensor x) -> Tensor")
separator_lib.impl("sepm", sepm, "CompositeExplicitAutograd")
separator_lib.define("sepm_backward(Tensor x) -> Tensor")
separator_lib.impl("sepm_backward", sepm_backward, "CompositeExplicitAutograd")
def _identity_prop_rule(op_schema: OpSchema) -> OutputSharding:
(x,) = op_schema.args_schema
assert isinstance(x, DTensorSpec), f"expecting DTensorSpec but got {x}"
return OutputSharding(output_spec=DTensorSpec(x.mesh, x.placements))
@register_prop_rule(torch.ops.separator.sepm.default)
def _prop_sepm(op_schema: OpSchema) -> OutputSharding:
return _identity_prop_rule(op_schema)
@register_prop_rule(torch.ops.separator.sepm_backward.default)
def _prop_sepm_backward(op_schema: OpSchema) -> OutputSharding:
return _identity_prop_rule(op_schema)
class SEPMFunction(torch.autograd.Function):
@staticmethod
def forward(ctx: Any, x: torch.Tensor) -> torch.Tensor:
return torch.ops.separator.sepm(x)
@staticmethod
def backward(ctx: Any, grad_x: torch.Tensor) -> torch.Tensor:
return torch.ops.separator.sepm_backward(grad_x)
class DummyModel(nn.Module):
def __init__(self, layers: int, dim: int):
super().__init__()
modules = []
for _ in range(layers):
modules.extend([nn.Linear(dim, dim), nn.ReLU()])
self.mod = nn.Sequential(*modules)
def forward(self, x):
return SEPMFunction.apply(self.mod(x))
class GraphPassWrapperTest(DTensorTestBase):
@property
def world_size(self):
return 1
def test_order(self):
@graph_optimization_pass(
prerequisites=[],
apply_after=[],
)
def my_pass1(gm) -> None:
return
@graph_optimization_pass(
prerequisites=[my_pass1],
apply_after=[],
)
def my_pass2(gm) -> None:
return
@graph_optimization_pass(
prerequisites=[],
apply_after=[my_pass1],
)
def my_pass3(gm) -> None:
return
gm = MagicMock(spec=IterGraphModule)
# No errors happen.
my_pass1(gm)
my_pass3(gm)
my_pass2(gm)
_optimized_func.clear()
# Only my_pass3 is okay as it has no prerequisites.
my_pass3(gm)
_optimized_func.clear()
# Prerequisite condition does not match.
with self.assertRaisesRegex(AssertionError, "are the prerequisites of"):
my_pass2(gm)
_optimized_func.clear()
# my_pass3 must be applied after my_pass1
with self.assertRaisesRegex(AssertionError, "must be applied after"):
my_pass3(gm)
my_pass1(gm)
_optimized_func.clear()
class TransformationTest(DTensorTestBase):
@property
def world_size(self):
return 2
def _init(self, batch_size, layers, dim, foreach: bool = False, fused: bool = True):
torch.manual_seed(0)
model = DummyModel(layers, dim).cuda()
ddp_model = DDP(deepcopy(model), device_ids=[self.rank])
optim = torch.optim.Adam(
model.parameters(), lr=0.01, foreach=foreach, fused=fused, capturable=True
)
ddp_optim = torch.optim.Adam(
ddp_model.parameters(),
lr=0.01,
foreach=foreach,
fused=fused,
capturable=True,
)
batch = torch.randn(batch_size, dim).cuda()
# materialize optimizer states
out = model(batch)
out.sum().backward()
optim.step()
optim.zero_grad()
ddp_out = ddp_model(batch)
ddp_out.sum().backward()
ddp_optim.step()
ddp_optim.zero_grad()
self.assertEqual(ddp_out, out)
self.assertEqual(list(ddp_model.parameters()), list(model.parameters()))
return model, optim, ddp_model, ddp_optim
def _test_train_step(
self, train_step, num_iters, batch_size, layers, dim, use_fused_optimizer=False
):
def _ddp_train_step(model, optim, batch):
model(batch).sum().backward()
with torch.no_grad():
for p in model.parameters():
p.grad *= self.world_size
optim.step()
optim.zero_grad()
model, optim, ddp_model, ddp_optim = self._init(
batch_size,
layers,
dim,
foreach=(not use_fused_optimizer),
fused=use_fused_optimizer,
)
for i in range(num_iters):
batch = torch.randn(batch_size, dim).cuda()
kwargs = {} if i < num_iters - 1 else {"last_train_step": True}
out = train_step(model, optim, batch, **kwargs)
ddp_out = _ddp_train_step(ddp_model, ddp_optim, batch)
self.assertEqual(list(ddp_model.parameters()), list(model.parameters()))
# @skip_if_lt_x_gpu(2)
# @with_comms
# def test_basic_transformation(self):
# batch_size = 100
# layers = 10
# dim = 100
# num_iters = 5
# @compile(gm_transformation=GraphModuleTransformation(num_iters=num_iters))
# def train_step(model, optim, batch):
# model(batch).sum().backward()
# optim.step()
# optim.zero_grad()
# self._test_tran_step_with_ddp(train_step, num_iters, batch_size, layers, dim)
# @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
# @skip_if_lt_x_gpu(2)
# @with_comms
# def test_inductor(self):
# batch_size = 100
# layers = 10
# dim = 100
# num_iters = 5
# @compile(
# gm_transformation=GraphModuleTransformation(
# num_iters=num_iters, enable_inductor=True
# )
# )
# def train_step(model, optim, batch):
# model(batch).sum().backward()
# optim.step()
# optim.zero_grad()
# # TODO: there are issues when lowering the optimizer. Disable
# # the test for now.
# """
# self._test_tran_step_with_ddp(
# train_step, num_iters, batch_size, layers, dim
# )
# """
# @skip_if_lt_x_gpu(2)
# @with_comms
# def test_graph_optimization_with_foreach(self):
# batch_size = 100
# layers = 2
# dim = 4096
# num_iters = 5
# @compile(
# gm_transformation=GraphModuleTransformation(
# num_iters=num_iters,
# enable_graph_optimization=True,
# dump_graphs=False,
# )
# )
# def train_step(model, optim, batch):
# model(batch).sum().backward()
# optim.step()
# optim.zero_grad()
# self._test_tran_step_with_ddp(train_step, num_iters, batch_size, layers, dim)
# @skip_if_lt_x_gpu(2)
# @with_comms
# def test_graph_optimization_with_fused(self):
# batch_size = 100
# layers = 2
# dim = 4096
# num_iters = 5
# @compile(
# gm_transformation=GraphModuleTransformation(
# num_iters=num_iters,
# enable_graph_optimization=True,
# dump_graphs=False,
# )
# )
# def train_step(model, optim, batch):
# model(batch).sum().backward()
# optim.step()
# optim.zero_grad()
# model, optim, _, _ = self._init(
# batch_size, layers, dim, foreach=False, fused=True
# )
# for _ in range(num_iters):
# batch = torch.randn(batch_size, dim).cuda()
# out = train_step(model, optim, batch)
# @skip_if_lt_x_gpu(2)
# @with_comms
# def test_graph_profiling_with_foreach(self):
# batch_size = 100
# layers = 2
# dim = 4096
# num_iters = 5
# @compile(
# gm_transformation=GraphModuleTransformation(
# num_iters=num_iters,
# enable_graph_optimization=True,
# enable_profiling=True,
# dump_graphs=False,
# )
# )
# def train_step(model, optim, batch):
# model(batch).sum().backward()
# optim.step()
# optim.zero_grad()
# self._test_tran_step_with_ddp(train_step, num_iters, batch_size, layers, dim)
@skip_if_lt_x_gpu(2)
@skipIfRocm
@with_comms
def test_graph_profiling_with_fused(self):
batch_size = 100
layers = 2
dim = 4096
num_iters = 5
@compile(
gm_transformation=GraphModuleTransformation(
enable_graph_optimization=True,
enable_profiling=True,
dump_graphs=False,
)
)
def train_step(model, optim, batch):
model(batch).sum().backward()
optim.step()
optim.zero_grad()
self._test_tran_step_with_ddp(train_step, num_iters, batch_size, layers, dim)
if __name__ == "__main__":
if False:
run_tests()