diff --git a/python/ray/data/_internal/execution/operators/actor_pool_map_operator.py b/python/ray/data/_internal/execution/operators/actor_pool_map_operator.py index ca19c45295438b..44c1ce57f6428f 100644 --- a/python/ray/data/_internal/execution/operators/actor_pool_map_operator.py +++ b/python/ray/data/_internal/execution/operators/actor_pool_map_operator.py @@ -211,8 +211,17 @@ def _dispatch_tasks(self): ).remote(DataContext.get_current(), ctx, *input_blocks) def _task_done_callback(actor_to_return): - # Return the actor that was running the task to the pool. - self._actor_pool.return_actor(actor_to_return) + if actor_to_return in self._actor_pool._num_tasks_in_flight: + # Return the actor that was running the task to the pool. + self._actor_pool.return_actor(actor_to_return) + else: + assert ( + actor_to_return.get_location + in self._actor_pool._restarting_actors + ) + # Move the actor from restarting to running state. + self._actor_pool.restarting_to_running(actor_to_return) + # Dipsatch more tasks. self._dispatch_tasks() @@ -294,9 +303,11 @@ def current_processor_usage(self) -> ExecutionResources: def pending_processor_usage(self) -> ExecutionResources: num_pending_workers = self._actor_pool.num_pending_actors() + num_restarting_workers = self._actor_pool.num_restarting_actors() + num_non_running_workers = num_pending_workers + num_restarting_workers return ExecutionResources( - cpu=self._ray_remote_args.get("num_cpus", 0) * num_pending_workers, - gpu=self._ray_remote_args.get("num_gpus", 0) * num_pending_workers, + cpu=self._ray_remote_args.get("num_cpus", 0) * num_non_running_workers, + gpu=self._ray_remote_args.get("num_gpus", 0) * num_non_running_workers, ) def num_active_actors(self) -> int: @@ -354,18 +365,19 @@ def get_autoscaling_actor_pools(self) -> List[AutoscalingActorPool]: def update_resource_usage(self) -> None: """Updates resources usage.""" - # Walk all active actors and for each actor that's not ALIVE, - # it's a candidate to be marked as a pending actor. actors = list(self._actor_pool._num_tasks_in_flight.keys()) for actor in actors: actor_state = actor._get_local_state() - if (actor_state != gcs_pb2.ActorTableData.ActorState.ALIVE): + if actor_state != gcs_pb2.ActorTableData.ActorState.ALIVE: + # If an actor is not ALIVE, it's a candidate to be marked as a + # restarting actor. self._actor_pool.running_to_restarting(actor, actor.get_location) else: + # If an actor is ALIVE, it's a candidate to be marked as a + # running actor, if not already the case. self._actor_pool.restarting_to_running(actor.get_location) - class _MapWorker: """An actor worker for MapOperator.""" @@ -522,8 +534,10 @@ def pending_to_running(self, ready_ref: ray.ObjectRef) -> bool: self._actor_locations[actor] = ray.get(ready_ref) return True - def running_to_restarting(self, actor: ray.actor.ActorHandle, ready_ref: ray.ObjectRef) -> bool: - """Mark the actor corresponding to the provided ready future as restaring. + def running_to_restarting( + self, actor: ray.actor.ActorHandle, ready_ref: ray.ObjectRef + ) -> bool: + """Mark the actor corresponding to the provided ready future as restarting. Args: actor: The running actor to add as restarting to the pool. diff --git a/python/ray/data/tests/test_actor_pool_fault_tolerance.py b/python/ray/data/tests/test_actor_pool_fault_tolerance.py index ab939994b9372d..f03ed5d113802c 100644 --- a/python/ray/data/tests/test_actor_pool_fault_tolerance.py +++ b/python/ray/data/tests/test_actor_pool_fault_tolerance.py @@ -1,11 +1,9 @@ import asyncio import threading -import time import pytest import ray -from ray.data.context import DataContext from ray.tests.conftest import * # noqa @@ -15,7 +13,6 @@ def test_removed_nodes_and_added_back(ray_start_cluster): cluster = ray_start_cluster cluster.add_node(num_cpus=0) ray.init() - #DataContext.get_current().max_errored_blocks = -1 @ray.remote(num_cpus=0) class Signal: @@ -50,7 +47,7 @@ async def wait_for_nodes_restarted(self): signal_actor = Signal.remote() # Spin up nodes - num_nodes = 3 + num_nodes = 5 nodes = [] for _ in range(num_nodes): nodes.append(cluster.add_node(num_cpus=10, num_gpus=1)) @@ -79,7 +76,7 @@ def __call__(self, batch): return batch res = [] - num_items = 10 + num_items = 100 def run_dataset(): nonlocal res @@ -116,6 +113,7 @@ def run_dataset(): thread.join() assert sorted(res, key=lambda x: x["id"]) == [{"id": i} for i in range(num_items)] + if __name__ == "__main__": import sys