From d072303e0bd457030da03ea9d1ab94920e0fd8af Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alejandro=20Aristiz=C3=A1bal?= Date: Thu, 1 Aug 2024 10:28:36 -0500 Subject: [PATCH 1/4] Fix package buttons not displaying in relevant stages (#608) * Fix package buttons not displaying in relevant stages * Fix style issue --- scripts/monitor/rano_monitor/widgets/summary.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/scripts/monitor/rano_monitor/widgets/summary.py b/scripts/monitor/rano_monitor/widgets/summary.py index 570b0088f..34cf22409 100644 --- a/scripts/monitor/rano_monitor/widgets/summary.py +++ b/scripts/monitor/rano_monitor/widgets/summary.py @@ -1,6 +1,6 @@ import os import pandas as pd -from rano_monitor.constants import REVIEW_FILENAME, REVIEWED_FILENAME +from rano_monitor.constants import REVIEW_FILENAME, REVIEWED_FILENAME, MANUAL_REVIEW_STAGE, DONE_STAGE from rano_monitor.messages import InvalidSubjectsUpdated from rano_monitor.messages import ReportUpdated from rano_monitor.messages import AnnotationsLoaded @@ -66,7 +66,10 @@ def update_summary(self): # Attach status_percents["DONE"] = 0.0 - package_btns.display = "MANUAL_REVIEW_REQUIRED" in status_percents + abs_status = display_report_df["status"].abs() + is_beyond_manual_review = (abs_status >= MANUAL_REVIEW_STAGE) + is_not_done = (abs_status < DONE_STAGE) + package_btns.display = any(is_beyond_manual_review & is_not_done) widgets = [] for name, val in status_percents.items(): From d8e3cba2f42151a489b4430af9c23052ac651f83 Mon Sep 17 00:00:00 2001 From: hasan7n <78664424+hasan7n@users.noreply.github.com> Date: Thu, 15 Aug 2024 01:41:13 +0200 Subject: [PATCH 2/4] refactor entities (#586) * refactor entities * Update cli/medperf/entities/interface.py Co-authored-by: Viacheslav Kukushkin * Apply suggestions from code review Co-authored-by: Viacheslav Kukushkin * update outdated result submission code * refactor schemas * use local_id in place of generated uid for clarity * update tests * no need to complicate things * use dynamic local_id * test some type annotations * modify __remote_prefilter * rename outdated intermediate vars * use TypeVar for type hints * Typo fix --------- Co-authored-by: Viacheslav Kukushkin --- README.md | 2 +- cli/cli_tests.sh | 3 +- cli/medperf/cli.py | 2 +- cli/medperf/commands/benchmark/benchmark.py | 16 +- cli/medperf/commands/benchmark/submit.py | 2 +- .../compatibility_test/compatibility_test.py | 8 +- .../commands/compatibility_test/run.py | 2 +- .../commands/compatibility_test/utils.py | 16 +- cli/medperf/commands/dataset/dataset.py | 16 +- cli/medperf/commands/execution.py | 14 +- cli/medperf/commands/list.py | 29 ++- cli/medperf/commands/mlcube/mlcube.py | 16 +- cli/medperf/commands/result/create.py | 7 +- cli/medperf/commands/result/result.py | 18 +- cli/medperf/commands/result/submit.py | 29 ++- cli/medperf/commands/view.py | 23 +- cli/medperf/entities/benchmark.py | 220 +++--------------- cli/medperf/entities/cube.py | 170 +++----------- cli/medperf/entities/dataset.py | 196 +++------------- cli/medperf/entities/interface.py | 216 ++++++++++++++--- cli/medperf/entities/report.py | 107 +++------ cli/medperf/entities/result.py | 188 +++------------ cli/medperf/entities/schemas.py | 32 ++- .../tests/commands/benchmark/test_submit.py | 4 +- .../tests/commands/mlcube/test_submit.py | 2 +- .../tests/commands/result/test_create.py | 3 + .../tests/commands/result/test_submit.py | 1 + cli/medperf/tests/commands/test_execution.py | 18 +- cli/medperf/tests/commands/test_list.py | 8 +- cli/medperf/tests/commands/test_view.py | 211 ++++++----------- cli/medperf/tests/entities/test_benchmark.py | 5 +- cli/medperf/tests/entities/test_cube.py | 13 +- cli/medperf/tests/entities/test_entity.py | 73 +++--- cli/medperf/tests/entities/utils.py | 81 ++++--- 34 files changed, 675 insertions(+), 1076 deletions(-) diff --git a/README.md b/README.md index 2f07c511a..550d281ca 100644 --- a/README.md +++ b/README.md @@ -17,7 +17,7 @@ Inside this repo you can find all important pieces for running MedPerf. In its c If you use MedPerf, please cite our main paper: Karargyris, A., Umeton, R., Sheller, M.J. et al. Federated benchmarking of medical artificial intelligence with MedPerf. *Nature Machine Intelligence* **5**, 799–810 (2023). [https://www.nature.com/articles/s42256-023-00652-2](https://www.nature.com/articles/s42256-023-00652-2) -Additonally, here you can see how others used MedPerf already: [https://scholar.google.com/scholar?q="medperf"](https://scholar.google.com/scholar?q="medperf"). +Additionally, here you can see how others used MedPerf already: [https://scholar.google.com/scholar?q="medperf"](https://scholar.google.com/scholar?q="medperf"). ## Experiments diff --git a/cli/cli_tests.sh b/cli/cli_tests.sh index 697c40105..4640fcca0 100755 --- a/cli/cli_tests.sh +++ b/cli/cli_tests.sh @@ -5,7 +5,6 @@ ################### Start Testing ######################## ########################################################## - ########################################################## echo "==========================================" echo "Printing MedPerf version" @@ -195,7 +194,7 @@ echo "Running data submission step" echo "=====================================" print_eval "medperf dataset submit -p $PREP_UID -d $DIRECTORY/dataset_a -l $DIRECTORY/dataset_a --name='dataset_a' --description='mock dataset a' --location='mock location a' -y" checkFailed "Data submission step failed" -DSET_A_UID=$(medperf dataset ls | grep dataset_a | tr -s ' ' | cut -d ' ' -f 1) +DSET_A_UID=$(medperf dataset ls | grep dataset_a | tr -s ' ' | awk '{$1=$1;print}' | cut -d ' ' -f 1) echo "DSET_A_UID=$DSET_A_UID" ########################################################## diff --git a/cli/medperf/cli.py b/cli/medperf/cli.py index 4fc7102c4..0910c3ed8 100644 --- a/cli/medperf/cli.py +++ b/cli/medperf/cli.py @@ -71,7 +71,7 @@ def execute( please run the command again with the --no-cache option.\n""" ) else: - ResultSubmission.run(result.generated_uid, approved=approval) + ResultSubmission.run(result.local_id, approved=approval) config.ui.print("✅ Done!") diff --git a/cli/medperf/commands/benchmark/benchmark.py b/cli/medperf/commands/benchmark/benchmark.py index f02d67cb4..35d719b0d 100644 --- a/cli/medperf/commands/benchmark/benchmark.py +++ b/cli/medperf/commands/benchmark/benchmark.py @@ -16,14 +16,16 @@ @app.command("ls") @clean_except def list( - local: bool = typer.Option(False, "--local", help="Get local benchmarks"), + unregistered: bool = typer.Option( + False, "--unregistered", help="Get unregistered benchmarks" + ), mine: bool = typer.Option(False, "--mine", help="Get current-user benchmarks"), ): - """List benchmarks stored locally and remotely from the user""" + """List benchmarks""" EntityList.run( Benchmark, fields=["UID", "Name", "Description", "State", "Approval Status", "Registered"], - local_only=local, + unregistered=unregistered, mine_only=mine, ) @@ -162,10 +164,10 @@ def view( "--format", help="Format to display contents. Available formats: [yaml, json]", ), - local: bool = typer.Option( + unregistered: bool = typer.Option( False, - "--local", - help="Display local benchmarks if benchmark ID is not provided", + "--unregistered", + help="Display unregistered benchmarks if benchmark ID is not provided", ), mine: bool = typer.Option( False, @@ -180,4 +182,4 @@ def view( ), ): """Displays the information of one or more benchmarks""" - EntityView.run(entity_id, Benchmark, format, local, mine, output) + EntityView.run(entity_id, Benchmark, format, unregistered, mine, output) diff --git a/cli/medperf/commands/benchmark/submit.py b/cli/medperf/commands/benchmark/submit.py index ebace1880..05d1a0d10 100644 --- a/cli/medperf/commands/benchmark/submit.py +++ b/cli/medperf/commands/benchmark/submit.py @@ -79,7 +79,7 @@ def run_compatibility_test(self): self.ui.print("Running compatibility test") self.bmk.write() data_uid, results = CompatibilityTestExecution.run( - benchmark=self.bmk.generated_uid, + benchmark=self.bmk.local_id, no_cache=self.no_cache, skip_data_preparation_step=self.skip_data_preparation_step, ) diff --git a/cli/medperf/commands/compatibility_test/compatibility_test.py b/cli/medperf/commands/compatibility_test/compatibility_test.py index a3b25ac78..0bd4a4695 100644 --- a/cli/medperf/commands/compatibility_test/compatibility_test.py +++ b/cli/medperf/commands/compatibility_test/compatibility_test.py @@ -95,7 +95,11 @@ def run( @clean_except def list(): """List previously executed tests reports.""" - EntityList.run(TestReport, fields=["UID", "Data Source", "Model", "Evaluator"]) + EntityList.run( + TestReport, + fields=["UID", "Data Source", "Model", "Evaluator"], + unregistered=True, + ) @app.command("view") @@ -116,4 +120,4 @@ def view( ), ): """Displays the information of one or more test reports""" - EntityView.run(entity_id, TestReport, format, output=output) + EntityView.run(entity_id, TestReport, format, unregistered=True, output=output) diff --git a/cli/medperf/commands/compatibility_test/run.py b/cli/medperf/commands/compatibility_test/run.py index 2e3082849..f06603d57 100644 --- a/cli/medperf/commands/compatibility_test/run.py +++ b/cli/medperf/commands/compatibility_test/run.py @@ -239,7 +239,7 @@ def cached_results(self): """ if self.no_cache: return - uid = self.report.generated_uid + uid = self.report.local_id try: report = TestReport.get(uid) except InvalidArgumentError: diff --git a/cli/medperf/commands/compatibility_test/utils.py b/cli/medperf/commands/compatibility_test/utils.py index a12ac5ea2..c56a57d41 100644 --- a/cli/medperf/commands/compatibility_test/utils.py +++ b/cli/medperf/commands/compatibility_test/utils.py @@ -138,23 +138,23 @@ def create_test_dataset( # TODO: existing dataset could make problems # make some changes since this is a test dataset config.tmp_paths.remove(data_creation.dataset.path) - data_creation.dataset.write() if skip_data_preparation_step: data_creation.make_dataset_prepared() dataset = data_creation.dataset + old_generated_uid = dataset.generated_uid + old_path = dataset.path # prepare/check dataset DataPreparation.run(dataset.generated_uid) # update dataset generated_uid - old_path = dataset.path - generated_uid = get_folders_hash([dataset.data_path, dataset.labels_path]) - dataset.generated_uid = generated_uid - dataset.write() - if dataset.input_data_hash != dataset.generated_uid: + new_generated_uid = get_folders_hash([dataset.data_path, dataset.labels_path]) + if new_generated_uid != old_generated_uid: # move to a correct location if it underwent preparation - new_path = old_path.replace(dataset.input_data_hash, generated_uid) + new_path = old_path.replace(old_generated_uid, new_generated_uid) remove_path(new_path) os.rename(old_path, new_path) + dataset.generated_uid = new_generated_uid + dataset.write() - return generated_uid + return new_generated_uid diff --git a/cli/medperf/commands/dataset/dataset.py b/cli/medperf/commands/dataset/dataset.py index a27e36814..fc18022ac 100644 --- a/cli/medperf/commands/dataset/dataset.py +++ b/cli/medperf/commands/dataset/dataset.py @@ -17,17 +17,19 @@ @app.command("ls") @clean_except def list( - local: bool = typer.Option(False, "--local", help="Get local datasets"), + unregistered: bool = typer.Option( + False, "--unregistered", help="Get unregistered datasets" + ), mine: bool = typer.Option(False, "--mine", help="Get current-user datasets"), mlcube: int = typer.Option( None, "--mlcube", "-m", help="Get datasets for a given data prep mlcube" ), ): - """List datasets stored locally and remotely from the user""" + """List datasets""" EntityList.run( Dataset, fields=["UID", "Name", "Data Preparation Cube UID", "State", "Status", "Owner"], - local_only=local, + unregistered=unregistered, mine_only=mine, mlcube=mlcube, ) @@ -149,8 +151,10 @@ def view( "--format", help="Format to display contents. Available formats: [yaml, json]", ), - local: bool = typer.Option( - False, "--local", help="Display local datasets if dataset ID is not provided" + unregistered: bool = typer.Option( + False, + "--unregistered", + help="Display unregistered datasets if dataset ID is not provided", ), mine: bool = typer.Option( False, @@ -165,4 +169,4 @@ def view( ), ): """Displays the information of one or more datasets""" - EntityView.run(entity_id, Dataset, format, local, mine, output) + EntityView.run(entity_id, Dataset, format, unregistered, mine, output) diff --git a/cli/medperf/commands/execution.py b/cli/medperf/commands/execution.py index d8afb2244..85416fe96 100644 --- a/cli/medperf/commands/execution.py +++ b/cli/medperf/commands/execution.py @@ -47,12 +47,12 @@ def prepare(self): logging.debug(f"tmp results output: {self.results_path}") def __setup_logs_path(self): - model_uid = self.model.generated_uid - eval_uid = self.evaluator.generated_uid - data_hash = self.dataset.generated_uid + model_uid = self.model.local_id + eval_uid = self.evaluator.local_id + data_uid = self.dataset.local_id logs_path = os.path.join( - config.experiments_logs_folder, str(model_uid), str(data_hash) + config.experiments_logs_folder, str(model_uid), str(data_uid) ) os.makedirs(logs_path, exist_ok=True) model_logs_path = os.path.join(logs_path, "model.log") @@ -60,10 +60,10 @@ def __setup_logs_path(self): return model_logs_path, metrics_logs_path def __setup_predictions_path(self): - model_uid = self.model.generated_uid - data_hash = self.dataset.generated_uid + model_uid = self.model.local_id + data_uid = self.dataset.local_id preds_path = os.path.join( - config.predictions_folder, str(model_uid), str(data_hash) + config.predictions_folder, str(model_uid), str(data_uid) ) if os.path.exists(preds_path): msg = f"Found existing predictions for model {self.model.id} on dataset " diff --git a/cli/medperf/commands/list.py b/cli/medperf/commands/list.py index 5fd462bf7..99236ac3f 100644 --- a/cli/medperf/commands/list.py +++ b/cli/medperf/commands/list.py @@ -1,3 +1,5 @@ +from typing import List, Type +from medperf.entities.interface import Entity from medperf.exceptions import InvalidArgumentError from tabulate import tabulate @@ -8,29 +10,38 @@ class EntityList: @staticmethod def run( - entity_class, - fields, - local_only: bool = False, + entity_class: Type[Entity], + fields: List[str], + unregistered: bool = False, mine_only: bool = False, **kwargs, ): """Lists all local datasets Args: - local_only (bool, optional): Display all local results. Defaults to False. - mine_only (bool, optional): Display all current-user results. Defaults to False. + unregistered (bool, optional): Display only local unregistered results. Defaults to False. + mine_only (bool, optional): Display all registered current-user results. Defaults to False. kwargs (dict): Additional parameters for filtering entity lists. """ - entity_list = EntityList(entity_class, fields, local_only, mine_only, **kwargs) + entity_list = EntityList( + entity_class, fields, unregistered, mine_only, **kwargs + ) entity_list.prepare() entity_list.validate() entity_list.filter() entity_list.display() - def __init__(self, entity_class, fields, local_only, mine_only, **kwargs): + def __init__( + self, + entity_class: Type[Entity], + fields: List[str], + unregistered: bool, + mine_only: bool, + **kwargs, + ): self.entity_class = entity_class self.fields = fields - self.local_only = local_only + self.unregistered = unregistered self.mine_only = mine_only self.filters = kwargs self.data = [] @@ -40,7 +51,7 @@ def prepare(self): self.filters["owner"] = get_medperf_user_data()["id"] entities = self.entity_class.all( - local_only=self.local_only, filters=self.filters + unregistered=self.unregistered, filters=self.filters ) self.data = [entity.display_dict() for entity in entities] diff --git a/cli/medperf/commands/mlcube/mlcube.py b/cli/medperf/commands/mlcube/mlcube.py index 4c365e574..9256f35f2 100644 --- a/cli/medperf/commands/mlcube/mlcube.py +++ b/cli/medperf/commands/mlcube/mlcube.py @@ -16,14 +16,16 @@ @app.command("ls") @clean_except def list( - local: bool = typer.Option(False, "--local", help="Get local mlcubes"), + unregistered: bool = typer.Option( + False, "--unregistered", help="Get unregistered mlcubes" + ), mine: bool = typer.Option(False, "--mine", help="Get current-user mlcubes"), ): - """List mlcubes stored locally and remotely from the user""" + """List mlcubes""" EntityList.run( Cube, fields=["UID", "Name", "State", "Registered"], - local_only=local, + unregistered=unregistered, mine_only=mine, ) @@ -148,8 +150,10 @@ def view( "--format", help="Format to display contents. Available formats: [yaml, json]", ), - local: bool = typer.Option( - False, "--local", help="Display local mlcubes if mlcube ID is not provided" + unregistered: bool = typer.Option( + False, + "--unregistered", + help="Display unregistered mlcubes if mlcube ID is not provided", ), mine: bool = typer.Option( False, @@ -164,4 +168,4 @@ def view( ), ): """Displays the information of one or more mlcubes""" - EntityView.run(entity_id, Cube, format, local, mine, output) + EntityView.run(entity_id, Cube, format, unregistered, mine, output) diff --git a/cli/medperf/commands/result/create.py b/cli/medperf/commands/result/create.py index 42f97d990..26d52fa2e 100644 --- a/cli/medperf/commands/result/create.py +++ b/cli/medperf/commands/result/create.py @@ -1,5 +1,6 @@ import os from typing import List, Optional +from medperf.account_management.account_management import get_medperf_user_data from medperf.commands.execution import Execution from medperf.entities.result import Result from tabulate import tabulate @@ -143,7 +144,9 @@ def __validate_models(self, benchmark_models): raise InvalidArgumentError(msg) def load_cached_results(self): - results = Result.all() + user_id = get_medperf_user_data()["id"] + results = Result.all(filters={"owner": user_id}) + results += Result.all(unregistered=True) benchmark_dset_results = [ result for result in results @@ -254,7 +257,7 @@ def print_summary(self): data_lists_for_display.append( [ experiment["model_uid"], - experiment["result"].generated_uid, + experiment["result"].local_id, experiment["result"].metadata["partial"], experiment["cached"], experiment["error"], diff --git a/cli/medperf/commands/result/result.py b/cli/medperf/commands/result/result.py index 6fbb3b08a..40b65c52e 100644 --- a/cli/medperf/commands/result/result.py +++ b/cli/medperf/commands/result/result.py @@ -62,17 +62,19 @@ def submit( @app.command("ls") @clean_except def list( - local: bool = typer.Option(False, "--local", help="Get local results"), + unregistered: bool = typer.Option( + False, "--unregistered", help="Get unregistered results" + ), mine: bool = typer.Option(False, "--mine", help="Get current-user results"), benchmark: int = typer.Option( None, "--benchmark", "-b", help="Get results for a given benchmark" ), ): - """List results stored locally and remotely from the user""" + """List results""" EntityList.run( Result, fields=["UID", "Benchmark", "Model", "Dataset", "Registered"], - local_only=local, + unregistered=unregistered, mine_only=mine, benchmark=benchmark, ) @@ -88,8 +90,10 @@ def view( "--format", help="Format to display contents. Available formats: [yaml, json]", ), - local: bool = typer.Option( - False, "--local", help="Display local results if result ID is not provided" + unregistered: bool = typer.Option( + False, + "--unregistered", + help="Display unregistered results if result ID is not provided", ), mine: bool = typer.Option( False, @@ -107,4 +111,6 @@ def view( ), ): """Displays the information of one or more results""" - EntityView.run(entity_id, Result, format, local, mine, output, benchmark=benchmark) + EntityView.run( + entity_id, Result, format, unregistered, mine, output, benchmark=benchmark + ) diff --git a/cli/medperf/commands/result/submit.py b/cli/medperf/commands/result/submit.py index 15649ee04..b69a596ce 100644 --- a/cli/medperf/commands/result/submit.py +++ b/cli/medperf/commands/result/submit.py @@ -3,7 +3,6 @@ from medperf.exceptions import CleanExit from medperf.utils import remove_path, dict_pretty_print, approval_prompt from medperf.entities.result import Result -from medperf.enums import Status from medperf import config @@ -11,6 +10,7 @@ class ResultSubmission: @classmethod def run(cls, result_uid, approved=False): sub = cls(result_uid, approved=approved) + sub.get_result() updated_result_dict = sub.upload_results() sub.to_permanent_path(updated_result_dict) sub.write(updated_result_dict) @@ -21,27 +21,26 @@ def __init__(self, result_uid, approved=False): self.ui = config.ui self.approved = approved - def request_approval(self, result): - if result.approval_status == Status.APPROVED: - return True + def get_result(self): + self.result = Result.get(self.result_uid) - dict_pretty_print(result.results) + def request_approval(self): + dict_pretty_print(self.result.results) self.ui.print("Above are the results generated by the model") approved = approval_prompt( - "Do you approve uploading the presented results to the MLCommons comms? [Y/n]" + "Do you approve uploading the presented results to the MedPerf? [Y/n]" ) return approved def upload_results(self): - result = Result.get(self.result_uid) - approved = self.approved or self.request_approval(result) + approved = self.approved or self.request_approval() if not approved: raise CleanExit("Results upload operation cancelled") - updated_result_dict = result.upload() + updated_result_dict = self.result.upload() return updated_result_dict def to_permanent_path(self, result_dict: dict): @@ -50,12 +49,12 @@ def to_permanent_path(self, result_dict: dict): Args: result_dict (dict): updated results dictionary """ - result = Result(**result_dict) - result_storage = config.results_folder - old_res_loc = os.path.join(result_storage, result.generated_uid) - new_res_loc = result.path - remove_path(new_res_loc) - os.rename(old_res_loc, new_res_loc) + + old_result_loc = self.result.path + updated_result = Result(**result_dict) + new_result_loc = updated_result.path + remove_path(new_result_loc) + os.rename(old_result_loc, new_result_loc) def write(self, updated_result_dict): result = Result(**updated_result_dict) diff --git a/cli/medperf/commands/view.py b/cli/medperf/commands/view.py index b4c242f0a..d19aedec0 100644 --- a/cli/medperf/commands/view.py +++ b/cli/medperf/commands/view.py @@ -1,6 +1,6 @@ import yaml import json -from typing import Union +from typing import Union, Type from medperf import config from medperf.account_management import get_medperf_user_data @@ -12,9 +12,9 @@ class EntityView: @staticmethod def run( entity_id: Union[int, str], - entity_class: Entity, + entity_class: Type[Entity], format: str = "yaml", - local_only: bool = False, + unregistered: bool = False, mine_only: bool = False, output: str = None, **kwargs, @@ -24,14 +24,14 @@ def run( Args: entity_id (Union[int, str]): Entity identifies entity_class (Entity): Entity type - local_only (bool, optional): Display all local entities. Defaults to False. + unregistered (bool, optional): Display only local unregistered entities. Defaults to False. mine_only (bool, optional): Display all current-user entities. Defaults to False. format (str, optional): What format to use to display the contents. Valid formats: [yaml, json]. Defaults to yaml. output (str, optional): Path to a file for storing the entity contents. If not provided, the contents are printed. kwargs (dict): Additional parameters for filtering entity lists. """ entity_view = EntityView( - entity_id, entity_class, format, local_only, mine_only, output, **kwargs + entity_id, entity_class, format, unregistered, mine_only, output, **kwargs ) entity_view.validate() entity_view.prepare() @@ -41,12 +41,19 @@ def run( entity_view.store() def __init__( - self, entity_id, entity_class, format, local_only, mine_only, output, **kwargs + self, + entity_id: Union[int, str], + entity_class: Type[Entity], + format: str, + unregistered: bool, + mine_only: bool, + output: str, + **kwargs, ): self.entity_id = entity_id self.entity_class = entity_class self.format = format - self.local_only = local_only + self.unregistered = unregistered self.mine_only = mine_only self.output = output self.filters = kwargs @@ -65,7 +72,7 @@ def prepare(self): self.filters["owner"] = get_medperf_user_data()["id"] entities = self.entity_class.all( - local_only=self.local_only, filters=self.filters + unregistered=self.unregistered, filters=self.filters ) self.data = [entity.todict() for entity in entities] diff --git a/cli/medperf/entities/benchmark.py b/cli/medperf/entities/benchmark.py index 849ea3fcd..e03fcdb4f 100644 --- a/cli/medperf/entities/benchmark.py +++ b/cli/medperf/entities/benchmark.py @@ -1,18 +1,13 @@ -import os -from medperf.exceptions import MedperfException -import yaml -import logging -from typing import List, Optional, Union +from typing import List, Optional from pydantic import HttpUrl, Field import medperf.config as config -from medperf.entities.interface import Entity, Uploadable -from medperf.exceptions import CommunicationRetrievalError, InvalidArgumentError -from medperf.entities.schemas import MedperfSchema, ApprovableSchema, DeployableSchema +from medperf.entities.interface import Entity +from medperf.entities.schemas import ApprovableSchema, DeployableSchema from medperf.account_management import get_medperf_user_data -class Benchmark(Entity, Uploadable, MedperfSchema, ApprovableSchema, DeployableSchema): +class Benchmark(Entity, ApprovableSchema, DeployableSchema): """ Class representing a Benchmark @@ -35,6 +30,26 @@ class Benchmark(Entity, Uploadable, MedperfSchema, ApprovableSchema, DeployableS user_metadata: dict = {} is_active: bool = True + @staticmethod + def get_type(): + return "benchmark" + + @staticmethod + def get_storage_path(): + return config.benchmarks_folder + + @staticmethod + def get_comms_retriever(): + return config.comms.get_benchmark + + @staticmethod + def get_metadata_filename(): + return config.benchmarks_filename + + @staticmethod + def get_comms_uploader(): + return config.comms.upload_benchmark + def __init__(self, *args, **kwargs): """Creates a new benchmark instance @@ -43,54 +58,12 @@ def __init__(self, *args, **kwargs): """ super().__init__(*args, **kwargs) - self.generated_uid = f"p{self.data_preparation_mlcube}m{self.reference_model_mlcube}e{self.data_evaluator_mlcube}" - path = config.benchmarks_folder - if self.id: - path = os.path.join(path, str(self.id)) - else: - path = os.path.join(path, self.generated_uid) - self.path = path - - @classmethod - def all(cls, local_only: bool = False, filters: dict = {}) -> List["Benchmark"]: - """Gets and creates instances of all retrievable benchmarks - - Args: - local_only (bool, optional): Wether to retrieve only local entities. Defaults to False. - filters (dict, optional): key-value pairs specifying filters to apply to the list of entities. - - Returns: - List[Benchmark]: a list of Benchmark instances. - """ - logging.info("Retrieving all benchmarks") - benchmarks = [] - - if not local_only: - benchmarks = cls.__remote_all(filters=filters) - - remote_uids = set([bmk.id for bmk in benchmarks]) + @property + def local_id(self): + return self.name - local_benchmarks = cls.__local_all() - - benchmarks += [bmk for bmk in local_benchmarks if bmk.id not in remote_uids] - - return benchmarks - - @classmethod - def __remote_all(cls, filters: dict) -> List["Benchmark"]: - benchmarks = [] - try: - comms_fn = cls.__remote_prefilter(filters) - bmks_meta = comms_fn() - benchmarks = [cls(**meta) for meta in bmks_meta] - except CommunicationRetrievalError: - msg = "Couldn't retrieve all benchmarks from the server" - logging.warning(msg) - - return benchmarks - - @classmethod - def __remote_prefilter(cls, filters: dict) -> callable: + @staticmethod + def remote_prefilter(filters: dict) -> callable: """Applies filtering logic that must be done before retrieving remote entities Args: @@ -104,104 +77,6 @@ def __remote_prefilter(cls, filters: dict) -> callable: comms_fn = config.comms.get_user_benchmarks return comms_fn - @classmethod - def __local_all(cls) -> List["Benchmark"]: - benchmarks = [] - bmks_storage = config.benchmarks_folder - try: - uids = next(os.walk(bmks_storage))[1] - except StopIteration: - msg = "Couldn't iterate over benchmarks directory" - logging.warning(msg) - raise MedperfException(msg) - - for uid in uids: - meta = cls.__get_local_dict(uid) - benchmark = cls(**meta) - benchmarks.append(benchmark) - - return benchmarks - - @classmethod - def get( - cls, benchmark_uid: Union[str, int], local_only: bool = False - ) -> "Benchmark": - """Retrieves and creates a Benchmark instance from the server. - If benchmark already exists in the platform then retrieve that - version. - - Args: - benchmark_uid (str): UID of the benchmark. - comms (Comms): Instance of a communication interface. - - Returns: - Benchmark: a Benchmark instance with the retrieved data. - """ - - if not str(benchmark_uid).isdigit() or local_only: - return cls.__local_get(benchmark_uid) - - try: - return cls.__remote_get(benchmark_uid) - except CommunicationRetrievalError: - logging.warning(f"Getting Benchmark {benchmark_uid} from comms failed") - logging.info(f"Looking for benchmark {benchmark_uid} locally") - return cls.__local_get(benchmark_uid) - - @classmethod - def __remote_get(cls, benchmark_uid: int) -> "Benchmark": - """Retrieves and creates a Dataset instance from the comms instance. - If the dataset is present in the user's machine then it retrieves it from there. - - Args: - dset_uid (str): server UID of the dataset - - Returns: - Dataset: Specified Dataset Instance - """ - logging.debug(f"Retrieving benchmark {benchmark_uid} remotely") - benchmark_dict = config.comms.get_benchmark(benchmark_uid) - benchmark = cls(**benchmark_dict) - benchmark.write() - return benchmark - - @classmethod - def __local_get(cls, benchmark_uid: Union[str, int]) -> "Benchmark": - """Retrieves and creates a Dataset instance from the comms instance. - If the dataset is present in the user's machine then it retrieves it from there. - - Args: - dset_uid (str): server UID of the dataset - - Returns: - Dataset: Specified Dataset Instance - """ - logging.debug(f"Retrieving benchmark {benchmark_uid} locally") - benchmark_dict = cls.__get_local_dict(benchmark_uid) - benchmark = cls(**benchmark_dict) - return benchmark - - @classmethod - def __get_local_dict(cls, benchmark_uid) -> dict: - """Retrieves a local benchmark information - - Args: - benchmark_uid (str): uid of the local benchmark - - Returns: - dict: information of the benchmark - """ - logging.info(f"Retrieving benchmark {benchmark_uid} from local storage") - storage = config.benchmarks_folder - bmk_storage = os.path.join(storage, str(benchmark_uid)) - bmk_file = os.path.join(bmk_storage, config.benchmarks_filename) - if not os.path.exists(bmk_file): - raise InvalidArgumentError("No benchmark with the given uid could be found") - with open(bmk_file, "r") as f: - data = yaml.safe_load(f) - - return data - @classmethod def get_models_uids(cls, benchmark_uid: int) -> List[int]: """Retrieves the list of models associated to the benchmark @@ -221,43 +96,6 @@ def get_models_uids(cls, benchmark_uid: int) -> List[int]: ] return models_uids - def todict(self) -> dict: - """Dictionary representation of the benchmark instance - - Returns: - dict: Dictionary containing benchmark information - """ - return self.extended_dict() - - def write(self) -> str: - """Writes the benchmark into disk - - Args: - filename (str, optional): name of the file. Defaults to config.benchmarks_filename. - - Returns: - str: path to the created benchmark file - """ - data = self.todict() - bmk_file = os.path.join(self.path, config.benchmarks_filename) - if not os.path.exists(bmk_file): - os.makedirs(self.path, exist_ok=True) - with open(bmk_file, "w") as f: - yaml.dump(data, f) - return bmk_file - - def upload(self): - """Uploads a benchmark to the server - - Args: - comms (Comms): communications entity to submit through - """ - if self.for_test: - raise InvalidArgumentError("Cannot upload test benchmarks.") - body = self.todict() - updated_body = config.comms.upload_benchmark(body) - return updated_body - def display_dict(self): return { "UID": self.identifier, diff --git a/cli/medperf/entities/cube.py b/cli/medperf/entities/cube.py index 98d2b95a8..714342c53 100644 --- a/cli/medperf/entities/cube.py +++ b/cli/medperf/entities/cube.py @@ -1,7 +1,7 @@ import os import yaml import logging -from typing import List, Dict, Optional, Union +from typing import Dict, Optional, Union from pydantic import Field from pathlib import Path @@ -12,21 +12,15 @@ generate_tmp_path, spawn_and_kill, ) -from medperf.entities.interface import Entity, Uploadable -from medperf.entities.schemas import MedperfSchema, DeployableSchema -from medperf.exceptions import ( - InvalidArgumentError, - ExecutionError, - InvalidEntityError, - MedperfException, - CommunicationRetrievalError, -) +from medperf.entities.interface import Entity +from medperf.entities.schemas import DeployableSchema +from medperf.exceptions import InvalidArgumentError, ExecutionError, InvalidEntityError import medperf.config as config from medperf.comms.entity_resources import resources from medperf.account_management import get_medperf_user_data -class Cube(Entity, Uploadable, MedperfSchema, DeployableSchema): +class Cube(Entity, DeployableSchema): """ Class representing an MLCube Container @@ -48,6 +42,26 @@ class Cube(Entity, Uploadable, MedperfSchema, DeployableSchema): metadata: dict = {} user_metadata: dict = {} + @staticmethod + def get_type(): + return "cube" + + @staticmethod + def get_storage_path(): + return config.cubes_folder + + @staticmethod + def get_comms_retriever(): + return config.comms.get_cube_metadata + + @staticmethod + def get_metadata_filename(): + return config.cube_metadata_filename + + @staticmethod + def get_comms_uploader(): + return config.comms.upload_mlcube + def __init__(self, *args, **kwargs): """Creates a Cube instance @@ -56,60 +70,17 @@ def __init__(self, *args, **kwargs): """ super().__init__(*args, **kwargs) - self.generated_uid = self.name - path = config.cubes_folder - if self.id: - path = os.path.join(path, str(self.id)) - else: - path = os.path.join(path, self.generated_uid) - # NOTE: maybe have these as @property, to have the same entity reusable - # before and after submission - self.path = path - self.cube_path = os.path.join(path, config.cube_filename) + self.cube_path = os.path.join(self.path, config.cube_filename) self.params_path = None if self.git_parameters_url: - self.params_path = os.path.join(path, config.params_filename) - - @classmethod - def all(cls, local_only: bool = False, filters: dict = {}) -> List["Cube"]: - """Class method for retrieving all retrievable MLCubes - - Args: - local_only (bool, optional): Wether to retrieve only local entities. Defaults to False. - filters (dict, optional): key-value pairs specifying filters to apply to the list of entities. - - Returns: - List[Cube]: List containing all cubes - """ - logging.info("Retrieving all cubes") - cubes = [] - if not local_only: - cubes = cls.__remote_all(filters=filters) - - remote_uids = set([cube.id for cube in cubes]) - - local_cubes = cls.__local_all() - - cubes += [cube for cube in local_cubes if cube.id not in remote_uids] - - return cubes - - @classmethod - def __remote_all(cls, filters: dict) -> List["Cube"]: - cubes = [] - - try: - comms_fn = cls.__remote_prefilter(filters) - cubes_meta = comms_fn() - cubes = [cls(**meta) for meta in cubes_meta] - except CommunicationRetrievalError: - msg = "Couldn't retrieve all cubes from the server" - logging.warning(msg) + self.params_path = os.path.join(self.path, config.params_filename) - return cubes + @property + def local_id(self): + return self.name - @classmethod - def __remote_prefilter(cls, filters: dict): + @staticmethod + def remote_prefilter(filters: dict): """Applies filtering logic that must be done before retrieving remote entities Args: @@ -124,25 +95,6 @@ def __remote_prefilter(cls, filters: dict): return comms_fn - @classmethod - def __local_all(cls) -> List["Cube"]: - cubes = [] - cubes_folder = config.cubes_folder - try: - uids = next(os.walk(cubes_folder))[1] - logging.debug(f"Local cubes found: {uids}") - except StopIteration: - msg = "Couldn't iterate over cubes directory" - logging.warning(msg) - raise MedperfException(msg) - - for uid in uids: - meta = cls.__get_local_dict(uid) - cube = cls(**meta) - cubes.append(cube) - - return cubes - @classmethod def get(cls, cube_uid: Union[str, int], local_only: bool = False) -> "Cube": """Retrieves and creates a Cube instance from the comms. If cube already exists @@ -155,36 +107,12 @@ def get(cls, cube_uid: Union[str, int], local_only: bool = False) -> "Cube": Cube : a Cube instance with the retrieved data. """ - if not str(cube_uid).isdigit() or local_only: - cube = cls.__local_get(cube_uid) - else: - try: - cube = cls.__remote_get(cube_uid) - except CommunicationRetrievalError: - logging.warning(f"Getting MLCube {cube_uid} from comms failed") - logging.info(f"Retrieving MLCube {cube_uid} from local storage") - cube = cls.__local_get(cube_uid) - + cube = super().get(cube_uid, local_only) if not cube.is_valid: raise InvalidEntityError("The requested MLCube is marked as INVALID.") cube.download_config_files() return cube - @classmethod - def __remote_get(cls, cube_uid: int) -> "Cube": - logging.debug(f"Retrieving mlcube {cube_uid} remotely") - meta = config.comms.get_cube_metadata(cube_uid) - cube = cls(**meta) - cube.write() - return cube - - @classmethod - def __local_get(cls, cube_uid: Union[str, int]) -> "Cube": - logging.debug(f"Retrieving cube {cube_uid} locally") - local_meta = cls.__get_local_dict(cube_uid) - cube = cls(**local_meta) - return cube - def download_mlcube(self): url = self.git_mlcube_url path, file_hash = resources.get_cube(url, self.path, self.mlcube_hash) @@ -320,7 +248,7 @@ def run( """ kwargs.update(string_params) cmd = f"mlcube --log-level {config.loglevel} run" - cmd += f" --mlcube=\"{self.cube_path}\" --task={task} --platform={config.platform} --network=none" + cmd += f' --mlcube="{self.cube_path}" --task={task} --platform={config.platform} --network=none' if config.gpus is not None: cmd += f" --gpus={config.gpus}" if read_protected_input: @@ -430,36 +358,6 @@ def get_config(self, identifier): return cube - def todict(self) -> Dict: - return self.extended_dict() - - def write(self): - cube_loc = str(Path(self.cube_path).parent) - meta_file = os.path.join(cube_loc, config.cube_metadata_filename) - os.makedirs(cube_loc, exist_ok=True) - with open(meta_file, "w") as f: - yaml.dump(self.todict(), f) - return meta_file - - def upload(self): - if self.for_test: - raise InvalidArgumentError("Cannot upload test mlcubes.") - cube_dict = self.todict() - updated_cube_dict = config.comms.upload_mlcube(cube_dict) - return updated_cube_dict - - @classmethod - def __get_local_dict(cls, uid): - cubes_folder = config.cubes_folder - meta_file = os.path.join(cubes_folder, str(uid), config.cube_metadata_filename) - if not os.path.exists(meta_file): - raise InvalidArgumentError( - "The requested mlcube information could not be found locally" - ) - with open(meta_file, "r") as f: - meta = yaml.safe_load(f) - return meta - def display_dict(self): return { "UID": self.identifier, diff --git a/cli/medperf/entities/dataset.py b/cli/medperf/entities/dataset.py index 4c210431f..7f13c2185 100644 --- a/cli/medperf/entities/dataset.py +++ b/cli/medperf/entities/dataset.py @@ -1,22 +1,17 @@ import os import yaml -import logging from pydantic import Field, validator -from typing import List, Optional, Union +from typing import Optional, Union from medperf.utils import remove_path -from medperf.entities.interface import Entity, Uploadable -from medperf.entities.schemas import MedperfSchema, DeployableSchema -from medperf.exceptions import ( - InvalidArgumentError, - MedperfException, - CommunicationRetrievalError, -) +from medperf.entities.interface import Entity +from medperf.entities.schemas import DeployableSchema + import medperf.config as config from medperf.account_management import get_medperf_user_data -class Dataset(Entity, Uploadable, MedperfSchema, DeployableSchema): +class Dataset(Entity, DeployableSchema): """ Class representing a Dataset @@ -37,6 +32,26 @@ class Dataset(Entity, Uploadable, MedperfSchema, DeployableSchema): report: dict = {} submitted_as_prepared: bool + @staticmethod + def get_type(): + return "dataset" + + @staticmethod + def get_storage_path(): + return config.datasets_folder + + @staticmethod + def get_comms_retriever(): + return config.comms.get_dataset + + @staticmethod + def get_metadata_filename(): + return config.reg_file + + @staticmethod + def get_comms_uploader(): + return config.comms.upload_dataset + @validator("data_preparation_mlcube", pre=True, always=True) def check_data_preparation_mlcube(cls, v, *, values, **kwargs): if not isinstance(v, int) and not values["for_test"]: @@ -47,20 +62,16 @@ def check_data_preparation_mlcube(cls, v, *, values, **kwargs): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - - path = config.datasets_folder - if self.id: - path = os.path.join(path, str(self.id)) - else: - path = os.path.join(path, self.generated_uid) - - self.path = path self.data_path = os.path.join(self.path, "data") self.labels_path = os.path.join(self.path, "labels") self.report_path = os.path.join(self.path, config.report_file) self.metadata_path = os.path.join(self.path, config.metadata_folder) self.statistics_path = os.path.join(self.path, config.statistics_filename) + @property + def local_id(self): + return self.generated_uid + def set_raw_paths(self, raw_data_path: str, raw_labels_path: str): raw_paths_file = os.path.join(self.path, config.dataset_raw_paths_file) data = {"data_path": raw_data_path, "labels_path": raw_labels_path} @@ -86,48 +97,8 @@ def is_ready(self): flag_file = os.path.join(self.path, config.ready_flag_file) return os.path.exists(flag_file) - def todict(self): - return self.extended_dict() - - @classmethod - def all(cls, local_only: bool = False, filters: dict = {}) -> List["Dataset"]: - """Gets and creates instances of all the locally prepared datasets - - Args: - local_only (bool, optional): Wether to retrieve only local entities. Defaults to False. - filters (dict, optional): key-value pairs specifying filters to apply to the list of entities. - - Returns: - List[Dataset]: a list of Dataset instances. - """ - logging.info("Retrieving all datasets") - dsets = [] - if not local_only: - dsets = cls.__remote_all(filters=filters) - - remote_uids = set([dset.id for dset in dsets]) - - local_dsets = cls.__local_all() - - dsets += [dset for dset in local_dsets if dset.id not in remote_uids] - - return dsets - - @classmethod - def __remote_all(cls, filters: dict) -> List["Dataset"]: - dsets = [] - try: - comms_fn = cls.__remote_prefilter(filters) - dsets_meta = comms_fn() - dsets = [cls(**meta) for meta in dsets_meta] - except CommunicationRetrievalError: - msg = "Couldn't retrieve all datasets from the server" - logging.warning(msg) - - return dsets - - @classmethod - def __remote_prefilter(cls, filters: dict) -> callable: + @staticmethod + def remote_prefilter(filters: dict) -> callable: """Applies filtering logic that must be done before retrieving remote entities Args: @@ -149,111 +120,6 @@ def func(): return comms_fn - @classmethod - def __local_all(cls) -> List["Dataset"]: - dsets = [] - datasets_folder = config.datasets_folder - try: - uids = next(os.walk(datasets_folder))[1] - except StopIteration: - msg = "Couldn't iterate over the dataset directory" - logging.warning(msg) - raise MedperfException(msg) - - for uid in uids: - local_meta = cls.__get_local_dict(uid) - dset = cls(**local_meta) - dsets.append(dset) - - return dsets - - @classmethod - def get(cls, dset_uid: Union[str, int], local_only: bool = False) -> "Dataset": - """Retrieves and creates a Dataset instance from the comms instance. - If the dataset is present in the user's machine then it retrieves it from there. - - Args: - dset_uid (str): server UID of the dataset - - Returns: - Dataset: Specified Dataset Instance - """ - if not str(dset_uid).isdigit() or local_only: - return cls.__local_get(dset_uid) - - try: - return cls.__remote_get(dset_uid) - except CommunicationRetrievalError: - logging.warning(f"Getting Dataset {dset_uid} from comms failed") - logging.info(f"Looking for dataset {dset_uid} locally") - return cls.__local_get(dset_uid) - - @classmethod - def __remote_get(cls, dset_uid: int) -> "Dataset": - """Retrieves and creates a Dataset instance from the comms instance. - If the dataset is present in the user's machine then it retrieves it from there. - - Args: - dset_uid (str): server UID of the dataset - - Returns: - Dataset: Specified Dataset Instance - """ - logging.debug(f"Retrieving dataset {dset_uid} remotely") - meta = config.comms.get_dataset(dset_uid) - dataset = cls(**meta) - dataset.write() - return dataset - - @classmethod - def __local_get(cls, dset_uid: Union[str, int]) -> "Dataset": - """Retrieves and creates a Dataset instance from the comms instance. - If the dataset is present in the user's machine then it retrieves it from there. - - Args: - dset_uid (str): server UID of the dataset - - Returns: - Dataset: Specified Dataset Instance - """ - logging.debug(f"Retrieving dataset {dset_uid} locally") - local_meta = cls.__get_local_dict(dset_uid) - dataset = cls(**local_meta) - return dataset - - def write(self): - logging.info(f"Updating registration information for dataset: {self.id}") - logging.debug(f"registration information: {self.todict()}") - regfile = os.path.join(self.path, config.reg_file) - os.makedirs(self.path, exist_ok=True) - with open(regfile, "w") as f: - yaml.dump(self.todict(), f) - return regfile - - def upload(self): - """Uploads the registration information to the comms. - - Args: - comms (Comms): Instance of the comms interface. - """ - if self.for_test: - raise InvalidArgumentError("Cannot upload test datasets.") - dataset_dict = self.todict() - updated_dataset_dict = config.comms.upload_dataset(dataset_dict) - return updated_dataset_dict - - @classmethod - def __get_local_dict(cls, data_uid): - dataset_path = os.path.join(config.datasets_folder, str(data_uid)) - regfile = os.path.join(dataset_path, config.reg_file) - if not os.path.exists(regfile): - raise InvalidArgumentError( - "The requested dataset information could not be found locally" - ) - with open(regfile, "r") as f: - reg = yaml.safe_load(f) - return reg - def display_dict(self): return { "UID": self.identifier, diff --git a/cli/medperf/entities/interface.py b/cli/medperf/entities/interface.py index af2afabd7..835fbdf22 100644 --- a/cli/medperf/entities/interface.py +++ b/cli/medperf/entities/interface.py @@ -1,77 +1,225 @@ -from typing import List, Dict, Union -from abc import ABC, abstractmethod +from typing import List, Dict, Union, Callable +from abc import ABC +import logging +import os +import yaml +from medperf.exceptions import MedperfException, InvalidArgumentError +from medperf.entities.schemas import MedperfSchema +from typing import Type, TypeVar +EntityType = TypeVar("EntityType", bound="Entity") -class Entity(ABC): - @abstractmethod + +class Entity(MedperfSchema, ABC): + @staticmethod + def get_type() -> str: + raise NotImplementedError() + + @staticmethod + def get_storage_path() -> str: + raise NotImplementedError() + + @staticmethod + def get_comms_retriever() -> Callable[[int], dict]: + raise NotImplementedError() + + @staticmethod + def get_metadata_filename() -> str: + raise NotImplementedError() + + @staticmethod + def get_comms_uploader() -> Callable[[dict], dict]: + raise NotImplementedError() + + @property + def local_id(self) -> str: + raise NotImplementedError() + + @property + def identifier(self) -> Union[int, str]: + return self.id or self.local_id + + @property + def is_registered(self) -> bool: + return self.id is not None + + @property + def path(self) -> str: + storage_path = self.get_storage_path() + return os.path.join(storage_path, str(self.identifier)) + + @classmethod def all( - cls, local_only: bool = False, comms_func: callable = None - ) -> List["Entity"]: + cls: Type[EntityType], unregistered: bool = False, filters: dict = {} + ) -> List[EntityType]: """Gets a list of all instances of the respective entity. - Wether the list is local or remote depends on the implementation. + Whether the list is local or remote depends on the implementation. Args: - local_only (bool, optional): Wether to retrieve only local entities. Defaults to False. - comms_func (callable, optional): Function to use to retrieve remote entities. - If not provided, will use the default entrypoint. + unregistered (bool, optional): Wether to retrieve only unregistered local entities. Defaults to False. + filters (dict, optional): key-value pairs specifying filters to apply to the list of entities. + Returns: List[Entity]: a list of entities. """ + logging.info(f"Retrieving all {cls.get_type()} entities") + if unregistered: + if filters: + raise InvalidArgumentError( + "Filtering is not supported for unregistered entities" + ) + return cls.__unregistered_all() + return cls.__remote_all(filters=filters) + + @classmethod + def __remote_all(cls: Type[EntityType], filters: dict) -> List[EntityType]: + comms_fn = cls.remote_prefilter(filters) + entity_meta = comms_fn() + entities = [cls(**meta) for meta in entity_meta] + return entities + + @classmethod + def __unregistered_all(cls: Type[EntityType]) -> List[EntityType]: + entities = [] + storage_path = cls.get_storage_path() + try: + uids = next(os.walk(storage_path))[1] + except StopIteration: + msg = f"Couldn't iterate over the {cls.get_type()} storage" + logging.warning(msg) + raise MedperfException(msg) + + for uid in uids: + if uid.isdigit(): + continue + entity = cls.__local_get(uid) + entities.append(entity) + + return entities + + @staticmethod + def remote_prefilter(filters: dict) -> callable: + """Applies filtering logic that must be done before retrieving remote entities - @abstractmethod - def get(cls, uid: Union[str, int]) -> "Entity": + Args: + filters (dict): filters to apply + + Returns: + callable: A function for retrieving remote entities with the applied prefilters + """ + raise NotImplementedError + + @classmethod + def get( + cls: Type[EntityType], uid: Union[str, int], local_only: bool = False + ) -> EntityType: """Gets an instance of the respective entity. Wether this requires only local read or remote calls depends on the implementation. Args: uid (str): Unique Identifier to retrieve the entity + local_only (bool): If True, the entity will be retrieved locally Returns: Entity: Entity Instance associated to the UID """ - @abstractmethod - def todict(self) -> Dict: - """Dictionary representation of the entity + if not str(uid).isdigit() or local_only: + return cls.__local_get(uid) + return cls.__remote_get(uid) + + @classmethod + def __remote_get(cls: Type[EntityType], uid: int) -> EntityType: + """Retrieves and creates an entity instance from the comms instance. + + Args: + uid (int): server UID of the entity Returns: - Dict: Dictionary containing information about the entity + Entity: Specified Entity Instance """ + logging.debug(f"Retrieving {cls.get_type()} {uid} remotely") + comms_func = cls.get_comms_retriever() + entity_dict = comms_func(uid) + entity = cls(**entity_dict) + entity.write() + return entity - @abstractmethod - def write(self) -> str: - """Writes the entity to the local storage + @classmethod + def __local_get(cls: Type[EntityType], uid: Union[str, int]) -> EntityType: + """Retrieves and creates an entity instance from the local storage. + + Args: + uid (str|int): UID of the entity Returns: - str: Path to the stored entity + Entity: Specified Entity Instance """ + logging.debug(f"Retrieving {cls.get_type()} {uid} locally") + entity_dict = cls.__get_local_dict(uid) + entity = cls(**entity_dict) + return entity - @abstractmethod - def display_dict(self) -> dict: - """Returns a dictionary of entity properties that can be displayed - to a user interface using a verbose name of the property rather than - the internal names + @classmethod + def __get_local_dict(cls: Type[EntityType], uid: Union[str, int]) -> dict: + """Retrieves a local entity information + + Args: + uid (str): uid of the local entity Returns: - dict: the display dictionary + dict: information of the entity """ + logging.info(f"Retrieving {cls.get_type()} {uid} from local storage") + storage_path = cls.get_storage_path() + metadata_filename = cls.get_metadata_filename() + entity_file = os.path.join(storage_path, str(uid), metadata_filename) + if not os.path.exists(entity_file): + raise InvalidArgumentError( + f"No {cls.get_type()} with the given uid could be found" + ) + with open(entity_file, "r") as f: + data = yaml.safe_load(f) + + return data + + def write(self) -> str: + """Writes the entity to the local storage + Returns: + str: Path to the stored entity + """ + data = self.todict() + metadata_filename = self.get_metadata_filename() + entity_file = os.path.join(self.path, metadata_filename) + os.makedirs(self.path, exist_ok=True) + with open(entity_file, "w") as f: + yaml.dump(data, f) + return entity_file -class Uploadable: - @abstractmethod def upload(self) -> Dict: """Upload the entity-related information to the communication's interface Returns: Dict: Dictionary with the updated entity information """ + if self.for_test: + raise InvalidArgumentError( + f"This test {self.get_type()} cannot be uploaded." + ) + body = self.todict() + comms_func = self.get_comms_uploader() + updated_body = comms_func(body) + return updated_body - @property - def identifier(self): - return self.id or self.generated_uid + def display_dict(self) -> dict: + """Returns a dictionary of entity properties that can be displayed + to a user interface using a verbose name of the property rather than + the internal names - @property - def is_registered(self): - return self.id is not None + Returns: + dict: the display dictionary + """ + raise NotImplementedError diff --git a/cli/medperf/entities/report.py b/cli/medperf/entities/report.py index c76f09894..cefd168b3 100644 --- a/cli/medperf/entities/report.py +++ b/cli/medperf/entities/report.py @@ -1,16 +1,11 @@ import hashlib -import os -import yaml -import logging from typing import List, Union, Optional -from medperf.entities.schemas import MedperfBaseSchema import medperf.config as config -from medperf.exceptions import InvalidArgumentError from medperf.entities.interface import Entity -class TestReport(Entity, MedperfBaseSchema): +class TestReport(Entity): """ Class representing a compatibility test report entry @@ -23,8 +18,16 @@ class TestReport(Entity, MedperfBaseSchema): - model cube - evaluator cube - results + + Note: This entity is only a local one, there is no TestReports on the server + However, we still use the same Entity interface used by other entities + in order to reduce repeated code. Consequently, we mocked a few methods + and attributes inherited from the Entity interface that are not relevant to + this entity, such as the `name` and `id` attributes, and such as + the `get` and `all` methods. """ + name: Optional[str] = "name" demo_dataset_url: Optional[str] demo_dataset_hash: Optional[str] data_path: Optional[str] @@ -35,13 +38,25 @@ class TestReport(Entity, MedperfBaseSchema): data_evaluator_mlcube: Union[int, str] results: Optional[dict] + @staticmethod + def get_type(): + return "report" + + @staticmethod + def get_storage_path(): + return config.tests_folder + + @staticmethod + def get_metadata_filename(): + return config.test_report_file + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.generated_uid = self.__generate_uid() - path = config.tests_folder - self.path = os.path.join(path, self.generated_uid) + self.id = None + self.for_test = True - def __generate_uid(self): + @property + def local_id(self): """A helper that generates a unique hash for a test report.""" params = self.todict() del params["results"] @@ -52,71 +67,21 @@ def set_results(self, results): self.results = results @classmethod - def all( - cls, local_only: bool = False, mine_only: bool = False - ) -> List["TestReport"]: - """Gets and creates instances of test reports. - Arguments are only specified for compatibility with - `Entity.List` and `Entity.View`, but they don't contribute to - the logic. - - Returns: - List[TestReport]: List containing all test reports - """ - logging.info("Retrieving all reports") - reports = [] - tests_folder = config.tests_folder - try: - uids = next(os.walk(tests_folder))[1] - except StopIteration: - msg = "Couldn't iterate over the tests directory" - logging.warning(msg) - raise RuntimeError(msg) - - for uid in uids: - local_meta = cls.__get_local_dict(uid) - report = cls(**local_meta) - reports.append(report) - - return reports + def all(cls, unregistered: bool = False, filters: dict = {}) -> List["TestReport"]: + assert unregistered, "Reports are only unregistered" + assert filters == {}, "Reports cannot be filtered" + return super().all(unregistered=True, filters={}) @classmethod - def get(cls, report_uid: str) -> "TestReport": - """Retrieves and creates a TestReport instance obtained the user's machine - + def get(cls, uid: str, local_only: bool = False) -> "TestReport": + """Gets an instance of the TestReport. ignores local_only inherited flag as TestReport is always a local entity. Args: - report_uid (str): UID of the TestReport instance - + uid (str): Report Unique Identifier + local_only (bool): ignored. Left for aligning with parent Entity class Returns: - TestReport: Specified TestReport instance + TestReport: Report Instance associated to the UID """ - logging.debug(f"Retrieving report {report_uid}") - report_dict = cls.__get_local_dict(report_uid) - report = cls(**report_dict) - report.write() - return report - - def todict(self): - return self.extended_dict() - - def write(self): - report_file = os.path.join(self.path, config.test_report_file) - os.makedirs(self.path, exist_ok=True) - with open(report_file, "w") as f: - yaml.dump(self.todict(), f) - return report_file - - @classmethod - def __get_local_dict(cls, local_uid): - report_path = os.path.join(config.tests_folder, str(local_uid)) - report_file = os.path.join(report_path, config.test_report_file) - if not os.path.exists(report_file): - raise InvalidArgumentError( - f"The requested report {local_uid} could not be retrieved" - ) - with open(report_file, "r") as f: - report_info = yaml.safe_load(f) - return report_info + return super().get(uid, local_only=True) def display_dict(self): if self.data_path: @@ -127,7 +92,7 @@ def display_dict(self): data_source = f"{self.prepared_data_hash}" return { - "UID": self.generated_uid, + "UID": self.local_id, "Data Source": data_source, "Model": ( self.model if isinstance(self.model, int) else self.model[:27] + "..." diff --git a/cli/medperf/entities/result.py b/cli/medperf/entities/result.py index c82add87b..0e96d1feb 100644 --- a/cli/medperf/entities/result.py +++ b/cli/medperf/entities/result.py @@ -1,16 +1,10 @@ -import os -import yaml -import logging -from typing import List, Union - -from medperf.entities.interface import Entity, Uploadable -from medperf.entities.schemas import MedperfSchema, ApprovableSchema +from medperf.entities.interface import Entity +from medperf.entities.schemas import ApprovableSchema import medperf.config as config -from medperf.exceptions import CommunicationRetrievalError, InvalidArgumentError from medperf.account_management import get_medperf_user_data -class Result(Entity, Uploadable, MedperfSchema, ApprovableSchema): +class Result(Entity, ApprovableSchema): """ Class representing a Result entry @@ -28,59 +22,36 @@ class Result(Entity, Uploadable, MedperfSchema, ApprovableSchema): metadata: dict = {} user_metadata: dict = {} - def __init__(self, *args, **kwargs): - """Creates a new result instance""" - super().__init__(*args, **kwargs) - - self.generated_uid = f"b{self.benchmark}m{self.model}d{self.dataset}" - path = config.results_folder - if self.id: - path = os.path.join(path, str(self.id)) - else: - path = os.path.join(path, self.generated_uid) - - self.path = path - - @classmethod - def all(cls, local_only: bool = False, filters: dict = {}) -> List["Result"]: - """Gets and creates instances of all the user's results - - Args: - local_only (bool, optional): Wether to retrieve only local entities. Defaults to False. - filters (dict, optional): key-value pairs specifying filters to apply to the list of entities. - - Returns: - List[Result]: List containing all results - """ - logging.info("Retrieving all results") - results = [] - if not local_only: - results = cls.__remote_all(filters=filters) - - remote_uids = set([result.id for result in results]) + @staticmethod + def get_type(): + return "result" - local_results = cls.__local_all() + @staticmethod + def get_storage_path(): + return config.results_folder - results += [res for res in local_results if res.id not in remote_uids] + @staticmethod + def get_comms_retriever(): + return config.comms.get_result - return results + @staticmethod + def get_metadata_filename(): + return config.results_info_file - @classmethod - def __remote_all(cls, filters: dict) -> List["Result"]: - results = [] + @staticmethod + def get_comms_uploader(): + return config.comms.upload_result - try: - comms_fn = cls.__remote_prefilter(filters) - results_meta = comms_fn() - results = [cls(**meta) for meta in results_meta] - except CommunicationRetrievalError: - msg = "Couldn't retrieve all results from the server" - logging.warning(msg) + def __init__(self, *args, **kwargs): + """Creates a new result instance""" + super().__init__(*args, **kwargs) - return results + @property + def local_id(self): + return self.name - @classmethod - def __remote_prefilter(cls, filters: dict) -> callable: + @staticmethod + def remote_prefilter(filters: dict) -> callable: """Applies filtering logic that must be done before retrieving remote entities Args: @@ -104,113 +75,6 @@ def get_benchmark_results(): return comms_fn - @classmethod - def __local_all(cls) -> List["Result"]: - results = [] - results_folder = config.results_folder - try: - uids = next(os.walk(results_folder))[1] - except StopIteration: - msg = "Couldn't iterate over the dataset directory" - logging.warning(msg) - raise RuntimeError(msg) - - for uid in uids: - local_meta = cls.__get_local_dict(uid) - result = cls(**local_meta) - results.append(result) - - return results - - @classmethod - def get(cls, result_uid: Union[str, int], local_only: bool = False) -> "Result": - """Retrieves and creates a Result instance obtained from the platform. - If the result instance already exists in the user's machine, it loads - the local instance - - Args: - result_uid (str): UID of the Result instance - - Returns: - Result: Specified Result instance - """ - if not str(result_uid).isdigit() or local_only: - return cls.__local_get(result_uid) - - try: - return cls.__remote_get(result_uid) - except CommunicationRetrievalError: - logging.warning(f"Getting Result {result_uid} from comms failed") - logging.info(f"Looking for result {result_uid} locally") - return cls.__local_get(result_uid) - - @classmethod - def __remote_get(cls, result_uid: int) -> "Result": - """Retrieves and creates a Dataset instance from the comms instance. - If the dataset is present in the user's machine then it retrieves it from there. - - Args: - result_uid (str): server UID of the dataset - - Returns: - Dataset: Specified Dataset Instance - """ - logging.debug(f"Retrieving result {result_uid} remotely") - meta = config.comms.get_result(result_uid) - result = cls(**meta) - result.write() - return result - - @classmethod - def __local_get(cls, result_uid: Union[str, int]) -> "Result": - """Retrieves and creates a Dataset instance from the comms instance. - If the dataset is present in the user's machine then it retrieves it from there. - - Args: - result_uid (str): server UID of the dataset - - Returns: - Dataset: Specified Dataset Instance - """ - logging.debug(f"Retrieving result {result_uid} locally") - local_meta = cls.__get_local_dict(result_uid) - result = cls(**local_meta) - return result - - def todict(self): - return self.extended_dict() - - def upload(self): - """Uploads the results to the comms - - Args: - comms (Comms): Instance of the communications interface. - """ - if self.for_test: - raise InvalidArgumentError("Cannot upload test results.") - results_info = self.todict() - updated_results_info = config.comms.upload_result(results_info) - return updated_results_info - - def write(self): - result_file = os.path.join(self.path, config.results_info_file) - os.makedirs(self.path, exist_ok=True) - with open(result_file, "w") as f: - yaml.dump(self.todict(), f) - return result_file - - @classmethod - def __get_local_dict(cls, local_uid): - result_path = os.path.join(config.results_folder, str(local_uid)) - result_file = os.path.join(result_path, config.results_info_file) - if not os.path.exists(result_file): - raise InvalidArgumentError( - f"The requested result {local_uid} could not be retrieved" - ) - with open(result_file, "r") as f: - results_info = yaml.safe_load(f) - return results_info - def display_dict(self): return { "UID": self.identifier, diff --git a/cli/medperf/entities/schemas.py b/cli/medperf/entities/schemas.py index 0e7a54291..79926abd9 100644 --- a/cli/medperf/entities/schemas.py +++ b/cli/medperf/entities/schemas.py @@ -8,7 +8,15 @@ from medperf.utils import format_errors_dict -class MedperfBaseSchema(BaseModel): +class MedperfSchema(BaseModel): + for_test: bool = False + id: Optional[int] + name: str = Field(..., max_length=64) + owner: Optional[int] + is_valid: bool = True + created_at: Optional[datetime] + modified_at: Optional[datetime] + def __init__(self, *args, **kwargs): """Override the ValidationError procedure so we can format the error message in our desired way @@ -46,7 +54,7 @@ def dict(self, *args, **kwargs) -> dict: out_dict = {k: v for k, v in model_dict.items() if k in valid_fields} return out_dict - def extended_dict(self) -> dict: + def todict(self) -> dict: """Dictionary containing both original and alias fields Returns: @@ -68,27 +76,17 @@ def empty_str_to_none(cls, v): return None return v - class Config: - allow_population_by_field_name = True - extra = "allow" - use_enum_values = True - - -class MedperfSchema(MedperfBaseSchema): - for_test: bool = False - id: Optional[int] - name: str = Field(..., max_length=64) - owner: Optional[int] - is_valid: bool = True - created_at: Optional[datetime] - modified_at: Optional[datetime] - @validator("name", pre=True, always=True) def name_max_length(cls, v, *, values, **kwargs): if not values["for_test"] and len(v) > 20: raise ValueError("The name must have no more than 20 characters") return v + class Config: + allow_population_by_field_name = True + extra = "allow" + use_enum_values = True + class DeployableSchema(BaseModel): state: str = "DEVELOPMENT" diff --git a/cli/medperf/tests/commands/benchmark/test_submit.py b/cli/medperf/tests/commands/benchmark/test_submit.py index b00e1c5a8..7e2d5b23b 100644 --- a/cli/medperf/tests/commands/benchmark/test_submit.py +++ b/cli/medperf/tests/commands/benchmark/test_submit.py @@ -94,7 +94,7 @@ def test_run_compatibility_test_uses_expected_default_parameters(mocker, comms, # Assert comp_spy.assert_called_once_with( - benchmark=bmk.generated_uid, no_cache=True, skip_data_preparation_step=False + benchmark=bmk.local_id, no_cache=True, skip_data_preparation_step=False ) @@ -117,7 +117,7 @@ def test_run_compatibility_test_with_passed_parameters(mocker, force, skip, comm # Assert comp_spy.assert_called_once_with( - benchmark=bmk.generated_uid, no_cache=force, skip_data_preparation_step=skip + benchmark=bmk.local_id, no_cache=force, skip_data_preparation_step=skip ) diff --git a/cli/medperf/tests/commands/mlcube/test_submit.py b/cli/medperf/tests/commands/mlcube/test_submit.py index 630390205..a946c1fef 100644 --- a/cli/medperf/tests/commands/mlcube/test_submit.py +++ b/cli/medperf/tests/commands/mlcube/test_submit.py @@ -57,7 +57,7 @@ def test_to_permanent_path_renames_correctly(mocker, comms, ui, cube, uid): submission.cube = cube spy = mocker.patch("os.rename") mocker.patch("os.path.exists", return_value=False) - old_path = os.path.join(config.cubes_folder, cube.generated_uid) + old_path = os.path.join(config.cubes_folder, cube.local_id) new_path = os.path.join(config.cubes_folder, str(uid)) # Act submission.to_permanent_path({**cube.todict(), "id": uid}) diff --git a/cli/medperf/tests/commands/result/test_create.py b/cli/medperf/tests/commands/result/test_create.py index 74299c77e..c69544781 100644 --- a/cli/medperf/tests/commands/result/test_create.py +++ b/cli/medperf/tests/commands/result/test_create.py @@ -57,6 +57,9 @@ def mock_result_all(mocker, state_variables): TestResult(benchmark=triplet[0], model=triplet[1], dataset=triplet[2]) for triplet in cached_results_triplets ] + mocker.patch( + PATCH_EXECUTION.format("get_medperf_user_data", return_value={"id": 1}) + ) mocker.patch(PATCH_EXECUTION.format("Result.all"), return_value=results) diff --git a/cli/medperf/tests/commands/result/test_submit.py b/cli/medperf/tests/commands/result/test_submit.py index 10680fbe1..26b03fbcc 100644 --- a/cli/medperf/tests/commands/result/test_submit.py +++ b/cli/medperf/tests/commands/result/test_submit.py @@ -25,6 +25,7 @@ def submission(mocker, comms, ui, result, dataset): sub = ResultSubmission(1) mocker.patch(PATCH_SUBMISSION.format("Result"), return_value=result) mocker.patch(PATCH_SUBMISSION.format("Result.get"), return_value=result) + sub.get_result() return sub diff --git a/cli/medperf/tests/commands/test_execution.py b/cli/medperf/tests/commands/test_execution.py index 669d7dfd9..d50ca5d31 100644 --- a/cli/medperf/tests/commands/test_execution.py +++ b/cli/medperf/tests/commands/test_execution.py @@ -102,8 +102,8 @@ def test_failure_with_existing_predictions(mocker, setup, ignore_model_errors, f # Arrange preds_path = os.path.join( config.predictions_folder, - INPUT_MODEL.generated_uid, - INPUT_DATASET.generated_uid, + INPUT_MODEL.local_id, + INPUT_DATASET.local_id, ) fs.create_dir(preds_path) @@ -149,22 +149,22 @@ def test_cube_run_are_called_properly(mocker, setup): # Arrange exp_preds_path = os.path.join( config.predictions_folder, - INPUT_MODEL.generated_uid, - INPUT_DATASET.generated_uid, + INPUT_MODEL.local_id, + INPUT_DATASET.local_id, ) exp_model_logs_path = os.path.join( config.experiments_logs_folder, - INPUT_MODEL.generated_uid, - INPUT_DATASET.generated_uid, + INPUT_MODEL.local_id, + INPUT_DATASET.local_id, "model.log", ) exp_metrics_logs_path = os.path.join( config.experiments_logs_folder, - INPUT_MODEL.generated_uid, - INPUT_DATASET.generated_uid, - f"metrics_{INPUT_EVALUATOR.generated_uid}.log", + INPUT_MODEL.local_id, + INPUT_DATASET.local_id, + f"metrics_{INPUT_EVALUATOR.local_id}.log", ) exp_model_call = call( diff --git a/cli/medperf/tests/commands/test_list.py b/cli/medperf/tests/commands/test_list.py index 1c2dc3267..ce7035960 100644 --- a/cli/medperf/tests/commands/test_list.py +++ b/cli/medperf/tests/commands/test_list.py @@ -47,18 +47,18 @@ def set_common_attributes(self, setup): self.state_variables = state_variables self.spies = spies - @pytest.mark.parametrize("local_only", [False, True]) + @pytest.mark.parametrize("unregistered", [False, True]) @pytest.mark.parametrize("mine_only", [False, True]) - def test_entity_all_is_called_properly(self, mocker, local_only, mine_only): + def test_entity_all_is_called_properly(self, mocker, unregistered, mine_only): # Arrange filters = {"owner": 1} if mine_only else {} # Act - EntityList.run(Entity, [], local_only, mine_only) + EntityList.run(Entity, [], unregistered, mine_only) # Assert self.spies["all"].assert_called_once_with( - local_only=local_only, filters=filters + unregistered=unregistered, filters=filters ) @pytest.mark.parametrize("fields", [["UID", "MLCube"]]) diff --git a/cli/medperf/tests/commands/test_view.py b/cli/medperf/tests/commands/test_view.py index a2dddfeda..0ffe0fb13 100644 --- a/cli/medperf/tests/commands/test_view.py +++ b/cli/medperf/tests/commands/test_view.py @@ -1,143 +1,86 @@ import pytest -import yaml -import json from medperf.entities.interface import Entity -from medperf.exceptions import InvalidArgumentError from medperf.commands.view import EntityView - -def expected_output(entities, format): - if isinstance(entities, list): - data = [entity.todict() for entity in entities] - else: - data = entities.todict() - - if format == "yaml": - return yaml.dump(data) - if format == "json": - return json.dumps(data) - - -def generate_entity(id, mocker): - entity = mocker.create_autospec(spec=Entity) - mocker.patch.object(entity, "todict", return_value={"id": id}) - return entity +PATCH_VIEW = "medperf.commands.view.{}" @pytest.fixture -def ui_spy(mocker, ui): - return mocker.patch.object(ui, "print") +def entity(mocker): + return mocker.create_autospec(Entity) -@pytest.fixture( - params=[{"local": ["1", "2", "3"], "remote": ["4", "5", "6"], "user": ["4"]}] -) -def setup(request, mocker): - local_ids = request.param.get("local", []) - remote_ids = request.param.get("remote", []) - user_ids = request.param.get("user", []) - all_ids = list(set(local_ids + remote_ids + user_ids)) - - local_entities = [generate_entity(id, mocker) for id in local_ids] - remote_entities = [generate_entity(id, mocker) for id in remote_ids] - user_entities = [generate_entity(id, mocker) for id in user_ids] - all_entities = list(set(local_entities + remote_entities + user_entities)) - - def mock_all(filters={}, local_only=False): - if "owner" in filters: - return user_entities - if local_only: - return local_entities - return all_entities - - def mock_get(entity_id): - if entity_id in all_ids: - return generate_entity(entity_id, mocker) - else: - raise InvalidArgumentError - - mocker.patch("medperf.commands.view.get_medperf_user_data", return_value={"id": 1}) - mocker.patch.object(Entity, "all", side_effect=mock_all) - mocker.patch.object(Entity, "get", side_effect=mock_get) - - return local_entities, remote_entities, user_entities, all_entities - - -class TestViewEntityID: - def test_view_displays_entity_if_given(self, mocker, setup, ui_spy): - # Arrange - entity_id = "1" - entity = generate_entity(entity_id, mocker) - output = expected_output(entity, "yaml") - - # Act - EntityView.run(entity_id, Entity) - - # Assert - ui_spy.assert_called_once_with(output) - - def test_view_displays_all_if_no_id(self, setup, ui_spy): - # Arrange - *_, entities = setup - output = expected_output(entities, "yaml") - - # Act - EntityView.run(None, Entity) - - # Assert - ui_spy.assert_called_once_with(output) - - -class TestViewFilteredEntities: - def test_view_displays_local_entities(self, setup, ui_spy): - # Arrange - entities, *_ = setup - output = expected_output(entities, "yaml") - - # Act - EntityView.run(None, Entity, local_only=True) - - # Assert - ui_spy.assert_called_once_with(output) - - def test_view_displays_user_entities(self, setup, ui_spy): - # Arrange - *_, entities, _ = setup - output = expected_output(entities, "yaml") - - # Act - EntityView.run(None, Entity, mine_only=True) - - # Assert - ui_spy.assert_called_once_with(output) - - -@pytest.mark.parametrize("entity_id", ["4", None]) -@pytest.mark.parametrize("format", ["yaml", "json"]) -class TestViewOutput: - @pytest.fixture - def output(self, setup, mocker, entity_id, format): - if entity_id is None: - *_, entities = setup - return expected_output(entities, format) - else: - entity = generate_entity(entity_id, mocker) - return expected_output(entity, format) - - def test_view_displays_specified_format(self, entity_id, output, ui_spy, format): - # Act - EntityView.run(entity_id, Entity, format=format) - - # Assert - ui_spy.assert_called_once_with(output) - - def test_view_stores_specified_format(self, entity_id, output, format, fs): - # Arrange - filename = "file.txt" - - # Act - EntityView.run(entity_id, Entity, format=format, output=filename) - - # Assert - contents = open(filename, "r").read() - assert contents == output +@pytest.fixture +def entity_view(mocker): + view_class = EntityView(None, Entity, "", "", "", "") + return view_class + + +def test_prepare_with_id_given(mocker, entity_view, entity): + # Arrange + entity_view.entity_id = 1 + get_spy = mocker.patch(PATCH_VIEW.format("Entity.get"), return_value=entity) + all_spy = mocker.patch(PATCH_VIEW.format("Entity.all"), return_value=[entity]) + + # Act + entity_view.prepare() + + # Assert + get_spy.assert_called_once_with(1) + all_spy.assert_not_called() + assert not isinstance(entity_view.data, list) + + +def test_prepare_with_no_id_given(mocker, entity_view, entity): + # Arrange + entity_view.entity_id = None + entity_view.mine_only = False + get_spy = mocker.patch(PATCH_VIEW.format("Entity.get"), return_value=entity) + all_spy = mocker.patch(PATCH_VIEW.format("Entity.all"), return_value=[entity]) + + # Act + entity_view.prepare() + + # Assert + all_spy.assert_called_once() + get_spy.assert_not_called() + assert isinstance(entity_view.data, list) + + +@pytest.mark.parametrize("unregistered", [False, True]) +def test_prepare_with_no_id_calls_all_with_unregistered_properly( + mocker, entity_view, entity, unregistered +): + # Arrange + entity_view.entity_id = None + entity_view.mine_only = False + entity_view.unregistered = unregistered + all_spy = mocker.patch(PATCH_VIEW.format("Entity.all"), return_value=[entity]) + + # Act + entity_view.prepare() + + # Assert + all_spy.assert_called_once_with(unregistered=unregistered, filters={}) + + +@pytest.mark.parametrize("filters", [{}, {"f1": "v1"}]) +@pytest.mark.parametrize("mine_only", [False, True]) +def test_prepare_with_no_id_calls_all_with_proper_filters( + mocker, entity_view, entity, filters, mine_only +): + # Arrange + entity_view.entity_id = None + entity_view.mine_only = False + entity_view.unregistered = False + entity_view.filters = filters + all_spy = mocker.patch(PATCH_VIEW.format("Entity.all"), return_value=[entity]) + mocker.patch(PATCH_VIEW.format("get_medperf_user_data"), return_value={"id": 1}) + if mine_only: + filters["owner"] = 1 + + # Act + entity_view.prepare() + + # Assert + all_spy.assert_called_once_with(unregistered=False, filters=filters) diff --git a/cli/medperf/tests/entities/test_benchmark.py b/cli/medperf/tests/entities/test_benchmark.py index 3f1fde2e2..c36771e12 100644 --- a/cli/medperf/tests/entities/test_benchmark.py +++ b/cli/medperf/tests/entities/test_benchmark.py @@ -9,8 +9,9 @@ @pytest.fixture( params={ - "local": [1, 2, 3], - "remote": [4, 5, 6], + "unregistered": ["b1", "b2"], + "local": ["b1", "b2", 1, 2, 3], + "remote": [1, 2, 3, 4, 5, 6], "user": [4], "models": [10, 11], } diff --git a/cli/medperf/tests/entities/test_cube.py b/cli/medperf/tests/entities/test_cube.py index 96f81dba0..89e7cc5a9 100644 --- a/cli/medperf/tests/entities/test_cube.py +++ b/cli/medperf/tests/entities/test_cube.py @@ -24,7 +24,14 @@ } -@pytest.fixture(params={"local": [1, 2, 3], "remote": [4, 5, 6], "user": [4]}) +@pytest.fixture( + params={ + "unregistered": ["c1", "c2"], + "local": ["c1", "c2", 1, 2, 3], + "remote": [1, 2, 3, 4, 5, 6], + "user": [4], + } +) def setup(request, mocker, comms, fs): local_ents = request.param.get("local", []) remote_ents = request.param.get("remote", []) @@ -282,7 +289,9 @@ def test_run_stops_execution_if_child_fails(self, mocker, setup, task): cube.run(task) -@pytest.mark.parametrize("setup", [{"local": [DEFAULT_CUBE]}], indirect=True) +@pytest.mark.parametrize( + "setup", [{"local": [DEFAULT_CUBE], "remote": [DEFAULT_CUBE]}], indirect=True +) @pytest.mark.parametrize("task", ["task"]) @pytest.mark.parametrize( "out_key,out_value", diff --git a/cli/medperf/tests/entities/test_entity.py b/cli/medperf/tests/entities/test_entity.py index c636b2c26..5f2d24b3a 100644 --- a/cli/medperf/tests/entities/test_entity.py +++ b/cli/medperf/tests/entities/test_entity.py @@ -15,7 +15,7 @@ setup_result_fs, setup_result_comms, ) -from medperf.exceptions import InvalidArgumentError +from medperf.exceptions import CommunicationRetrievalError, InvalidArgumentError @pytest.fixture(params=[Benchmark, Cube, Dataset, Result]) @@ -23,7 +23,14 @@ def Implementation(request): return request.param -@pytest.fixture(params={"local": [1, 2, 3], "remote": [4, 5, 6], "user": [4]}) +@pytest.fixture( + params={ + "unregistered": ["e1", "e2"], + "local": ["e1", "e2", 1, 2, 3], + "remote": [1, 2, 3, 4, 5, 6], + "user": [4], + } +) def setup(request, mocker, comms, Implementation, fs): local_ids = request.param.get("local", []) remote_ids = request.param.get("remote", []) @@ -54,39 +61,52 @@ def setup(request, mocker, comms, Implementation, fs): @pytest.mark.parametrize( "setup", - [{"local": [283, 17, 493], "remote": [283, 1, 2], "user": [2]}], + [ + { + "unregistered": ["e1", "e2"], + "local": ["e1", "e2", 283], + "remote": [283, 1, 2], + "user": [2], + } + ], indirect=True, ) class TestAll: @pytest.fixture(autouse=True) def set_common_attributes(self, setup): self.ids = setup + self.unregistered_ids = set(self.ids["unregistered"]) self.local_ids = set(self.ids["local"]) self.remote_ids = set(self.ids["remote"]) self.user_ids = set(self.ids["user"]) - def test_all_returns_all_remote_and_local(self, Implementation): - # Arrange - all_ids = self.local_ids.union(self.remote_ids) - + def test_all_returns_all_remote_by_default(self, Implementation): # Act entities = Implementation.all() # Assert retrieved_ids = set([e.todict()["id"] for e in entities]) - assert all_ids == retrieved_ids + assert self.remote_ids == retrieved_ids - def test_all_local_only_returns_all_local(self, Implementation): + def test_all_unregistered_returns_all_unregistered(self, Implementation): # Act - entities = Implementation.all(local_only=True) + entities = Implementation.all(unregistered=True) # Assert - retrieved_ids = set([e.todict()["id"] for e in entities]) - assert self.local_ids == retrieved_ids + retrieved_ids = set([e.local_id for e in entities]) + assert self.unregistered_ids == retrieved_ids @pytest.mark.parametrize( - "setup", [{"local": [78], "remote": [479, 42, 7, 1]}], indirect=True + "setup", + [ + { + "unregistered": ["e1", "e2"], + "local": ["e1", "e2", 479], + "remote": [479, 42, 7, 1], + } + ], + indirect=True, ) class TestGet: def test_get_retrieves_entity_from_server(self, Implementation, setup): @@ -99,30 +119,20 @@ def test_get_retrieves_entity_from_server(self, Implementation, setup): # Assert assert entity.todict()["id"] == id - def test_get_retrieves_entity_local_if_not_on_server(self, Implementation, setup): - # Arrange - id = setup["local"][0] - - # Act - entity = Implementation.get(id) - - # Assert - assert entity.todict()["id"] == id - def test_get_raises_error_if_nonexistent(self, Implementation, setup): # Arrange id = str(19283) # Act & Assert - with pytest.raises(InvalidArgumentError): + with pytest.raises(CommunicationRetrievalError): Implementation.get(id) -@pytest.mark.parametrize("setup", [{"local": [742]}], indirect=True) +@pytest.mark.parametrize("setup", [{"remote": [742]}], indirect=True) class TestToDict: @pytest.fixture(autouse=True) def set_common_attributes(self, setup): - self.id = setup["local"][0] + self.id = setup["remote"][0] def test_todict_returns_dict_representation(self, Implementation): # Arrange @@ -147,7 +157,16 @@ def test_todict_can_recreate_object(self, Implementation): assert ent_dict == ent_copy_dict -@pytest.mark.parametrize("setup", [{"local": [36]}], indirect=True) +@pytest.mark.parametrize( + "setup", + [ + { + "unregistered": ["e1", "e2"], + "local": ["e1", "e2"], + } + ], + indirect=True, +) class TestUpload: @pytest.fixture(autouse=True) def set_common_attributes(self, setup): diff --git a/cli/medperf/tests/entities/utils.py b/cli/medperf/tests/entities/utils.py index 522251ca7..c3bde6feb 100644 --- a/cli/medperf/tests/entities/utils.py +++ b/cli/medperf/tests/entities/utils.py @@ -15,14 +15,16 @@ # Setup Benchmark def setup_benchmark_fs(ents, fs): - bmks_path = config.benchmarks_folder for ent in ents: - if not isinstance(ent, dict): - # Assume we're passing ids - ent = {"id": str(ent)} - id = ent["id"] - bmk_filepath = os.path.join(bmks_path, str(id), config.benchmarks_filename) - bmk_contents = TestBenchmark(**ent) + # Assume we're passing ids, local_ids, or dicts + if isinstance(ent, dict): + bmk_contents = TestBenchmark(**ent) + elif isinstance(ent, int) or isinstance(ent, str) and ent.isdigit(): + bmk_contents = TestBenchmark(id=str(ent)) + else: + bmk_contents = TestBenchmark(id=None, name=ent) + + bmk_filepath = os.path.join(bmk_contents.path, config.benchmarks_filename) cubes_ids = [] cubes_ids.append(bmk_contents.data_preparation_mlcube) cubes_ids.append(bmk_contents.reference_model_mlcube) @@ -30,7 +32,7 @@ def setup_benchmark_fs(ents, fs): cubes_ids = list(set(cubes_ids)) setup_cube_fs(cubes_ids, fs) try: - fs.create_file(bmk_filepath, contents=yaml.dump(bmk_contents.dict())) + fs.create_file(bmk_filepath, contents=yaml.dump(bmk_contents.todict())) except FileExistsError: pass @@ -51,17 +53,17 @@ def setup_benchmark_comms(mocker, comms, all_ents, user_ents, uploaded): # Setup Cube def setup_cube_fs(ents, fs): - cubes_path = config.cubes_folder for ent in ents: - if not isinstance(ent, dict): - # Assume we're passing ids - ent = {"id": str(ent)} - id = ent["id"] - meta_cube_file = os.path.join( - cubes_path, str(id), config.cube_metadata_filename - ) - cube = TestCube(**ent) - meta = cube.dict() + # Assume we're passing ids, names, or dicts + if isinstance(ent, dict): + cube = TestCube(**ent) + elif isinstance(ent, int) or isinstance(ent, str) and ent.isdigit(): + cube = TestCube(id=str(ent)) + else: + cube = TestCube(id=None, name=ent) + + meta_cube_file = os.path.join(cube.path, config.cube_metadata_filename) + meta = cube.todict() try: fs.create_file(meta_cube_file, contents=yaml.dump(meta)) except FileExistsError: @@ -124,18 +126,20 @@ def setup_cube_comms_downloads(mocker, fs): # Setup Dataset def setup_dset_fs(ents, fs): - dsets_path = config.datasets_folder for ent in ents: - if not isinstance(ent, dict): - # Assume passing ids - ent = {"id": str(ent)} - id = ent["id"] - reg_dset_file = os.path.join(dsets_path, str(id), config.reg_file) - dset_contents = TestDataset(**ent) + # Assume we're passing ids, generated_uids, or dicts + if isinstance(ent, dict): + dset_contents = TestDataset(**ent) + elif isinstance(ent, int) or isinstance(ent, str) and ent.isdigit(): + dset_contents = TestDataset(id=str(ent)) + else: + dset_contents = TestDataset(id=None, generated_uid=ent) + + reg_dset_file = os.path.join(dset_contents.path, config.reg_file) cube_id = dset_contents.data_preparation_mlcube setup_cube_fs([cube_id], fs) try: - fs.create_file(reg_dset_file, contents=yaml.dump(dset_contents.dict())) + fs.create_file(reg_dset_file, contents=yaml.dump(dset_contents.todict())) except FileExistsError: pass @@ -155,22 +159,25 @@ def setup_dset_comms(mocker, comms, all_ents, user_ents, uploaded): # Setup Result def setup_result_fs(ents, fs): - results_path = config.results_folder for ent in ents: - if not isinstance(ent, dict): - # Assume passing ids - ent = {"id": str(ent)} - id = ent["id"] - result_file = os.path.join(results_path, str(id), config.results_info_file) - bmk_id = ent.get("benchmark", 1) - cube_id = ent.get("model", 1) - dataset_id = ent.get("dataset", 1) + # Assume we're passing ids, names, or dicts + if isinstance(ent, dict): + result_contents = TestResult(**ent) + elif isinstance(ent, int) or isinstance(ent, str) and ent.isdigit(): + result_contents = TestResult(id=str(ent)) + else: + result_contents = TestResult(id=None, name=ent) + + result_file = os.path.join(result_contents.path, config.results_info_file) + bmk_id = result_contents.benchmark + cube_id = result_contents.model + dataset_id = result_contents.dataset setup_benchmark_fs([bmk_id], fs) setup_cube_fs([cube_id], fs) setup_dset_fs([dataset_id], fs) - result_contents = TestResult(**ent) + try: - fs.create_file(result_file, contents=yaml.dump(result_contents.dict())) + fs.create_file(result_file, contents=yaml.dump(result_contents.todict())) except FileExistsError: pass From 37073cf072ba44d415260ac6df2346a8acde1dde Mon Sep 17 00:00:00 2001 From: hasan7n Date: Thu, 15 Aug 2024 02:37:07 +0200 Subject: [PATCH 3/4] update versions of chrome and selenium for tests --- .github/workflows/auth-ci.yml | 2 +- cli/test-requirements.txt | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/auth-ci.yml b/.github/workflows/auth-ci.yml index b316397f7..90c55a200 100644 --- a/.github/workflows/auth-ci.yml +++ b/.github/workflows/auth-ci.yml @@ -29,7 +29,7 @@ jobs: - name: Setup Chrome run: | sudo apt-get install -y wget - wget -O chrome.deb https://dl.google.com/linux/direct/google-chrome-stable_current_amd64.deb + wget -O chrome.deb https://dl.google.com/linux/chrome/deb/pool/main/g/google-chrome-stable/google-chrome-stable_126.0.6478.126-1_amd64.deb sudo dpkg -i chrome.deb rm chrome.deb diff --git a/cli/test-requirements.txt b/cli/test-requirements.txt index 469901472..2f0712691 100644 --- a/cli/test-requirements.txt +++ b/cli/test-requirements.txt @@ -1,2 +1,2 @@ -selenium==4.10.0 -webdriver-manager==3.8.6 +selenium==4.23.1 +webdriver-manager==4.0.2 From 61863163f87d4ac2a13b2955e2f8776266da5617 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alejandro=20Aristiz=C3=A1bal?= Date: Thu, 15 Aug 2024 06:53:47 -0500 Subject: [PATCH 4/4] Check for refresh token absolute expiration time (#604) * Handle expired tokens * Implement tests for token expiration * Use absolute expiration time instead of expiration time * Fix style issues * refresh token expiration: bugfix and migration * fix local auth bug * don't logout if access token not expired --------- Co-authored-by: hasan7n Co-authored-by: hasan7n <78664424+hasan7n@users.noreply.github.com> --- .../account_management/account_management.py | 15 +++++++- cli/medperf/comms/auth/auth0.py | 37 ++++++++++++++++--- cli/medperf/comms/auth/local.py | 1 + cli/medperf/config.py | 2 + cli/medperf/exceptions.py | 4 ++ cli/medperf/storage/__init__.py | 14 +++++++ cli/medperf/tests/comms/test_auth0.py | 35 +++++++++++++++++- 7 files changed, 99 insertions(+), 9 deletions(-) diff --git a/cli/medperf/account_management/account_management.py b/cli/medperf/account_management/account_management.py index d19cf8cad..f8ea4e88e 100644 --- a/cli/medperf/account_management/account_management.py +++ b/cli/medperf/account_management/account_management.py @@ -19,16 +19,29 @@ def set_credentials( id_token_payload, token_issued_at, token_expires_in, + login_event=False, ): email = id_token_payload["email"] TokenStore().set_tokens(email, access_token, refresh_token) + config_p = read_config() + + if login_event: + # Set the time the user logged in, so that we can track the lifetime of + # the refresh token + logged_in_at = token_issued_at + else: + # This means this is a refresh event. Preserve the logged_in_at timestamp. + logged_in_at = config_p.active_profile[config.credentials_keyword][ + "logged_in_at" + ] account_info = { "email": email, "token_issued_at": token_issued_at, "token_expires_in": token_expires_in, + "logged_in_at": logged_in_at, } - config_p = read_config() + config_p.active_profile[config.credentials_keyword] = account_info write_config(config_p) diff --git a/cli/medperf/comms/auth/auth0.py b/cli/medperf/comms/auth/auth0.py index 3da6ecaac..60a052b72 100644 --- a/cli/medperf/comms/auth/auth0.py +++ b/cli/medperf/comms/auth/auth0.py @@ -4,7 +4,7 @@ import sqlite3 from medperf.comms.auth.interface import Auth from medperf.comms.auth.token_verifier import verify_token -from medperf.exceptions import CommunicationError +from medperf.exceptions import CommunicationError, AuthenticationError import requests import medperf.config as config from medperf.utils import log_response_error @@ -66,6 +66,7 @@ def login(self, email): id_token_payload, token_issued_at, token_expires_in, + login_event=True, ) def __request_device_code(self): @@ -191,11 +192,35 @@ def _access_token(self): refresh_token = creds["refresh_token"] token_expires_in = creds["token_expires_in"] token_issued_at = creds["token_issued_at"] - if ( - time.time() - > token_issued_at + token_expires_in - config.token_expiration_leeway - ): - access_token = self.__refresh_access_token(refresh_token) + logged_in_at = creds["logged_in_at"] + + # token_issued_at and expires_in are for the access token + sliding_expiration_time = ( + token_issued_at + token_expires_in - config.token_expiration_leeway + ) + absolute_expiration_time = ( + logged_in_at + + config.token_absolute_expiry + - config.refresh_token_expiration_leeway + ) + current_time = time.time() + + if current_time < sliding_expiration_time: + # Access token not expired. No need to refresh. + return access_token + + # So we need to refresh. + if current_time > absolute_expiration_time: + # Expired refresh token. Force logout and ask the user to re-authenticate + logging.debug( + f"Refresh token expired: {absolute_expiration_time=} <> {current_time=}" + ) + self.logout() + raise AuthenticationError("Token expired. Please re-authenticate") + + # Expired access token and not expired refresh token. Refresh. + access_token = self.__refresh_access_token(refresh_token) + return access_token def __refresh_access_token(self, refresh_token): diff --git a/cli/medperf/comms/auth/local.py b/cli/medperf/comms/auth/local.py index 1c12ba1cc..a597d4ce3 100644 --- a/cli/medperf/comms/auth/local.py +++ b/cli/medperf/comms/auth/local.py @@ -39,6 +39,7 @@ def login(self, email): id_token_payload, token_issued_at, token_expires_in, + login_event=True, ) def logout(self): diff --git a/cli/medperf/config.py b/cli/medperf/config.py index f43af8f4b..2c5b520be 100644 --- a/cli/medperf/config.py +++ b/cli/medperf/config.py @@ -37,6 +37,8 @@ auth_jwks_cache_ttl = 600 # fetch jwks every 10 mins. Default value in auth0 python SDK token_expiration_leeway = 10 # Refresh tokens 10 seconds before expiration +refresh_token_expiration_leeway = 10 # Logout users 10 seconds before absolute token expiration. +token_absolute_expiry = 2592000 # Refresh token absolute expiration time (seconds). This value is set on auth0's configuration access_token_storage_id = "medperf_access_token" refresh_token_storage_id = "medperf_refresh_token" diff --git a/cli/medperf/exceptions.py b/cli/medperf/exceptions.py index 6844f28d4..68e858cfe 100644 --- a/cli/medperf/exceptions.py +++ b/cli/medperf/exceptions.py @@ -30,6 +30,10 @@ class ExecutionError(MedperfException): """Raised when an execution component fails""" +class AuthenticationError(MedperfException): + """Raised when authentication can't be processed""" + + class CleanExit(MedperfException): """Raised when Medperf needs to stop for non erroneous reasons""" diff --git a/cli/medperf/storage/__init__.py b/cli/medperf/storage/__init__.py index acebb6f31..74e4e7962 100644 --- a/cli/medperf/storage/__init__.py +++ b/cli/medperf/storage/__init__.py @@ -1,5 +1,6 @@ import os import shutil +import time from medperf import config from medperf.config_management import read_config, write_config @@ -25,6 +26,7 @@ def apply_configuration_migrations(): config_p = read_config() + # Migration for moving the logs folder to a new location if "logs_folder" not in config_p.storage: return @@ -35,4 +37,16 @@ def apply_configuration_migrations(): del config_p.storage["logs_folder"] + # Migration for tracking the login timestamp (i.e., refresh token issuance timestamp) + if config.credentials_keyword in config_p.active_profile: + # So the user is logged in + if "logged_in_at" not in config_p.active_profile[config.credentials_keyword]: + # Apply migration. We will set it to the current time, since this + # will make sure they will not be logged out before the actual refresh + # token expiration (for a better user experience). However, currently logged + # in users will still face a confusing error when the refresh token expires. + config_p.active_profile[config.credentials_keyword][ + "logged_in_at" + ] = time.time() + write_config(config_p) diff --git a/cli/medperf/tests/comms/test_auth0.py b/cli/medperf/tests/comms/test_auth0.py index 313f4f613..9eadd9662 100644 --- a/cli/medperf/tests/comms/test_auth0.py +++ b/cli/medperf/tests/comms/test_auth0.py @@ -2,9 +2,12 @@ from unittest.mock import ANY from medperf.tests.mocks import MockResponse from medperf.comms.auth.auth0 import Auth0 +from medperf import config +from medperf.exceptions import AuthenticationError import sqlite3 import pytest + PATCH_AUTH = "medperf.comms.auth.auth0.{}" @@ -35,6 +38,7 @@ def test_token_is_not_refreshed_if_not_expired(mocker, setup): "access_token": "", "token_expires_in": 900, "token_issued_at": time.time(), + "logged_in_at": time.time(), } mocker.patch(PATCH_AUTH.format("read_credentials"), return_value=creds) spy = mocker.patch(PATCH_AUTH.format("Auth0._Auth0__refresh_access_token")) @@ -48,11 +52,14 @@ def test_token_is_not_refreshed_if_not_expired(mocker, setup): def test_token_is_refreshed_if_expired(mocker, setup): # Arrange + expiration_time = 900 + mocked_issued_at = time.time() - expiration_time creds = { "refresh_token": "", "access_token": "", - "token_expires_in": 900, - "token_issued_at": time.time() - 1000, + "token_expires_in": expiration_time, + "token_issued_at": mocked_issued_at, + "logged_in_at": time.time(), } mocker.patch(PATCH_AUTH.format("read_credentials"), return_value=creds) spy = mocker.patch(PATCH_AUTH.format("Auth0._Auth0__refresh_access_token")) @@ -64,6 +71,30 @@ def test_token_is_refreshed_if_expired(mocker, setup): spy.assert_called_once() +def test_logs_out_if_session_reaches_token_absolute_expiration_time(mocker, setup): + # Arrange + expiration_time = 900 + absolute_expiration_time = config.token_absolute_expiry + mocked_logged_in_at = time.time() - absolute_expiration_time + mocked_issued_at = time.time() - expiration_time + creds = { + "refresh_token": "", + "access_token": "", + "token_expires_in": expiration_time, + "token_issued_at": mocked_issued_at, + "logged_in_at": mocked_logged_in_at, + } + mocker.patch(PATCH_AUTH.format("read_credentials"), return_value=creds) + spy = mocker.patch(PATCH_AUTH.format("Auth0.logout")) + + # Act + with pytest.raises(AuthenticationError): + Auth0().access_token + + # Assert + spy.assert_called_once() + + def test_refresh_token_sets_new_tokens(mocker, setup): # Arrange access_token = "access_token"