Skip to content

Commit

Permalink
Manage restarting state for Actors during _task_done_callback
Browse files Browse the repository at this point in the history
Signed-off-by: Srinath Krishnamachari <[email protected]>
  • Loading branch information
srinathk10 committed Oct 10, 2024
1 parent b8d705a commit 8e4578a
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -211,8 +211,17 @@ def _dispatch_tasks(self):
).remote(DataContext.get_current(), ctx, *input_blocks)

def _task_done_callback(actor_to_return):
# Return the actor that was running the task to the pool.
self._actor_pool.return_actor(actor_to_return)
if actor_to_return in self._actor_pool._num_tasks_in_flight:
# Return the actor that was running the task to the pool.
self._actor_pool.return_actor(actor_to_return)
else:
assert (
actor_to_return.get_location
in self._actor_pool._restarting_actors
)
# Move the actor from restarting to running state.
self._actor_pool.restarting_to_running(actor_to_return)

# Dipsatch more tasks.
self._dispatch_tasks()

Expand Down Expand Up @@ -294,9 +303,11 @@ def current_processor_usage(self) -> ExecutionResources:

def pending_processor_usage(self) -> ExecutionResources:
num_pending_workers = self._actor_pool.num_pending_actors()
num_restarting_workers = self._actor_pool.num_restarting_actors()
num_non_running_workers = num_pending_workers + num_restarting_workers
return ExecutionResources(
cpu=self._ray_remote_args.get("num_cpus", 0) * num_pending_workers,
gpu=self._ray_remote_args.get("num_gpus", 0) * num_pending_workers,
cpu=self._ray_remote_args.get("num_cpus", 0) * num_non_running_workers,
gpu=self._ray_remote_args.get("num_gpus", 0) * num_non_running_workers,
)

def num_active_actors(self) -> int:
Expand Down Expand Up @@ -354,18 +365,19 @@ def get_autoscaling_actor_pools(self) -> List[AutoscalingActorPool]:

def update_resource_usage(self) -> None:
"""Updates resources usage."""
# Walk all active actors and for each actor that's not ALIVE,
# it's a candidate to be marked as a pending actor.
actors = list(self._actor_pool._num_tasks_in_flight.keys())
for actor in actors:
actor_state = actor._get_local_state()
if (actor_state != gcs_pb2.ActorTableData.ActorState.ALIVE):
if actor_state != gcs_pb2.ActorTableData.ActorState.ALIVE:
# If an actor is not ALIVE, it's a candidate to be marked as a
# restarting actor.
self._actor_pool.running_to_restarting(actor, actor.get_location)
else:
# If an actor is ALIVE, it's a candidate to be marked as a
# running actor, if not already the case.
self._actor_pool.restarting_to_running(actor.get_location)



class _MapWorker:
"""An actor worker for MapOperator."""

Expand Down Expand Up @@ -522,8 +534,10 @@ def pending_to_running(self, ready_ref: ray.ObjectRef) -> bool:
self._actor_locations[actor] = ray.get(ready_ref)
return True

def running_to_restarting(self, actor: ray.actor.ActorHandle, ready_ref: ray.ObjectRef) -> bool:
"""Mark the actor corresponding to the provided ready future as restaring.
def running_to_restarting(
self, actor: ray.actor.ActorHandle, ready_ref: ray.ObjectRef
) -> bool:
"""Mark the actor corresponding to the provided ready future as restarting.
Args:
actor: The running actor to add as restarting to the pool.
Expand Down
8 changes: 3 additions & 5 deletions python/ray/data/tests/test_actor_pool_fault_tolerance.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import asyncio
import threading
import time

import pytest

import ray
from ray.data.context import DataContext
from ray.tests.conftest import * # noqa


Expand All @@ -15,7 +13,6 @@ def test_removed_nodes_and_added_back(ray_start_cluster):
cluster = ray_start_cluster
cluster.add_node(num_cpus=0)
ray.init()
#DataContext.get_current().max_errored_blocks = -1

@ray.remote(num_cpus=0)
class Signal:
Expand Down Expand Up @@ -50,7 +47,7 @@ async def wait_for_nodes_restarted(self):
signal_actor = Signal.remote()

# Spin up nodes
num_nodes = 3
num_nodes = 5
nodes = []
for _ in range(num_nodes):
nodes.append(cluster.add_node(num_cpus=10, num_gpus=1))
Expand Down Expand Up @@ -79,7 +76,7 @@ def __call__(self, batch):
return batch

res = []
num_items = 10
num_items = 100

def run_dataset():
nonlocal res
Expand Down Expand Up @@ -116,6 +113,7 @@ def run_dataset():
thread.join()
assert sorted(res, key=lambda x: x["id"]) == [{"id": i} for i in range(num_items)]


if __name__ == "__main__":
import sys

Expand Down

0 comments on commit 8e4578a

Please sign in to comment.