Skip to content

Commit

Permalink
Testing for functional collectives
Browse files Browse the repository at this point in the history
  • Loading branch information
emmay78 committed Oct 22, 2024
1 parent ebb7999 commit 5829271
Showing 1 changed file with 9 additions and 1 deletion.
10 changes: 9 additions & 1 deletion torch/distributed/_tools/fake_collectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from torch.futures import Future
from torch.testing._internal.distributed.fake_pg import FakeStore
from torch._subclasses.fake_tensor import FakeTensorMode, FakeTensor
from torch.distributed._functional_collectives import all_reduce
from torch.distributed._functional_collectives import *
from torch.utils._python_dispatch import TorchDispatchMode
from functools import wraps
from contextlib import contextmanager, nullcontext
Expand Down Expand Up @@ -178,6 +178,14 @@ def run_test():
dist.barrier()
dist.send(test_tensor, dst=1)
dist.recv(test_tensor, src=1)

# testing for functional collectives
output = wait_tensor(test_tensor)
output = broadcast(test_tensor, src=0, group=dist.group.WORLD)
output = all_reduce(test_tensor, reduceOp="avg", group=dist.group.WORLD)
output = all_gather_tensor(test_tensor, gather_dim=0, group=dist.group.WORLD)
output = reduce_scatter_tensor(test_tensor, scatter_dim=0, reduceOp="sum", group=dist.group.WORLD)
output = all_to_all_single(test_tensor, output_split_sizes=[0], input_split_sizes=[1], group=dist.group.WORLD)

dist.barrier()

Expand Down

0 comments on commit 5829271

Please sign in to comment.