From a5ecee2110f40ecb74d84a32dd3ad3350ba95756 Mon Sep 17 00:00:00 2001 From: hasan7n Date: Sat, 25 Nov 2023 07:06:54 +0100 Subject: [PATCH] stop using benchmark.models for permission reasons --- cli/medperf/commands/result/create.py | 20 +++++++--- cli/medperf/entities/benchmark.py | 20 +--------- .../tests/commands/result/test_create.py | 21 +++++++++- cli/medperf/tests/entities/test_benchmark.py | 39 ++++++++----------- cli/medperf/tests/entities/utils.py | 2 +- cli/medperf/tests/mocks/benchmark.py | 3 +- 6 files changed, 54 insertions(+), 51 deletions(-) diff --git a/cli/medperf/commands/result/create.py b/cli/medperf/commands/result/create.py index 3024c220f..576cd3555 100644 --- a/cli/medperf/commands/result/create.py +++ b/cli/medperf/commands/result/create.py @@ -50,7 +50,6 @@ def run( execution.prepare() execution.validate() execution.prepare_models() - execution.validate_models() if not no_cache: execution.load_cached_results() with execution.ui.interactive(): @@ -101,8 +100,19 @@ def validate(self): def prepare_models(self): if self.models_input_file: self.models_uids = self.__get_models_from_file() - elif self.models_uids is None: - self.models_uids = self.benchmark.models + + if self.models_uids == [self.benchmark.reference_model_mlcube]: + # avoid the need of sending a request to the server for + # finding the benchmark's associated models + return + + benchmark_models = Benchmark.get_models_uids(self.benchmark_uid) + benchmark_models.append(self.benchmark.reference_model_mlcube) + + if self.models_uids is None: + self.models_uids = benchmark_models + else: + self.__validate_models(benchmark_models) def __get_models_from_file(self): if not os.path.exists(self.models_input_file): @@ -117,9 +127,9 @@ def __get_models_from_file(self): msg += "The file should contain a list of comma-separated integers" raise InvalidArgumentError(msg) - def validate_models(self): + def __validate_models(self, benchmark_models): models_set = set(self.models_uids) - benchmark_models_set = set(self.benchmark.models) + benchmark_models_set = set(benchmark_models) non_assoc_cubes = models_set.difference(benchmark_models_set) if non_assoc_cubes: if len(non_assoc_cubes) > 1: diff --git a/cli/medperf/entities/benchmark.py b/cli/medperf/entities/benchmark.py index c76d280a1..01d3d48da 100644 --- a/cli/medperf/entities/benchmark.py +++ b/cli/medperf/entities/benchmark.py @@ -3,7 +3,7 @@ import yaml import logging from typing import List, Optional, Union -from pydantic import HttpUrl, Field, validator +from pydantic import HttpUrl, Field import medperf.config as config from medperf.entities.interface import Entity, Uploadable @@ -32,18 +32,10 @@ class Benchmark(Entity, Uploadable, MedperfSchema, ApprovableSchema, DeployableS data_preparation_mlcube: int reference_model_mlcube: int data_evaluator_mlcube: int - models: List[int] = None metadata: dict = {} user_metadata: dict = {} is_active: bool = True - @validator("models", pre=True, always=True) - def set_default_models_value(cls, value, values, **kwargs): - if not value: - # Empty or None value assigned - return [values["reference_model_mlcube"]] - return value - def __init__(self, *args, **kwargs): """Creates a new benchmark instance @@ -91,11 +83,6 @@ def __remote_all(cls, filters: dict) -> List["Benchmark"]: try: comms_fn = cls.__remote_prefilter(filters) bmks_meta = comms_fn() - for bmk_meta in bmks_meta: - # Loading all related models for all benchmarks could be expensive. - # Most probably not necessary when getting all benchmarks. - # If associated models for a benchmark are needed then use Benchmark.get() - bmk_meta["models"] = [bmk_meta["reference_model_mlcube"]] benchmarks = [cls(**meta) for meta in bmks_meta] except CommunicationRetrievalError: msg = "Couldn't retrieve all benchmarks from the server" @@ -175,9 +162,6 @@ def __remote_get(cls, benchmark_uid: int) -> "Benchmark": """ logging.debug(f"Retrieving benchmark {benchmark_uid} remotely") benchmark_dict = config.comms.get_benchmark(benchmark_uid) - ref_model = benchmark_dict["reference_model_mlcube"] - add_models = cls.get_models_uids(benchmark_uid) - benchmark_dict["models"] = [ref_model] + add_models benchmark = cls(**benchmark_dict) benchmark.write() return benchmark @@ -273,7 +257,6 @@ def upload(self): raise InvalidArgumentError("Cannot upload test benchmarks.") body = self.todict() updated_body = config.comms.upload_benchmark(body) - updated_body["models"] = body["models"] return updated_body def display_dict(self): @@ -285,7 +268,6 @@ def display_dict(self): "Created At": self.created_at, "Data Preparation MLCube": int(self.data_preparation_mlcube), "Reference Model MLCube": int(self.reference_model_mlcube), - "Associated Models": ",".join(map(str, self.models)), "Data Evaluator MLCube": int(self.data_evaluator_mlcube), "State": self.state, "Approval Status": self.approval_status, diff --git a/cli/medperf/tests/commands/result/test_create.py b/cli/medperf/tests/commands/result/test_create.py index 8dfe90606..7a11c05bf 100644 --- a/cli/medperf/tests/commands/result/test_create.py +++ b/cli/medperf/tests/commands/result/test_create.py @@ -29,10 +29,13 @@ def __get_side_effect(id): data_evaluator_mlcube=evaluator["uid"], data_preparation_mlcube=benchmark_prep_cube, reference_model_mlcube=benchmark_models[0], - models=benchmark_models, ) mocker.patch(PATCH_EXECUTION.format("Benchmark.get"), side_effect=__get_side_effect) + mocker.patch( + PATCH_EXECUTION.format("Benchmark.get_models_uids"), + return_value=benchmark_models[1:], + ) def mock_dataset(mocker, state_variables): @@ -126,12 +129,16 @@ def setup(request, mocker, ui, fs): ui_error_spy = mocker.patch.object(ui, "print_error") ui_print_spy = mocker.patch.object(ui, "print") tabulate_spy = mocker.spy(create_module, "tabulate") + validate_models_spy = mocker.spy( + create_module.BenchmarkExecution, "_BenchmarkExecution__validate_models" + ) spies = { "ui_error": ui_error_spy, "ui_print": ui_print_spy, "tabulate": tabulate_spy, "exec": exec_spy, + "validate_models": validate_models_spy, } return state_variables, spies @@ -326,3 +333,15 @@ def test_execution_of_one_model_writes_result(self, mocker, setup): yaml.safe_load(open(expected_file))["results"] == self.state_variables["models_props"][model_uid]["results"] ) + + def test_execution_of_reference_model_does_not_call_validate(self, mocker, setup): + # Arrange + model_uid = self.state_variables["benchmark_models"][0] + dset_uid = 2 + bmk_uid = 1 + + # Act + BenchmarkExecution.run(bmk_uid, dset_uid, models_uids=[model_uid]) + + # Assert + self.spies["validate_models"].assert_not_called() diff --git a/cli/medperf/tests/entities/test_benchmark.py b/cli/medperf/tests/entities/test_benchmark.py index 9d6ddfe4d..3f1fde2e2 100644 --- a/cli/medperf/tests/entities/test_benchmark.py +++ b/cli/medperf/tests/entities/test_benchmark.py @@ -32,36 +32,29 @@ def setup(request, mocker, comms, fs): @pytest.mark.parametrize( - "setup", + "setup,expected_models", [ - { - "remote": [721], - "models": [ - {"model_mlcube": 37, "approval_status": "APPROVED"}, - {"model_mlcube": 23, "approval_status": "APPROVED"}, - {"model_mlcube": 495, "approval_status": "APPROVED"}, - ], - } + ( + { + "remote": [721], + "models": [ + {"model_mlcube": 37, "approval_status": "APPROVED"}, + {"model_mlcube": 23, "approval_status": "APPROVED"}, + {"model_mlcube": 495, "approval_status": "PENDING"}, + ], + }, + [37, 23], + ) ], - indirect=True, + indirect=["setup"], ) class TestModels: - def test_benchmark_includes_reference_model_in_models(self, setup): - # Act - id = setup["remote"][0] - benchmark = Benchmark.get(id) - - # Assert - assert benchmark.reference_model_mlcube in benchmark.models - - def test_benchmark_includes_additional_models_in_models(self, setup): + def test_benchmark_get_models_works_as_expected(self, setup, expected_models): # Arrange id = setup["remote"][0] - models = setup["models"] - models = [model["model_mlcube"] for model in models] # Act - benchmark = Benchmark.get(id) + assciated_models = Benchmark.get_models_uids(id) # Assert - assert set(models).issubset(set(benchmark.models)) + assert set(assciated_models) == set(expected_models) diff --git a/cli/medperf/tests/entities/utils.py b/cli/medperf/tests/entities/utils.py index 3a87a87a0..2b141da51 100644 --- a/cli/medperf/tests/entities/utils.py +++ b/cli/medperf/tests/entities/utils.py @@ -24,7 +24,7 @@ def setup_benchmark_fs(ents, fs): id = ent["id"] bmk_filepath = os.path.join(bmks_path, str(id), config.benchmarks_filename) bmk_contents = TestBenchmark(**ent) - cubes_ids = bmk_contents.models + cubes_ids = [] cubes_ids.append(bmk_contents.data_preparation_mlcube) cubes_ids.append(bmk_contents.reference_model_mlcube) cubes_ids.append(bmk_contents.data_evaluator_mlcube) diff --git a/cli/medperf/tests/mocks/benchmark.py b/cli/medperf/tests/mocks/benchmark.py index cd5f6f435..7ede9a7cd 100644 --- a/cli/medperf/tests/mocks/benchmark.py +++ b/cli/medperf/tests/mocks/benchmark.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Optional from medperf.enums import Status from medperf.entities.benchmark import Benchmark @@ -11,5 +11,4 @@ class TestBenchmark(Benchmark): data_preparation_mlcube: int = 1 reference_model_mlcube: int = 2 data_evaluator_mlcube: int = 3 - models: List[int] = [2] approval_status: Status = Status.APPROVED