Skip to content

Commit

Permalink
up
Browse files Browse the repository at this point in the history
Signed-off-by: Ubuntu <[email protected]>
  • Loading branch information
Ubuntu committed Oct 16, 2024
1 parent f188a5c commit e34d35b
Showing 1 changed file with 16 additions and 21 deletions.
37 changes: 16 additions & 21 deletions python/ray/dag/tests/experimental/test_torch_tensor_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,32 +293,27 @@ def test_torch_tensor_nccl_overlap(ray_start_regular, overlap_gpu_communication)
pytest.skip("NCCL tests require GPUs")

assert (
sum(node["Resources"].get("GPU", 0) for node in ray.nodes()) > 2
), "This test requires at least 3 GPUs"
sum(node["Resources"].get("GPU", 0) for node in ray.nodes()) >= 4
), "This test requires at least 4 GPUs"

actor_cls = TorchTensorWorker.options(num_cpus=0, num_gpus=1)

sender1 = actor_cls.remote()
sender2 = actor_cls.remote()
receiver = actor_cls.remote()
worker_cls = TorchTensorWorker.options(num_cpus=0, num_gpus=1)
num_senders = 3
senders = [worker_cls.remote() for _ in range(num_senders)]
receiver = worker_cls.remote()

shape = (10000, 10000)
dtype = torch.float16

with InputNode() as inp:
branch1 = sender1.send.bind(shape, dtype, inp)

branch1 = branch1.with_type_hint(
TorchTensorType(shape, dtype, transport="nccl", _direct_return=True)
)
branch1 = receiver.recv_and_matmul.bind(branch1)

branch2 = sender2.send.bind(shape, dtype, inp)
branch2 = branch2.with_type_hint(
TorchTensorType(shape, dtype, transport="nccl", _direct_return=True)
)
branch2 = receiver.recv_and_matmul.bind(branch2)
dag = MultiOutputNode([branch1, branch2])
branches = [sender.send.bind(shape, dtype, inp) for sender in senders]
branches = [
branch.with_type_hint(
TorchTensorType(shape, dtype, transport="nccl", _direct_return=True)
)
for branch in branches
]
branches = [receiver.recv_and_matmul.bind(branch) for branch in branches]
dag = MultiOutputNode(branches)

# Test normal execution.
compiled_dag = dag.experimental_compile(
Expand All @@ -329,7 +324,7 @@ def test_torch_tensor_nccl_overlap(ray_start_regular, overlap_gpu_communication)
for i in range(5):
ref = compiled_dag.execute(i)
result = ray.get(ref)
assert result == [(i, shape, dtype), (i, shape, dtype)]
assert result == [(i, shape, dtype)] * num_senders
duration = time.monotonic() - start
print(f"{overlap_gpu_communication=}, {duration=}")

Expand Down

0 comments on commit e34d35b

Please sign in to comment.