Skip to content

Commit

Permalink
stop using benchmark.models for permission reasons
Browse files Browse the repository at this point in the history
  • Loading branch information
hasan7n committed Nov 25, 2023
1 parent c7b6578 commit a5ecee2
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 51 deletions.
20 changes: 15 additions & 5 deletions cli/medperf/commands/result/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ def run(
execution.prepare()
execution.validate()
execution.prepare_models()
execution.validate_models()
if not no_cache:
execution.load_cached_results()
with execution.ui.interactive():
Expand Down Expand Up @@ -101,8 +100,19 @@ def validate(self):
def prepare_models(self):
if self.models_input_file:
self.models_uids = self.__get_models_from_file()
elif self.models_uids is None:
self.models_uids = self.benchmark.models

if self.models_uids == [self.benchmark.reference_model_mlcube]:
# avoid the need of sending a request to the server for
# finding the benchmark's associated models
return

benchmark_models = Benchmark.get_models_uids(self.benchmark_uid)
benchmark_models.append(self.benchmark.reference_model_mlcube)

if self.models_uids is None:
self.models_uids = benchmark_models
else:
self.__validate_models(benchmark_models)

def __get_models_from_file(self):
if not os.path.exists(self.models_input_file):
Expand All @@ -117,9 +127,9 @@ def __get_models_from_file(self):
msg += "The file should contain a list of comma-separated integers"
raise InvalidArgumentError(msg)

def validate_models(self):
def __validate_models(self, benchmark_models):
models_set = set(self.models_uids)
benchmark_models_set = set(self.benchmark.models)
benchmark_models_set = set(benchmark_models)
non_assoc_cubes = models_set.difference(benchmark_models_set)
if non_assoc_cubes:
if len(non_assoc_cubes) > 1:
Expand Down
20 changes: 1 addition & 19 deletions cli/medperf/entities/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import yaml
import logging
from typing import List, Optional, Union
from pydantic import HttpUrl, Field, validator
from pydantic import HttpUrl, Field

import medperf.config as config
from medperf.entities.interface import Entity, Uploadable
Expand Down Expand Up @@ -32,18 +32,10 @@ class Benchmark(Entity, Uploadable, MedperfSchema, ApprovableSchema, DeployableS
data_preparation_mlcube: int
reference_model_mlcube: int
data_evaluator_mlcube: int
models: List[int] = None
metadata: dict = {}
user_metadata: dict = {}
is_active: bool = True

@validator("models", pre=True, always=True)
def set_default_models_value(cls, value, values, **kwargs):
if not value:
# Empty or None value assigned
return [values["reference_model_mlcube"]]
return value

def __init__(self, *args, **kwargs):
"""Creates a new benchmark instance
Expand Down Expand Up @@ -91,11 +83,6 @@ def __remote_all(cls, filters: dict) -> List["Benchmark"]:
try:
comms_fn = cls.__remote_prefilter(filters)
bmks_meta = comms_fn()
for bmk_meta in bmks_meta:
# Loading all related models for all benchmarks could be expensive.
# Most probably not necessary when getting all benchmarks.
# If associated models for a benchmark are needed then use Benchmark.get()
bmk_meta["models"] = [bmk_meta["reference_model_mlcube"]]
benchmarks = [cls(**meta) for meta in bmks_meta]
except CommunicationRetrievalError:
msg = "Couldn't retrieve all benchmarks from the server"
Expand Down Expand Up @@ -175,9 +162,6 @@ def __remote_get(cls, benchmark_uid: int) -> "Benchmark":
"""
logging.debug(f"Retrieving benchmark {benchmark_uid} remotely")
benchmark_dict = config.comms.get_benchmark(benchmark_uid)
ref_model = benchmark_dict["reference_model_mlcube"]
add_models = cls.get_models_uids(benchmark_uid)
benchmark_dict["models"] = [ref_model] + add_models
benchmark = cls(**benchmark_dict)
benchmark.write()
return benchmark
Expand Down Expand Up @@ -273,7 +257,6 @@ def upload(self):
raise InvalidArgumentError("Cannot upload test benchmarks.")
body = self.todict()
updated_body = config.comms.upload_benchmark(body)
updated_body["models"] = body["models"]
return updated_body

def display_dict(self):
Expand All @@ -285,7 +268,6 @@ def display_dict(self):
"Created At": self.created_at,
"Data Preparation MLCube": int(self.data_preparation_mlcube),
"Reference Model MLCube": int(self.reference_model_mlcube),
"Associated Models": ",".join(map(str, self.models)),
"Data Evaluator MLCube": int(self.data_evaluator_mlcube),
"State": self.state,
"Approval Status": self.approval_status,
Expand Down
21 changes: 20 additions & 1 deletion cli/medperf/tests/commands/result/test_create.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,13 @@ def __get_side_effect(id):
data_evaluator_mlcube=evaluator["uid"],
data_preparation_mlcube=benchmark_prep_cube,
reference_model_mlcube=benchmark_models[0],
models=benchmark_models,
)

mocker.patch(PATCH_EXECUTION.format("Benchmark.get"), side_effect=__get_side_effect)
mocker.patch(
PATCH_EXECUTION.format("Benchmark.get_models_uids"),
return_value=benchmark_models[1:],
)


def mock_dataset(mocker, state_variables):
Expand Down Expand Up @@ -126,12 +129,16 @@ def setup(request, mocker, ui, fs):
ui_error_spy = mocker.patch.object(ui, "print_error")
ui_print_spy = mocker.patch.object(ui, "print")
tabulate_spy = mocker.spy(create_module, "tabulate")
validate_models_spy = mocker.spy(
create_module.BenchmarkExecution, "_BenchmarkExecution__validate_models"
)

spies = {
"ui_error": ui_error_spy,
"ui_print": ui_print_spy,
"tabulate": tabulate_spy,
"exec": exec_spy,
"validate_models": validate_models_spy,
}
return state_variables, spies

Expand Down Expand Up @@ -326,3 +333,15 @@ def test_execution_of_one_model_writes_result(self, mocker, setup):
yaml.safe_load(open(expected_file))["results"]
== self.state_variables["models_props"][model_uid]["results"]
)

def test_execution_of_reference_model_does_not_call_validate(self, mocker, setup):
# Arrange
model_uid = self.state_variables["benchmark_models"][0]
dset_uid = 2
bmk_uid = 1

# Act
BenchmarkExecution.run(bmk_uid, dset_uid, models_uids=[model_uid])

# Assert
self.spies["validate_models"].assert_not_called()
39 changes: 16 additions & 23 deletions cli/medperf/tests/entities/test_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,36 +32,29 @@ def setup(request, mocker, comms, fs):


@pytest.mark.parametrize(
"setup",
"setup,expected_models",
[
{
"remote": [721],
"models": [
{"model_mlcube": 37, "approval_status": "APPROVED"},
{"model_mlcube": 23, "approval_status": "APPROVED"},
{"model_mlcube": 495, "approval_status": "APPROVED"},
],
}
(
{
"remote": [721],
"models": [
{"model_mlcube": 37, "approval_status": "APPROVED"},
{"model_mlcube": 23, "approval_status": "APPROVED"},
{"model_mlcube": 495, "approval_status": "PENDING"},
],
},
[37, 23],
)
],
indirect=True,
indirect=["setup"],
)
class TestModels:
def test_benchmark_includes_reference_model_in_models(self, setup):
# Act
id = setup["remote"][0]
benchmark = Benchmark.get(id)

# Assert
assert benchmark.reference_model_mlcube in benchmark.models

def test_benchmark_includes_additional_models_in_models(self, setup):
def test_benchmark_get_models_works_as_expected(self, setup, expected_models):
# Arrange
id = setup["remote"][0]
models = setup["models"]
models = [model["model_mlcube"] for model in models]

# Act
benchmark = Benchmark.get(id)
assciated_models = Benchmark.get_models_uids(id)

# Assert
assert set(models).issubset(set(benchmark.models))
assert set(assciated_models) == set(expected_models)
2 changes: 1 addition & 1 deletion cli/medperf/tests/entities/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def setup_benchmark_fs(ents, fs):
id = ent["id"]
bmk_filepath = os.path.join(bmks_path, str(id), config.benchmarks_filename)
bmk_contents = TestBenchmark(**ent)
cubes_ids = bmk_contents.models
cubes_ids = []
cubes_ids.append(bmk_contents.data_preparation_mlcube)
cubes_ids.append(bmk_contents.reference_model_mlcube)
cubes_ids.append(bmk_contents.data_evaluator_mlcube)
Expand Down
3 changes: 1 addition & 2 deletions cli/medperf/tests/mocks/benchmark.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional
from typing import Optional
from medperf.enums import Status
from medperf.entities.benchmark import Benchmark

Expand All @@ -11,5 +11,4 @@ class TestBenchmark(Benchmark):
data_preparation_mlcube: int = 1
reference_model_mlcube: int = 2
data_evaluator_mlcube: int = 3
models: List[int] = [2]
approval_status: Status = Status.APPROVED

0 comments on commit a5ecee2

Please sign in to comment.