Skip to content

Commit

Permalink
[core][compiled-graphs] Don't persist input_nodes in _CollectiveOpera…
Browse files Browse the repository at this point in the history
…tion to avoid wrong understanding about DAGs (ray-project#48463)

If we persist input_nodes in _CollectiveOperation, all input_nodes will be added to the upstream_nodes when building the DAG. However, not all input_nodes belong to the args of the DAG node. This could potentially cause issues when compiling the graph.
  • Loading branch information
kevin85421 authored Nov 4, 2024
1 parent ec9775d commit 3581e62
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 6 deletions.
10 changes: 4 additions & 6 deletions python/ray/dag/collective_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,22 +37,21 @@ def __init__(
op: _CollectiveOp,
transport: Optional[Union[str, GPUCommunicator]] = None,
):
self._input_nodes: List[DAGNode] = input_nodes
if len(self._input_nodes) == 0:
if len(input_nodes) == 0:
raise ValueError("Expected input nodes for a collective operation")
if len(set(self._input_nodes)) != len(self._input_nodes):
if len(set(input_nodes)) != len(input_nodes):
raise ValueError("Expected unique input nodes for a collective operation")

self._actor_handles: List["ray.actor.ActorHandle"] = []
for input_node in self._input_nodes:
for input_node in input_nodes:
actor_handle = input_node._get_actor_handle()
if actor_handle is None:
raise ValueError("Expected an actor handle from the input node")
self._actor_handles.append(actor_handle)
if len(set(self._actor_handles)) != len(self._actor_handles):
invalid_input_nodes = [
input_node
for input_node in self._input_nodes
for input_node in input_nodes
if self._actor_handles.count(input_node._get_actor_handle()) > 1
]
raise ValueError(
Expand All @@ -76,7 +75,6 @@ def __init__(
def __str__(self) -> str:
return (
f"CollectiveGroup("
f"_input_nodes={self._input_nodes}, "
f"_actor_handles={self._actor_handles}, "
f"_op={self._op}, "
f"_type_hint={self._type_hint})"
Expand Down
5 changes: 5 additions & 0 deletions python/ray/dag/dag_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,11 @@ def _collect_upstream_nodes(self) -> List["DAGNode"]:
"""
Retrieve upstream nodes and update their downstream dependencies.
Currently, the DAG assumes that all DAGNodes in `args`, `kwargs`, and
`other_args_to_resolve` are upstream nodes. However, Ray Compiled Graphs
builds the upstream/downstream relationship based only on args. Be cautious
when persisting DAGNodes in `other_args_to_resolve` and kwargs in the future.
TODO (kevin85421): Currently, the upstream nodes and downstream nodes have
circular references. Therefore, it relies on the garbage collector to clean
them up instead of reference counting. We should consider using weak references
Expand Down

0 comments on commit 3581e62

Please sign in to comment.