-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_parallel.py
242 lines (188 loc) · 7.59 KB
/
test_parallel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
# Owner(s): ["oncall: distributed"]
from copy import deepcopy
from typing import Any
import torch
import torch.nn as nn
from torch.distributed._spmd.api import compile
from torch.distributed._spmd.parallel_mode import DataParallel
from torch.distributed._tensor import Replicate
from torch.distributed._tensor.op_schema import OpSchema, OutputSharding
from torch.distributed._tensor.ops.utils import register_prop_rule
from torch.distributed._tensor.placement_types import DTensorSpec
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
with_comms,
)
def sep(x: torch.Tensor) -> torch.Tensor:
return x
def sep_backward(grad: torch.Tensor) -> torch.Tensor:
return grad
separator_lib = torch.library.Library("separator", "DEF")
separator_lib.define("sep(Tensor x) -> Tensor")
separator_lib.impl("sep", sep, "CompositeExplicitAutograd")
separator_lib.define("sep_backward(Tensor x) -> Tensor")
separator_lib.impl("sep_backward", sep_backward, "CompositeExplicitAutograd")
def _identity_prop_rule(op_schema: OpSchema) -> OutputSharding:
(x,) = op_schema.args_schema
assert isinstance(x, DTensorSpec), f"expecting DTensorSpec but got {x}"
return OutputSharding(output_spec=DTensorSpec(x.mesh, x.placements))
@register_prop_rule(torch.ops.separator.sep.default)
def _prop_sepm(op_schema: OpSchema) -> OutputSharding:
return _identity_prop_rule(op_schema)
@register_prop_rule(torch.ops.separator.sep_backward.default)
def _prop_sepm_backward(op_schema: OpSchema) -> OutputSharding:
return _identity_prop_rule(op_schema)
class SEPFunction(torch.autograd.Function):
@staticmethod
def forward(ctx: Any, x: torch.Tensor) -> torch.Tensor:
return torch.ops.separator.sep(x)
@staticmethod
def backward(ctx: Any, grad_x: torch.Tensor) -> torch.Tensor:
return torch.ops.separator.sep_backward(grad_x)
class SimpleMLP(nn.Module):
def __init__(self):
super(SimpleMLP, self).__init__()
self.net1 = nn.Linear(50, 32)
self.relu = nn.ReLU()
self.net2 = nn.Linear(32, 8)
def forward(self, x):
return SEPFunction.apply(torch.sigmoid(self.net2(self.relu(self.net1(x)))))
def reset_parameters(self, *args, **kwargs):
self.net1.reset_parameters()
self.net2.reset_parameters()
# simple train step definition, just an example
def train_step(model, optim, train_batch):
def loss_fn(out, labels):
return (out - labels).sum()
optim.zero_grad()
inputs, labels = train_batch
out = model(inputs)
loss = loss_fn(out, labels)
loss.backward()
optim.step()
return loss
class TestDataParallel(DTensorTestBase):
@property
def world_size(self):
return 2
def _test_data_parallel(
self,
mod,
ddp_mod,
opt,
ddp_opt,
inp,
train_step,
data_parallel_mode,
data_parallel_options=None,
):
ddp_inp = deepcopy(inp)
# need one step to warm up optimizers
train_step(mod, opt, inp)
opt.zero_grad()
# DDP run full train step once to align with the warmup
train_step(ddp_mod, ddp_opt, ddp_inp)
ddp_opt.zero_grad()
# train a DDP model once manually as DDP grads are different
torch.sum(ddp_mod(ddp_inp[0]) - ddp_inp[1]).backward()
ddp_opt.step()
# compile it with replicate and run step once
compiled_fn = compile(
parallel_mode=DataParallel(
parallel_style=data_parallel_mode, _preserve_node_type=True
)
)(train_step)
compiled_fn(mod, opt, inp)
for p1, p2 in zip(mod.parameters(), ddp_mod.parameters()):
# mod parameters are DTensors, convert to local tensor before compare
if data_parallel_mode == "fully_shard":
# gather the shards for comparison
p1_replica = p1.redistribute(placements=[Replicate()])
p1_local_param = p1_replica.to_local()
else:
p1_local_param = p1.to_local()
self.assertEqual(p1_local_param, p2)
@skip_if_lt_x_gpu(2)
@with_comms
def test_replicate_sgd(self):
sgd_configs = [
{"lr": 0.1},
{"lr": 0.1, "momentum": 0.9},
{"lr": 0.1, "momentum": 0.9, "foreach": True},
]
for config in sgd_configs:
mod = SimpleMLP().cuda(self.rank)
opt = torch.optim.SGD(mod.parameters(), **config)
train_batch = (
torch.randn((128, 50), device=torch.device(self.rank)),
torch.randn((128, 8), device=torch.device(self.rank)),
)
ddp_mod = DDP(deepcopy(mod), device_ids=[self.rank])
ddp_opt = torch.optim.SGD(ddp_mod.parameters(), **config)
self._test_data_parallel(
mod, ddp_mod, opt, ddp_opt, train_batch, train_step, "replicate"
)
@skip_if_lt_x_gpu(2)
@with_comms
def test_replicate_adam_fused(self):
mod = SimpleMLP().cuda(self.rank)
opt = torch.optim.Adam(mod.parameters(), lr=0.1, fused=True)
train_batch = (
torch.randn((128, 50), device=torch.device(self.rank)),
torch.randn((128, 8), device=torch.device(self.rank)),
)
ddp_mod = DDP(deepcopy(mod), device_ids=[self.rank])
ddp_opt = torch.optim.Adam(ddp_mod.parameters(), lr=0.1, fused=True)
self._test_data_parallel(
mod, ddp_mod, opt, ddp_opt, train_batch, train_step, "replicate"
)
# @skip_if_lt_x_gpu(2)
# @with_comms
# def test_fully_shard_sgd(self):
# mod = SimpleMLP().cuda(self.rank)
# opt = torch.optim.SGD(mod.parameters(), lr=0.1)
# train_batch = (
# torch.randn(128, 50).to(self.rank),
# torch.randn(128, 8).to(self.rank),
# )
# ddp_mod = DDP(deepcopy(mod), device_ids=[self.rank])
# ddp_opt = torch.optim.SGD(ddp_mod.parameters(), lr=0.1)
# self._test_data_parallel(
# mod, ddp_mod, opt, ddp_opt, train_batch, train_step, "fully_shard"
# )
# @skip_if_lt_x_gpu(2)
# @with_comms
# def test_fully_shard_sgd_momentum(self):
# mod = SimpleMLP().cuda(self.rank)
# opt = torch.optim.SGD(mod.parameters(), lr=0.1, momentum=0.9)
# train_batch = (
# torch.randn(128, 50).to(self.rank),
# torch.randn(128, 8).to(self.rank),
# )
# ddp_mod = DDP(deepcopy(mod), device_ids=[self.rank])
# ddp_opt = torch.optim.SGD(ddp_mod.parameters(), lr=0.1, momentum=0.9)
# self._test_data_parallel(
# mod, ddp_mod, opt, ddp_opt, train_batch, train_step, "fully_shard"
# )
# @skip_if_lt_x_gpu(2)
# @with_comms
# def test_fully_shard_sgd_foreach(self):
# mod = SimpleMLP().cuda(self.rank)
# opt = torch.optim.SGD(mod.parameters(), lr=0.1, momentum=0.9, foreach=True)
# train_batch = (
# torch.randn(128, 50).to(self.rank),
# torch.randn(128, 8).to(self.rank),
# )
# ddp_mod = DDP(deepcopy(mod), device_ids=[self.rank])
# ddp_opt = torch.optim.SGD(
# ddp_mod.parameters(), lr=0.1, momentum=0.9, foreach=True
# )
# self._test_data_parallel(
# mod, ddp_mod, opt, ddp_opt, train_batch, train_step, "fully_shard"
# )
if __name__ == "__main__":
if False:
run_tests()