forked from mlcommons/medperf
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'fl-poc' of https://github.com/hasan7n/medperf into be_e…
…nable_partial_epochs
- Loading branch information
Showing
39 changed files
with
630 additions
and
288 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
from medperf import config | ||
from medperf.account_management.account_management import get_medperf_user_data | ||
from medperf.entities.ca import CA | ||
from medperf.entities.training_exp import TrainingExp | ||
from medperf.entities.cube import Cube | ||
from medperf.utils import ( | ||
get_pki_assets_path, | ||
generate_tmp_path, | ||
dict_pretty_print, | ||
remove_path, | ||
) | ||
from medperf.certificates import trust | ||
import yaml | ||
import os | ||
|
||
|
||
class GetExperimentStatus: | ||
@classmethod | ||
def run(cls, training_exp_id: int): | ||
"""Starts the aggregation server of a training experiment | ||
Args: | ||
training_exp_id (int): Training experiment UID. | ||
""" | ||
execution = cls(training_exp_id) | ||
execution.prepare() | ||
execution.prepare_plan() | ||
execution.prepare_pki_assets() | ||
with config.ui.interactive(): | ||
execution.prepare_admin_cube() | ||
execution.get_experiment_status() | ||
execution.print_experiment_status() | ||
execution.store_status() | ||
|
||
def __init__(self, training_exp_id: int) -> None: | ||
self.training_exp_id = training_exp_id | ||
self.ui = config.ui | ||
|
||
def prepare(self): | ||
self.training_exp = TrainingExp.get(self.training_exp_id) | ||
self.ui.print(f"Training Experiment: {self.training_exp.name}") | ||
self.user_email: str = get_medperf_user_data()["email"] | ||
self.status_output = generate_tmp_path() | ||
self.temp_dir = generate_tmp_path() | ||
|
||
def prepare_plan(self): | ||
self.training_exp.prepare_plan() | ||
|
||
def prepare_pki_assets(self): | ||
ca = CA.from_experiment(self.training_exp_id) | ||
trust(ca) | ||
self.admin_pki_assets = get_pki_assets_path(self.user_email, ca.name) | ||
self.ca = ca | ||
|
||
def prepare_admin_cube(self): | ||
self.cube = self.__get_cube(self.training_exp.fl_admin_mlcube, "FL Admin") | ||
|
||
def __get_cube(self, uid: int, name: str) -> Cube: | ||
self.ui.text = ( | ||
"Retrieving and setting up training MLCube. This may take some time." | ||
) | ||
cube = Cube.get(uid) | ||
cube.download_run_files() | ||
self.ui.print(f"> {name} cube download complete") | ||
return cube | ||
|
||
def get_experiment_status(self): | ||
env_dict = {"MEDPERF_ADMIN_PARTICIPANT_CN": self.user_email} | ||
params = { | ||
"node_cert_folder": self.admin_pki_assets, | ||
"ca_cert_folder": self.ca.pki_assets, | ||
"plan_path": self.training_exp.plan_path, | ||
"output_status_file": self.status_output, | ||
"temp_dir": self.temp_dir, | ||
} | ||
|
||
self.ui.text = "Getting training experiment status" | ||
self.cube.run(task="get_experiment_status", env_dict=env_dict, **params) | ||
|
||
def print_experiment_status(self): | ||
with open(self.status_output) as f: | ||
contents = yaml.safe_load(f) | ||
dict_pretty_print(contents) | ||
|
||
def store_status(self): | ||
new_status_path = self.training_exp.status_path | ||
remove_path(new_status_path) | ||
os.rename(self.status_output, new_status_path) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
from medperf import config | ||
from medperf.account_management.account_management import get_medperf_user_data | ||
from medperf.entities.ca import CA | ||
from medperf.entities.training_exp import TrainingExp | ||
from medperf.entities.cube import Cube | ||
from medperf.utils import get_pki_assets_path, generate_tmp_path | ||
from medperf.certificates import trust | ||
|
||
|
||
class UpdatePlan: | ||
@classmethod | ||
def run(cls, training_exp_id: int, field_name: str, field_value: str): | ||
"""Starts the aggregation server of a training experiment | ||
Args: | ||
training_exp_id (int): Training experiment UID. | ||
""" | ||
execution = cls(training_exp_id, field_name, field_value) | ||
execution.prepare() | ||
execution.prepare_plan() | ||
execution.prepare_pki_assets() | ||
with config.ui.interactive(): | ||
execution.prepare_admin_cube() | ||
execution.update_plan() | ||
|
||
def __init__(self, training_exp_id: int, field_name: str, field_value: str) -> None: | ||
self.training_exp_id = training_exp_id | ||
self.field_name = field_name | ||
self.field_value = field_value | ||
self.ui = config.ui | ||
|
||
def prepare(self): | ||
self.training_exp = TrainingExp.get(self.training_exp_id) | ||
self.ui.print(f"Training Experiment: {self.training_exp.name}") | ||
self.user_email: str = get_medperf_user_data()["email"] | ||
self.temp_dir = generate_tmp_path() | ||
|
||
def prepare_plan(self): | ||
self.training_exp.prepare_plan() | ||
|
||
def prepare_pki_assets(self): | ||
ca = CA.from_experiment(self.training_exp_id) | ||
trust(ca) | ||
self.admin_pki_assets = get_pki_assets_path(self.user_email, ca.name) | ||
self.ca = ca | ||
|
||
def prepare_admin_cube(self): | ||
self.cube = self.__get_cube(self.training_exp.fl_admin_mlcube, "FL Admin") | ||
|
||
def __get_cube(self, uid: int, name: str) -> Cube: | ||
self.ui.text = ( | ||
"Retrieving and setting up training MLCube. This may take some time." | ||
) | ||
cube = Cube.get(uid) | ||
cube.download_run_files() | ||
self.ui.print(f"> {name} cube download complete") | ||
return cube | ||
|
||
def update_plan(self): | ||
env_dict = { | ||
"MEDPERF_ADMIN_PARTICIPANT_CN": self.user_email, | ||
"MEDPERF_UPDATE_FIELD_NAME": self.field_name, | ||
"MEDPERF_UPDATE_FIELD_VALUE": self.field_value, | ||
} | ||
|
||
params = { | ||
"node_cert_folder": self.admin_pki_assets, | ||
"ca_cert_folder": self.ca.pki_assets, | ||
"plan_path": self.training_exp.plan_path, | ||
"temp_dir": self.temp_dir, | ||
} | ||
|
||
self.ui.text = "Updating plan" | ||
self.cube.run(task="update_plan", env_dict=env_dict, **params) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -121,6 +121,7 @@ METRIC_PARAMS="$ASSETS_URL/metrics/mlcube/workspace/parameters.yaml" | |
# FL cubes | ||
TRAIN_MLCUBE="https://raw.githubusercontent.com/hasan7n/medperf/19c80d88deaad27b353d1cb9bc180757534027aa/examples/fl/fl/mlcube/mlcube.yaml" | ||
TRAIN_WEIGHTS="https://storage.googleapis.com/medperf-storage/testfl/init_weights_miccai.tar.gz" | ||
FLADMIN_MLCUBE="https://raw.githubusercontent.com/hasan7n/medperf/bc431ffe6c3b761b28674816e6f26511e8b27042/examples/fl/fl_admin/mlcube/mlcube.yaml" | ||
|
||
# test users credentials | ||
MODELOWNER="[email protected]" | ||
|
@@ -129,6 +130,7 @@ BENCHMARKOWNER="[email protected]" | |
ADMIN="[email protected]" | ||
DATAOWNER2="[email protected]" | ||
AGGOWNER="[email protected]" | ||
FLADMIN="[email protected]" | ||
|
||
# local MLCubes for local compatibility tests | ||
PREP_LOCAL="$(dirname $(dirname $(realpath "$0")))/examples/chestxray_tutorial/data_preparator/mlcube" | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.