From 58292711de020fd4fe03970cd0ae50852644108b Mon Sep 17 00:00:00 2001 From: Emma Yang Date: Tue, 22 Oct 2024 13:39:56 -0400 Subject: [PATCH] Testing for functional collectives --- torch/distributed/_tools/fake_collectives.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/torch/distributed/_tools/fake_collectives.py b/torch/distributed/_tools/fake_collectives.py index c9ee1900bd5d3..80e5d9891bdae 100644 --- a/torch/distributed/_tools/fake_collectives.py +++ b/torch/distributed/_tools/fake_collectives.py @@ -4,7 +4,7 @@ from torch.futures import Future from torch.testing._internal.distributed.fake_pg import FakeStore from torch._subclasses.fake_tensor import FakeTensorMode, FakeTensor -from torch.distributed._functional_collectives import all_reduce +from torch.distributed._functional_collectives import * from torch.utils._python_dispatch import TorchDispatchMode from functools import wraps from contextlib import contextmanager, nullcontext @@ -178,6 +178,14 @@ def run_test(): dist.barrier() dist.send(test_tensor, dst=1) dist.recv(test_tensor, src=1) + + # testing for functional collectives + output = wait_tensor(test_tensor) + output = broadcast(test_tensor, src=0, group=dist.group.WORLD) + output = all_reduce(test_tensor, reduceOp="avg", group=dist.group.WORLD) + output = all_gather_tensor(test_tensor, gather_dim=0, group=dist.group.WORLD) + output = reduce_scatter_tensor(test_tensor, scatter_dim=0, reduceOp="sum", group=dist.group.WORLD) + output = all_to_all_single(test_tensor, output_split_sizes=[0], input_split_sizes=[1], group=dist.group.WORLD) dist.barrier()