Skip to content

Commit

Permalink
Updates to straggler handling functionality (securefederatedai#996)
Browse files Browse the repository at this point in the history
* v1.0 straggler handling added

Signed-off-by: Parth Mandaliya <[email protected]>

* Added start_straggler_cutoff_timer abstract function in StragglerHandlingFunction abstract class.
Renamed start_timer and __timer_expired functions to start_straggler_cutoff_timer and _straggler_cutoff_time_elapsed respectively.
Added docstring to both functions mentioned above.

Signed-off-by: Parth Mandaliya <[email protected]>

* Updated logs straggler handling in aggregator.py.
If one or more collaborator(s) does not even 1 task results in time, all tasks results sent by that collaborator is excluded from aggregation.

Signed-off-by: Parth Mandaliya <[email protected]>

* Only time based straggler handling policies require timer thread,
removing start_straggler_cutoff_timer function from parent class StragglerHandlingFunction

Signed-off-by: Parth Mandaliya <[email protected]>

* Find all unfinished tasks and straggler collaborators

Signed-off-by: Parth Mandaliya <[email protected]>

* Review comments incorporated

Signed-off-by: Parth Mandaliya <[email protected]>

* Added inline comments

Signed-off-by: Parth Mandaliya <[email protected]>

* Changed logic to keep track of collaborators which have reported results for all tasks
Changed straggler handling logs
Added docstring for functions in all straggler handling policies

Signed-off-by: Parth Mandaliya <[email protected]>

* 1. StragglerHandlingFunction: Added an interface method start_policy
2. CutoffTimeBasedStragglerHandling: start_policy method implements a timer to wait
for cutoff-time and then call provided call callback method
3. Aggregator:
    - sendlocaltaskresults: update collaborators_done to keep track of collaborators
	that have finished ALL tasks
    - _straggler_cutoff_time_elapsed: call back function that is called after cutoff
	time has elapsed and applies the straggler policy

Signed-off-by: Parth Mandaliya <[email protected]>

* Removed logger argument from start_policy &
Added logger argument to get_straggler_handling_policy in plan.py

Signed-off-by: Parth Mandaliya <[email protected]>

* If cutoff time is set to infinite, do not start the timer thread.

Signed-off-by: Parth Mandaliya <[email protected]>

* Resolved lint issues.

Signed-off-by: Parth Mandaliya <[email protected]>

* Pytest and code coverage test case failure resolved

Signed-off-by: Parth Mandaliya <[email protected]>

* Default logger value is set to None

Signed-off-by: Parth Mandaliya <[email protected]>

* Redesigned percentage based straggler policy.
minimum_reporting cannot be set 0 in any straggler policy

Signed-off-by: Parth Mandaliya <[email protected]>

* Code cleanup

Signed-off-by: Parth Mandaliya <[email protected]>

* Only collaborators_done are used for aggregation

Signed-off-by: Parth Mandaliya <[email protected]>

* Internal review comments incorporated
Logs updated
Logger argument removed from straggler handling policy classes

Signed-off-by: Parth Mandaliya <[email protected]>

* Resolving potential issues found during testing

Signed-off-by: Parth Mandaliya <[email protected]>

* Use _collaborator_task_completed method to check if all given tasks to
collaborator are completed or not

Signed-off-by: Parth Mandaliya <[email protected]>

* Few test cases failing issue resolved

Signed-off-by: Parth Mandaliya <[email protected]>

* Potential issue in aggregator based workflow tutorial resolved

Signed-off-by: Parth Mandaliya <[email protected]>

* Corner case issue discovered during testing is patched.

Signed-off-by: Parth Mandaliya <[email protected]>

* Added reset_policy_for_round function in straggler handling policy base class.

Signed-off-by: Parth Mandaliya <[email protected]>

* This commit includes following changes:
1. In cutoff time based policy after cutoff time expires wait for all collaborators not just minimum required.
2. Irregardless of tasks assigned to collaborators if minimum required collaborators report resultsw in time apply straggler handling policy.

Signed-off-by: Parth Mandaliya <[email protected]>

* Code cleanup

Signed-off-by: Parth Mandaliya <[email protected]>

* Logs modified

Signed-off-by: Parth Mandaliya <[email protected]>

* Review comments on PR incorporated.

Signed-off-by: Parth Mandaliya <[email protected]>

* Log updated

Signed-off-by: Parth Mandaliya <[email protected]>

* Condition in straggler_cutoff_check fixed

Signed-off-by: Parth Mandaliya <[email protected]>

* If straggler cutoff time set to infinite only wait for minimum required collaborators to report results.

Signed-off-by: Parth Mandaliya <[email protected]>

* If straggler cutoff time is set to infinite wait for ALL collaborators not only for minimum_reporting collaborators

Signed-off-by: Parth Mandaliya <[email protected]>

* Teodor's review comments and internal review comments incorporated.
flake8 issues resovled.

Signed-off-by: Parth Mandaliya <[email protected]>

* Changed minimum_reporting validation.
Merged conditions in straggler_cutoff_check function in CutoffTimeBasedStragglerHandling class.

Signed-off-by: Parth Mandaliya <[email protected]>

* Resolved all flake8 issues.

Signed-off-by: Parth Mandaliya <[email protected]>

* Modified single inline comment

Signed-off-by: Parth Mandaliya <[email protected]>

* Review comments incorporated.

Signed-off-by: Parth Mandaliya <[email protected]>

* Lint fixes

Signed-off-by: Ishant Thakare <[email protected]>

* Incorporated Micah's review comments and removed unused code

Signed-off-by: Ishant Thakare <[email protected]>

* Incorporated Teo's review comments

Signed-off-by: Ishant Thakare <[email protected]>

* Incorporated review comments & added mutex in aggregator for thread safety

Signed-off-by: Ishant Thakare <[email protected]>

* Review comments incorporated and updated logs

Signed-off-by: Ishant Thakare <[email protected]>

* Reverted a comment for Pytest and code coverage fix

Signed-off-by: Ishant Thakare <[email protected]>

---------

Signed-off-by: Parth Mandaliya <[email protected]>
Signed-off-by: Ishant Thakare <[email protected]>
Co-authored-by: Ishant Thakare <[email protected]>
  • Loading branch information
ParthMandaliya and ishant162 authored Sep 23, 2024
1 parent 3881a48 commit 7c33420
Show file tree
Hide file tree
Showing 7 changed files with 265 additions and 182 deletions.
2 changes: 1 addition & 1 deletion openfl/component/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,5 @@
PercentageBasedStragglerHandling,
)
from openfl.component.straggler_handling_functions.straggler_handling_function import (
StragglerHandlingFunction,
StragglerHandlingPolicy,
)
160 changes: 99 additions & 61 deletions openfl/component/aggregator/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import queue
import time
from logging import getLogger
from threading import Lock

from openfl.component.straggler_handling_functions import CutoffTimeBasedStragglerHandling
from openfl.databases import TensorDB
Expand Down Expand Up @@ -53,9 +54,10 @@ class Aggregator:
collaborator_tasks_results (dict): Dict of collaborator tasks
results.
collaborator_task_weight (dict): Dict of col task weight.
lock: A threading Lock object used to ensure thread-safe operations.
.. note::
- plan setting
- plan setting
"""

def __init__(
Expand Down Expand Up @@ -177,6 +179,13 @@ def __init__(

self.collaborator_task_weight = {} # {TaskResultKey: data_size}

# maintain a list of collaborators that have completed task and
# reported results in a given round
self.collaborators_done = []

# Initialize a lock for thread safety
self.lock = Lock()

def _load_initial_tensors(self):
"""Load all of the tensors required to begin federated learning.
Expand Down Expand Up @@ -391,11 +400,30 @@ def get_tasks(self, collaborator_name):
)
sleep_time = 0

if hasattr(self.straggler_handling_policy, "round_start_time"):
self.straggler_handling_policy.round_start_time = time.time()
# Start straggler handling policy for timer based callback is required
# for %age based policy callback is not required
self.straggler_handling_policy.start_policy(callback=self._straggler_cutoff_time_elapsed)

return tasks, self.round_number, sleep_time, time_to_quit

def _straggler_cutoff_time_elapsed(self) -> None:
"""
This method is called by the straggler handling policy when cutoff timer is elapsed.
It applies straggler handling policy and ends the round early.
Returns:
None
"""
self.logger.warning(
f"Round number: {self.round_number} cutoff timer elapsed after "
f"{self.straggler_handling_policy.straggler_cutoff_time}s. "
f"Applying {self.straggler_handling_policy.__class__.__name__} policy."
)

with self.lock:
# Check if minimum collaborators reported results
self._end_of_round_with_stragglers_check()

def get_aggregated_tensor(
self,
collaborator_name,
Expand Down Expand Up @@ -573,10 +601,10 @@ def send_local_task_results(
Returns:
None
"""
if self._time_to_quit() or self._is_task_done(task_name):
if self._time_to_quit() or collaborator_name in self.stragglers:
self.logger.warning(
f"STRAGGLER: Collaborator {collaborator_name} is reporting results "
"after task {task_name} has finished."
f"after task {task_name} has finished."
)
return

Expand All @@ -596,10 +624,11 @@ def send_local_task_results(

# we mustn't have results already
if self._collaborator_task_completed(collaborator_name, task_name, round_number):
raise ValueError(
self.logger.warning(
f"Aggregator already has task results from collaborator {collaborator_name}"
f" for task {task_key}"
)
return

# By giving task_key it's own weight, we can support different
# training/validation weights
Expand Down Expand Up @@ -632,7 +661,31 @@ def send_local_task_results(
task_results.append(tensor_key)

self.collaborator_tasks_results[task_key] = task_results
self._end_of_task_check(task_name)

with self.lock:
self._is_collaborator_done(collaborator_name, round_number)

self._end_of_round_with_stragglers_check()

def _end_of_round_with_stragglers_check(self):
"""
Checks if the minimum required collaborators have reported their results,
identifies any stragglers, and initiates an early round end if necessary.
Returns:
None
"""
if self.straggler_handling_policy.straggler_cutoff_check(
len(self.collaborators_done), len(self.authorized_cols)
):
self.stragglers = [
collab_name
for collab_name in self.authorized_cols
if collab_name not in self.collaborators_done
]
if len(self.stragglers) != 0:
self.logger.warning(f"Identified stragglers: {self.stragglers}")
self._end_of_round_check()

def _process_named_tensor(self, named_tensor, collaborator_name):
"""Extract the named tensor fields.
Expand Down Expand Up @@ -724,21 +777,6 @@ def _process_named_tensor(self, named_tensor, collaborator_name):

return final_tensor_key, final_nparray

def _end_of_task_check(self, task_name):
"""Check whether all collaborators who are supposed to perform the
task complete.
Args:
task_name (str): Task name.
The task name to check.
Returns:
bool: Whether the task is done.
"""
if self._is_task_done(task_name):
# now check for the end of the round
self._end_of_round_check()

def _prepare_trained(self, tensor_name, origin, round_number, report, agg_results):
"""Prepare aggregated tensorkey tags.
Expand Down Expand Up @@ -839,11 +877,12 @@ def _compute_validation_related_task_metrics(self, task_name):
all_collaborators_for_task = self.assigner.get_collaborators_for_task(
task_name, self.round_number
)
# leave out stragglers for the round
# Leave out straggler for the round even if they've paritally
# completed given tasks
collaborators_for_task = []
for c in all_collaborators_for_task:
if self._collaborator_task_completed(c, task_name, self.round_number):
collaborators_for_task.append(c)
collaborators_for_task = [
c for c in all_collaborators_for_task if c in self.collaborators_done
]

# The collaborator data sizes for that task
collaborator_weights_unnormalized = {
Expand Down Expand Up @@ -919,7 +958,7 @@ def _end_of_round_check(self):
Returns:
None
"""
if not self._is_round_done() or self._end_of_round_check_done[self.round_number]:
if self._end_of_round_check_done[self.round_number]:
return

# Compute all validation related metrics
Expand All @@ -932,6 +971,8 @@ def _end_of_round_check(self):
self.round_number += 1
# resetting stragglers for task for a new round
self.stragglers = []
# resetting collaborators_done for next round
self.collaborators_done = []

# Save the latest model
self.logger.info("Saving round %s model...", self.round_number)
Expand All @@ -945,49 +986,46 @@ def _end_of_round_check(self):

# Cleaning tensor db
self.tensor_db.clean_up(self.db_store_rounds)
# Reset straggler handling policy for the next round.
self.straggler_handling_policy.reset_policy_for_round()

def _is_task_done(self, task_name):
"""Check that task is done.
def _is_collaborator_done(self, collaborator_name: str, round_number: int) -> None:
"""
Check if all tasks given to the collaborator are completed then,
completed or not.
Args:
task_name (str): Task name.
collaborator_name (str): Collaborator name.
round_number (int): Round number.
Returns:
bool: Whether the task is done.
None
"""
all_collaborators = self.assigner.get_collaborators_for_task(task_name, self.round_number)

collaborators_done = []
for c in all_collaborators:
if self._collaborator_task_completed(c, task_name, self.round_number):
collaborators_done.append(c)

straggler_check = self.straggler_handling_policy.straggler_cutoff_check(
len(collaborators_done), all_collaborators
)
if self.round_number != round_number:
self.logger.warning(
f"Collaborator {collaborator_name} is reporting results"
f" for the wrong round: {round_number}. Ignoring..."
)
return

if straggler_check:
for c in all_collaborators:
if c not in collaborators_done:
self.stragglers.append(c)
# Get all tasks given to the collaborator for current round
all_tasks = self.assigner.get_tasks_for_collaborator(collaborator_name, self.round_number)
# Check if all given tasks are completed by the collaborator
all_tasks_completed = True
for task in all_tasks:
if hasattr(task, "name"):
task = task.name
all_tasks_completed = all_tasks_completed and self._collaborator_task_completed(
collaborator=collaborator_name, task_name=task, round_num=self.round_number
)
# If the collaborator has completed ALL tasks for current round,
# update collaborators_done
if all_tasks_completed:
self.collaborators_done.append(collaborator_name)
self.logger.info(
"\tEnding task %s early due to straggler cutoff policy",
task_name,
f"Round: {self.round_number}, Collaborators that have completed all tasks: "
f"{self.collaborators_done}"
)
self.logger.warning("\tIdentified stragglers: %s", self.stragglers)

# all are done or straggler policy calls for early round end.
return straggler_check or len(all_collaborators) == len(collaborators_done)

def _is_round_done(self):
"""Check that round is done.
Returns:
bool: Whether the round is done.
"""
tasks_for_round = self.assigner.get_all_tasks_for_round(self.round_number)

return all(self._is_task_done(task_name) for task_name in tasks_for_round)

def _log_big_warning(self):
"""Warn user about single collaborator cert mode."""
Expand Down
2 changes: 1 addition & 1 deletion openfl/component/straggler_handling_functions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,5 @@
PercentageBasedStragglerHandling,
)
from openfl.component.straggler_handling_functions.straggler_handling_function import (
StragglerHandlingFunction,
StragglerHandlingPolicy,
)
Loading

0 comments on commit 7c33420

Please sign in to comment.