diff --git a/python/ray/data/_internal/execution/operators/actor_pool_map_operator.py b/python/ray/data/_internal/execution/operators/actor_pool_map_operator.py index cb2e0952fce1a..9e6d8f06ccd21 100644 --- a/python/ray/data/_internal/execution/operators/actor_pool_map_operator.py +++ b/python/ray/data/_internal/execution/operators/actor_pool_map_operator.py @@ -211,14 +211,8 @@ def _dispatch_tasks(self): ).remote(DataContext.get_current(), ctx, *input_blocks) def _task_done_callback(actor_to_return): - # If actor is found in restarting, move it to running. - self._actor_pool.restarting_to_running( - actor, actor.get_location.remote() - ) - # Return the actor that was running the task to the pool. self._actor_pool.return_actor(actor_to_return) - # Dipsatch more tasks. self._dispatch_tasks() @@ -228,7 +222,7 @@ def _task_done_callback(actor_to_return): self._submit_data_task( gen, bundle, - lambda: _task_done_callback(actor_to_return), # noqa: B023 + lambda: _task_done_callback(actor_to_return), ) def _refresh_actor_cls(self): @@ -300,11 +294,9 @@ def current_processor_usage(self) -> ExecutionResources: def pending_processor_usage(self) -> ExecutionResources: num_pending_workers = self._actor_pool.num_pending_actors() - num_restarting_workers = self._actor_pool.num_restarting_actors() - num_non_running_workers = num_pending_workers + num_restarting_workers return ExecutionResources( - cpu=self._ray_remote_args.get("num_cpus", 0) * num_non_running_workers, - gpu=self._ray_remote_args.get("num_gpus", 0) * num_non_running_workers, + cpu=self._ray_remote_args.get("num_cpus", 0) * num_pending_workers, + gpu=self._ray_remote_args.get("num_gpus", 0) * num_pending_workers, ) def num_active_actors(self) -> int: @@ -366,20 +358,15 @@ def _manage_actor_restarting_state(self, actor): # If an actor is not ALIVE, it's a candidate to be marked as a # restarting actor. assert actor_state is gcs_pb2.ActorTableData.ActorState.RESTARTING - if actor in self._actor_pool._num_tasks_in_flight: - # Change Actor state from running to restarting - self._actor_pool.running_to_restarting( - actor, actor.get_location.remote() - ) + self._actor_pool.mark_running_actor_as_restarting(actor) else: # If an actor is ALIVE, it's a candidate to be marked as a # running actor, if not already the case. - self._actor_pool.restarting_to_running(actor, actor.get_location.remote()) + self._actor_pool.clear_restarting_from_running_actor(actor) def update_resource_usage(self) -> None: """Updates resources usage.""" - actors = list(self._actor_pool._num_tasks_in_flight.keys()) - for actor in actors: + for actor in self._actor_pool._running_actors.keys(): self._manage_actor_restarting_state(actor) @@ -418,6 +405,23 @@ def __repr__(self): return f"MapWorker({self.src_fn_name})" +class _ActorRunningState: + """Actor running state""" + + def __init__( + self, + # Number of tasks in flight per actor. + num_tasks_in_flight: int, + # Node id of each ready actor. + actor_location: str, + # Is Actor in restarting state + is_restarting: bool, + ): + self._num_tasks_in_flight = num_tasks_in_flight + self._actor_location = actor_location + self._is_restarting = is_restarting + + class _ActorPool(AutoscalingActorPool): """A pool of actors for map task execution. @@ -443,14 +447,8 @@ def __init__( assert self._max_tasks_in_flight >= 1 assert self._create_actor_fn is not None - # Number of tasks in flight per actor. - self._num_tasks_in_flight: Dict[ray.actor.ActorHandle, int] = {} - # Number of tasks in flight per restarting actor. - self._num_tasks_restarting_actors: Dict[ray.actor.ActorHandle, int] = {} - # Node id of each ready actor. - self._actor_locations: Dict[ray.actor.ActorHandle, str] = {} - # Node id of each restarting actor. - self._restarting_actor_locations: Dict[ray.actor.ActorHandle, str] = {} + # Actors that are running. + self._running_actors: Dict[ray.actor.ActorHandle, _ActorRunningState] = {} # Actors that are not yet ready (still pending creation). self._pending_actors: Dict[ObjectRef, ray.actor.ActorHandle] = {} # Whether actors that become idle should be eagerly killed. This is False until @@ -472,25 +470,31 @@ def current_size(self) -> int: return self.num_pending_actors() + self.num_running_actors() def num_running_actors(self) -> int: - return len(self._num_tasks_in_flight) + return len(self._running_actors) + + def num_restarting_actors(self) -> int: + return sum( + 1 if running_actor_state._is_restarting is True else 0 + for running_actor_state in self._running_actors.values() + ) def num_active_actors(self) -> int: return sum( - 1 if num_tasks_in_flight > 0 else 0 - for num_tasks_in_flight in self._num_tasks_in_flight.values() + 1 if running_actor_state._num_tasks_in_flight > 0 else 0 + for running_actor_state in self._running_actors.values() ) def num_pending_actors(self) -> int: return len(self._pending_actors) - def num_restarting_actors(self) -> int: - return len(self._num_tasks_restarting_actors) - def max_tasks_in_flight_per_actor(self) -> int: return self._max_tasks_in_flight def current_in_flight_tasks(self) -> int: - return sum(num for _, num in self._num_tasks_in_flight.items()) + return sum( + running_actor_state._num_tasks_in_flight + for running_actor_state in self._running_actors.values() + ) def scale_up(self, num_actors: int) -> int: for _ in range(num_actors): @@ -507,6 +511,22 @@ def scale_down(self, num_actors: int) -> int: # === End of overriding methods of AutoscalingActorPool === + def mark_running_actor_as_restarting(self, actor: ray.actor.ActorHandle): + """Mark the running actor as restarting + + Args: The running actor to be marked as restarting + """ + assert actor in self._running_actors + self._running_actors[actor]._is_restarting = True + + def clear_restarting_from_running_actor(self, actor: ray.actor.ActorHandle): + """Clear restarting from the running actor + + Args: The running actor to be cleared of restarting state + """ + assert actor in self._running_actors + self._running_actors[actor]._is_restarting = False + def add_pending_actor(self, actor: ray.actor.ActorHandle, ready_ref: ray.ObjectRef): """Adds a pending actor to the pool. @@ -537,55 +557,11 @@ def pending_to_running(self, ready_ref: ray.ObjectRef) -> bool: # The actor has been removed from the pool before becoming running. return False actor = self._pending_actors.pop(ready_ref) - self._num_tasks_in_flight[actor] = 0 - self._actor_locations[actor] = ray.get(ready_ref) - return True - - def running_to_restarting( - self, actor: ray.actor.ActorHandle, ready_ref: ray.ObjectRef - ) -> bool: - """Mark the actor corresponding to the provided ready future as restarting. - - Args: - actor: The running actor to add as restarting to the pool. - ready_ref: The ready future for the actor that we wish to mark as - restarting. - - Returns: - Whether the actor was still running. This can return False if the actor had - already been killed. - """ - if actor not in self._num_tasks_in_flight: - # The actor has been removed from the pool before becoming restarting. - return False - self._num_tasks_restarting_actors[actor] = self._num_tasks_in_flight[actor] - self._num_tasks_in_flight[actor] = 0 - self._restarting_actor_locations[actor] = ray.get(ready_ref) - self._remove_actor(actor) - return True - - def restarting_to_running( - self, actor: ray.actor.ActorHandle, ready_ref: ray.ObjectRef - ) -> bool: - """Mark the actor as running, making the actor pickable. - - Args: - actor: Then restarting actor to add as running to the pool. - ready_ref: The ready future for the actor that we wish to mark as - running. - - Returns: - Whether the actor was still restarting. This can return False if the actor - had already been killed. - """ - if actor not in self._restarting_actor_locations: - # The actor has been removed from the pool before becoming running. - return False - self._num_tasks_in_flight[actor] = self._num_tasks_restarting_actors[actor] - self._num_tasks_restarting_actors[actor] = 0 - self._actor_locations[actor] = ray.get(ready_ref) - del self._restarting_actor_locations[actor] - del self._num_tasks_restarting_actors[actor] + self._running_actors[actor] = _ActorRunningState( + num_tasks_in_flight=0, + actor_location=ray.get(ready_ref), + is_restarting=False, + ) return True def pick_actor( @@ -594,13 +570,13 @@ def pick_actor( """Picks an actor for task submission based on busyness and locality. None will be returned if all actors are either at capacity (according to - max_tasks_in_flight) or are still pending/restarting. + max_tasks_in_flight) or are still pending. Args: locality_hint: Try to pick an actor that is local for this bundle. """ - if not self._num_tasks_in_flight: - # Actor pool is empty or all actors are still pending/restarting. + if not self._running_actors: + # Actor pool is empty or all actors are still pending. return None if locality_hint: @@ -614,31 +590,43 @@ def penalty_key(actor): We prioritize valid actors, those with argument locality, and those that are not busy, in that order. """ - busyness = self._num_tasks_in_flight[actor] - invalid = busyness >= self._max_tasks_in_flight - requires_remote_fetch = self._actor_locations[actor] != preferred_loc + busyness = self._running_actors[actor]._num_tasks_in_flight + is_restarting = self._running_actors[actor]._is_restarting + invalid = busyness >= self._max_tasks_in_flight or is_restarting + requires_remote_fetch = ( + self._running_actors[actor]._actor_location != preferred_loc + ) return invalid, requires_remote_fetch, busyness - actor = min(self._num_tasks_in_flight.keys(), key=penalty_key) - if self._num_tasks_in_flight[actor] >= self._max_tasks_in_flight: - # All actors are at capacity. + actor = min(self._running_actors.keys(), key=penalty_key) + if ( + self._running_actors[actor]._num_tasks_in_flight + >= self._max_tasks_in_flight + or self._running_actors[actor]._is_restarting + ): + # All actors are at capacity or restarting. return None if locality_hint: - if self._actor_locations[actor] == preferred_loc: + if self._running_actors[actor]._actor_location == preferred_loc: self._locality_hits += 1 else: self._locality_misses += 1 - self._num_tasks_in_flight[actor] += 1 + self._running_actors[actor]._num_tasks_in_flight += 1 return actor def return_actor(self, actor: ray.actor.ActorHandle): """Returns the provided actor to the pool.""" - assert actor in self._num_tasks_in_flight - assert self._num_tasks_in_flight[actor] > 0 + assert actor in self._running_actors + assert self._running_actors[actor]._num_tasks_in_flight > 0 - self._num_tasks_in_flight[actor] -= 1 - if self._should_kill_idle_actors and self._num_tasks_in_flight[actor] == 0: + # Mark restarting as false, now that the actor in running + self._running_actors[actor]._is_restarting = False + self._running_actors[actor]._num_tasks_in_flight -= 1 + if ( + self._should_kill_idle_actors + and self._running_actors[actor]._num_tasks_in_flight == 0 + ): self._remove_actor(actor) def get_pending_actor_refs(self) -> List[ray.ObjectRef]: @@ -647,37 +635,31 @@ def get_pending_actor_refs(self) -> List[ray.ObjectRef]: def num_idle_actors(self) -> int: """Return the number of idle actors in the pool.""" return sum( - 1 if tasks_in_flight == 0 else 0 - for tasks_in_flight in self._num_tasks_in_flight.values() + 1 if running_actor._num_tasks_in_flight == 0 else 0 + for running_actor in self._running_actors.values() ) def num_free_slots(self) -> int: """Return the number of free slots for task execution.""" - if not self._num_tasks_in_flight: + if not self._running_actors: return 0 return sum( - max(0, self._max_tasks_in_flight - num_tasks_in_flight) - for num_tasks_in_flight in self._num_tasks_in_flight.values() + max(0, self._max_tasks_in_flight - running_actor._num_tasks_in_flight) + for running_actor in self._running_actors.values() ) def kill_inactive_actor(self) -> bool: - """Kills a single pending, restarting or idle actor, if any actors are - pending/restarting/idle. + """Kills a single pending or idle actor, if any actors are pending/idle. Returns whether an inactive actor was actually killed. """ - # Prioritize killing pending actor. + # We prioritize killing pending actors over idle actors to reduce actor starting + # churn. killed = self._maybe_kill_pending_actor() - if killed: - return True - - # Next prioritize killing restarting actor. - killed = self._maybe_kill_restarting_actor() - if killed: - return True - - # Finally, kill an idle actor. - return self._maybe_kill_idle_actor() + if not killed: + # If no pending actor was killed, so kill actor. + killed = self._maybe_kill_idle_actor() + return killed def _maybe_kill_pending_actor(self) -> bool: if self._pending_actors: @@ -689,21 +671,9 @@ def _maybe_kill_pending_actor(self) -> bool: # No pending actors, so indicate to the caller that no actors were killed. return False - def _maybe_kill_restarting_actor(self) -> bool: - for actor in self._restarting_actor_locations.keys(): - if self._num_tasks_restarting_actors[actor] == 0: - # At least one restarting actor, so kill first one. - self._remove_actor(actor) - del self._num_tasks_restarting_actors[actor] - del self._restarting_actor_locations[actor] - return True - # No candidate restarting actors, so indicate to the caller that no actors were - # killed. - return False - def _maybe_kill_idle_actor(self) -> bool: - for actor, tasks_in_flight in self._num_tasks_in_flight.items(): - if tasks_in_flight == 0: + for actor, running_actor in self._running_actors.items(): + if running_actor._num_tasks_in_flight == 0: # At least one idle actor, so kill first one found. self._remove_actor(actor) return True @@ -715,11 +685,10 @@ def kill_all_inactive_actors(self): idle in the future will be eagerly killed. This is called once the operator is done submitting work to the pool, and this - function is idempotent. Adding new pending/restarting actors after calling this - function will raise an error. + function is idempotent. Adding new pending actors after calling this function + will raise an error. """ self._kill_all_pending_actors() - self._kill_all_restarting_actors() self._kill_all_idle_actors() def kill_all_actors(self): @@ -728,7 +697,6 @@ def kill_all_actors(self): This is called once the operator is shutting down. """ self._kill_all_pending_actors() - self._kill_all_restarting_actors() self._kill_all_running_actors() def _kill_all_pending_actors(self): @@ -736,23 +704,18 @@ def _kill_all_pending_actors(self): self._remove_actor(actor) self._pending_actors.clear() - def _kill_all_restarting_actors(self): - for actor in self._restarting_actor_locations.keys(): - self._remove_actor(actor) - self._restarting_actor_locations.clear() - def _kill_all_idle_actors(self): idle_actors = [ actor - for actor, tasks_in_flight in self._num_tasks_in_flight.items() - if tasks_in_flight == 0 + for actor, running_actor in self._running_actors.items() + if running_actor._num_tasks_in_flight == 0 ] for actor in idle_actors: self._remove_actor(actor) self._should_kill_idle_actors = True def _kill_all_running_actors(self): - actors = list(self._num_tasks_in_flight.keys()) + actors = list(self._running_actors.keys()) for actor in actors: self._remove_actor(actor) @@ -761,12 +724,8 @@ def _remove_actor(self, actor: ray.actor.ActorHandle): # NOTE: we remove references to the actor and let ref counting # garbage collect the actor, instead of using ray.kill. # Because otherwise the actor cannot be restarted upon lineage reconstruction. - for state_dict in [ - self._num_tasks_in_flight, - self._actor_locations, - ]: - if actor in state_dict: - del state_dict[actor] + if actor in self._running_actors: + del self._running_actors[actor] def _get_location(self, bundle: RefBundle) -> Optional[NodeIdStr]: """Ask Ray for the node id of the given bundle. diff --git a/python/ray/data/tests/test_actor_pool_map_operator.py b/python/ray/data/tests/test_actor_pool_map_operator.py index 38e5d781fc2ac..239c5c120fc10 100644 --- a/python/ray/data/tests/test_actor_pool_map_operator.py +++ b/python/ray/data/tests/test_actor_pool_map_operator.py @@ -128,6 +128,37 @@ def test_pending_to_running(self): assert pool.num_idle_actors() == 0 assert pool.num_free_slots() == 3 + def test_restarting_to_running(self): + # Test that actor is correctly transitioned from restarting to running. + pool = self._create_actor_pool(max_tasks_in_flight=1) + actor = self._add_ready_actor(pool) + + # Mark the actor as restarting and test pick_actor fails + pool.mark_running_actor_as_restarting(actor) + assert pool.pick_actor() is None + assert pool.current_size() == 1 + assert pool.num_pending_actors() == 0 + assert pool.num_running_actors() == 1 + assert pool.num_restarting_actors() == 1 + assert pool.num_active_actors() == 0 + assert pool.num_idle_actors() == 1 + assert pool.num_free_slots() == 1 + + # Clear the actor as restarting and test pick_actor succeeds + pool.clear_restarting_from_running_actor(actor) + picked_actor = pool.pick_actor() + assert picked_actor == actor + assert pool.current_size() == 1 + assert pool.num_pending_actors() == 0 + assert pool.num_running_actors() == 1 + assert pool.num_restarting_actors() == 0 + assert pool.num_active_actors() == 1 + assert pool.num_idle_actors() == 0 + assert pool.num_free_slots() == 0 + + # Return the actor + pool.return_actor(picked_actor) + def test_repeated_picking(self): # Test that we can repeatedly pick the same actor. pool = self._create_actor_pool(max_tasks_in_flight=999) @@ -157,6 +188,25 @@ def test_return_actor(self): assert pool.num_idle_actors() == 1 # Actor should now be idle. assert pool.num_free_slots() == 999 + def test_returned_actor_to_running(self): + # Test that we can return the actor and it will be marked as running and clear + # restarting flag. + pool = self._create_actor_pool(max_tasks_in_flight=999) + self._add_ready_actor(pool) + picked_actor = pool.pick_actor() + pool.mark_running_actor_as_restarting(picked_actor) + assert pool.num_restarting_actors() == 1 + # Return the actor as many times as it was picked. + pool.return_actor(picked_actor) + assert pool.num_restarting_actors() == 0 + # Check that the per-state pool sizes are as expected. + assert pool.current_size() == 1 + assert pool.num_pending_actors() == 0 + assert pool.num_running_actors() == 1 + assert pool.num_active_actors() == 0 + assert pool.num_idle_actors() == 1 # Actor should now be idle. + assert pool.num_free_slots() == 999 + def test_pick_max_tasks_in_flight(self): # Test that we can't pick an actor beyond the max_tasks_in_flight cap. pool = self._create_actor_pool(max_tasks_in_flight=2) @@ -213,6 +263,35 @@ def test_pick_all_max_tasks_in_flight(self): # Check that the next pick doesn't return an actor. assert pool.pick_actor() is None + def test_pick_ordering_restarting(self): + # Test that pick ordering is honored by restarting actors + pool = self._create_actor_pool(max_tasks_in_flight=2) + # Add 4 actors to the pool. + actors = [self._add_ready_actor(pool) for _ in range(4)] + + # Pick actors + for _ in range(4): + picked_actor = pool.pick_actor() + assert pool._running_actors[picked_actor]._num_tasks_in_flight == 1 + + # Mark actor[0] as restarting + pool.mark_running_actor_as_restarting(actors[0]) + + # Verify clearing restarting makes the actor pickable + for _ in range(4): + picked_actor = pool.pick_actor() + if picked_actor is not None: + assert pool._running_actors[picked_actor]._num_tasks_in_flight == 2 + else: + picked_actor = actors[0] + assert pool._running_actors[picked_actor]._num_tasks_in_flight == 1 + pool.clear_restarting_from_running_actor(picked_actor) + picked_actor = pool.pick_actor() + picked_actor = actors[0] + assert pool._running_actors[picked_actor]._num_tasks_in_flight == 2 + # Check that the next pick doesn't return an actor. + assert pool.pick_actor() is None + def test_pick_ordering_with_returns(self): # Test that pick ordering works with returns. pool = self._create_actor_pool() @@ -515,12 +594,12 @@ def test_locality_manager_busyness_ranking(self): actor2 = self._add_ready_actor(pool, node_id="node2") # Fake actor 2 as more busy. - pool._num_tasks_in_flight[actor2] = 1 + pool._running_actors[actor2]._num_tasks_in_flight = 1 res1 = pool.pick_actor(bundles[0]) assert res1 == actor1 # Fake actor 2 as more busy again. - pool._num_tasks_in_flight[actor2] = 2 + pool._running_actors[actor2]._num_tasks_in_flight = 2 res2 = pool.pick_actor(bundles[0]) assert res2 == actor1 @@ -601,7 +680,7 @@ async def wait_for_nodes_restarted(self): signal_actor = Signal.remote() # Spin up nodes - num_nodes = 5 + num_nodes = 2 nodes = [] for _ in range(num_nodes): nodes.append(cluster.add_node(num_cpus=10, num_gpus=1))