From e34d35b07f490c8edf1b4cf3efe1ee062bf70301 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 16 Oct 2024 00:25:34 +0000 Subject: [PATCH] up Signed-off-by: Ubuntu --- .../experimental/test_torch_tensor_dag.py | 37 ++++++++----------- 1 file changed, 16 insertions(+), 21 deletions(-) diff --git a/python/ray/dag/tests/experimental/test_torch_tensor_dag.py b/python/ray/dag/tests/experimental/test_torch_tensor_dag.py index 885fe8c5cab4b..4306ae9de752e 100644 --- a/python/ray/dag/tests/experimental/test_torch_tensor_dag.py +++ b/python/ray/dag/tests/experimental/test_torch_tensor_dag.py @@ -293,32 +293,27 @@ def test_torch_tensor_nccl_overlap(ray_start_regular, overlap_gpu_communication) pytest.skip("NCCL tests require GPUs") assert ( - sum(node["Resources"].get("GPU", 0) for node in ray.nodes()) > 2 - ), "This test requires at least 3 GPUs" + sum(node["Resources"].get("GPU", 0) for node in ray.nodes()) >= 4 + ), "This test requires at least 4 GPUs" - actor_cls = TorchTensorWorker.options(num_cpus=0, num_gpus=1) - - sender1 = actor_cls.remote() - sender2 = actor_cls.remote() - receiver = actor_cls.remote() + worker_cls = TorchTensorWorker.options(num_cpus=0, num_gpus=1) + num_senders = 3 + senders = [worker_cls.remote() for _ in range(num_senders)] + receiver = worker_cls.remote() shape = (10000, 10000) dtype = torch.float16 with InputNode() as inp: - branch1 = sender1.send.bind(shape, dtype, inp) - - branch1 = branch1.with_type_hint( - TorchTensorType(shape, dtype, transport="nccl", _direct_return=True) - ) - branch1 = receiver.recv_and_matmul.bind(branch1) - - branch2 = sender2.send.bind(shape, dtype, inp) - branch2 = branch2.with_type_hint( - TorchTensorType(shape, dtype, transport="nccl", _direct_return=True) - ) - branch2 = receiver.recv_and_matmul.bind(branch2) - dag = MultiOutputNode([branch1, branch2]) + branches = [sender.send.bind(shape, dtype, inp) for sender in senders] + branches = [ + branch.with_type_hint( + TorchTensorType(shape, dtype, transport="nccl", _direct_return=True) + ) + for branch in branches + ] + branches = [receiver.recv_and_matmul.bind(branch) for branch in branches] + dag = MultiOutputNode(branches) # Test normal execution. compiled_dag = dag.experimental_compile( @@ -329,7 +324,7 @@ def test_torch_tensor_nccl_overlap(ray_start_regular, overlap_gpu_communication) for i in range(5): ref = compiled_dag.execute(i) result = ray.get(ref) - assert result == [(i, shape, dtype), (i, shape, dtype)] + assert result == [(i, shape, dtype)] * num_senders duration = time.monotonic() - start print(f"{overlap_gpu_communication=}, {duration=}")