diff --git a/.github/workflows/unittests.yml b/.github/workflows/unittests.yml index 7ebe0096d..11c570b73 100644 --- a/.github/workflows/unittests.yml +++ b/.github/workflows/unittests.yml @@ -22,7 +22,8 @@ jobs: pip install flake8 pytest if [ -f requirements.txt ]; then pip install -r requirements.txt; fi if [ -f cli/requirements.txt ]; then pip install -e cli; fi - if [ -f server/requirements.txt ]; then pip install -r server/requirements.txt; fi + pip install -r server/requirements.txt + pip install -r server/test-requirements.txt - name: Lint with flake8 run: | # stop the build if there are Python syntax errors or undefined names @@ -35,6 +36,7 @@ jobs: # Ignore E231, as it is raising warnings with auto-generated code. flake8 . --count --max-complexity=10 --max-line-length=127 --ignore F821,W503,E231 --statistics --exclude=examples/,"*/migrations/*",cli/medperf/templates/ - name: Test with pytest + working-directory: ./cli run: | pytest - name: Set server environment vars @@ -45,4 +47,4 @@ jobs: run: python manage.py migrate - name: Run server unit tests working-directory: ./server - run: python manage.py test + run: python manage.py test --parallel diff --git a/cli/cli_tests.sh b/cli/cli_tests.sh index d770e3839..7ebcc3cfa 100755 --- a/cli/cli_tests.sh +++ b/cli/cli_tests.sh @@ -244,16 +244,6 @@ checkFailed "Failing model association failed" echo "\n" -########################################################## -echo "=====================================" -echo "Changing priority of model2" -echo "=====================================" -medperf association set_priority -b $BMK_UID -m $MODEL2_UID -p 77 -checkFailed "Priority set of model2 failed" -########################################################## - -echo "\n" - ########################################################## echo "=====================================" echo "Activate modelowner profile" @@ -278,6 +268,26 @@ checkFailed "failing model association approval failed" echo "\n" +########################################################## +echo "=====================================" +echo "Activate benchmarkowner profile" +echo "=====================================" +medperf profile activate testbenchmark +checkFailed "testbenchmark profile activation failed" +########################################################## + +echo "\n" + +########################################################## +echo "=====================================" +echo "Changing priority of model2" +echo "=====================================" +medperf association set_priority -b $BMK_UID -m $MODEL2_UID -p 77 +checkFailed "Priority set of model2 failed" +########################################################## + +echo "\n" + ########################################################## echo "=====================================" echo "Activate dataowner profile" diff --git a/cli/medperf/commands/association/priority.py b/cli/medperf/commands/association/priority.py index 7edaa6cfc..c58db2450 100644 --- a/cli/medperf/commands/association/priority.py +++ b/cli/medperf/commands/association/priority.py @@ -1,12 +1,11 @@ from medperf import config from medperf.exceptions import InvalidArgumentError +from medperf.entities.benchmark import Benchmark class AssociationPriority: @staticmethod - def run( - benchmark_uid: int, mlcube_uid: int, priority: int, - ): + def run(benchmark_uid: int, mlcube_uid: int, priority: int): """Sets priority for an association between a benchmark and an mlcube Args: @@ -15,7 +14,7 @@ def run( priority (int): priority value """ - associated_cubes = config.comms.get_benchmark_models(benchmark_uid) + associated_cubes = Benchmark.get_models_uids(benchmark_uid) if mlcube_uid not in associated_cubes: raise InvalidArgumentError( "The given mlcube doesn't exist or is not associated with the benchmark" diff --git a/cli/medperf/commands/benchmark/benchmark.py b/cli/medperf/commands/benchmark/benchmark.py index aa35a859c..0fdfbc1fe 100644 --- a/cli/medperf/commands/benchmark/benchmark.py +++ b/cli/medperf/commands/benchmark/benchmark.py @@ -37,7 +37,7 @@ def submit( ), docs_url: str = typer.Option("", "--docs-url", "-u", help="URL to documentation"), demo_url: str = typer.Option( - "", + ..., "--demo-url", help="""Identifier to download the demonstration dataset tarball file.\n See `medperf mlcube submit --help` for more information""", diff --git a/cli/medperf/commands/compatibility_test/utils.py b/cli/medperf/commands/compatibility_test/utils.py index 4f8784c7d..a956a5796 100644 --- a/cli/medperf/commands/compatibility_test/utils.py +++ b/cli/medperf/commands/compatibility_test/utils.py @@ -1,4 +1,4 @@ -from medperf.utils import storage_path, get_folder_hash +from medperf.utils import storage_path, get_folders_hash from medperf.exceptions import InvalidArgumentError, InvalidEntityError from medperf.comms.entity_resources import resources @@ -34,7 +34,7 @@ def download_demo_data(dset_url, dset_hash): def prepare_local_cube(path): - temp_uid = get_folder_hash(path) + temp_uid = get_folders_hash([path]) cubes_storage = storage_path(config.cubes_storage) dst = os.path.join(cubes_storage, temp_uid) os.symlink(path, dst) diff --git a/cli/medperf/commands/dataset/create.py b/cli/medperf/commands/dataset/create.py index 422d708ad..ab8fc15c1 100644 --- a/cli/medperf/commands/dataset/create.py +++ b/cli/medperf/commands/dataset/create.py @@ -13,7 +13,7 @@ from medperf.utils import ( remove_path, generate_tmp_path, - get_folder_hash, + get_folders_hash, storage_path, ) from medperf.exceptions import InvalidArgumentError, ExecutionError @@ -195,10 +195,6 @@ def run_prepare(self): "output_path": out_datapath, "output_labels_path": out_labelspath, } - prepare_str_params = { - "Ptasks.prepare.parameters.input.data_path.opts": "ro", - "Ptasks.prepare.parameters.input.labels_path.opts": "ro", - } def sigint_handler(sig, frame): metadata = {"execution_status": "interrupted"} @@ -242,7 +238,6 @@ def sigint_handler(sig, frame): with self.ui.interactive(): self.cube.run( task="prepare", - string_params=prepare_str_params, timeout=prepare_timeout, **prepare_params, ) @@ -291,32 +286,19 @@ def run_sanity_check(self): # Specify parameters for the tasks sanity_params = { "data_path": out_datapath, - } - sanity_str_params = { - "Ptasks.sanity_check.parameters.input.data_path.opts": "ro", + "labels_path": out_labelspath, } - if self.labels_specified: - # Add the labels parameter - sanity_params["labels_path"] = out_labelspath - if self.metadata_specified: sanity_params["metadata_path"] = self.metadata_path - sanity_params[ - "Ptasks.sanity_check.parameters.input.metadata_paths.opts" - ] = "ro" if self.report_specified: sanity_params["report_file"] = out_report - sanity_str_params[ - "Ptasks.sanity_check.parameters.input.report_file.opts" - ] = "ro" self.ui.text = "Running sanity check..." try: self.cube.run( task="sanity_check", - string_params=sanity_str_params, timeout=sanity_check_timeout, **sanity_params, ) @@ -339,21 +321,14 @@ def run_statistics(self): "labels_path": out_labelspath, "output_path": self.out_statistics_path, } - statistics_str_params = { - "Ptasks.statistics.parameters.input.data_path.opts": "ro" - } if self.metadata_specified: statistics_params["metadata_path"] = self.metadata_path - statistics_params[ - "Ptasks.statistics.parameters.input.metadata_path.opts" - ] = "ro" self.ui.text = "Generating statistics..." self.cube.run( task="statistics", - string_params=statistics_str_params, timeout=statistics_timeout, **statistics_params, ) @@ -366,8 +341,8 @@ def get_statistics(self): def generate_uids(self): """Auto-generates dataset UIDs for both input and output paths""" - self.in_uid = get_folder_hash(self.data_path) - self.generated_uid = get_folder_hash(self.out_datapath) + self.in_uid = get_folders_hash([self.data_path, self.labels_path]) + self.generated_uid = get_folders_hash([self.out_datapath, self.out_labelspath]) def to_permanent_path(self) -> str: """Renames the temporary data folder to permanent one using the hash of diff --git a/cli/medperf/commands/execution.py b/cli/medperf/commands/execution.py index 66730c272..5669a4d2b 100644 --- a/cli/medperf/commands/execution.py +++ b/cli/medperf/commands/execution.py @@ -86,7 +86,6 @@ def run_inference(self): timeout=infer_timeout, data_path=data_path, output_path=preds_path, - string_params={"Ptasks.infer.parameters.input.data_path.opts": "ro"}, ) self.ui.print("> Model execution complete") @@ -113,10 +112,6 @@ def run_evaluation(self): predictions=preds_path, labels=labels_path, output_path=results_path, - string_params={ - "Ptasks.evaluate.parameters.input.predictions.opts": "ro", - "Ptasks.evaluate.parameters.input.labels.opts": "ro", - }, ) except ExecutionError as e: logging.error(f"Metrics MLCube Execution failed: {e}") diff --git a/cli/medperf/commands/result/create.py b/cli/medperf/commands/result/create.py index 3024c220f..576cd3555 100644 --- a/cli/medperf/commands/result/create.py +++ b/cli/medperf/commands/result/create.py @@ -50,7 +50,6 @@ def run( execution.prepare() execution.validate() execution.prepare_models() - execution.validate_models() if not no_cache: execution.load_cached_results() with execution.ui.interactive(): @@ -101,8 +100,19 @@ def validate(self): def prepare_models(self): if self.models_input_file: self.models_uids = self.__get_models_from_file() - elif self.models_uids is None: - self.models_uids = self.benchmark.models + + if self.models_uids == [self.benchmark.reference_model_mlcube]: + # avoid the need of sending a request to the server for + # finding the benchmark's associated models + return + + benchmark_models = Benchmark.get_models_uids(self.benchmark_uid) + benchmark_models.append(self.benchmark.reference_model_mlcube) + + if self.models_uids is None: + self.models_uids = benchmark_models + else: + self.__validate_models(benchmark_models) def __get_models_from_file(self): if not os.path.exists(self.models_input_file): @@ -117,9 +127,9 @@ def __get_models_from_file(self): msg += "The file should contain a list of comma-separated integers" raise InvalidArgumentError(msg) - def validate_models(self): + def __validate_models(self, benchmark_models): models_set = set(self.models_uids) - benchmark_models_set = set(self.benchmark.models) + benchmark_models_set = set(benchmark_models) non_assoc_cubes = models_set.difference(benchmark_models_set) if non_assoc_cubes: if len(non_assoc_cubes) > 1: diff --git a/cli/medperf/comms/interface.py b/cli/medperf/comms/interface.py index 82c359f51..4cce0fb71 100644 --- a/cli/medperf/comms/interface.py +++ b/cli/medperf/comms/interface.py @@ -50,14 +50,14 @@ def get_benchmark(self, benchmark_uid: int) -> dict: """ @abstractmethod - def get_benchmark_models(self, benchmark_uid: int) -> List[int]: - """Retrieves all the models associated with a benchmark. reference model not included + def get_benchmark_model_associations(self, benchmark_uid: int) -> List[int]: + """Retrieves all the model associations of a benchmark. Args: benchmark_uid (int): UID of the desired benchmark Returns: - list[int]: List of model UIDS + list[int]: List of benchmark model associations """ @abstractmethod diff --git a/cli/medperf/comms/rest.py b/cli/medperf/comms/rest.py index 285698495..6e6bb5057 100644 --- a/cli/medperf/comms/rest.py +++ b/cli/medperf/comms/rest.py @@ -5,7 +5,12 @@ from medperf.enums import Status import medperf.config as config from medperf.comms.interface import Comms -from medperf.utils import sanitize_json, log_response_error, format_errors_dict +from medperf.utils import ( + sanitize_json, + log_response_error, + format_errors_dict, + filter_latest_associations, +) from medperf.exceptions import ( CommunicationError, CommunicationRetrievalError, @@ -174,18 +179,17 @@ def get_benchmark(self, benchmark_uid: int) -> dict: ) return res.json() - def get_benchmark_models(self, benchmark_uid: int) -> List[int]: - """Retrieves all the models associated with a benchmark. reference model not included + def get_benchmark_model_associations(self, benchmark_uid: int) -> List[int]: + """Retrieves all the model associations of a benchmark. Args: benchmark_uid (int): UID of the desired benchmark Returns: - list[int]: List of model UIDS + list[int]: List of benchmark model associations """ - models = self.__get_list(f"{self.server_url}/benchmarks/{benchmark_uid}/models") - model_uids = [model["id"] for model in models] - return model_uids + assocs = self.__get_list(f"{self.server_url}/benchmarks/{benchmark_uid}/models") + return filter_latest_associations(assocs, "model_mlcube") def get_user_benchmarks(self) -> List[dict]: """Retrieves all benchmarks created by the user @@ -475,7 +479,7 @@ def get_datasets_associations(self) -> List[dict]: List[dict]: List containing all associations information """ assocs = self.__get_list(f"{self.server_url}/me/datasets/associations/") - return assocs + return filter_latest_associations(assocs, "dataset") def get_cubes_associations(self) -> List[dict]: """Get all cube associations related to the current user @@ -484,7 +488,7 @@ def get_cubes_associations(self) -> List[dict]: List[dict]: List containing all associations information """ assocs = self.__get_list(f"{self.server_url}/me/mlcubes/associations/") - return assocs + return filter_latest_associations(assocs, "model_mlcube") def set_mlcube_association_priority( self, benchmark_uid: int, mlcube_uid: int, priority: int diff --git a/cli/medperf/entities/benchmark.py b/cli/medperf/entities/benchmark.py index b3b1a90f7..01d3d48da 100644 --- a/cli/medperf/entities/benchmark.py +++ b/cli/medperf/entities/benchmark.py @@ -3,7 +3,7 @@ import yaml import logging from typing import List, Optional, Union -from pydantic import HttpUrl, Field, validator +from pydantic import HttpUrl, Field import medperf.config as config from medperf.entities.interface import Entity, Uploadable @@ -26,24 +26,16 @@ class Benchmark(Entity, Uploadable, MedperfSchema, ApprovableSchema, DeployableS description: Optional[str] = Field(None, max_length=20) docs_url: Optional[HttpUrl] - demo_dataset_tarball_url: Optional[str] + demo_dataset_tarball_url: str demo_dataset_tarball_hash: Optional[str] demo_dataset_generated_uid: Optional[str] data_preparation_mlcube: int reference_model_mlcube: int data_evaluator_mlcube: int - models: List[int] = None metadata: dict = {} user_metadata: dict = {} is_active: bool = True - @validator("models", pre=True, always=True) - def set_default_models_value(cls, value, values, **kwargs): - if not value: - # Empty or None value assigned - return [values["reference_model_mlcube"]] - return value - def __init__(self, *args, **kwargs): """Creates a new benchmark instance @@ -91,11 +83,6 @@ def __remote_all(cls, filters: dict) -> List["Benchmark"]: try: comms_fn = cls.__remote_prefilter(filters) bmks_meta = comms_fn() - for bmk_meta in bmks_meta: - # Loading all related models for all benchmarks could be expensive. - # Most probably not necessary when getting all benchmarks. - # If associated models for a benchmark are needed then use Benchmark.get() - bmk_meta["models"] = [bmk_meta["reference_model_mlcube"]] benchmarks = [cls(**meta) for meta in bmks_meta] except CommunicationRetrievalError: msg = "Couldn't retrieve all benchmarks from the server" @@ -175,9 +162,6 @@ def __remote_get(cls, benchmark_uid: int) -> "Benchmark": """ logging.debug(f"Retrieving benchmark {benchmark_uid} remotely") benchmark_dict = config.comms.get_benchmark(benchmark_uid) - ref_model = benchmark_dict["reference_model_mlcube"] - add_models = cls.get_models_uids(benchmark_uid) - benchmark_dict["models"] = [ref_model] + add_models benchmark = cls(**benchmark_dict) benchmark.write() return benchmark @@ -230,7 +214,13 @@ def get_models_uids(cls, benchmark_uid: int) -> List[int]: Returns: List[int]: List of mlcube uids """ - return config.comms.get_benchmark_models(benchmark_uid) + associations = config.comms.get_benchmark_model_associations(benchmark_uid) + models_uids = [ + assoc["model_mlcube"] + for assoc in associations + if assoc["approval_status"] == "APPROVED" + ] + return models_uids def todict(self) -> dict: """Dictionary representation of the benchmark instance @@ -267,7 +257,6 @@ def upload(self): raise InvalidArgumentError("Cannot upload test benchmarks.") body = self.todict() updated_body = config.comms.upload_benchmark(body) - updated_body["models"] = body["models"] return updated_body def display_dict(self): @@ -279,7 +268,6 @@ def display_dict(self): "Created At": self.created_at, "Data Preparation MLCube": int(self.data_preparation_mlcube), "Reference Model MLCube": int(self.reference_model_mlcube), - "Associated Models": ",".join(map(str, self.models)), "Data Evaluator MLCube": int(self.data_evaluator_mlcube), "State": self.state, "Approval Status": self.approval_status, diff --git a/cli/medperf/entities/cube.py b/cli/medperf/entities/cube.py index ac2d58b55..2e33f6406 100644 --- a/cli/medperf/entities/cube.py +++ b/cli/medperf/entities/cube.py @@ -277,6 +277,7 @@ def run( output_logs: str = None, string_params: Dict[str, str] = {}, timeout: int = None, + read_protected_input: bool = True, **kwargs, ): """Executes a given task on the cube instance @@ -286,6 +287,7 @@ def run( string_params (Dict[str], optional): Extra parameters that can't be passed as normal function args. Defaults to {}. timeout (int, optional): timeout for the task in seconds. Defaults to None. + read_protected_input (bool, optional): Wether to disable write permissions on input volumes. Defaults to True. kwargs (dict): additional arguments that are passed directly to the mlcube command """ kwargs.update(string_params) @@ -294,6 +296,8 @@ def run( cmd += f" --mlcube={self.cube_path} --task={task} --platform={config.platform} --network=none" if config.gpus is not None: cmd += f" --gpus={config.gpus}" + if read_protected_input: + cmd += " --mount=ro" for k, v in kwargs.items(): cmd_arg = f'{k}="{v}"' cmd = " ".join([cmd, cmd_arg]) diff --git a/cli/medperf/entities/result.py b/cli/medperf/entities/result.py index 1eaeaeadf..861247397 100644 --- a/cli/medperf/entities/result.py +++ b/cli/medperf/entities/result.py @@ -27,6 +27,7 @@ class Result(Entity, Uploadable, MedperfSchema, ApprovableSchema): dataset: int results: dict metadata: dict = {} + user_metadata: dict = {} def __init__(self, *args, **kwargs): """Creates a new result instance""" diff --git a/cli/medperf/entities/schemas.py b/cli/medperf/entities/schemas.py index ad0f5f596..ca1db3ad2 100644 --- a/cli/medperf/entities/schemas.py +++ b/cli/medperf/entities/schemas.py @@ -79,6 +79,7 @@ class MedperfSchema(MedperfBaseSchema): id: Optional[int] name: str = Field(..., max_length=64) owner: Optional[int] + is_valid: bool = True created_at: Optional[datetime] modified_at: Optional[datetime] @@ -92,7 +93,6 @@ def name_max_length(cls, v, *, values, **kwargs): class DeployableSchema(BaseModel): # TODO: This must change after allowing edits state: str = "OPERATION" - is_valid: bool = True class ApprovableSchema(BaseModel): diff --git a/cli/medperf/tests/commands/association/test_priority.py b/cli/medperf/tests/commands/association/test_priority.py index c74118af9..8d7a70392 100644 --- a/cli/medperf/tests/commands/association/test_priority.py +++ b/cli/medperf/tests/commands/association/test_priority.py @@ -21,9 +21,9 @@ def func(benchmark_uid, mlcube_uid, priority): return func -def get_benchmark_models_behavior(associations): +def get_benchmark_model_associations_behavior(associations): def func(benchmark_uid): - return [assoc["model_mlcube"] for assoc in associations] + return associations return func @@ -31,8 +31,8 @@ def func(benchmark_uid): def setup_comms(mocker, comms, associations): mocker.patch.object( comms, - "get_benchmark_models", - side_effect=get_benchmark_models_behavior(associations), + "get_benchmark_model_associations", + side_effect=get_benchmark_model_associations_behavior(associations), ) mocker.patch.object( comms, @@ -49,7 +49,9 @@ def setup(request, mocker, comms): @pytest.mark.parametrize( - "setup", [{"associations": TEST_ASSOCIATIONS}], indirect=True, + "setup", + [{"associations": TEST_ASSOCIATIONS}], + indirect=True, ) class TestRun: @pytest.fixture(autouse=True) diff --git a/cli/medperf/tests/commands/dataset/test_create.py b/cli/medperf/tests/commands/dataset/test_create.py index 1b23b91ce..2026c5a27 100644 --- a/cli/medperf/tests/commands/dataset/test_create.py +++ b/cli/medperf/tests/commands/dataset/test_create.py @@ -123,17 +123,12 @@ def test_run_cube_tasks_runs_required_tasks(self, mocker, preparation): labels_path=LABELS_PATH, output_path=OUT_DATAPATH, output_labels_path=OUT_LABELSPATH, - string_params={ - "Ptasks.prepare.parameters.input.data_path.opts": "ro", - "Ptasks.prepare.parameters.input.labels_path.opts": "ro", - }, ) check = call( task="sanity_check", timeout=None, data_path=OUT_DATAPATH, labels_path=OUT_LABELSPATH, - string_params={"Ptasks.sanity_check.parameters.input.data_path.opts": "ro"}, ) stats = call( task="statistics", @@ -141,7 +136,6 @@ def test_run_cube_tasks_runs_required_tasks(self, mocker, preparation): data_path=OUT_DATAPATH, labels_path=OUT_LABELSPATH, output_path=STATISTICS_PATH, - string_params={"Ptasks.statistics.parameters.input.data_path.opts": "ro"}, ) calls = [prepare, check, stats] @@ -214,15 +208,30 @@ def test_fails_if_invalid_params(self, mocker, benchmark_uid, cube_uid, comms, u else: preparation.validate() - @pytest.mark.parametrize("in_path", ["data_path", "input_path", "/usr/data/path"]) - @pytest.mark.parametrize("out_path", ["out_path", "~/.medperf/data/123"]) + @pytest.mark.parametrize( + "in_path", + [ + ["data", "labels"], + ["in_data", "in_labels"], + ["/usr/data/path", "usr/labels/path"], + ], + ) + @pytest.mark.parametrize( + "out_path", + [ + ["out_data", "out_labels"], + ["~/.medperf/data/123/data", "~/.medperf/data/123/labels"], + ], + ) def test_generate_uids_assigns_uids_to_obj_properties( self, mocker, in_path, out_path, preparation ): # Arrange - mocker.patch(PATCH_DATAPREP.format("get_folder_hash"), side_effect=lambda x: x) - preparation.data_path = in_path - preparation.out_datapath = out_path + mocker.patch(PATCH_DATAPREP.format("get_folders_hash"), side_effect=lambda x: x) + preparation.data_path = in_path[0] + preparation.labels_path = in_path[1] + preparation.out_datapath = out_path[0] + preparation.out_labelspath = out_path[1] # Act preparation.generate_uids() diff --git a/cli/medperf/tests/commands/result/test_create.py b/cli/medperf/tests/commands/result/test_create.py index 8dfe90606..7a11c05bf 100644 --- a/cli/medperf/tests/commands/result/test_create.py +++ b/cli/medperf/tests/commands/result/test_create.py @@ -29,10 +29,13 @@ def __get_side_effect(id): data_evaluator_mlcube=evaluator["uid"], data_preparation_mlcube=benchmark_prep_cube, reference_model_mlcube=benchmark_models[0], - models=benchmark_models, ) mocker.patch(PATCH_EXECUTION.format("Benchmark.get"), side_effect=__get_side_effect) + mocker.patch( + PATCH_EXECUTION.format("Benchmark.get_models_uids"), + return_value=benchmark_models[1:], + ) def mock_dataset(mocker, state_variables): @@ -126,12 +129,16 @@ def setup(request, mocker, ui, fs): ui_error_spy = mocker.patch.object(ui, "print_error") ui_print_spy = mocker.patch.object(ui, "print") tabulate_spy = mocker.spy(create_module, "tabulate") + validate_models_spy = mocker.spy( + create_module.BenchmarkExecution, "_BenchmarkExecution__validate_models" + ) spies = { "ui_error": ui_error_spy, "ui_print": ui_print_spy, "tabulate": tabulate_spy, "exec": exec_spy, + "validate_models": validate_models_spy, } return state_variables, spies @@ -326,3 +333,15 @@ def test_execution_of_one_model_writes_result(self, mocker, setup): yaml.safe_load(open(expected_file))["results"] == self.state_variables["models_props"][model_uid]["results"] ) + + def test_execution_of_reference_model_does_not_call_validate(self, mocker, setup): + # Arrange + model_uid = self.state_variables["benchmark_models"][0] + dset_uid = 2 + bmk_uid = 1 + + # Act + BenchmarkExecution.run(bmk_uid, dset_uid, models_uids=[model_uid]) + + # Assert + self.spies["validate_models"].assert_not_called() diff --git a/cli/medperf/tests/commands/test_execution.py b/cli/medperf/tests/commands/test_execution.py index a94e11b29..c007dbdee 100644 --- a/cli/medperf/tests/commands/test_execution.py +++ b/cli/medperf/tests/commands/test_execution.py @@ -178,7 +178,6 @@ def test_cube_run_are_called_properly(mocker, setup): timeout=config.infer_timeout, data_path=INPUT_DATASET.data_path, output_path=exp_preds_path, - string_params={"Ptasks.infer.parameters.input.data_path.opts": "ro"}, ) exp_eval_call = call( task="evaluate", @@ -187,10 +186,6 @@ def test_cube_run_are_called_properly(mocker, setup): predictions=exp_preds_path, labels=INPUT_DATASET.labels_path, output_path=ANY, - string_params={ - "Ptasks.evaluate.parameters.input.predictions.opts": "ro", - "Ptasks.evaluate.parameters.input.labels.opts": "ro", - }, ) # Act Execution.run(INPUT_DATASET, INPUT_MODEL, INPUT_EVALUATOR) diff --git a/cli/medperf/tests/comms/test_rest.py b/cli/medperf/tests/comms/test_rest.py index 8f820ed9e..25babeb54 100644 --- a/cli/medperf/tests/comms/test_rest.py +++ b/cli/medperf/tests/comms/test_rest.py @@ -24,7 +24,7 @@ def server(mocker, ui): [ ("get_benchmark", "get", 200, [1], {}, (f"{full_url}/benchmarks/1",), {}), ( - "get_benchmark_models", + "get_benchmark_model_associations", "get_list", 200, [1], @@ -288,17 +288,21 @@ def test_get_benchmarks_calls_benchmarks_path(mocker, server, body): assert bmarks == [body] -@pytest.mark.parametrize("exp_uids", [[142, 437, 196], [303, 27, 24], [40, 19, 399]]) -def test_get_benchmark_models_return_uids(mocker, server, exp_uids): +def test_get_benchmark_model_associations_calls_expected_functions(mocker, server): # Arrange - body = [{"id": uid} for uid in exp_uids] - mocker.patch(patch_server.format("REST._REST__get_list"), return_value=body) - + assocs = [{"model_mlcube": uid} for uid in [1, 2, 3]] + spy_list = mocker.patch( + patch_server.format("REST._REST__get_list"), return_value=assocs + ) + spy_filter = mocker.patch( + patch_server.format("filter_latest_associations"), side_effect=lambda x, y: x + ) # Act - uids = server.get_benchmark_models(1) + server.get_benchmark_model_associations(1) # Assert - assert set(uids) == set(exp_uids) + spy_list.assert_called_once() + spy_filter.assert_called_once() def test_get_user_benchmarks_calls_auth_get_for_expected_path(mocker, server): diff --git a/cli/medperf/tests/entities/test_benchmark.py b/cli/medperf/tests/entities/test_benchmark.py index dd3166f5f..3f1fde2e2 100644 --- a/cli/medperf/tests/entities/test_benchmark.py +++ b/cli/medperf/tests/entities/test_benchmark.py @@ -8,7 +8,12 @@ @pytest.fixture( - params={"local": [1, 2, 3], "remote": [4, 5, 6], "user": [4], "models": [10, 11],} + params={ + "local": [1, 2, 3], + "remote": [4, 5, 6], + "user": [4], + "models": [10, 11], + } ) def setup(request, mocker, comms, fs): local_ids = request.param.get("local", []) @@ -20,31 +25,36 @@ def setup(request, mocker, comms, fs): setup_benchmark_fs(local_ids, fs) setup_benchmark_comms(mocker, comms, remote_ids, user_ids, uploaded) - mocker.patch.object(comms, "get_benchmark_models", return_value=models) + mocker.patch.object(comms, "get_benchmark_model_associations", return_value=models) request.param["uploaded"] = uploaded return request.param @pytest.mark.parametrize( - "setup", [{"remote": [721], "models": [37, 23, 495],}], indirect=True, + "setup,expected_models", + [ + ( + { + "remote": [721], + "models": [ + {"model_mlcube": 37, "approval_status": "APPROVED"}, + {"model_mlcube": 23, "approval_status": "APPROVED"}, + {"model_mlcube": 495, "approval_status": "PENDING"}, + ], + }, + [37, 23], + ) + ], + indirect=["setup"], ) class TestModels: - def test_benchmark_includes_reference_model_in_models(self, setup): - # Act - id = setup["remote"][0] - benchmark = Benchmark.get(id) - - # Assert - assert benchmark.reference_model_mlcube in benchmark.models - - def test_benchmark_includes_additional_models_in_models(self, setup): + def test_benchmark_get_models_works_as_expected(self, setup, expected_models): # Arrange id = setup["remote"][0] - models = setup["models"] # Act - benchmark = Benchmark.get(id) + assciated_models = Benchmark.get_models_uids(id) # Assert - assert set(models).issubset(set(benchmark.models)) + assert set(assciated_models) == set(expected_models) diff --git a/cli/medperf/tests/entities/test_cube.py b/cli/medperf/tests/entities/test_cube.py index aee7f4009..eb0196a44 100644 --- a/cli/medperf/tests/entities/test_cube.py +++ b/cli/medperf/tests/entities/test_cube.py @@ -162,7 +162,7 @@ def test_cube_runs_command(self, mocker, timeout, setup, task): ) expected_cmd = ( f"mlcube run --mlcube={self.manifest_path} --task={task} " - + f"--platform={self.platform} --network=none" + + f"--platform={self.platform} --network=none --mount=ro" ) # Act @@ -172,13 +172,29 @@ def test_cube_runs_command(self, mocker, timeout, setup, task): # Assert spy.assert_any_call(expected_cmd, timeout=timeout) + def test_cube_runs_command_with_rw_access(self, mocker, setup, task): + # Arrange + mpexpect = MockPexpect(0, "expected_hash") + spy = mocker.patch("pexpect.spawn", side_effect=mpexpect.spawn) + expected_cmd = ( + f"mlcube run --mlcube={self.manifest_path} --task={task} " + + f"--platform={self.platform} --network=none" + ) + + # Act + cube = Cube.get(self.id) + cube.run(task, read_protected_input=False) + + # Assert + spy.assert_any_call(expected_cmd, timeout=None) + def test_cube_runs_command_with_extra_args(self, mocker, setup, task): # Arrange mpexpect = MockPexpect(0, "expected_hash") spy = mocker.patch("pexpect.spawn", side_effect=mpexpect.spawn) expected_cmd = ( f"mlcube run --mlcube={self.manifest_path} --task={task} " - + f'--platform={self.platform} --network=none test="test"' + + f'--platform={self.platform} --network=none --mount=ro test="test"' ) # Act diff --git a/cli/medperf/tests/entities/utils.py b/cli/medperf/tests/entities/utils.py index f73bd73ab..2b141da51 100644 --- a/cli/medperf/tests/entities/utils.py +++ b/cli/medperf/tests/entities/utils.py @@ -24,7 +24,7 @@ def setup_benchmark_fs(ents, fs): id = ent["id"] bmk_filepath = os.path.join(bmks_path, str(id), config.benchmarks_filename) bmk_contents = TestBenchmark(**ent) - cubes_ids = bmk_contents.models + cubes_ids = [] cubes_ids.append(bmk_contents.data_preparation_mlcube) cubes_ids.append(bmk_contents.reference_model_mlcube) cubes_ids.append(bmk_contents.data_evaluator_mlcube) @@ -44,7 +44,7 @@ def setup_benchmark_comms(mocker, comms, all_ents, user_ents, uploaded): "get_instance": "get_benchmark", "upload_instance": "upload_benchmark", } - mocker.patch.object(comms, "get_benchmark_models", return_value=[]) + mocker.patch.object(comms, "get_benchmark_model_associations", return_value=[]) mock_comms_entity_gets( mocker, comms, generate_fn, comms_calls, all_ents, user_ents, uploaded ) diff --git a/cli/medperf/tests/mocks/benchmark.py b/cli/medperf/tests/mocks/benchmark.py index d5e4c3230..7ede9a7cd 100644 --- a/cli/medperf/tests/mocks/benchmark.py +++ b/cli/medperf/tests/mocks/benchmark.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Optional from medperf.enums import Status from medperf.entities.benchmark import Benchmark @@ -6,9 +6,9 @@ class TestBenchmark(Benchmark): id: Optional[int] = 1 name: str = "name" + demo_dataset_tarball_url: str = "tarball_url" demo_dataset_tarball_hash: str = "tarball_hash" data_preparation_mlcube: int = 1 reference_model_mlcube: int = 2 data_evaluator_mlcube: int = 3 - models: List[int] = [2] approval_status: Status = Status.APPROVED diff --git a/cli/medperf/tests/test_utils.py b/cli/medperf/tests/test_utils.py index fbfff9be8..4130c70ce 100644 --- a/cli/medperf/tests/test_utils.py +++ b/cli/medperf/tests/test_utils.py @@ -1,3 +1,4 @@ +from datetime import datetime import os import pytest import logging @@ -343,7 +344,7 @@ def test_dict_pretty_print_passes_clean_dict_to_yaml(mocker, ui, dict_with_nones spy.assert_called_once_with(exp_dict) -def test_get_folder_hash_hashes_all_files_in_folder(mocker, filesystem): +def test_get_folders_hash_hashes_all_files_in_folder(mocker, filesystem): # Arrange fs = filesystem[0] files = filesystem[1] @@ -352,13 +353,13 @@ def test_get_folder_hash_hashes_all_files_in_folder(mocker, filesystem): spy = mocker.patch(patch_utils.format("get_file_hash"), side_effect=files) # Act - utils.get_folder_hash("test") + utils.get_folders_hash(["test"]) # Assert spy.assert_has_calls(exp_calls) -def test_get_folder_hash_sorts_individual_hashes(mocker, filesystem): +def test_get_folders_hash_sorts_individual_hashes(mocker, filesystem): # Arrange fs = filesystem[0] files = filesystem[1] @@ -367,13 +368,13 @@ def test_get_folder_hash_sorts_individual_hashes(mocker, filesystem): spy = mocker.patch("builtins.sorted", side_effect=sorted) # Act - utils.get_folder_hash("test") + utils.get_folders_hash(["test"]) # Assert spy.assert_called_once_with(files) -def test_get_folder_hash_returns_expected_hash(mocker, filesystem): +def test_get_folders_hash_returns_expected_hash(mocker, filesystem): # Arrange fs = filesystem[0] files = filesystem[1] @@ -381,7 +382,7 @@ def test_get_folder_hash_returns_expected_hash(mocker, filesystem): mocker.patch(patch_utils.format("get_file_hash"), side_effect=files) # Act - hash = utils.get_folder_hash("test") + hash = utils.get_folders_hash(["test"]) # Assert assert hash == "b7e9365f1e796ba29e9e6b1b94b5f4cc7238530601fad8ec96ece9fee68c3d7f" @@ -458,3 +459,44 @@ def test_get_cube_image_name_fails_if_cube_not_configured(mocker, fs): # Act & Assert with pytest.raises(MedperfException): utils.get_cube_image_name(cube_path) + + +@pytest.mark.parametrize( + "associations,expected_result", + [ + ( + [ + {"dataset": 1, "created_at": datetime.fromtimestamp(5).isoformat()}, + {"dataset": 2, "created_at": datetime.fromtimestamp(6).isoformat()}, + {"dataset": 1, "created_at": datetime.fromtimestamp(7).isoformat()}, + ], + [ + {"dataset": 1, "created_at": datetime.fromtimestamp(7).isoformat()}, + {"dataset": 2, "created_at": datetime.fromtimestamp(6).isoformat()}, + ], + ), + ( + [ + {"dataset": 1, "created_at": datetime.fromtimestamp(5).isoformat()}, + {"dataset": 2, "created_at": datetime.fromtimestamp(6).isoformat()}, + {"dataset": 3, "created_at": datetime.fromtimestamp(7).isoformat()}, + {"dataset": 2, "created_at": datetime.fromtimestamp(4).isoformat()}, + ], + [ + {"dataset": 1, "created_at": datetime.fromtimestamp(5).isoformat()}, + {"dataset": 2, "created_at": datetime.fromtimestamp(6).isoformat()}, + {"dataset": 3, "created_at": datetime.fromtimestamp(7).isoformat()}, + ], + ), + ], +) +def test_filter_latest_associations_works_as_expected( + mocker, associations, expected_result +): + # Act + filtered = utils.filter_latest_associations(associations, "dataset") + + # Assert + assert sorted(filtered, key=lambda x: x["dataset"]) == sorted( + expected_result, key=lambda x: x["dataset"] + ) diff --git a/cli/medperf/utils.py b/cli/medperf/utils.py index cd367a9a6..d4c8d7ff3 100644 --- a/cli/medperf/utils.py +++ b/cli/medperf/utils.py @@ -16,10 +16,10 @@ import shutil from pexpect import spawn from datetime import datetime +from pydantic.datetime_parse import parse_datetime from typing import List from colorama import Fore, Style from pexpect.exceptions import TIMEOUT - import medperf.config as config from medperf.logging.filters.redacting_filter import RedactingFilter from medperf.exceptions import ExecutionError, MedperfException, InvalidEntityError @@ -382,9 +382,9 @@ def combine_proc_sp_text(proc: spawn) -> str: return proc_out -def get_folder_hash(path: str) -> str: - """Generates a hash for all the contents of the folder. This procedure - hashes all of the files in the folder, sorts them and then hashes that list. +def get_folders_hash(paths: List[str]) -> str: + """Generates a hash for all the contents of the fiven folders. This procedure + hashes all of the files in all passed folders, sorts them and then hashes that list. Args: path (str): Folder to hash @@ -393,11 +393,14 @@ def get_folder_hash(path: str) -> str: str: sha256 hash of the whole folder """ hashes = [] - for root, _, files in os.walk(path, topdown=False): - for file in files: - logging.debug(f"Hashing file {file}") - filepath = os.path.join(root, file) - hashes.append(get_file_hash(filepath)) + + # The hash doesn't depend on the order of paths or folders, as the hashes get sorted after the fact + for path in paths: + for root, _, files in os.walk(path, topdown=False): + for file in files: + logging.debug(f"Hashing file {file}") + filepath = os.path.join(root, file) + hashes.append(get_file_hash(filepath)) hashes = sorted(hashes) sha = hashlib.sha256() @@ -501,3 +504,27 @@ def verify_hash(obtained_hash: str, expected_hash: str): raise InvalidEntityError( f"Hash mismatch. Expected {expected_hash}, found {obtained_hash}." ) + + +def filter_latest_associations(associations, entity_key): + """Given a list of entity-benchmark associations, this function + retrieves a list containing the latest association of each + entity instance. + + Args: + associations (list[dict]): the list of associations + entity_key (str): either "dataset" or "model_mlcube" + + Returns: + list[dict]: the list containing the latest association of each + entity instance. + """ + + associations.sort(key=lambda assoc: parse_datetime(assoc["created_at"])) + latest_associations = {} + for assoc in associations: + entity_id = assoc[entity_key] + latest_associations[entity_id] = assoc + + latest_associations = list(latest_associations.values()) + return latest_associations diff --git a/server/README.md b/server/README.md index ee6360eda..63b940bf1 100644 --- a/server/README.md +++ b/server/README.md @@ -1,3 +1,64 @@ # Server Documentation TBD + +## Writing Tests + +Each endpoint must have a test file. An exception is for the endpoints defined in the utils folder, one single file contains tests for all its endpoints. + +### Naming conventions + +A test file in a module is named according to the relative endpoint it tests. For example, the test files for the `/datasets/` and `/benchmarks/` endpoints (POST and GET list) are named as `test_.py`. The test file for `/results//` endpoint is named as `test_pk.py`. + +### What to keep in mind when testing + +Testing an endpoint means testing, for each HTTP method it supports, the following: + +- Serializer validation rules (`serializers.py`) +- Database constraints (`models.py`) +- Permissions (referred to in `views.py`) + +Testing is focused on the actions that are not expected to happen, and focuses less on the actions that can happen (as an example, the tests should ensure that an unauthenticated user cannot access an endpoint, but they may not ensure that a certain type of user can edit a certain field.) + +### How tests should work + +Each test class should inherit from `MedPerfTest`, which sets up the local authentication and provides utils to create assets (users, datasets, mlcubes, ...) + +Each test class contains at least one test function. Both test classes and test class functions can be parameterized. **Each instance of a parameterized test is run independantly**; a new fresh database is used and the class's `SetUp` method is called prior to each test execution. + +### Running tests + +#### Run the whole tests + +To run the whole tests, run: + +```bash +python manage.py test +``` + +use the `--parallel` option to parallelize the tests. + +```bash +python manage.py test --parallel +``` + +#### Run individual files + +You can run individual tests files. For example: + +```bash +python manage.py test dataset.tests.test_pk +``` + +#### Run individual tests + +Running individual test classes or test functions can be done as follows. Example: + +```bash +python manage.py test benchmark.tests.test_ -k BenchmarkPostTest +python manage.py test benchmark.tests.test_ -k test_creation_of_duplicate_name_gets_rejected +``` + +### Debugging tests + +Tests are not "unittests". For example, the test suite for `dataset` relies on the `mlcube` functionalities. This is because the `dataset` tests use utils to create a preparation MLCube for the datasets. When debugging, it might be useful to run test suites in a certain order, and use the `--failfast` option to exit on the first failure. A script is provided for this: `debug_tests.sh`. diff --git a/server/benchmark/migrations/0002_alter_benchmark_demo_dataset_tarball_url.py b/server/benchmark/migrations/0002_alter_benchmark_demo_dataset_tarball_url.py new file mode 100644 index 000000000..f0863dff6 --- /dev/null +++ b/server/benchmark/migrations/0002_alter_benchmark_demo_dataset_tarball_url.py @@ -0,0 +1,18 @@ +# Generated by Django 3.2.20 on 2023-11-24 04:21 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('benchmark', '0001_initial'), + ] + + operations = [ + migrations.AlterField( + model_name='benchmark', + name='demo_dataset_tarball_url', + field=models.CharField(max_length=256), + ), + ] diff --git a/server/benchmark/models.py b/server/benchmark/models.py index 739bc5a44..2e0d51e3d 100644 --- a/server/benchmark/models.py +++ b/server/benchmark/models.py @@ -20,7 +20,7 @@ class Benchmark(models.Model): description = models.CharField(max_length=100, blank=True) docs_url = models.CharField(max_length=100, blank=True) owner = models.ForeignKey(User, on_delete=models.PROTECT) - demo_dataset_tarball_url = models.CharField(max_length=256, blank=True) + demo_dataset_tarball_url = models.CharField(max_length=256) demo_dataset_tarball_hash = models.CharField(max_length=100) demo_dataset_generated_uid = models.CharField(max_length=128) data_preparation_mlcube = models.ForeignKey( @@ -34,7 +34,9 @@ class Benchmark(models.Model): related_name="reference_model_mlcube", ) data_evaluator_mlcube = models.ForeignKey( - "mlcube.MlCube", on_delete=models.PROTECT, related_name="data_evaluator_mlcube", + "mlcube.MlCube", + on_delete=models.PROTECT, + related_name="data_evaluator_mlcube", ) metadata = models.JSONField(default=dict, blank=True, null=True) state = models.CharField( diff --git a/server/benchmark/permissions.py b/server/benchmark/permissions.py index 9ea74498a..807587a3f 100644 --- a/server/benchmark/permissions.py +++ b/server/benchmark/permissions.py @@ -1,5 +1,7 @@ from rest_framework.permissions import BasePermission from .models import Benchmark +from benchmarkdataset.models import BenchmarkDataset +from django.db.models import OuterRef, Subquery class IsAdmin(BasePermission): @@ -25,3 +27,34 @@ def has_permission(self, request, view): return True else: return False + + +# TODO: check effciency / database costs +class IsAssociatedDatasetOwner(BasePermission): + def has_permission(self, request, view): + pk = view.kwargs.get("pk", None) + if not pk: + return False + + if not request.user.is_authenticated: + # This check is to prevent internal server error + # since user.dataset_set is used below + return False + + latest_datasets_assocs_status = ( + BenchmarkDataset.objects.all() + .filter(benchmark__id=pk, dataset__id=OuterRef("id")) + .order_by("-created_at")[:1] + .values("approval_status") + ) + + user_associated_datasets = ( + request.user.dataset_set.all() + .annotate(assoc_status=Subquery(latest_datasets_assocs_status)) + .filter(assoc_status="APPROVED") + ) + + if user_associated_datasets.exists(): + return True + else: + return False diff --git a/server/benchmark/serializers.py b/server/benchmark/serializers.py index 224bf6698..11f007d37 100644 --- a/server/benchmark/serializers.py +++ b/server/benchmark/serializers.py @@ -18,6 +18,19 @@ def validate(self, data): raise serializers.ValidationError( "User can own at most one pending benchmark" ) + + if "state" in data and data["state"] == "OPERATION": + dev_mlcubes = [ + data["data_preparation_mlcube"].state == "DEVELOPMENT", + data["reference_model_mlcube"].state == "DEVELOPMENT", + data["data_evaluator_mlcube"].state == "DEVELOPMENT", + ] + if any(dev_mlcubes): + raise serializers.ValidationError( + "User cannot mark a benchmark as operational" + " if its MLCubes are not operational" + ) + return data @@ -28,34 +41,49 @@ class Meta: fields = "__all__" def update(self, instance, validated_data): + if "approval_status" in validated_data: + if validated_data["approval_status"] != instance.approval_status: + instance.approval_status = validated_data["approval_status"] + if instance.approval_status != "PENDING": + instance.approved_at = timezone.now() + validated_data.pop("approval_status", None) for k, v in validated_data.items(): setattr(instance, k, v) - if instance.approval_status != "PENDING": - instance.approved_at = timezone.now() instance.save() return instance - def validate(self, data): - owner = self.instance.owner - if "approval_status" in data: - if ( - data["approval_status"] == "PENDING" - and self.instance.approval_status != "PENDING" - ): - pending_benchmarks = Benchmark.objects.filter( - owner=owner, approval_status="PENDING" + def validate_approval_status(self, approval_status): + if approval_status == "PENDING": + raise serializers.ValidationError( + "User can only approve or reject a benchmark" + ) + if self.instance.state == "DEVELOPMENT": + raise serializers.ValidationError( + "User cannot approve or reject when benchmark is in development stage" + ) + + if approval_status == "APPROVED": + if self.instance.approval_status == "REJECTED": + raise serializers.ValidationError( + "User can approve only a pending request" ) - if len(pending_benchmarks) > 0: - raise serializers.ValidationError( - "User can own at most one pending benchmark" - ) - if ( - data["approval_status"] != "PENDING" - and self.instance.state == "DEVELOPMENT" - ): + return approval_status + + def validate_state(self, state): + if state == "OPERATION" and self.instance.state != "OPERATION": + dev_mlcubes = [ + self.instance.data_preparation_mlcube.state == "DEVELOPMENT", + self.instance.reference_model_mlcube.state == "DEVELOPMENT", + self.instance.data_evaluator_mlcube.state == "DEVELOPMENT", + ] + if any(dev_mlcubes): raise serializers.ValidationError( - "User cannot approve or reject when benchmark is in development stage" + "User cannot mark a benchmark as operational" + " if its MLCubes are not operational" ) + return state + + def validate(self, data): if self.instance.state == "OPERATION": editable_fields = [ "is_valid", diff --git a/server/benchmark/tests.py b/server/benchmark/tests/__init__.py similarity index 100% rename from server/benchmark/tests.py rename to server/benchmark/tests/__init__.py diff --git a/server/benchmark/tests/test_.py b/server/benchmark/tests/test_.py new file mode 100644 index 000000000..866dfb1c5 --- /dev/null +++ b/server/benchmark/tests/test_.py @@ -0,0 +1,336 @@ +from rest_framework import status + +from medperf.tests import MedPerfTest + +from parameterized import parameterized, parameterized_class + + +class BenchmarkTest(MedPerfTest): + def generic_setup(self): + # setup users + bmk_owner = "bmk_owner" + prep_mlcube_owner = "prep_mlcube_owner" + ref_mlcube_owner = "ref_mlcube_owner" + eval_mlcube_owner = "eval_mlcube_owner" + + self.create_user(bmk_owner) + self.create_user(prep_mlcube_owner) + self.create_user(ref_mlcube_owner) + self.create_user(eval_mlcube_owner) + + # create mlcubes + self.set_credentials(prep_mlcube_owner) + prep = self.mock_mlcube(name="prep", mlcube_hash="prep", state="OPERATION") + prep = self.create_mlcube(prep).data + + self.set_credentials(ref_mlcube_owner) + ref_model = self.mock_mlcube( + name="ref_model", mlcube_hash="ref_model", state="OPERATION" + ) + ref_model = self.create_mlcube(ref_model).data + + self.set_credentials(eval_mlcube_owner) + eval = self.mock_mlcube(name="eval", mlcube_hash="eval", state="OPERATION") + eval = self.create_mlcube(eval).data + + # setup globals + self.bmk_owner = bmk_owner + self.prep_mlcube_owner = prep_mlcube_owner + self.ref_mlcube_owner = ref_mlcube_owner + self.eval_mlcube_owner = eval_mlcube_owner + + self.prep = prep + self.ref_model = ref_model + self.eval = eval + + self.url = self.api_prefix + "/benchmarks/" + self.set_credentials(None) + + +@parameterized_class( + [ + {"actor": "prep_mlcube_owner"}, + {"actor": "ref_mlcube_owner"}, + {"actor": "eval_mlcube_owner"}, + {"actor": "bmk_owner"}, + ] +) +class BenchmarkPostTest(BenchmarkTest): + """Test module for POST /benchmarks""" + + def setUp(self): + super(BenchmarkPostTest, self).setUp() + self.generic_setup() + self.set_credentials(self.actor) + + def test_created_benchmark_fields_are_saved_as_expected(self): + """Testing the valid scenario""" + # Arrange + benchmark = self.mock_benchmark( + self.prep["id"], self.ref_model["id"], self.eval["id"] + ) + get_bmk_url = self.api_prefix + "/benchmarks/{0}/" + + # Act + response = self.client.post(self.url, benchmark, format="json") + + # Assert + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + uid = response.data["id"] + response = self.client.get(get_bmk_url.format(uid)) + + self.assertEqual( + response.status_code, status.HTTP_200_OK, "benchmark retreival faild" + ) + + for k, v in response.data.items(): + if k in benchmark: + self.assertEqual(benchmark[k], v, f"Unexpected value for {k}") + + @parameterized.expand([(True,), (False,)]) + def test_creation_of_duplicate_name_gets_rejected(self, different_name): + """Testing the model fields rules""" + # Arrange + benchmark = self.mock_benchmark( + self.prep["id"], self.ref_model["id"], self.eval["id"] + ) + self.create_benchmark(benchmark) + if different_name: + benchmark["name"] = "different name" + + # Act + response = self.client.post(self.url, benchmark, format="json") + + # Assert + if different_name: + exp_status = status.HTTP_201_CREATED + else: + exp_status = status.HTTP_400_BAD_REQUEST + self.assertEqual(response.status_code, exp_status) + + def test_default_values_are_as_expected(self): + """Testing the model fields rules""" + + # Arrange + default_values = { + "description": "", + "docs_url": "", + "metadata": {}, + "state": "DEVELOPMENT", + "is_valid": True, + "is_active": True, + "approval_status": "PENDING", + "user_metadata": {}, + "approved_at": None, + } + + benchmark = self.mock_benchmark( + self.prep["id"], self.ref_model["id"], self.eval["id"] + ) + for key in default_values: + if key in benchmark: + del benchmark[key] + + # Act + response = self.client.post(self.url, benchmark, format="json") + + # Assert + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + for key, val in default_values.items(): + self.assertEqual( + val, response.data[key], f"Unexpected default value for {key}" + ) + + @parameterized.expand([("PENDING",), ("APPROVED",), ("REJECTED",)]) + def test_creation_of_new_benchmark_while_previous_pending_is_rejected( + self, approval_status + ): + """Testing the serializer rules""" + # Arrange + benchmark = self.mock_benchmark( + self.prep["id"], self.ref_model["id"], self.eval["id"] + ) + self.create_benchmark(benchmark, target_approval_status=approval_status) + benchmark["name"] = "new name" + + # Act + response = self.client.post(self.url, benchmark, format="json") + + # Assert + if approval_status == "PENDING": + exp_status = status.HTTP_400_BAD_REQUEST + else: + exp_status = status.HTTP_201_CREATED + + self.assertEqual(response.status_code, exp_status) + + def test_readonly_fields(self): + """Testing the serializer rules""" + # Arrange + readonly = { + "owner": 10, + "approved_at": "some time", + "created_at": "some time", + "modified_at": "some time", + "approval_status": "APPROVED", + } + benchmark = self.mock_benchmark( + self.prep["id"], self.ref_model["id"], self.eval["id"] + ) + benchmark.update(readonly) + + # Act + response = self.client.post(self.url, benchmark, format="json") + + # Assert + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + for key, val in readonly.items(): + self.assertNotEqual( + val, response.data[key], f"readonly field {key} was modified" + ) + + def test_creating_operational_benchmark_with_prep_in_development(self): + # Arrange + + self.set_credentials(self.prep_mlcube_owner) + devprep = self.mock_mlcube( + name="devprep", mlcube_hash="devprep", state="DEVELOPMENT" + ) + devprep = self.create_mlcube(devprep).data + self.set_credentials(self.actor) + + benchmark = self.mock_benchmark( + devprep["id"], self.ref_model["id"], self.eval["id"], state="OPERATION" + ) + + # Act + response = self.client.post(self.url, benchmark, format="json") + + # Assert + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + def test_creating_operational_benchmark_with_refmodel_in_development(self): + # Arrange + + self.set_credentials(self.ref_mlcube_owner) + devrefmodel = self.mock_mlcube( + name="devrefmodel", mlcube_hash="devrefmodel", state="DEVELOPMENT" + ) + devrefmodel = self.create_mlcube(devrefmodel).data + self.set_credentials(self.actor) + + benchmark = self.mock_benchmark( + self.prep["id"], devrefmodel["id"], self.eval["id"], state="OPERATION" + ) + + # Act + response = self.client.post(self.url, benchmark, format="json") + + # Assert + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + def test_creating_operational_benchmark_with_eval_in_development(self): + # Arrange + self.set_credentials(self.eval_mlcube_owner) + deveval = self.mock_mlcube( + name="deveval", mlcube_hash="deveval", state="DEVELOPMENT" + ) + deveval = self.create_mlcube(deveval).data + self.set_credentials(self.actor) + + benchmark = self.mock_benchmark( + self.prep["id"], self.ref_model["id"], deveval["id"], state="OPERATION" + ) + + # Act + response = self.client.post(self.url, benchmark, format="json") + + # Assert + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + +@parameterized_class( + [ + {"actor": "prep_mlcube_owner"}, + {"actor": "ref_mlcube_owner"}, + {"actor": "eval_mlcube_owner"}, + {"actor": "bmk_owner"}, + {"actor": "other_user"}, + ] +) +class BenchmarkGetListTest(BenchmarkTest): + """Test module for GET /benchmarks/ endpoint""" + + def setUp(self): + super(BenchmarkGetListTest, self).setUp() + self.generic_setup() + self.set_credentials(self.bmk_owner) + benchmark = self.mock_benchmark( + self.prep["id"], self.ref_model["id"], self.eval["id"] + ) + benchmark = self.create_benchmark(benchmark).data + + other_user = "other_user" + self.create_user(other_user) + self.other_user = other_user + + self.testbenchmark = benchmark + self.set_credentials(self.actor) + + def test_generic_get_benchmark_list(self): + # Arrange + benchmark_id = self.testbenchmark["id"] + + # Act + response = self.client.get(self.url) + + # Assert + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(len(response.data["results"]), 1) + self.assertEqual(response.data["results"][0]["id"], benchmark_id) + + +class PermissionTest(BenchmarkTest): + """Test module for permissions of /benchmarks/ endpoint + Non-permitted actions: both GET and POST for unauthenticated users.""" + + def setUp(self): + super(PermissionTest, self).setUp() + self.generic_setup() + self.set_credentials(self.bmk_owner) + benchmark = self.mock_benchmark( + self.prep["id"], self.ref_model["id"], self.eval["id"] + ) + self.testbenchmark = benchmark + + @parameterized.expand( + [ + (None, status.HTTP_401_UNAUTHORIZED), + ] + ) + def test_get_permissions(self, user, exp_status): + # Arrange + self.set_credentials(self.bmk_owner) + self.create_benchmark(self.testbenchmark) + self.set_credentials(user) + + # Act + response = self.client.get(self.url) + + # Assert + self.assertEqual(response.status_code, exp_status) + + @parameterized.expand( + [ + (None, status.HTTP_401_UNAUTHORIZED), + ] + ) + def test_post_permissions(self, user, exp_status): + # Arrange + self.set_credentials(user) + + # Act + response = self.client.post(self.url, self.testbenchmark, format="json") + + # Assert + self.assertEqual(response.status_code, exp_status) diff --git a/server/benchmark/tests/test_pk.py b/server/benchmark/tests/test_pk.py new file mode 100644 index 000000000..18d6567dd --- /dev/null +++ b/server/benchmark/tests/test_pk.py @@ -0,0 +1,612 @@ +from rest_framework import status + +from medperf.tests import MedPerfTest + +from parameterized import parameterized, parameterized_class + + +class BenchmarkTest(MedPerfTest): + def generic_setup(self): + # setup users + bmk_owner = "bmk_owner" + prep_mlcube_owner = "prep_mlcube_owner" + ref_mlcube_owner = "ref_mlcube_owner" + eval_mlcube_owner = "eval_mlcube_owner" + other_user = "other_user" + + self.create_user(bmk_owner) + self.create_user(prep_mlcube_owner) + self.create_user(ref_mlcube_owner) + self.create_user(eval_mlcube_owner) + self.create_user(other_user) + + # setup globals + self.bmk_owner = bmk_owner + self.prep_mlcube_owner = prep_mlcube_owner + self.ref_mlcube_owner = ref_mlcube_owner + self.eval_mlcube_owner = eval_mlcube_owner + self.other_user = other_user + + self.url = self.api_prefix + "/benchmarks/{0}/" + self.set_credentials(None) + + +@parameterized_class( + [ + {"actor": "prep_mlcube_owner"}, + {"actor": "ref_mlcube_owner"}, + {"actor": "eval_mlcube_owner"}, + {"actor": "bmk_owner"}, + {"actor": "other_user"}, + ] +) +class BenchmarkGetTest(BenchmarkTest): + """Test module for GET /benchmarks/""" + + def setUp(self): + super(BenchmarkGetTest, self).setUp() + self.generic_setup() + _, _, _, testbenchmark = self.shortcut_create_benchmark( + self.prep_mlcube_owner, + self.ref_mlcube_owner, + self.eval_mlcube_owner, + self.bmk_owner, + target_approval_status="PENDING", + state="DEVELOPMENT", + ) + self.testbenchmark = testbenchmark + self.set_credentials(self.actor) + + def test_generic_get_benchmark(self): + # Arrange + bmk_id = self.testbenchmark["id"] + url = self.url.format(bmk_id) + + # Act + response = self.client.get(url) + + # Assert + self.assertEqual(response.status_code, status.HTTP_200_OK) + for k, v in response.data.items(): + if k in self.testbenchmark: + self.assertEqual(self.testbenchmark[k], v, f"Unexpected value for {k}") + + def test_benchmark_not_found(self): + # Arrange + invalid_id = 9999 + url = self.url.format(invalid_id) + + # Act + response = self.client.get(url) + + # Assert + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + + +@parameterized_class( + [ + {"actor": "bmk_owner"}, + ] +) +class BenchmarkPutTest(BenchmarkTest): + """Test module for PUT /benchmarks/ without approval_status""" + + def setUp(self): + super(BenchmarkPutTest, self).setUp() + self.generic_setup() + self.set_credentials(self.actor) + + def test_put_modifies_as_expected_in_development(self): + """All editable fields except approval status. It will + be tested in permissions testcases""" + # Arrange + _, _, _, testbenchmark = self.shortcut_create_benchmark( + self.prep_mlcube_owner, + self.ref_mlcube_owner, + self.eval_mlcube_owner, + self.bmk_owner, + target_approval_status="PENDING", + state="DEVELOPMENT", + ) + + new_data_preproc_mlcube = self.mock_mlcube( + name="new prep", mlcube_hash="new prep" + ) + new_ref_mlcube = self.mock_mlcube(name="new ref", mlcube_hash="new ref") + new_eval_mlcube = self.mock_mlcube(name="new eval", mlcube_hash="new eval") + new_prep_id = self.create_mlcube(new_data_preproc_mlcube).data["id"] + new_ref_id = self.create_mlcube(new_ref_mlcube).data["id"] + new_eval_id = self.create_mlcube(new_eval_mlcube).data["id"] + + newtestbenchmark = { + "name": "newstring", + "description": "newstring", + "docs_url": "newstring", + "demo_dataset_tarball_url": "newstring", + "demo_dataset_tarball_hash": "newstring", + "demo_dataset_generated_uid": "newstring", + "data_preparation_mlcube": new_prep_id, + "reference_model_mlcube": new_ref_id, + "data_evaluator_mlcube": new_eval_id, + "metadata": {"newkey": "newvalue"}, + "state": "OPERATION", + "is_valid": False, + "is_active": False, + "user_metadata": {"newkey2": "newvalue2"}, + } + url = self.url.format(testbenchmark["id"]) + + # Act + response = self.client.put(url, newtestbenchmark, format="json") + + # Assert + self.assertEqual(response.status_code, status.HTTP_200_OK) + response = self.client.get(url) + self.assertEqual(response.status_code, status.HTTP_200_OK) + + for k, v in response.data.items(): + if k in newtestbenchmark: + self.assertEqual(newtestbenchmark[k], v, f"{k} was not modified") + + @parameterized.expand([("APPROVED",), ("PENDING",)]) + def test_put_modifies_editable_fields_in_operation(self, benchmark_approval_status): + """All editable fields except approval status. It will + be tested in permissions testcases""" + + # Arrange + _, _, _, testbenchmark = self.shortcut_create_benchmark( + self.prep_mlcube_owner, + self.ref_mlcube_owner, + self.eval_mlcube_owner, + self.bmk_owner, + target_approval_status=benchmark_approval_status, + state="OPERATION", + ) + url = self.url.format(testbenchmark["id"]) + + newtestbenchmark = { + "is_valid": False, + "is_active": False, + "user_metadata": {"newkey": "newval"}, + "demo_dataset_tarball_url": "newstring", + } + url = self.url.format(testbenchmark["id"]) + + # Act + response = self.client.put(url, newtestbenchmark, format="json") + + # Assert + self.assertEqual(response.status_code, status.HTTP_200_OK) + response = self.client.get(url) + self.assertEqual(response.status_code, status.HTTP_200_OK) + + for k, v in response.data.items(): + if k in newtestbenchmark: + self.assertEqual(newtestbenchmark[k], v, f"{k} was not modified") + + @parameterized.expand([("APPROVED",), ("PENDING",)]) + def test_put_does_not_modify_non_editable_fields_in_operation( + self, benchmark_approval_status + ): + # Arrange + _, _, _, testbenchmark = self.shortcut_create_benchmark( + self.prep_mlcube_owner, + self.ref_mlcube_owner, + self.eval_mlcube_owner, + self.bmk_owner, + target_approval_status=benchmark_approval_status, + state="OPERATION", + ) + + new_data_preproc_mlcube = self.mock_mlcube( + name="new prep", mlcube_hash="new prep" + ) + new_ref_mlcube = self.mock_mlcube(name="new ref", mlcube_hash="new ref") + new_eval_mlcube = self.mock_mlcube(name="new eval", mlcube_hash="new eval") + new_prep_id = self.create_mlcube(new_data_preproc_mlcube).data["id"] + new_ref_id = self.create_mlcube(new_ref_mlcube).data["id"] + new_eval_id = self.create_mlcube(new_eval_mlcube).data["id"] + + newtestbenchmark = { + "name": "newstring", + "description": "newstring", + "docs_url": "newstring", + "demo_dataset_tarball_hash": "newstring", + "demo_dataset_generated_uid": "newstring", + "data_preparation_mlcube": new_prep_id, + "reference_model_mlcube": new_ref_id, + "data_evaluator_mlcube": new_eval_id, + "metadata": {"newkey": "newvalue"}, + "state": "DEVELOPMENT", + } + + url = self.url.format(testbenchmark["id"]) + + for key in newtestbenchmark: + # Act + response = self.client.put(url, {key: newtestbenchmark[key]}, format="json") + + # Assert + self.assertEqual( + response.status_code, + status.HTTP_400_BAD_REQUEST, + f"{key} was modified", + ) + + @parameterized.expand( + [ + ("DEVELOPMENT", "PENDING"), + ("OPERATION", "PENDING"), + ("OPERATION", "APPROVED"), + ] + ) + def test_put_does_not_modify_readonly_fields_in_all_states( + self, state, benchmark_approval_status + ): + # Arrange + _, _, _, testbenchmark = self.shortcut_create_benchmark( + self.prep_mlcube_owner, + self.ref_mlcube_owner, + self.eval_mlcube_owner, + self.bmk_owner, + target_approval_status=benchmark_approval_status, + state=state, + ) + + newtestbenchmark = { + "owner": 10, + "approved_at": "some time", + "created_at": "some time", + "modified_at": "some time", + } + url = self.url.format(testbenchmark["id"]) + + # Act + response = self.client.put(url, newtestbenchmark, format="json") + + # Assert + self.assertEqual(response.status_code, status.HTTP_200_OK) + response = self.client.get(url) + self.assertEqual(response.status_code, status.HTTP_200_OK) + for k, v in newtestbenchmark.items(): + self.assertNotEqual(v, response.data[k], f"{k} was modified") + + def test_put_respects_unique_name(self): + # Arrange + _, _, _, testbenchmark = self.shortcut_create_benchmark( + self.prep_mlcube_owner, + self.ref_mlcube_owner, + self.eval_mlcube_owner, + self.bmk_owner, + target_approval_status="APPROVED", + ) + + _, _, _, newtestbenchmark = self.shortcut_create_benchmark( + self.prep_mlcube_owner, + self.ref_mlcube_owner, + self.eval_mlcube_owner, + self.bmk_owner, + target_approval_status="PENDING", + prep_mlcube_kwargs={"name": "newprep", "mlcube_hash": "newprephash"}, + ref_mlcube_kwargs={"name": "newref", "mlcube_hash": "newrefhash"}, + eval_mlcube_kwargs={"name": "neweval", "mlcube_hash": "newevalhash"}, + state="DEVELOPMENT", + name="newname", + ) + + put_body = {"name": testbenchmark["name"]} + url = self.url.format(newtestbenchmark["id"]) + + # Act + response = self.client.put(url, put_body, format="json") + + # Assert + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + def test_marking_as_operation_requires_prep_to_be_operation(self): + # Arrange + _, _, _, testbenchmark = self.shortcut_create_benchmark( + self.prep_mlcube_owner, + self.ref_mlcube_owner, + self.eval_mlcube_owner, + self.bmk_owner, + target_approval_status="PENDING", + state="DEVELOPMENT", + prep_mlcube_kwargs={"state": "DEVELOPMENT"}, + ) + + url = self.url.format(testbenchmark["id"]) + + # Act + response = self.client.put(url, {"state": "OPERATION"}, format="json") + + # Assert + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + def test_marking_as_operation_requires_refmodel_to_be_operation(self): + # Arrange + _, _, _, testbenchmark = self.shortcut_create_benchmark( + self.prep_mlcube_owner, + self.ref_mlcube_owner, + self.eval_mlcube_owner, + self.bmk_owner, + target_approval_status="PENDING", + state="DEVELOPMENT", + ref_mlcube_kwargs={"state": "DEVELOPMENT"}, + ) + + url = self.url.format(testbenchmark["id"]) + + # Act + response = self.client.put(url, {"state": "OPERATION"}, format="json") + + # Assert + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + def test_marking_as_operation_requires_evaluator_to_be_operation(self): + # Arrange + _, _, _, testbenchmark = self.shortcut_create_benchmark( + self.prep_mlcube_owner, + self.ref_mlcube_owner, + self.eval_mlcube_owner, + self.bmk_owner, + target_approval_status="PENDING", + state="DEVELOPMENT", + eval_mlcube_kwargs={"state": "DEVELOPMENT"}, + ) + + url = self.url.format(testbenchmark["id"]) + + # Act + response = self.client.put(url, {"state": "OPERATION"}, format="json") + + # Assert + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + +@parameterized_class( + [ + {"actor": "api_admin"}, + ] +) +class BenchmarkApproveTest(BenchmarkTest): + """Test module for PUT /benchmarks/ with approval_status field""" + + def setUp(self): + super(BenchmarkApproveTest, self).setUp() + self.generic_setup() + self.set_credentials(self.actor) + + @parameterized.expand( + [ + ("PENDING", "APPROVED"), + ("PENDING", "REJECTED"), + ] + ) + def test_approval_status_cannot_be_changed_in_development( + self, prev_approval_status, new_approval_status + ): + # Arrange + _, _, _, testbenchmark = self.shortcut_create_benchmark( + self.prep_mlcube_owner, + self.ref_mlcube_owner, + self.eval_mlcube_owner, + self.bmk_owner, + target_approval_status=prev_approval_status, + state="DEVELOPMENT", + ) + url = self.url.format(testbenchmark["id"]) + + # Act + response = self.client.put( + url, {"approval_status": new_approval_status}, format="json" + ) + + # Assert + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + @parameterized.expand( + [ + ("PENDING", "APPROVED", status.HTTP_200_OK), + ("PENDING", "REJECTED", status.HTTP_200_OK), + ("APPROVED", "REJECTED", status.HTTP_200_OK), + ("APPROVED", "PENDING", status.HTTP_400_BAD_REQUEST), + ("REJECTED", "PENDING", status.HTTP_400_BAD_REQUEST), + ("REJECTED", "APPROVED", status.HTTP_400_BAD_REQUEST), + ] + ) + def test_approval_status_change_in_operation( + self, prev_approval_status, new_approval_status, exp_status_code + ): + # Arrange + _, _, _, testbenchmark = self.shortcut_create_benchmark( + self.prep_mlcube_owner, + self.ref_mlcube_owner, + self.eval_mlcube_owner, + self.bmk_owner, + target_approval_status=prev_approval_status, + state="OPERATION", + ) + url = self.url.format(testbenchmark["id"]) + + # Act + response = self.client.put( + url, {"approval_status": new_approval_status}, format="json" + ) + + # Assert + self.assertEqual(response.status_code, exp_status_code) + + +@parameterized_class( + [ + {"actor": "api_admin"}, + ] +) +class BenchmarkDeleteTest(BenchmarkTest): + """Test module for DELETE /benchmarks/""" + + def setUp(self): + super(BenchmarkDeleteTest, self).setUp() + self.generic_setup() + _, _, _, testbenchmark = self.shortcut_create_benchmark( + self.prep_mlcube_owner, + self.ref_mlcube_owner, + self.eval_mlcube_owner, + self.bmk_owner, + target_approval_status="PENDING", + state="DEVELOPMENT", + ) + self.testbenchmark = testbenchmark + self.set_credentials(self.actor) + + def test_deletion_works_as_expected(self): + # Arrange + url = self.url.format(self.testbenchmark["id"]) + + # Act + response = self.client.delete(url) + + # Assert + self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) + response = self.client.get(url) + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + + +class PermissionTest(BenchmarkTest): + """Test module for permissions of /benchmarks/{pk} endpoint + Non-permitted actions: + GET: for unauthenticated users + DELETE: for all users except admin + PUT: + including approval_status: for all users except admin + not including approval_status: for all users except bmk_owner and admin + """ + + def setUp(self): + super(PermissionTest, self).setUp() + self.generic_setup() + + _, _, _, testbenchmark = self.shortcut_create_benchmark( + self.prep_mlcube_owner, + self.ref_mlcube_owner, + self.eval_mlcube_owner, + self.bmk_owner, + target_approval_status="PENDING", + state="DEVELOPMENT", + ) + self.testbenchmark = testbenchmark + self.url = self.url.format(self.testbenchmark["id"]) + + @parameterized.expand( + [ + (None, status.HTTP_401_UNAUTHORIZED), + ] + ) + def test_get_permissions(self, user, expected_status): + # Arrange + self.set_credentials(user) + + # Act + response = self.client.get(self.url) + + # Assert + self.assertEqual(response.status_code, expected_status) + + @parameterized.expand( + [ + ("prep_mlcube_owner", status.HTTP_403_FORBIDDEN), + ("ref_mlcube_owner", status.HTTP_403_FORBIDDEN), + ("eval_mlcube_owner", status.HTTP_403_FORBIDDEN), + ("other_user", status.HTTP_403_FORBIDDEN), + (None, status.HTTP_401_UNAUTHORIZED), + ] + ) + def test_put_permissions(self, user, expected_status): + # Arrange + self.set_credentials(self.bmk_owner) + + new_data_preproc_mlcube = self.mock_mlcube( + name="new prep", mlcube_hash="new prep" + ) + new_ref_mlcube = self.mock_mlcube(name="new ref", mlcube_hash="new ref") + new_eval_mlcube = self.mock_mlcube(name="new eval", mlcube_hash="new eval") + new_prep_id = self.create_mlcube(new_data_preproc_mlcube).data["id"] + new_ref_id = self.create_mlcube(new_ref_mlcube).data["id"] + new_eval_id = self.create_mlcube(new_eval_mlcube).data["id"] + + newtestbenchmark = { + "name": "newstring", + "description": "newstring", + "docs_url": "newstring", + "demo_dataset_tarball_url": "newstring", + "demo_dataset_tarball_hash": "newstring", + "demo_dataset_generated_uid": "newstring", + "data_preparation_mlcube": new_prep_id, + "reference_model_mlcube": new_ref_id, + "data_evaluator_mlcube": new_eval_id, + "metadata": {"newkey": "newvalue"}, + "state": "OPERATION", + "is_valid": False, + "is_active": False, + "user_metadata": {"newkey2": "newvalue2"}, + "owner": 10, + "approved_at": "some time", + "created_at": "some time", + "modified_at": "some time", + } + + self.set_credentials(user) + + for key in newtestbenchmark: + # Act + response = self.client.put( + self.url, {key: newtestbenchmark[key]}, format="json" + ) + + # Assert + self.assertEqual( + response.status_code, expected_status, f"{key} was modified" + ) + + @parameterized.expand( + [ + ("bmk_owner", status.HTTP_403_FORBIDDEN), + ("prep_mlcube_owner", status.HTTP_403_FORBIDDEN), + ("ref_mlcube_owner", status.HTTP_403_FORBIDDEN), + ("eval_mlcube_owner", status.HTTP_403_FORBIDDEN), + ("other_user", status.HTTP_403_FORBIDDEN), + (None, status.HTTP_401_UNAUTHORIZED), + ] + ) + def test_put_approval_status_permissions(self, user, expected_status): + # Arrange + self.set_credentials(user) + + # Act + response = self.client.put( + self.url, {"approval_status", "APPROVED"}, format="json" + ) + + # Assert + self.assertEqual(response.status_code, expected_status) + + @parameterized.expand( + [ + ("bmk_owner", status.HTTP_403_FORBIDDEN), + ("prep_mlcube_owner", status.HTTP_403_FORBIDDEN), + ("ref_mlcube_owner", status.HTTP_403_FORBIDDEN), + ("eval_mlcube_owner", status.HTTP_403_FORBIDDEN), + ("other_user", status.HTTP_403_FORBIDDEN), + (None, status.HTTP_401_UNAUTHORIZED), + ] + ) + def test_delete_permissions(self, user, expected_status): + # Arrange + self.set_credentials(user) + + # Act + response = self.client.delete(self.url) + + # Assert + self.assertEqual(response.status_code, expected_status) diff --git a/server/benchmark/tests/test_pk_datasets.py b/server/benchmark/tests/test_pk_datasets.py new file mode 100644 index 000000000..ca154ec88 --- /dev/null +++ b/server/benchmark/tests/test_pk_datasets.py @@ -0,0 +1,145 @@ +from rest_framework import status + +from medperf.tests import MedPerfTest + +from parameterized import parameterized, parameterized_class + + +class BenchmarkTest(MedPerfTest): + def generic_setup(self): + # setup users + bmk_owner = "bmk_owner" + prep_mlcube_owner = "prep_mlcube_owner" + ref_mlcube_owner = "ref_mlcube_owner" + eval_mlcube_owner = "eval_mlcube_owner" + data1_owner = "data1_owner" + data2_owner = "data2_owner" + other_user = "other_user" + + self.create_user(bmk_owner) + self.create_user(prep_mlcube_owner) + self.create_user(ref_mlcube_owner) + self.create_user(eval_mlcube_owner) + self.create_user(data1_owner) + self.create_user(data2_owner) + self.create_user(other_user) + + # create benchmark and datasets + prep, _, _, benchmark = self.shortcut_create_benchmark( + prep_mlcube_owner, ref_mlcube_owner, eval_mlcube_owner, bmk_owner + ) + data1 = self.mock_dataset( + prep["id"], generated_uid="dataset1", state="OPERATION" + ) + data2 = self.mock_dataset( + prep["id"], generated_uid="dataset2", state="OPERATION" + ) + + self.set_credentials(data1_owner) + data1 = self.create_dataset(data1).data + self.set_credentials(data2_owner) + data2 = self.create_dataset(data2).data + + # setup globals + self.bmk_owner = bmk_owner + self.prep_mlcube_owner = prep_mlcube_owner + self.ref_mlcube_owner = ref_mlcube_owner + self.eval_mlcube_owner = eval_mlcube_owner + self.data1_owner = data1_owner + self.data2_owner = data2_owner + self.other_user = other_user + self.benchmark_id = benchmark["id"] + self.data1_id = data1["id"] + self.data2_id = data2["id"] + + self.url = self.api_prefix + "/benchmarks/{0}/datasets/" + self.set_credentials(None) + + +@parameterized_class( + [ + {"actor": "bmk_owner"}, + ] +) +class BenchmarkDatasetGetListTest(BenchmarkTest): + """Test module for GET /benchmarks//datasets/ endpoint""" + + def setUp(self): + super(BenchmarkDatasetGetListTest, self).setUp() + self.generic_setup() + # create two association for data1 (one rejected one approved) + # and one pending association for data2 (just an arbitrary example) + + assoc = self.mock_dataset_association( + self.benchmark_id, self.data1_id, approval_status="REJECTED" + ) + self.create_dataset_association(assoc, self.data1_owner, self.bmk_owner) + assoc = self.mock_dataset_association( + self.benchmark_id, self.data1_id, approval_status="APPROVED" + ) + self.create_dataset_association(assoc, self.data1_owner, self.bmk_owner) + + assoc = self.mock_dataset_association( + self.benchmark_id, self.data2_id, approval_status="PENDING" + ) + self.create_dataset_association(assoc, self.data2_owner, self.bmk_owner) + + self.visible_fields = ["approval_status", "created_at", "dataset"] + self.set_credentials(self.actor) + + def test_generic_get_benchmark_datasets_list(self): + # Arrange + url = self.url.format(self.benchmark_id) + + # Act + response = self.client.get(url) + + # Assert + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual( + len(response.data["results"]), 3, "unexpected number of assocs" + ) + for assoc in response.data["results"]: + for key in assoc: + self.assertIn(key, self.visible_fields, f"{key} shouldn't be visible") + + +class PermissionTest(BenchmarkTest): + """Test module for permissions of /benchmarks/{pk}/datasets/ endpoint + Non-permitted actions: + GET: for all users except benchmark owner and admin + """ + + def setUp(self): + super(PermissionTest, self).setUp() + self.generic_setup() + # create two association for data1 (one rejected one approved) + # and one pending association for data2 (just an arbitrary example) + + assoc = self.mock_dataset_association( + self.benchmark_id, self.data1_id, approval_status="REJECTED" + ) + self.create_dataset_association(assoc, self.data1_owner, self.bmk_owner) + self.url = self.url.format(self.benchmark_id) + self.set_credentials(None) + + @parameterized.expand( + [ + ("prep_mlcube_owner", status.HTTP_403_FORBIDDEN), + ("ref_mlcube_owner", status.HTTP_403_FORBIDDEN), + ("eval_mlcube_owner", status.HTTP_403_FORBIDDEN), + ("data1_owner", status.HTTP_403_FORBIDDEN), + ("data2_owner", status.HTTP_403_FORBIDDEN), + ("other_user", status.HTTP_403_FORBIDDEN), + (None, status.HTTP_401_UNAUTHORIZED), + ] + ) + def test_get_permissions(self, user, expected_status): + # Arrange + self.set_credentials(user) + + # Act + response = self.client.get(self.url) + + # Assert + self.assertEqual(response.status_code, expected_status) diff --git a/server/benchmark/tests/test_pk_models.py b/server/benchmark/tests/test_pk_models.py new file mode 100644 index 000000000..bdc979d63 --- /dev/null +++ b/server/benchmark/tests/test_pk_models.py @@ -0,0 +1,176 @@ +from rest_framework import status + +from medperf.tests import MedPerfTest + +from parameterized import parameterized, parameterized_class + + +class BenchmarkTest(MedPerfTest): + def generic_setup(self): + # setup users + bmk_owner = "bmk_owner" + prep_mlcube_owner = "prep_mlcube_owner" + ref_mlcube_owner = "ref_mlcube_owner" + eval_mlcube_owner = "eval_mlcube_owner" + mlcube1_owner = "mlcube1_owner" + mlcube2_owner = "mlcube2_owner" + data1_owner = "data1_owner" + data2_owner = "data2_owner" + other_user = "other_user" + + self.create_user(bmk_owner) + self.create_user(prep_mlcube_owner) + self.create_user(ref_mlcube_owner) + self.create_user(eval_mlcube_owner) + self.create_user(mlcube1_owner) + self.create_user(mlcube2_owner) + self.create_user(data1_owner) + self.create_user(data2_owner) + self.create_user(other_user) + + # create benchmark and mlcubes + prep, _, _, benchmark = self.shortcut_create_benchmark( + prep_mlcube_owner, ref_mlcube_owner, eval_mlcube_owner, bmk_owner + ) + mlcube1 = self.mock_mlcube( + name="mlcube1", mlcube_hash="mlcube1hash", state="OPERATION" + ) + mlcube2 = self.mock_mlcube( + name="mlcube2", mlcube_hash="mlcube2hash", state="OPERATION" + ) + self.set_credentials(mlcube1_owner) + mlcube1 = self.create_mlcube(mlcube1).data + self.set_credentials(mlcube2_owner) + mlcube2 = self.create_mlcube(mlcube2).data + + # create datasets + dataset1 = self.mock_dataset( + prep["id"], generated_uid="dataset1", state="OPERATION" + ) + dataset2 = self.mock_dataset( + prep["id"], generated_uid="dataset2", state="OPERATION" + ) + self.set_credentials(data1_owner) + dataset1 = self.create_dataset(dataset1).data + self.set_credentials(data2_owner) + dataset2 = self.create_dataset(dataset2).data + + # create dataset associations: one approved and one not + # This is to test then that data1_owner can't see the mlcube associations + # but data2_owner can + assoc1 = self.mock_dataset_association( + benchmark["id"], dataset1["id"], approval_status="PENDING" + ) + self.create_dataset_association(assoc1, data1_owner, bmk_owner) + assoc2 = self.mock_dataset_association( + benchmark["id"], dataset2["id"], approval_status="APPROVED" + ) + self.create_dataset_association(assoc2, data2_owner, bmk_owner) + + # setup globals + self.bmk_owner = bmk_owner + self.prep_mlcube_owner = prep_mlcube_owner + self.ref_mlcube_owner = ref_mlcube_owner + self.eval_mlcube_owner = eval_mlcube_owner + self.mlcube1_owner = mlcube1_owner + self.mlcube2_owner = mlcube2_owner + self.data1_owner = data1_owner + self.data2_owner = data2_owner + self.other_user = other_user + self.benchmark_id = benchmark["id"] + self.mlcube1_id = mlcube1["id"] + self.mlcube2_id = mlcube2["id"] + + self.url = self.api_prefix + "/benchmarks/{0}/models/" + self.set_credentials(None) + + +@parameterized_class( + [ + {"actor": "bmk_owner"}, + {"actor": "data2_owner"}, + ] +) +class BenchmarkModelGetListTest(BenchmarkTest): + """Test module for GET /benchmarks//models/ endpoint""" + + def setUp(self): + super(BenchmarkModelGetListTest, self).setUp() + self.generic_setup() + # create two association for mlcube1 (one rejected one approved) + # and one pending association for mlcube2 (just an arbitrary example) + + assoc = self.mock_mlcube_association( + self.benchmark_id, self.mlcube1_id, approval_status="REJECTED" + ) + self.create_mlcube_association(assoc, self.mlcube1_owner, self.bmk_owner) + assoc = self.mock_mlcube_association( + self.benchmark_id, self.mlcube1_id, approval_status="APPROVED" + ) + self.create_mlcube_association(assoc, self.mlcube1_owner, self.bmk_owner) + + assoc = self.mock_mlcube_association( + self.benchmark_id, self.mlcube2_id, approval_status="PENDING" + ) + self.create_mlcube_association(assoc, self.mlcube2_owner, self.bmk_owner) + + self.visible_fields = ["approval_status", "created_at", "model_mlcube"] + self.set_credentials(self.actor) + + def test_generic_get_benchmark_models_list(self): + # Arrange + url = self.url.format(self.benchmark_id) + + # Act + response = self.client.get(url) + + # Assert + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual( + len(response.data["results"]), 3, "unexpected number of assocs" + ) + for assoc in response.data["results"]: + for key in assoc: + self.assertIn(key, self.visible_fields, f"{key} shouldn't be visible") + + +class PermissionTest(BenchmarkTest): + """Test module for permissions of /benchmarks/{pk}/models/ endpoint + Non-permitted actions: + GET: for unauthenticated users + """ + + def setUp(self): + super(PermissionTest, self).setUp() + self.generic_setup() + # create two association for mlcube1 (one rejected one approved) + # and one pending association for mlcube2 (just an arbitrary example) + + assoc = self.mock_mlcube_association( + self.benchmark_id, self.mlcube1_id, approval_status="REJECTED" + ) + self.create_mlcube_association(assoc, self.mlcube1_owner, self.bmk_owner) + self.url = self.url.format(self.benchmark_id) + self.set_credentials(None) + + @parameterized.expand( + [ + ("prep_mlcube_owner", status.HTTP_403_FORBIDDEN), + ("ref_mlcube_owner", status.HTTP_403_FORBIDDEN), + ("eval_mlcube_owner", status.HTTP_403_FORBIDDEN), + ("mlcube1_owner", status.HTTP_403_FORBIDDEN), + ("mlcube2_owner", status.HTTP_403_FORBIDDEN), + ("data1_owner", status.HTTP_403_FORBIDDEN), + ("other_user", status.HTTP_403_FORBIDDEN), + (None, status.HTTP_401_UNAUTHORIZED), + ] + ) + def test_get_permissions(self, user, expected_status): + # Arrange + self.set_credentials(user) + + # Act + response = self.client.get(self.url) + + # Assert + self.assertEqual(response.status_code, expected_status) diff --git a/server/benchmark/views.py b/server/benchmark/views.py index 2314d93d8..834233bbe 100644 --- a/server/benchmark/views.py +++ b/server/benchmark/views.py @@ -1,5 +1,5 @@ -from mlcube.serializers import MlCubeSerializer -from dataset.serializers import DatasetSerializer +from benchmarkmodel.serializers import BenchmarkListofModelsSerializer +from benchmarkdataset.serializers import BenchmarkListofDatasetsSerializer from result.serializers import ModelResultSerializer from report.serializers import ReportSerializer from django.http import Http404 @@ -10,7 +10,7 @@ from .models import Benchmark from .serializers import BenchmarkSerializer, BenchmarkApprovalSerializer -from .permissions import IsAdmin, IsBenchmarkOwner +from .permissions import IsAdmin, IsBenchmarkOwner, IsAssociatedDatasetOwner class BenchmarkList(GenericAPIView): @@ -41,7 +41,8 @@ def post(self, request, format=None): class BenchmarkModelList(GenericAPIView): - serializer_class = MlCubeSerializer + permission_classes = [IsAdmin | IsBenchmarkOwner | IsAssociatedDatasetOwner] + serializer_class = BenchmarkListofModelsSerializer queryset = "" def get_object(self, pk): @@ -55,15 +56,15 @@ def get(self, request, pk, format=None): Retrieve models associated with a benchmark instance. """ benchmark = self.get_object(pk) - modelgroups = benchmark.benchmarkmodel_set.all() - models = [gp.model_mlcube for gp in modelgroups] + models = benchmark.benchmarkmodel_set.all() models = self.paginate_queryset(models) - serializer = MlCubeSerializer(models, many=True) + serializer = BenchmarkListofModelsSerializer(models, many=True) return self.get_paginated_response(serializer.data) class BenchmarkDatasetList(GenericAPIView): - serializer_class = DatasetSerializer + permission_classes = [IsAdmin | IsBenchmarkOwner] + serializer_class = BenchmarkListofDatasetsSerializer queryset = "" def get_object(self, pk): @@ -77,10 +78,9 @@ def get(self, request, pk, format=None): Retrieve datasets associated with a benchmark instance. """ benchmark = self.get_object(pk) - datasetgroups = benchmark.benchmarkdataset_set.all() - datasets = [gp.dataset for gp in datasetgroups] + datasets = benchmark.benchmarkdataset_set.all() datasets = self.paginate_queryset(datasets) - serializer = DatasetSerializer(datasets, many=True) + serializer = BenchmarkListofDatasetsSerializer(datasets, many=True) return self.get_paginated_response(serializer.data) @@ -135,6 +135,8 @@ class BenchmarkDetail(GenericAPIView): def get_permissions(self): if self.request.method == "PUT": self.permission_classes = [IsAdmin | IsBenchmarkOwner] + if "approval_status" in self.request.data: + self.permission_classes = [IsAdmin] elif self.request.method == "DELETE": self.permission_classes = [IsAdmin] return super(self.__class__, self).get_permissions() diff --git a/server/benchmarkdataset/serializers.py b/server/benchmarkdataset/serializers.py index 0b7e70e35..9cc120079 100644 --- a/server/benchmarkdataset/serializers.py +++ b/server/benchmarkdataset/serializers.py @@ -12,10 +12,33 @@ class Meta: read_only_fields = ["initiated_by", "approved_at"] fields = "__all__" + def __validate_approval_status(self, last_benchmarkdataset, approval_status): + if not last_benchmarkdataset: + if approval_status != "PENDING": + raise serializers.ValidationError( + "User can approve or reject association request only if there are prior requests" + ) + else: + if approval_status == "PENDING": + if last_benchmarkdataset.approval_status != "REJECTED": + raise serializers.ValidationError( + "User can create a new request only if prior request is rejected" + ) + elif approval_status == "APPROVED": + raise serializers.ValidationError( + "User cannot create an approved association request" + ) + # approval_status == "REJECTED": + else: + if last_benchmarkdataset.approval_status != "APPROVED": + raise serializers.ValidationError( + "User can reject request only if prior request is approved" + ) + def validate(self, data): bid = self.context["request"].data.get("benchmark") dataset = self.context["request"].data.get("dataset") - approval_status = self.context["request"].data.get("approval_status") + approval_status = self.context["request"].data.get("approval_status", "PENDING") benchmark = Benchmark.objects.get(pk=bid) benchmark_state = benchmark.state if benchmark_state != "OPERATION": @@ -27,37 +50,24 @@ def validate(self, data): raise serializers.ValidationError( "Association requests can be made only on an approved benchmark" ) - dataset_state = Dataset.objects.get(pk=dataset).state + dataset_obj = Dataset.objects.get(pk=dataset) + dataset_state = dataset_obj.state if dataset_state != "OPERATION": raise serializers.ValidationError( "Association requests can be made only on an operational dataset" ) + if dataset_obj.data_preparation_mlcube != benchmark.data_preparation_mlcube: + raise serializers.ValidationError( + "Dataset association request can be made only if the dataset" + " was prepared with benchmark's data preparation MLCube" + ) last_benchmarkdataset = ( BenchmarkDataset.objects.filter(benchmark__id=bid, dataset__id=dataset) .order_by("-created_at") .first() ) - if not last_benchmarkdataset: - if approval_status != "PENDING": - raise serializers.ValidationError( - "User can approve or reject association request only if there are prior requests" - ) - else: - if approval_status == "PENDING": - if last_benchmarkdataset.approval_status != "REJECTED": - raise serializers.ValidationError( - "User can create a new request only if prior request is rejected" - ) - elif approval_status == "APPROVED": - raise serializers.ValidationError( - "User cannot create an approved association request" - ) - # approval_status == "REJECTED": - else: - if last_benchmarkdataset.approval_status != "APPROVED": - raise serializers.ValidationError( - "User can reject request only if prior request is approved" - ) + self.__validate_approval_status(last_benchmarkdataset, approval_status) + return data def create(self, validated_data): @@ -89,27 +99,34 @@ class Meta: def validate(self, data): if not self.instance: raise serializers.ValidationError("No dataset association found") + return data + + def validate_approval_status(self, cur_approval_status): last_approval_status = self.instance.approval_status - cur_approval_status = data["approval_status"] if last_approval_status != "PENDING": raise serializers.ValidationError( "User can approve or reject only a pending request" ) initiated_user = self.instance.initiated_by current_user = self.context["request"].user - if ( - last_approval_status != cur_approval_status - and cur_approval_status == "APPROVED" - ): + if cur_approval_status == "APPROVED": if current_user.id == initiated_user.id: raise serializers.ValidationError( "Same user cannot approve the association request" ) - return data + return cur_approval_status def update(self, instance, validated_data): - instance.approval_status = validated_data["approval_status"] - if instance.approval_status != "PENDING": - instance.approved_at = timezone.now() + if "approval_status" in validated_data: + if validated_data["approval_status"] != instance.approval_status: + instance.approval_status = validated_data["approval_status"] + if instance.approval_status != "PENDING": + instance.approved_at = timezone.now() instance.save() return instance + + +class BenchmarkListofDatasetsSerializer(serializers.ModelSerializer): + class Meta: + model = BenchmarkDataset + fields = ["dataset", "approval_status", "created_at"] diff --git a/server/benchmarkdataset/views.py b/server/benchmarkdataset/views.py index f4ab47b50..6aba370a8 100644 --- a/server/benchmarkdataset/views.py +++ b/server/benchmarkdataset/views.py @@ -30,7 +30,7 @@ def post(self, request, format=None): return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) -class BenchmarkDatasetApproval(GenericAPIView): +class DatasetBenchmarksList(GenericAPIView): serializer_class = BenchmarkDatasetListSerializer queryset = "" @@ -52,10 +52,15 @@ def get(self, request, pk, format=None): class DatasetApproval(GenericAPIView): - permission_classes = [IsAdmin | IsBenchmarkOwner | IsDatasetOwner] serializer_class = DatasetApprovalSerializer queryset = "" + def get_permissions(self): + self.permission_classes = [IsAdmin | IsBenchmarkOwner | IsDatasetOwner] + if self.request.method == "DELETE": + self.permission_classes = [IsAdmin] + return super(self.__class__, self).get_permissions() + def get_object(self, dataset_id, benchmark_id): try: return BenchmarkDataset.objects.filter( diff --git a/server/benchmarkmodel/serializers.py b/server/benchmarkmodel/serializers.py index 9d0659b5c..afa34acd4 100644 --- a/server/benchmarkmodel/serializers.py +++ b/server/benchmarkmodel/serializers.py @@ -101,10 +101,7 @@ def validate_approval_status(self, cur_approval_status): ) initiated_user = self.instance.initiated_by current_user = self.context["request"].user - if ( - last_approval_status != cur_approval_status - and cur_approval_status == "APPROVED" - ): + if cur_approval_status == "APPROVED": if current_user.id == initiated_user.id: raise serializers.ValidationError( "Same user cannot approve the association request" @@ -113,10 +110,17 @@ def validate_approval_status(self, cur_approval_status): def update(self, instance, validated_data): if "approval_status" in validated_data: - instance.approval_status = validated_data["approval_status"] - if instance.approval_status != "PENDING": - instance.approved_at = timezone.now() + if validated_data["approval_status"] != instance.approval_status: + instance.approval_status = validated_data["approval_status"] + if instance.approval_status != "PENDING": + instance.approved_at = timezone.now() if "priority" in validated_data: instance.priority = validated_data["priority"] instance.save() return instance + + +class BenchmarkListofModelsSerializer(serializers.ModelSerializer): + class Meta: + model = BenchmarkModel + fields = ["model_mlcube", "approval_status", "created_at"] diff --git a/server/benchmarkmodel/views.py b/server/benchmarkmodel/views.py index ccf841022..c4ec01301 100644 --- a/server/benchmarkmodel/views.py +++ b/server/benchmarkmodel/views.py @@ -30,7 +30,7 @@ def post(self, request, format=None): return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) -class BenchmarkModelApproval(GenericAPIView): +class ModelBenchmarksList(GenericAPIView): serializer_class = BenchmarkModelListSerializer queryset = "" @@ -59,6 +59,8 @@ def get_permissions(self): self.permission_classes = [IsAdmin | IsBenchmarkOwner | IsMlCubeOwner] if self.request.method == "PUT" and "priority" in self.request.data: self.permission_classes = [IsAdmin | IsBenchmarkOwner] + elif self.request.method == "DELETE": + self.permission_classes = [IsAdmin] return super(self.__class__, self).get_permissions() def get_object(self, model_id, benchmark_id): diff --git a/server/dataset/tests.py b/server/dataset/tests.py deleted file mode 100644 index 8756d5290..000000000 --- a/server/dataset/tests.py +++ /dev/null @@ -1,145 +0,0 @@ -from django.conf import settings -from django.contrib.auth import get_user_model -from rest_framework.test import APIClient -from rest_framework import status - -from medperf.tests import MedPerfTest - -User = get_user_model() - - -class DatasetTest(MedPerfTest): - """Test module for Dataset APIs""" - - def setUp(self): - super(DatasetTest, self).setUp() - username = "dataowner" - token, _ = self.create_user(username) - self.api_prefix = "/api/" + settings.SERVER_API_VERSION - self.client = APIClient() - self.token = token - self.client.credentials(HTTP_AUTHORIZATION="Bearer " + self.token) - data_preproc_mlcube = { - "name": "testmlcube", - "git_mlcube_url": "string", - "mlcube_hash": "string", - "git_parameters_url": "string", - "parameters_hash": "string", - "image_tarball_url": "", - "image_tarball_hash": "", - "image_hash": "string", - "additional_files_tarball_url": "string", - "additional_files_tarball_hash": "string", - "metadata": {"key": "value"}, - } - - response = self.client.post( - self.api_prefix + "/mlcubes/", data_preproc_mlcube, format="json" - ) - self.assertEqual(response.status_code, status.HTTP_201_CREATED) - self.data_preproc_mlcube_id = response.data["id"] - - def test_unauthenticated_user(self): - client = APIClient() - response = client.get(self.api_prefix + "/datasets/1/") - self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) - response = client.delete(self.api_prefix + "/datasets/1/") - self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) - response = client.put(self.api_prefix + "/datasets/1/") - self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) - response = client.post(self.api_prefix + "/datasets/", {}) - self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) - response = client.get(self.api_prefix + "/datasets/") - self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) - - def test_crud_user(self): - testdataset = { - "name": "dataset", - "description": "dataset-sample", - "location": "string", - "input_data_hash": "string", - "generated_uid": "string", - "split_seed": 0, - "metadata": {"key": "value"}, - "data_preparation_mlcube": self.data_preproc_mlcube_id, - } - - response = self.client.post( - self.api_prefix + "/datasets/", testdataset, format="json" - ) - self.assertEqual(response.status_code, status.HTTP_201_CREATED) - uid = response.data["id"] - response = self.client.get(self.api_prefix + "/datasets/{0}/".format(uid)) - self.assertEqual(response.status_code, status.HTTP_200_OK) - - for k, v in response.data.items(): - if k in testdataset: - self.assertEqual(testdataset[k], v) - - response = self.client.get(self.api_prefix + "/datasets/") - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(len(response.data["results"]), 1) - - newtestdataset = { - "name": "newdataset", - "description": "newdataset-sample", - "location": "newstring", - "input_data_hash": "newstring", - "generated_uid": "newstring", - "split_seed": 0, - "metadata": {"newkey": "newvalue"}, - "data_preparation_mlcube": self.data_preproc_mlcube_id, - } - - response = self.client.put( - self.api_prefix + "/datasets/{0}/".format(uid), - newtestdataset, - format="json", - ) - self.assertEqual(response.status_code, status.HTTP_200_OK) - response = self.client.get(self.api_prefix + "/datasets/{0}/".format(uid)) - self.assertEqual(response.status_code, status.HTTP_200_OK) - - for k, v in response.data.items(): - if k in newtestdataset: - self.assertEqual(newtestdataset[k], v) - - # TODO Revisit when delete permissions are fixed - # response = self.client.delete(self.api_prefix + "/datasets/{0}/".format(uid)) - # self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) - - # response = self.client.get(self.api_prefix + "/datasets/{0}/".format(uid)) - # self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) - - def test_invalid_dataset(self): - invalid_id = 9999 - response = self.client.get( - self.api_prefix + "/datasets/{0}/".format(invalid_id) - ) - self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) - - def test_duplicate_gen_uid(self): - testdataset = { - "name": "dataset", - "description": "dataset-sample", - "location": "string", - "input_data_hash": "string", - "generated_uid": "string", - "split_seed": 0, - "is_valid": True, - "metadata": {"key": "value"}, - "data_preparation_mlcube": self.data_preproc_mlcube_id, - } - - response = self.client.post( - self.api_prefix + "/datasets/", testdataset, format="json" - ) - self.assertEqual(response.status_code, status.HTTP_201_CREATED) - - response = self.client.post( - self.api_prefix + "/datasets/", testdataset, format="json" - ) - self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - - def test_optional_fields(self): - pass diff --git a/server/benchmarkdataset/tests.py b/server/dataset/tests/__init__.py similarity index 100% rename from server/benchmarkdataset/tests.py rename to server/dataset/tests/__init__.py diff --git a/server/dataset/tests/test_.py b/server/dataset/tests/test_.py new file mode 100644 index 000000000..53fa62a90 --- /dev/null +++ b/server/dataset/tests/test_.py @@ -0,0 +1,224 @@ +from rest_framework import status + +from medperf.tests import MedPerfTest + +from parameterized import parameterized, parameterized_class + + +class DatasetTest(MedPerfTest): + def generic_setup(self): + # setup users + data_owner = "data_owner" + prep_mlcube_owner = "prep_mlcube_owner" + + self.create_user(data_owner) + self.create_user(prep_mlcube_owner) + + # create prep mlcube + self.set_credentials(prep_mlcube_owner) + data_preproc_mlcube = self.mock_mlcube() + response = self.create_mlcube(data_preproc_mlcube) + + # setup globals + self.data_owner = data_owner + self.prep_mlcube_owner = prep_mlcube_owner + self.data_preproc_mlcube_id = response.data["id"] + self.url = self.api_prefix + "/datasets/" + self.set_credentials(None) + + +@parameterized_class( + [ + {"actor": "data_owner"}, + {"actor": "prep_mlcube_owner"}, + ] +) +class DatasetPostTest(DatasetTest): + """Test module for POST /datasets""" + + def setUp(self): + super(DatasetPostTest, self).setUp() + self.generic_setup() + self.set_credentials(self.actor) + + def test_created_dataset_fields_are_saved_as_expected(self): + """Testing the valid scenario""" + # Arrange + testdataset = self.mock_dataset( + data_preparation_mlcube=self.data_preproc_mlcube_id + ) + get_dataset_url = self.api_prefix + "/datasets/{0}/" + + # Act + response = self.client.post(self.url, testdataset, format="json") + + # Assert + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + uid = response.data["id"] + + response = self.client.get(get_dataset_url.format(uid)) + self.assertEqual( + response.status_code, status.HTTP_200_OK, "Dataset retreival failed" + ) + + for k, v in response.data.items(): + if k in testdataset: + self.assertEqual(testdataset[k], v, f"Unexpected value for {k}") + + @parameterized.expand([(True,), (False,)]) + def test_creation_of_duplicate_generated_uid_gets_rejected(self, different_uid): + """Testing the model fields rules""" + # Arrange + testdataset = self.mock_dataset( + data_preparation_mlcube=self.data_preproc_mlcube_id + ) + self.create_dataset(testdataset) + if different_uid: + testdataset["generated_uid"] = "different uid" + + # Act + response = self.client.post(self.url, testdataset, format="json") + + # Assert + if different_uid: + exp_status = status.HTTP_201_CREATED + else: + exp_status = status.HTTP_400_BAD_REQUEST + + self.assertEqual(response.status_code, exp_status) + + def test_default_values_are_as_expected(self): + """Testing the model fields rules""" + # Arrange + default_values = { + "is_valid": True, + "state": "DEVELOPMENT", + "generated_metadata": {}, + "user_metadata": {}, + "description": "", + "location": "", + } + testdataset = self.mock_dataset( + data_preparation_mlcube=self.data_preproc_mlcube_id + ) + for key in default_values: + if key in testdataset: + del testdataset[key] + + # Act + response = self.client.post(self.url, testdataset, format="json") + + # Assert + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + for key, val in default_values.items(): + self.assertEqual( + val, response.data[key], f"Unexpected default value for {key}" + ) + + def test_readonly_fields(self): + """Testing the serializer rules""" + # Arrange + readonly = { + "owner": 55, + "created_at": "time", + "modified_at": "time2", + } + testdataset = self.mock_dataset( + data_preparation_mlcube=self.data_preproc_mlcube_id + ) + testdataset.update(readonly) + + # Act + response = self.client.post(self.url, testdataset, format="json") + + # Assert + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + for key, val in readonly.items(): + self.assertNotEqual( + val, response.data[key], f"readonly field {key} was modified" + ) + + +@parameterized_class( + [ + {"actor": "data_owner"}, + {"actor": "prep_mlcube_owner"}, + {"actor": "other_user"}, + ] +) +class DatasetGetListTest(DatasetTest): + """Test module for GET /datasets/ endpoint""" + + def setUp(self): + super(DatasetGetListTest, self).setUp() + self.generic_setup() + self.set_credentials(self.data_owner) + testdataset = self.mock_dataset( + data_preparation_mlcube=self.data_preproc_mlcube_id + ) + testdataset = self.create_dataset(testdataset).data + + other_user = "other_user" + self.create_user("other_user") + self.other_user = other_user + + self.testdataset = testdataset + self.set_credentials(self.actor) + + def test_generic_get_dataset_list(self): + # Arrange + dataset_id = self.testdataset["id"] + + # Act + response = self.client.get(self.url) + + # Assert + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(len(response.data["results"]), 1) + self.assertEqual(response.data["results"][0]["id"], dataset_id) + + +class PermissionTest(DatasetTest): + """Test module for permissions of /datasets/ endpoint + Non-permitted actions: both GET and POST for unauthenticated users.""" + + def setUp(self): + super(PermissionTest, self).setUp() + self.generic_setup() + self.set_credentials(self.data_owner) + testdataset = self.mock_dataset( + data_preparation_mlcube=self.data_preproc_mlcube_id + ) + self.testdataset = testdataset + + @parameterized.expand( + [ + (None, status.HTTP_401_UNAUTHORIZED), + ] + ) + def test_get_permissions(self, user, exp_status): + # Arrange + self.set_credentials(self.data_owner) + self.create_dataset(self.testdataset) + self.set_credentials(user) + + # Act + response = self.client.get(self.url) + + # Assert + self.assertEqual(response.status_code, exp_status) + + @parameterized.expand( + [ + (None, status.HTTP_401_UNAUTHORIZED), + ] + ) + def test_post_permissions(self, user, exp_status): + # Arrange + self.set_credentials(user) + + # Act + response = self.client.post(self.url, self.testdataset, format="json") + + # Assert + self.assertEqual(response.status_code, exp_status) diff --git a/server/dataset/tests/test_benchmarks.py b/server/dataset/tests/test_benchmarks.py new file mode 100644 index 000000000..aaf527f13 --- /dev/null +++ b/server/dataset/tests/test_benchmarks.py @@ -0,0 +1,423 @@ +from rest_framework import status + +from medperf.tests import MedPerfTest + +from parameterized import parameterized, parameterized_class + + +class DatasetBenchmarksTest(MedPerfTest): + def generic_setup(self): + # setup users + data_owner = "data_owner" + bmk_owner = "bmk_owner" + bmk_prep_mlcube_owner = "bmk_prep_mlcube_owner" + ref_mlcube_owner = "ref_mlcube_owner" + eval_mlcube_owner = "eval_mlcube_owner" + other_user = "other_user" + + self.create_user(data_owner) + self.create_user(bmk_owner) + self.create_user(bmk_prep_mlcube_owner) + self.create_user(ref_mlcube_owner) + self.create_user(eval_mlcube_owner) + self.create_user(other_user) + + # setup globals + self.data_owner = data_owner + self.bmk_owner = bmk_owner + self.bmk_prep_mlcube_owner = bmk_prep_mlcube_owner + self.ref_mlcube_owner = ref_mlcube_owner + self.eval_mlcube_owner = eval_mlcube_owner + self.other_user = other_user + + self.url = self.api_prefix + "/datasets/benchmarks/" + self.set_credentials(None) + + +@parameterized_class( + [ + {"actor": "data_owner"}, + {"actor": "bmk_owner"}, + ], +) +class GenericDatasetBenchmarksPostTest(DatasetBenchmarksTest): + """Test module for POST /datasets/benchmarks""" + + def setUp(self): + super(GenericDatasetBenchmarksPostTest, self).setUp() + self.generic_setup() + prep, _, _, benchmark = self.shortcut_create_benchmark( + self.bmk_prep_mlcube_owner, + self.ref_mlcube_owner, + self.eval_mlcube_owner, + self.bmk_owner, + ) + self.set_credentials(self.data_owner) + dataset = self.mock_dataset( + data_preparation_mlcube=prep["id"], state="OPERATION" + ) + dataset = self.create_dataset(dataset).data + + self.bmk_id = benchmark["id"] + self.dataset_id = dataset["id"] + self.set_credentials(self.actor) + + def test_created_association_fields_are_saved_as_expected(self): + """Testing the valid scenario""" + # Arrange + testassoc = self.mock_dataset_association(self.bmk_id, self.dataset_id) + get_association_url = ( + self.api_prefix + f"/datasets/{self.dataset_id}/benchmarks/{self.bmk_id}/" + ) + + # Act + response = self.client.post(self.url, testassoc, format="json") + + # Assert + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + response = self.client.get(get_association_url) + + self.assertEqual( + response.status_code, + status.HTTP_200_OK, + "association retrieval failed", + ) + + for k, v in response.data[0].items(): + if k in testassoc: + self.assertEqual(testassoc[k], v, f"unexpected value for {k}") + + def test_default_values_are_as_expected(self): + """Testing the model fields rules""" + + # Arrange + default_values = { + "approved_at": None, + "approval_status": "PENDING", + } + testassoc = self.mock_dataset_association(self.bmk_id, self.dataset_id) + + for key in default_values: + if key in testassoc: + del testassoc[key] + + # Act + response = self.client.post(self.url, testassoc, format="json") + + # Assert + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + for key, val in default_values.items(): + self.assertEqual( + val, response.data[key], f"unexpected default value for {key}" + ) + + def test_readonly_fields(self): + """Testing the serializer rules""" + + # Arrange + readonly = { + "initiated_by": 55, + "created_at": "time", + "modified_at": "time2", + "approved_at": "time3", + } + testassoc = self.mock_dataset_association(self.bmk_id, self.dataset_id) + + testassoc.update(readonly) + + # Act + response = self.client.post(self.url, testassoc, format="json") + + # Assert + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + for key, val in readonly.items(): + self.assertNotEqual( + val, response.data[key], f"readonly field {key} was modified" + ) + + +@parameterized_class( + [ + {"actor": "data_owner"}, + {"actor": "bmk_owner"}, + ] +) +class SerializersDatasetBenchmarksPostTest(DatasetBenchmarksTest): + """Test module for serializers rules of POST /datasets/benchmarks""" + + def setUp(self): + super(SerializersDatasetBenchmarksPostTest, self).setUp() + self.generic_setup() + self.set_credentials(self.actor) + + @parameterized.expand([("DEVELOPMENT",), ("OPERATION",)]) + def test_association_with_unapproved_benchmark(self, state): + # Arrange + prep, _, _, benchmark = self.shortcut_create_benchmark( + self.bmk_prep_mlcube_owner, + self.ref_mlcube_owner, + self.eval_mlcube_owner, + self.bmk_owner, + target_approval_status="PENDING", + state=state, + ) + dataset = self.mock_dataset( + data_preparation_mlcube=prep["id"], state="OPERATION" + ) + + self.set_credentials(self.data_owner) + dataset = self.create_dataset(dataset).data + self.set_credentials(self.actor) + + testassoc = self.mock_dataset_association(benchmark["id"], dataset["id"]) + + # Act + response = self.client.post(self.url, testassoc, format="json") + + # Assert + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + def test_association_failure_with_development_dataset(self): + # Arrange + prep, _, _, benchmark = self.shortcut_create_benchmark( + self.bmk_prep_mlcube_owner, + self.ref_mlcube_owner, + self.eval_mlcube_owner, + self.bmk_owner, + ) + dataset = self.mock_dataset( + data_preparation_mlcube=prep["id"], state="DEVELOPMENT" + ) + + self.set_credentials(self.data_owner) + dataset = self.create_dataset(dataset).data + self.set_credentials(self.actor) + + testassoc = self.mock_dataset_association(benchmark["id"], dataset["id"]) + + # Act + response = self.client.post(self.url, testassoc, format="json") + + # Assert + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + def test_association_failure_with_dataset_not_prepared_with_benchmark_prep_cube( + self, + ): + # Arrange + _, _, _, benchmark = self.shortcut_create_benchmark( + self.bmk_prep_mlcube_owner, + self.ref_mlcube_owner, + self.eval_mlcube_owner, + self.bmk_owner, + ) + prep = self.mock_mlcube( + name="someprep", mlcube_hash="someprep", state="OPERATION" + ) + prep = self.create_mlcube(prep).data + dataset = self.mock_dataset( + data_preparation_mlcube=prep["id"], state="OPERATION" + ) + + self.set_credentials(self.data_owner) + dataset = self.create_dataset(dataset).data + self.set_credentials(self.actor) + + testassoc = self.mock_dataset_association(benchmark["id"], dataset["id"]) + + # Act + response = self.client.post(self.url, testassoc, format="json") + + # Assert + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + @parameterized.expand( + [ + ("PENDING", status.HTTP_201_CREATED), + ("APPROVED", status.HTTP_400_BAD_REQUEST), + ("REJECTED", status.HTTP_400_BAD_REQUEST), + ] + ) + def test_specified_association_approval_status_while_not_having_previous_association( + self, approval_status, exp_statuscode + ): + # Arrange + prep, _, _, benchmark = self.shortcut_create_benchmark( + self.bmk_prep_mlcube_owner, + self.ref_mlcube_owner, + self.eval_mlcube_owner, + self.bmk_owner, + ) + dataset = self.mock_dataset( + data_preparation_mlcube=prep["id"], state="OPERATION" + ) + + self.set_credentials(self.data_owner) + dataset = self.create_dataset(dataset).data + self.set_credentials(self.actor) + + testassoc = self.mock_dataset_association( + benchmark["id"], dataset["id"], approval_status=approval_status + ) + + # Act + response = self.client.post(self.url, testassoc, format="json") + + # Assert + self.assertEqual( + response.status_code, + exp_statuscode, + f"test failed for approval_status={approval_status}", + ) + + @parameterized.expand( + [ + ("PENDING", "PENDING", status.HTTP_400_BAD_REQUEST), + ("PENDING", "APPROVED", status.HTTP_400_BAD_REQUEST), + ("PENDING", "REJECTED", status.HTTP_400_BAD_REQUEST), + ("APPROVED", "PENDING", status.HTTP_400_BAD_REQUEST), + ("APPROVED", "APPROVED", status.HTTP_400_BAD_REQUEST), + ("APPROVED", "REJECTED", status.HTTP_201_CREATED), + ("REJECTED", "PENDING", status.HTTP_201_CREATED), + ("REJECTED", "APPROVED", status.HTTP_400_BAD_REQUEST), + ("REJECTED", "REJECTED", status.HTTP_400_BAD_REQUEST), + ] + ) + def test_specified_association_approval_status_while_having_previous_association( + self, prev_approval_status, new_approval_status, exp_statuscode + ): + # Arrange + prep, _, _, benchmark = self.shortcut_create_benchmark( + self.bmk_prep_mlcube_owner, + self.ref_mlcube_owner, + self.eval_mlcube_owner, + self.bmk_owner, + ) + dataset = self.mock_dataset( + data_preparation_mlcube=prep["id"], state="OPERATION" + ) + + self.set_credentials(self.data_owner) + dataset = self.create_dataset(dataset).data + self.set_credentials(self.actor) + + prev_assoc = self.mock_dataset_association( + benchmark["id"], dataset["id"], approval_status=prev_approval_status + ) + self.create_dataset_association(prev_assoc, self.data_owner, self.bmk_owner) + + new_assoc = self.mock_dataset_association( + benchmark["id"], dataset["id"], approval_status=new_approval_status + ) + + # Act + response = self.client.post(self.url, new_assoc, format="json") + + # Assert + self.assertEqual( + response.status_code, + exp_statuscode, + f"test failed when creating {new_approval_status} association " + f"with a previous {prev_approval_status} one", + ) + + def test_creation_of_rejected_association_sets_approval_time(self): + # Arrange + prep, _, _, benchmark = self.shortcut_create_benchmark( + self.bmk_prep_mlcube_owner, + self.ref_mlcube_owner, + self.eval_mlcube_owner, + self.bmk_owner, + ) + dataset = self.mock_dataset( + data_preparation_mlcube=prep["id"], state="OPERATION" + ) + + self.set_credentials(self.data_owner) + dataset = self.create_dataset(dataset).data + self.set_credentials(self.actor) + + prev_assoc = self.mock_dataset_association( + benchmark["id"], dataset["id"], approval_status="APPROVED" + ) + self.create_dataset_association(prev_assoc, self.data_owner, self.bmk_owner) + + new_assoc = self.mock_dataset_association( + benchmark["id"], dataset["id"], approval_status="REJECTED" + ) + + # Act + response = self.client.post(self.url, new_assoc, format="json") + + # Assert + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + self.assertNotEqual(response.data["approved_at"], None) + + def test_creation_of_pending_association_by_same_user_is_auto_approved(self): + # Arrange + prep, _, _, benchmark = self.shortcut_create_benchmark( + self.bmk_prep_mlcube_owner, + self.ref_mlcube_owner, + self.eval_mlcube_owner, + self.actor, + ) + dataset = self.mock_dataset( + data_preparation_mlcube=prep["id"], state="OPERATION" + ) + dataset = self.create_dataset(dataset).data + + testassoc = self.mock_dataset_association(benchmark["id"], dataset["id"]) + + # Act + response = self.client.post(self.url, testassoc, format="json") + + # Assert + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + self.assertEqual(response.data["approval_status"], "APPROVED") + self.assertNotEqual(response.data["approved_at"], None) + + +class PermissionTest(DatasetBenchmarksTest): + """Test module for permissions of /datasets/benchmarks endpoint + Non-permitted actions: POST for all users except data_owner, bmk_owner, and admins + """ + + def setUp(self): + super(PermissionTest, self).setUp() + self.generic_setup() + + prep, _, _, benchmark = self.shortcut_create_benchmark( + self.bmk_prep_mlcube_owner, + self.ref_mlcube_owner, + self.eval_mlcube_owner, + self.bmk_owner, + ) + self.set_credentials(self.data_owner) + dataset = self.mock_dataset( + data_preparation_mlcube=prep["id"], state="OPERATION" + ) + dataset = self.create_dataset(dataset).data + + self.bmk_id = benchmark["id"] + self.dataset_id = dataset["id"] + self.set_credentials(None) + + @parameterized.expand( + [ + ("bmk_prep_mlcube_owner", status.HTTP_403_FORBIDDEN), + ("ref_mlcube_owner", status.HTTP_403_FORBIDDEN), + ("eval_mlcube_owner", status.HTTP_403_FORBIDDEN), + ("other_user", status.HTTP_403_FORBIDDEN), + (None, status.HTTP_401_UNAUTHORIZED), + ] + ) + def test_post_permissions(self, user, exp_status): + # Arrange + self.set_credentials(user) + assoc = self.mock_mlcube_association(self.bmk_id, self.dataset_id) + + # Act + response = self.client.post(self.url, assoc, format="json") + + # Assert + self.assertEqual(response.status_code, exp_status) diff --git a/server/dataset/tests/test_pk.py b/server/dataset/tests/test_pk.py new file mode 100644 index 000000000..abfb1c4b8 --- /dev/null +++ b/server/dataset/tests/test_pk.py @@ -0,0 +1,360 @@ +from rest_framework import status + +from medperf.tests import MedPerfTest + +from parameterized import parameterized, parameterized_class + + +class DatasetTest(MedPerfTest): + def generic_setup(self): + # setup users + data_owner = "data_owner" + prep_mlcube_owner = "prep_mlcube_owner" + other_user = "other_user" + + self.create_user(data_owner) + self.create_user(prep_mlcube_owner) + self.create_user(other_user) + + # create prep mlcube + self.set_credentials(prep_mlcube_owner) + data_preproc_mlcube = self.mock_mlcube() + response = self.create_mlcube(data_preproc_mlcube) + + # setup globals + self.data_owner = data_owner + self.prep_mlcube_owner = prep_mlcube_owner + self.other_user = other_user + self.data_preproc_mlcube_id = response.data["id"] + self.url = self.api_prefix + "/datasets/{0}/" + self.set_credentials(None) + + +@parameterized_class( + [ + {"actor": "data_owner"}, + {"actor": "prep_mlcube_owner"}, + {"actor": "other_user"}, + ] +) +class DatasetGetTest(DatasetTest): + """Test module for GET /datasets/""" + + def setUp(self): + super(DatasetGetTest, self).setUp() + self.generic_setup() + self.set_credentials(self.data_owner) + testdataset = self.mock_dataset( + data_preparation_mlcube=self.data_preproc_mlcube_id + ) + testdataset = self.create_dataset(testdataset).data + self.testdataset = testdataset + self.set_credentials(self.actor) + + def test_generic_get_dataset(self): + # Arrange + dataset_id = self.testdataset["id"] + url = self.url.format(dataset_id) + + # Act + response = self.client.get(url) + + # Assert + self.assertEqual(response.status_code, status.HTTP_200_OK) + for k, v in response.data.items(): + if k in self.testdataset: + self.assertEqual(self.testdataset[k], v, f"Unexpected value for {k}") + + def test_dataset_not_found(self): + # Arrange + invalid_id = 9999 + url = self.url.format(invalid_id) + + # Act + response = self.client.get(url) + + # Assert + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + + +@parameterized_class( + [ + {"actor": "data_owner"}, + ] +) +class DatasetPutTest(DatasetTest): + """Test module for PUT /datasets/""" + + def setUp(self): + super(DatasetPutTest, self).setUp() + self.generic_setup() + self.set_credentials(self.actor) + + def test_put_modifies_editable_fields_in_development(self): + # Arrange + testdataset = self.mock_dataset( + data_preparation_mlcube=self.data_preproc_mlcube_id, state="DEVELOPMENT" + ) + testdataset = self.create_dataset(testdataset).data + + new_data_preproc_mlcube = self.mock_mlcube( + name="new name", mlcube_hash="new hash" + ) + new_prep_id = self.create_mlcube(new_data_preproc_mlcube).data["id"] + newtestdataset = { + "name": "newdataset", + "description": "newdataset-sample", + "location": "newstring", + "input_data_hash": "newstring", + "generated_uid": "newstring", + "split_seed": 1, + "data_preparation_mlcube": new_prep_id, + "is_valid": False, + "state": "OPERATION", + "generated_metadata": {"newkey": "newvalue"}, + "user_metadata": {"newkey2": "newvalue2"}, + } + url = self.url.format(testdataset["id"]) + + # Act + response = self.client.put(url, newtestdataset, format="json") + + # Assert + self.assertEqual(response.status_code, status.HTTP_200_OK) + response = self.client.get(url) + self.assertEqual(response.status_code, status.HTTP_200_OK) + + for k, v in response.data.items(): + if k in newtestdataset: + self.assertEqual(newtestdataset[k], v, f"{k} was not modified") + + def test_put_modifies_editable_fields_in_operation(self): + # Arrange + testdataset = self.mock_dataset( + data_preparation_mlcube=self.data_preproc_mlcube_id, state="OPERATION" + ) + testdataset = self.create_dataset(testdataset).data + + newtestdataset = {"is_valid": False, "user_metadata": {"newkey": "newval"}} + url = self.url.format(testdataset["id"]) + + # Act + response = self.client.put(url, newtestdataset, format="json") + + # Assert + self.assertEqual(response.status_code, status.HTTP_200_OK) + response = self.client.get(url) + self.assertEqual(response.status_code, status.HTTP_200_OK) + + for k, v in response.data.items(): + if k in newtestdataset: + self.assertEqual(newtestdataset[k], v, f"{k} was not modified") + + def test_put_does_not_modify_non_editable_fields_in_operation(self): + # Arrange + testdataset = self.mock_dataset( + data_preparation_mlcube=self.data_preproc_mlcube_id, state="OPERATION" + ) + testdataset = self.create_dataset(testdataset).data + + new_data_preproc_mlcube = self.mock_mlcube( + name="new name", mlcube_hash="new hash" + ) + new_prep_id = self.create_mlcube(new_data_preproc_mlcube).data["id"] + newtestdataset = { + "name": "newdataset", + "description": "newdataset-sample", + "location": "newstring", + "input_data_hash": "newstring", + "generated_uid": "newstring", + "split_seed": 6, + "data_preparation_mlcube": new_prep_id, + "state": "DEVELOPMENT", + "generated_metadata": {"newkey": "value"}, + } + + url = self.url.format(testdataset["id"]) + + for key in newtestdataset: + # Act + response = self.client.put(url, {key: newtestdataset[key]}, format="json") + # Assert + self.assertEqual( + response.status_code, + status.HTTP_400_BAD_REQUEST, + f"{key} was modified", + ) + + @parameterized.expand([("DEVELOPMENT",), ("OPERATION",)]) + def test_put_does_not_modify_readonly_fields_in_both_states(self, state): + # Arrange + testdataset = self.mock_dataset( + data_preparation_mlcube=self.data_preproc_mlcube_id, state=state + ) + testdataset = self.create_dataset(testdataset).data + + newtestdataset = {"owner": 5, "created_at": "time", "modified_at": "time"} + url = self.url.format(testdataset["id"]) + + # Act + response = self.client.put(url, newtestdataset, format="json") + + # Assert + self.assertEqual(response.status_code, status.HTTP_200_OK) + response = self.client.get(url) + self.assertEqual(response.status_code, status.HTTP_200_OK) + for k, v in newtestdataset.items(): + self.assertNotEqual(v, response.data[k], f"{k} was modified") + + def test_put_respects_unique_generated_uid(self): + # Arrange + testdataset = self.mock_dataset( + data_preparation_mlcube=self.data_preproc_mlcube_id + ) + testdataset = self.create_dataset(testdataset).data + + newtestdataset = self.mock_dataset( + data_preparation_mlcube=self.data_preproc_mlcube_id, + state="DEVELOPMENT", + generated_uid="new", + ) + newtestdataset = self.create_dataset(newtestdataset).data + + put_body = {"generated_uid": testdataset["generated_uid"]} + url = self.url.format(newtestdataset["id"]) + + # Act + response = self.client.put(url, put_body, format="json") + + # Assert + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + +@parameterized_class( + [ + {"actor": "api_admin"}, + ] +) +class DatasetDeleteTest(DatasetTest): + def setUp(self): + super(DatasetDeleteTest, self).setUp() + self.generic_setup() + self.set_credentials(self.data_owner) + testdataset = self.mock_dataset( + data_preparation_mlcube=self.data_preproc_mlcube_id + ) + testdataset = self.create_dataset(testdataset).data + self.testdataset = testdataset + + self.set_credentials(self.actor) + + def test_deletion_works_as_expected(self): + # Arrange + url = self.url.format(self.testdataset["id"]) + + # Act + response = self.client.delete(url) + + # Assert + self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) + response = self.client.get(url) + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + + +class PermissionTest(DatasetTest): + """Test module for permissions of /datasets/{pk} endpoint + Non-permitted actions: + GET: for unauthenticated users + DELETE: for all users except admin + PUT: for all users except data owner and admin + """ + + def setUp(self): + super(PermissionTest, self).setUp() + self.generic_setup() + + self.set_credentials(self.data_owner) + testdataset = self.mock_dataset( + data_preparation_mlcube=self.data_preproc_mlcube_id + ) + testdataset = self.create_dataset(testdataset).data + + self.testdataset = testdataset + self.url = self.url.format(self.testdataset["id"]) + + @parameterized.expand( + [ + (None, status.HTTP_401_UNAUTHORIZED), + ] + ) + def test_get_permissions(self, user, expected_status): + # Arrange + self.set_credentials(user) + + # Act + response = self.client.get(self.url) + + # Assert + self.assertEqual(response.status_code, expected_status) + + @parameterized.expand( + [ + ("prep_mlcube_owner", status.HTTP_403_FORBIDDEN), + ("other_user", status.HTTP_403_FORBIDDEN), + (None, status.HTTP_401_UNAUTHORIZED), + ] + ) + def test_put_permissions(self, user, expected_status): + # Arrange + self.set_credentials(self.prep_mlcube_owner) + new_data_preproc_mlcube = self.mock_mlcube( + name="new name", mlcube_hash="new hash" + ) + new_prep_id = self.create_mlcube(new_data_preproc_mlcube).data["id"] + newtestdataset = { + "name": "newdataset", + "description": "newdataset-sample", + "location": "newstring", + "input_data_hash": "newstring", + "generated_uid": "newstring", + "split_seed": 1, + "data_preparation_mlcube": new_prep_id, + "is_valid": False, + "state": "OPERATION", + "generated_metadata": {"newkey": "newvalue"}, + "user_metadata": {"newkey2": "newvalue2"}, + "owner": 5, + "created_at": "time", + "modified_at": "time", + } + self.set_credentials(user) + + for key in newtestdataset: + # Act + response = self.client.put( + self.url, {key: newtestdataset[key]}, format="json" + ) + # Assert + self.assertEqual( + response.status_code, + expected_status, + f"{key} was modified", + ) + + @parameterized.expand( + [ + ("data_owner", status.HTTP_403_FORBIDDEN), + ("prep_mlcube_owner", status.HTTP_403_FORBIDDEN), + ("other_user", status.HTTP_403_FORBIDDEN), + (None, status.HTTP_401_UNAUTHORIZED), + ] + ) + def test_delete_permissions(self, user, expected_status): + # Arrange + self.set_credentials(user) + + # Act + response = self.client.delete(self.url) + + # Assert + self.assertEqual(response.status_code, expected_status) diff --git a/server/dataset/tests/test_pk_benchmarks_bid.py b/server/dataset/tests/test_pk_benchmarks_bid.py new file mode 100644 index 000000000..a2c41ced8 --- /dev/null +++ b/server/dataset/tests/test_pk_benchmarks_bid.py @@ -0,0 +1,470 @@ +from rest_framework import status + +from medperf.tests import MedPerfTest + +from parameterized import parameterized, parameterized_class + + +class DatasetTest(MedPerfTest): + def generic_setup(self): + # setup users + data_owner = "data_owner" + bmk_owner = "bmk_owner" + bmk_prep_mlcube_owner = "bmk_prep_mlcube_owner" + ref_mlcube_owner = "ref_mlcube_owner" + eval_mlcube_owner = "eval_mlcube_owner" + other_user = "other_user" + + self.create_user(data_owner) + self.create_user(bmk_owner) + self.create_user(bmk_prep_mlcube_owner) + self.create_user(ref_mlcube_owner) + self.create_user(eval_mlcube_owner) + self.create_user(other_user) + + # create benchmark and dataset + self.set_credentials(bmk_owner) + prep, _, _, benchmark = self.shortcut_create_benchmark( + bmk_prep_mlcube_owner, + ref_mlcube_owner, + eval_mlcube_owner, + bmk_owner, + ) + self.set_credentials(data_owner) + dataset = self.mock_dataset( + data_preparation_mlcube=prep["id"], state="OPERATION" + ) + dataset = self.create_dataset(dataset).data + + # setup globals + self.data_owner = data_owner + self.bmk_owner = bmk_owner + self.bmk_prep_mlcube_owner = bmk_prep_mlcube_owner + self.ref_mlcube_owner = ref_mlcube_owner + self.eval_mlcube_owner = eval_mlcube_owner + self.other_user = other_user + + self.bmk_id = benchmark["id"] + self.dataset_id = dataset["id"] + self.url = self.api_prefix + "/datasets/{0}/benchmarks/{1}/" + self.set_credentials(None) + + +@parameterized_class( + [ + {"initiator": "data_owner", "actor": "bmk_owner"}, + {"initiator": "bmk_owner", "actor": "data_owner"}, + ] +) +class BenchmarkDatasetGetTest(DatasetTest): + """Test module for GET /datasets/""" + + def setUp(self): + super(BenchmarkDatasetGetTest, self).setUp() + self.generic_setup() + + self.url = self.url.format(self.dataset_id, self.bmk_id) + self.visible_fields = [ + "approval_status", + "initiated_by", + "approved_at", + "created_at", + "modified_at", + ] + + if self.initiator == self.data_owner: + self.approving_user = self.bmk_owner + else: + self.approving_user = self.data_owner + + self.set_credentials(self.actor) + + @parameterized.expand([("PENDING",), ("APPROVED",), ("REJECTED",)]) + def test_generic_get_benchmark_dataset(self, approval_status): + # Arrange + testassoc = self.mock_dataset_association( + self.bmk_id, self.dataset_id, approval_status=approval_status + ) + + testassoc = self.create_dataset_association( + testassoc, self.initiator, self.approving_user + ).data + if isinstance(testassoc, list): + testassoc = testassoc[0] + + # Act + response = self.client.get(self.url) + + # Assert + self.assertEqual(response.status_code, status.HTTP_200_OK) + for k, v in testassoc.items(): + if k in self.visible_fields: + self.assertIn(k, response.data[0]) + self.assertEqual(response.data[0][k], v, f"Unexpected value for {k}") + else: + self.assertNotIn(k, response.data[0], f"{k} should not be visible") + + def test_benchmark_dataset_returns_a_list(self): + # Arrange + testassoc = self.mock_dataset_association( + self.bmk_id, self.dataset_id, approval_status="REJECTED" + ) + + self.create_dataset_association(testassoc, self.initiator, self.approving_user) + + testassoc2 = self.mock_dataset_association( + self.bmk_id, self.dataset_id, approval_status="PENDING" + ) + + self.create_dataset_association(testassoc2, self.initiator, self.approving_user) + + # Act + response = self.client.get(self.url) + + # Assert + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(len(response.data), 2) + + +@parameterized_class( + [ + {"initiator": "data_owner", "actor": "bmk_owner"}, + {"initiator": "bmk_owner", "actor": "data_owner"}, + ] +) +class DatasetPutTest(DatasetTest): + """Test module for PUT /datasets/""" + + def setUp(self): + super(DatasetPutTest, self).setUp() + self.generic_setup() + self.url = self.url.format(self.dataset_id, self.bmk_id) + if self.initiator == self.data_owner: + self.approving_user = self.bmk_owner + else: + self.approving_user = self.data_owner + self.set_credentials(self.actor) + + @parameterized.expand([("PENDING",), ("APPROVED",), ("REJECTED",)]) + def test_put_does_not_modify_readonly_fields(self, approval_status): + # Arrange + testassoc = self.mock_dataset_association( + self.bmk_id, self.dataset_id, approval_status=approval_status + ) + + self.create_dataset_association(testassoc, self.initiator, self.approving_user) + + put_body = { + "initiated_by": 55, + "approved_at": "time", + "created_at": "time", + "modified_at": "time", + } + # Act + response = self.client.put(self.url, put_body, format="json") + + # Assert + self.assertEqual(response.status_code, status.HTTP_200_OK) + response = self.client.get(self.url) + self.assertEqual(response.status_code, status.HTTP_200_OK) + for k, v in put_body.items(): + self.assertNotEqual(v, response.data[0][k], f"{k} was modified") + + @parameterized.expand([("APPROVED",), ("REJECTED",)]) + def test_modifying_approval_status_to_pending_is_not_allowed( + self, original_approval_status + ): + # Arrange + testassoc = self.mock_dataset_association( + self.bmk_id, self.dataset_id, approval_status=original_approval_status + ) + + testassoc = self.create_dataset_association( + testassoc, self.initiator, self.approving_user + ).data + if isinstance(testassoc, list): + testassoc = testassoc[0] + + put_body = { + "approval_status": "PENDING", + } + # Act + response = self.client.put(self.url, put_body, format="json") + + # Assert + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + @parameterized.expand( + [ + ("PENDING", status.HTTP_200_OK), + ("APPROVED", status.HTTP_400_BAD_REQUEST), + ("REJECTED", status.HTTP_400_BAD_REQUEST), + ] + ) + def test_rejecting_an_association_is_allowed_only_if_it_was_pending( + self, original_approval_status, exp_status + ): + # Arrange + testassoc = self.mock_dataset_association( + self.bmk_id, self.dataset_id, approval_status=original_approval_status + ) + + testassoc = self.create_dataset_association( + testassoc, self.initiator, self.approving_user + ).data + if isinstance(testassoc, list): + testassoc = testassoc[0] + + put_body = { + "approval_status": "REJECTED", + } + # Act + response = self.client.put(self.url, put_body, format="json") + + # Assert + self.assertEqual(response.status_code, exp_status) + + @parameterized.expand([("APPROVED",), ("REJECTED",)]) + def test_approving_an_association_is_disallowed_if_it_was_not_pending( + self, original_approval_status + ): + # Arrange + testassoc = self.mock_dataset_association( + self.bmk_id, self.dataset_id, approval_status=original_approval_status + ) + + testassoc = self.create_dataset_association( + testassoc, self.initiator, self.approving_user + ).data + if isinstance(testassoc, list): + testassoc = testassoc[0] + + put_body = { + "approval_status": "APPROVED", + } + # Act + response = self.client.put(self.url, put_body, format="json") + + # Assert + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + def test_approving_an_association_is_allowed_if_user_is_different(self): + """This is also a permission test""" + + # Arrange + testassoc = self.mock_dataset_association( + self.bmk_id, self.dataset_id, approval_status="PENDING" + ) + + self.create_dataset_association(testassoc, self.initiator, None) + + put_body = { + "approval_status": "APPROVED", + } + # Act + response = self.client.put(self.url, put_body, format="json") + + # Assert + if self.initiator == self.actor: + exp_status = status.HTTP_400_BAD_REQUEST + else: + exp_status = status.HTTP_200_OK + + self.assertEqual(response.status_code, exp_status) + + def test_put_works_on_latest_association(self): + # Arrange + testassoc = self.mock_dataset_association( + self.bmk_id, self.dataset_id, approval_status="REJECTED" + ) + + self.create_dataset_association(testassoc, self.initiator, self.approving_user) + + testassoc2 = self.mock_dataset_association( + self.bmk_id, self.dataset_id, approval_status="PENDING" + ) + + self.create_dataset_association(testassoc2, self.initiator, self.approving_user) + + put_body = {"approval_status": "REJECTED"} + # this will fail if latest assoc is not pending. + # so, success of this test implies the PUT acts on testassoc2 (latest) + + # Act + response = self.client.put(self.url, put_body, format="json") + + # Assert + self.assertEqual(response.status_code, status.HTTP_200_OK) + + +@parameterized_class( + [ + {"actor": "api_admin"}, + ] +) +class DatasetDeleteTest(DatasetTest): + def setUp(self): + super(DatasetDeleteTest, self).setUp() + self.generic_setup() + self.url = self.url.format(self.dataset_id, self.bmk_id) + self.set_credentials(self.actor) + + def test_deletion_works_as_expected_for_single_assoc(self): + # Arrange + testassoc = self.mock_dataset_association(self.bmk_id, self.dataset_id) + self.create_dataset_association(testassoc, self.data_owner, self.bmk_owner) + + # Act + response = self.client.delete(self.url) + + # Assert + self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) + response = self.client.get(self.url) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(len(response.data), 0) + + def test_deletion_works_as_expected_for_multiple_assoc(self): + # Arrange + testassoc = self.mock_dataset_association( + self.bmk_id, self.dataset_id, approval_status="REJECTED" + ) + self.create_dataset_association(testassoc, self.data_owner, self.bmk_owner) + + testassoc2 = self.mock_dataset_association(self.bmk_id, self.dataset_id) + self.create_dataset_association(testassoc2, self.data_owner, self.bmk_owner) + + # Act + response = self.client.delete(self.url) + + # Assert + self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) + response = self.client.get(self.url) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(len(response.data), 0) + + +@parameterized_class( + [ + {"initiator": "data_owner", "association_status": "PENDING"}, + {"initiator": "bmk_owner", "association_status": "PENDING"}, + {"initiator": "data_owner", "association_status": "REJECTED"}, + {"initiator": "bmk_owner", "association_status": "REJECTED"}, + {"initiator": "data_owner", "association_status": "APPROVED"}, + {"initiator": "bmk_owner", "association_status": "APPROVED"}, + ] +) +class PermissionTest(DatasetTest): + """Test module for permissions of /datasets/{pk} endpoint + Non-permitted actions: + GET: for all users except data owner, bmk_owner and admin + DELETE: for all users except admin + PUT: for all users except data owner, bmk_owner and admin + if approval_status == APPROVED, initiated_user is not allowed + """ + + def setUp(self): + super(PermissionTest, self).setUp() + self.generic_setup() + + self.url = self.url.format(self.dataset_id, self.bmk_id) + + if self.initiator == self.data_owner: + self.approving_user = self.bmk_owner + else: + self.approving_user = self.data_owner + + testassoc = self.mock_dataset_association( + self.bmk_id, self.dataset_id, approval_status=self.association_status + ) + + self.create_dataset_association(testassoc, self.initiator, self.approving_user) + + # TODO: determine for all tests what should be 404 instead of 400 or 403 + @parameterized.expand( + [ + ("bmk_prep_mlcube_owner", status.HTTP_403_FORBIDDEN), + ("ref_mlcube_owner", status.HTTP_403_FORBIDDEN), + ("eval_mlcube_owner", status.HTTP_403_FORBIDDEN), + ("other_user", status.HTTP_403_FORBIDDEN), + (None, status.HTTP_401_UNAUTHORIZED), + ] + ) + def test_get_permissions(self, user, expected_status): + # Arrange + self.set_credentials(user) + + # Act + response = self.client.get(self.url) + + # Assert + self.assertEqual(response.status_code, expected_status) + + @parameterized.expand( + [ + ("bmk_prep_mlcube_owner", status.HTTP_403_FORBIDDEN), + ("ref_mlcube_owner", status.HTTP_403_FORBIDDEN), + ("eval_mlcube_owner", status.HTTP_403_FORBIDDEN), + ("other_user", status.HTTP_403_FORBIDDEN), + (None, status.HTTP_401_UNAUTHORIZED), + ] + ) + def test_put_permissions(self, user, expected_status): + # Arrange + + put_body = { + "approval_status": "APPROVED", + "initiated_by": 55, + "approved_at": "time", + "created_at": "time", + "modified_at": "time", + } + self.set_credentials(user) + + for key in put_body: + # Act + response = self.client.put(self.url, {key: put_body[key]}, format="json") + # Assert + self.assertEqual( + response.status_code, + expected_status, + f"{key} was modified", + ) + + def test_put_permissions_for_approval_status(self): + # Arrange + if self.association_status != "PENDING": + # skip cases that will fail and already tested before + # they are for reasons are other permissions + return + + put_body = { + "approval_status": "APPROVED", + } + self.set_credentials(self.initiator) + + # Act + response = self.client.put(self.url, put_body, format="json") + # Assert + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + # TODO: move this check to permission checks in serializers to return 403 + + @parameterized.expand( + [ + ("data_owner", status.HTTP_403_FORBIDDEN), + ("bmk_owner", status.HTTP_403_FORBIDDEN), + ("bmk_prep_mlcube_owner", status.HTTP_403_FORBIDDEN), + ("ref_mlcube_owner", status.HTTP_403_FORBIDDEN), + ("eval_mlcube_owner", status.HTTP_403_FORBIDDEN), + ("other_user", status.HTTP_403_FORBIDDEN), + (None, status.HTTP_401_UNAUTHORIZED), + ] + ) + def test_delete_permissions(self, user, expected_status): + # Arrange + self.set_credentials(user) + + # Act + response = self.client.delete(self.url) + + # Assert + self.assertEqual(response.status_code, expected_status) diff --git a/server/dataset/urls.py b/server/dataset/urls.py index 9b8ce2601..5aa23fd5a 100644 --- a/server/dataset/urls.py +++ b/server/dataset/urls.py @@ -8,6 +8,7 @@ path("", views.DatasetList.as_view()), path("/", views.DatasetDetail.as_view()), path("benchmarks/", bviews.BenchmarkDatasetList.as_view()), - path("/benchmarks/", bviews.BenchmarkDatasetApproval.as_view()), path("/benchmarks//", bviews.DatasetApproval.as_view()), + # path("/benchmarks/", bviews.DatasetBenchmarksList.as_view()), + # NOTE: when activating this endpoint later, check permissions and write tests ] diff --git a/server/debug_tests.sh b/server/debug_tests.sh new file mode 100644 index 000000000..423d24e93 --- /dev/null +++ b/server/debug_tests.sh @@ -0,0 +1,19 @@ +python manage.py test \ + mlcube.tests.test_ \ + mlcube.tests.test_pk \ + dataset.tests.test_ \ + dataset.tests.test_pk \ + benchmark.tests.test_ \ + benchmark.tests.test_pk \ + benchmark.tests.test_pk_datasets \ + benchmark.tests.test_pk_models \ + dataset.tests.test_benchmarks \ + dataset.tests.test_pk_benchmarks_bid \ + mlcube.tests.test_benchmarks \ + mlcube.tests.test_pk_benchmarks_bid \ + result.tests.test_ \ + result.tests.test_pk \ + utils.tests \ + user.tests.test_ \ + user.tests.test_pk \ + --failfast \ No newline at end of file diff --git a/server/medperf/testing_utils.py b/server/medperf/testing_utils.py index ff9d59c69..5304858d6 100644 --- a/server/medperf/testing_utils.py +++ b/server/medperf/testing_utils.py @@ -47,8 +47,141 @@ def set_user_as_admin(user_id): user_obj.save() -def setup_api_admin(): - token, user_info = create_user("apiadmin") +def setup_api_admin(username): + token, user_info = create_user(username) user_id = user_info["id"] set_user_as_admin(user_id) return token + + +def mock_mlcube(**kwargs): + data = { + "name": "testmlcube", + "git_mlcube_url": "string", + "mlcube_hash": "string", + "git_parameters_url": "string", + "parameters_hash": "string", + "image_tarball_url": "", + "image_tarball_hash": "", + "image_hash": "string", + "additional_files_tarball_url": "string", + "additional_files_tarball_hash": "string", + "state": "DEVELOPMENT", + "is_valid": True, + "metadata": {"key": "value"}, + "user_metadata": {"key2": "value2"}, + } + + for key, val in kwargs.items(): + if key not in data: + raise ValueError(f"Invalid argument: {key}") + data[key] = val + + return data + + +def mock_dataset(data_preparation_mlcube, **kwargs): + data = { + "name": "dataset", + "description": "dataset-sample", + "location": "string", + "input_data_hash": "string", + "generated_uid": "string", + "split_seed": 0, + "data_preparation_mlcube": data_preparation_mlcube, + "is_valid": True, + "state": "DEVELOPMENT", + "generated_metadata": {"key": "value"}, + "user_metadata": {"key2": "value2"}, + } + + for key, val in kwargs.items(): + if key not in data: + raise ValueError(f"Invalid argument: {key}") + data[key] = val + + return data + + +def mock_benchmark( + data_preparation_mlcube, + reference_model_mlcube, + data_evaluator_mlcube, + **kwargs, +): + data = { + "name": "string", + "description": "string", + "docs_url": "string", + "demo_dataset_tarball_url": "string", + "demo_dataset_tarball_hash": "string", + "demo_dataset_generated_uid": "string", + "data_preparation_mlcube": data_preparation_mlcube, + "reference_model_mlcube": reference_model_mlcube, + "data_evaluator_mlcube": data_evaluator_mlcube, + "metadata": {"key": "value"}, + "state": "DEVELOPMENT", + "is_valid": True, + "is_active": True, + "user_metadata": {"key2": "value2"}, + } + + for key, val in kwargs.items(): + if key not in data: + raise ValueError(f"Invalid argument: {key}") + data[key] = val + + return data + + +def mock_result(benchmark, model, dataset, **kwargs): + data = { + "name": "string", + "benchmark": benchmark, + "model": model, + "dataset": dataset, + "results": {"key": "value"}, + "metadata": {"key2": "value2"}, + "user_metadata": {"key3": "value3"}, + "approval_status": "PENDING", + "is_valid": True, + } + + for key, val in kwargs.items(): + if key not in data: + raise ValueError(f"Invalid argument: {key}") + data[key] = val + + return data + + +def mock_dataset_association(benchmark, dataset, **kwargs): + data = { + "dataset": dataset, + "benchmark": benchmark, + "metadata": {"key": "value"}, + "approval_status": "PENDING", + } + + for key, val in kwargs.items(): + if key not in data: + raise ValueError(f"Invalid argument: {key}") + data[key] = val + + return data + + +def mock_mlcube_association(benchmark, mlcube, **kwargs): + data = { + "model_mlcube": mlcube, + "benchmark": benchmark, + "metadata": {"key": "value"}, + "approval_status": "PENDING", + } + + for key, val in kwargs.items(): + if key not in data: + raise ValueError(f"Invalid argument: {key}") + data[key] = val + + return data diff --git a/server/medperf/tests.py b/server/medperf/tests.py index ed4c3d518..360b26e75 100644 --- a/server/medperf/tests.py +++ b/server/medperf/tests.py @@ -1,11 +1,26 @@ from django.test import TestCase from django.test import override_settings -from .testing_utils import PUBLIC_KEY, setup_api_admin, create_user +from django.conf import settings +from rest_framework.test import APIClient +from rest_framework import status +from .testing_utils import ( + PUBLIC_KEY, + setup_api_admin, + create_user, + mock_benchmark, + mock_dataset, + mock_mlcube, + mock_result, + mock_dataset_association, + mock_mlcube_association, +) class MedPerfTest(TestCase): """Common settings module for MedPerf APIs""" + # TODO: for all DELETE tests, we should revisit when we allow users + # to delete. We should test the effects of model.CASCADE and model.PROTECT def setUp(self): SIMPLE_JWT = { "ALGORITHM": "RS256", @@ -24,8 +39,170 @@ def setUp(self): SECURE_SSL_REDIRECT=False, SIMPLE_JWT=SIMPLE_JWT ) settings_manager.enable() - self.admin_token = setup_api_admin() self.addCleanup(settings_manager.disable) + self.tokens = {} + self.current_user = None + + self.api_admin = "api_admin" + admin_token = setup_api_admin(self.api_admin) + self.tokens[self.api_admin] = admin_token + self.api_prefix = "/api/" + settings.SERVER_API_VERSION + self.client = APIClient() + self.mock_benchmark = mock_benchmark + self.mock_dataset = mock_dataset + self.mock_mlcube = mock_mlcube + self.mock_result = mock_result + self.mock_dataset_association = mock_dataset_association + self.mock_mlcube_association = mock_mlcube_association + def create_user(self, username): - return create_user(username) + token, _ = create_user(username) + self.tokens[username] = token + + def set_credentials(self, username): + self.current_user = username + if username is None: + self.client.credentials() + else: + token = self.tokens[username] + self.client.credentials(HTTP_AUTHORIZATION="Bearer " + token) + + def __create_asset(self, data, url): + response = self.client.post(url, data, format="json") + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + return response + + def create_benchmark(self, data, target_approval_status="APPROVED"): + # preserve current credentials + backup_user = self.current_user + + if target_approval_status != "PENDING": + data["state"] = "OPERATION" + response = self.__create_asset(data, self.api_prefix + "/benchmarks/") + if target_approval_status != "PENDING": + self.set_credentials(self.api_admin) + uid = response.data["id"] + url = self.api_prefix + "/benchmarks/{0}/".format(uid) + response = self.client.put( + url, {"approval_status": target_approval_status}, format="json" + ) + self.assertEqual(response.status_code, status.HTTP_200_OK) + response = self.client.get(url) + self.assertEqual(response.status_code, status.HTTP_200_OK) + + # restore user + self.set_credentials(backup_user) + return response + + def create_dataset(self, data): + return self.__create_asset(data, self.api_prefix + "/datasets/") + + def create_mlcube(self, data): + return self.__create_asset(data, self.api_prefix + "/mlcubes/") + + def create_result(self, data): + return self.__create_asset(data, self.api_prefix + "/results/") + + def create_dataset_association( + self, data, initiating_user, approving_user, set_status_directly=False + ): + # preserve current credentials + backup_user = self.current_user + + self.set_credentials(initiating_user) + target_approval_status = data["approval_status"] + if not set_status_directly: + data["approval_status"] = "PENDING" + response = self.__create_asset(data, self.api_prefix + "/datasets/benchmarks/") + if target_approval_status != "PENDING" and not set_status_directly: + dataset_id = data["dataset"] + benchmark_id = data["benchmark"] + url = self.api_prefix + f"/datasets/{dataset_id}/benchmarks/{benchmark_id}/" + self.set_credentials(approving_user) + response = self.client.put( + url, {"approval_status": target_approval_status}, format="json" + ) + self.assertEqual(response.status_code, status.HTTP_200_OK) + response = self.client.get(url) + self.assertEqual(response.status_code, status.HTTP_200_OK) + + # restore user + self.set_credentials(backup_user) + + return response + + def create_mlcube_association( + self, data, initiating_user, approving_user, set_status_directly=False + ): + # preserve current credentials + backup_user = self.current_user + + self.set_credentials(initiating_user) + target_approval_status = data["approval_status"] + if not set_status_directly: + data["approval_status"] = "PENDING" + response = self.__create_asset(data, self.api_prefix + "/mlcubes/benchmarks/") + if target_approval_status != "PENDING" and not set_status_directly: + mlcube_id = data["model_mlcube"] + benchmark_id = data["benchmark"] + url = self.api_prefix + f"/mlcubes/{mlcube_id}/benchmarks/{benchmark_id}/" + self.set_credentials(approving_user) + response = self.client.put( + url, {"approval_status": target_approval_status}, format="json" + ) + self.assertEqual(response.status_code, status.HTTP_200_OK) + response = self.client.get(url) + self.assertEqual(response.status_code, status.HTTP_200_OK) + + # restore user + self.set_credentials(backup_user) + + return response + + def shortcut_create_benchmark( + self, + prep_mlcube_owner, + ref_mlcube_owner, + eval_mlcube_owner, + bmk_owner, + target_approval_status="APPROVED", + prep_mlcube_kwargs={}, + ref_mlcube_kwargs={}, + eval_mlcube_kwargs={}, + **kwargs, + ): + # preserve current credentials + backup_user = self.current_user + + # create mlcubes + self.set_credentials(prep_mlcube_owner) + prep = self.mock_mlcube(name="prep", mlcube_hash="prep", state="OPERATION") + prep.update(prep_mlcube_kwargs) + prep = self.create_mlcube(prep).data + + self.set_credentials(ref_mlcube_owner) + ref_model = self.mock_mlcube( + name="ref_model", mlcube_hash="ref_model", state="OPERATION" + ) + ref_model.update(ref_mlcube_kwargs) + ref_model = self.create_mlcube(ref_model).data + + self.set_credentials(eval_mlcube_owner) + eval = self.mock_mlcube(name="eval", mlcube_hash="eval", state="OPERATION") + eval.update(eval_mlcube_kwargs) + eval = self.create_mlcube(eval).data + + # create benchmark + self.set_credentials(bmk_owner) + benchmark = self.mock_benchmark( + prep["id"], ref_model["id"], eval["id"], **kwargs + ) + benchmark = self.create_benchmark( + benchmark, target_approval_status=target_approval_status + ).data + + # restore user + self.set_credentials(backup_user) + + return prep, ref_model, eval, benchmark diff --git a/server/mlcube/migrations/0002_alter_mlcube_unique_together.py b/server/mlcube/migrations/0002_alter_mlcube_unique_together.py new file mode 100644 index 000000000..74dd5f91f --- /dev/null +++ b/server/mlcube/migrations/0002_alter_mlcube_unique_together.py @@ -0,0 +1,17 @@ +# Generated by Django 3.2.20 on 2023-11-16 23:12 + +from django.db import migrations + + +class Migration(migrations.Migration): + + dependencies = [ + ('mlcube', '0001_initial'), + ] + + operations = [ + migrations.AlterUniqueTogether( + name='mlcube', + unique_together={('image_tarball_hash', 'image_hash', 'additional_files_tarball_hash', 'mlcube_hash', 'parameters_hash')}, + ), + ] diff --git a/server/mlcube/models.py b/server/mlcube/models.py index 74e5a7f37..8835b31cb 100644 --- a/server/mlcube/models.py +++ b/server/mlcube/models.py @@ -36,14 +36,10 @@ def __str__(self): class Meta: unique_together = ( ( - "image_tarball_url", "image_tarball_hash", "image_hash", - "additional_files_tarball_url", "additional_files_tarball_hash", - "git_mlcube_url", "mlcube_hash", - "git_parameters_url", "parameters_hash", ), ) diff --git a/server/mlcube/serializers.py b/server/mlcube/serializers.py index 8abb2004e..16bd63fb7 100644 --- a/server/mlcube/serializers.py +++ b/server/mlcube/serializers.py @@ -2,6 +2,54 @@ from .models import MlCube +def validate_optional_mlcube_components(data): + git_parameters_url = data.get("git_parameters_url", "") + parameters_hash = data.get("parameters_hash", "") + + additional_files_tarball_url = data.get("additional_files_tarball_url", "") + additional_files_tarball_hash = data.get("additional_files_tarball_hash", "") + + image_hash = data.get("image_hash", "") + + image_tarball_url = data.get("image_tarball_url", "") + image_tarball_hash = data.get("image_tarball_hash", "") + + # validate nonblank parameters file hash + if git_parameters_url and not parameters_hash: + raise serializers.ValidationError("Parameters require file hash") + + if not git_parameters_url and parameters_hash: + raise serializers.ValidationError("Paramters hash was provided without URL") + + # validate nonblank additional files hash + if additional_files_tarball_url and not additional_files_tarball_hash: + raise serializers.ValidationError("Additional files require file hash") + + if not additional_files_tarball_url and additional_files_tarball_hash: + raise serializers.ValidationError( + "Additional files hash was provided without URL" + ) + + # validate images attributes. + if not image_hash and not image_tarball_hash: + raise serializers.ValidationError( + "Image hash or Image tarball hash must be provided" + ) + if image_hash and image_tarball_hash: + raise serializers.ValidationError( + "Image hash and Image tarball hash can't be provided at the same time" + ) + if image_tarball_url and not image_tarball_hash: + raise serializers.ValidationError( + "Providing Image tarball requires providing image tarball hash" + ) + + if not image_tarball_url and image_tarball_hash: + raise serializers.ValidationError( + "image tarball hash should not be provided if no image tarball url is provided" + ) + + class MlCubeSerializer(serializers.ModelSerializer): class Meta: model = MlCube @@ -9,44 +57,7 @@ class Meta: read_only_fields = ["owner"] def validate(self, data): - git_parameters_url = data["git_parameters_url"] - parameters_hash = data["parameters_hash"] - - additional_files_tarball_url = data["additional_files_tarball_url"] - additional_files_tarball_hash = data["additional_files_tarball_hash"] - - image_hash = data["image_hash"] - - image_tarball_url = data["image_tarball_url"] - image_tarball_hash = data["image_tarball_hash"] - - # validate nonblank parameters file hash - if git_parameters_url and not parameters_hash: - raise serializers.ValidationError("Parameters require file hash") - - # validate nonblank additional files hash - if additional_files_tarball_url and not additional_files_tarball_hash: - raise serializers.ValidationError("Additional files require file hash") - - # validate images attributes. - if not image_hash and not image_tarball_hash: - raise serializers.ValidationError( - "Image hash or Image tarball hash must be provided" - ) - if image_hash and image_tarball_hash: - raise serializers.ValidationError( - "Image hash and Image tarball hash can't be provided at the same time" - ) - if image_tarball_url and not image_tarball_hash: - raise serializers.ValidationError( - "Providing Image tarball requires providing image tarball hash" - ) - - if not image_tarball_url and image_tarball_hash: - raise serializers.ValidationError( - "image tarball hash should not be provided if no image tarball url is provided" - ) - + validate_optional_mlcube_components(data) return data @@ -72,4 +83,19 @@ def validate(self, data): raise serializers.ValidationError( "User cannot update non editable fields in Operation mode" ) + + updated_dict = {} + for key in [ + "git_parameters_url", + "parameters_hash", + "additional_files_tarball_url", + "additional_files_tarball_hash", + "image_hash", + "image_tarball_url", + "image_tarball_hash", + ]: + updated_dict[key] = data.get(key, getattr(self.instance, key)) + + validate_optional_mlcube_components(updated_dict) + return data diff --git a/server/mlcube/tests.py b/server/mlcube/tests.py deleted file mode 100644 index e732826d5..000000000 --- a/server/mlcube/tests.py +++ /dev/null @@ -1,101 +0,0 @@ -from django.conf import settings -from django.contrib.auth import get_user_model -from rest_framework.test import APIClient -from rest_framework import status - -from medperf.tests import MedPerfTest - -User = get_user_model() - - -class MlCubeTest(MedPerfTest): - """Test module for MLCube APIs""" - - def setUp(self): - super(MlCubeTest, self).setUp() - username = "mlcubeowner" - token, _ = self.create_user(username) - self.api_prefix = "/api/" + settings.SERVER_API_VERSION - self.client = APIClient() - self.token = token - self.client.credentials(HTTP_AUTHORIZATION="Bearer " + self.token) - - def test_unauthenticated_user(self): - client = APIClient() - response = client.get(self.api_prefix + "/mlcubes/1/") - self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) - response = client.delete(self.api_prefix + "/mlcubes/1/") - self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) - response = client.put(self.api_prefix + "/mlcubes/1/") - self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) - response = client.post(self.api_prefix + "/mlcubes/", {}) - self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) - response = client.get(self.api_prefix + "/mlcubes/") - self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) - - def test_crud_user(self): - testmlcube = { - "name": "testmlcube", - "git_mlcube_url": "string", - "mlcube_hash": "string", - "git_parameters_url": "string", - "parameters_hash": "string", - "image_tarball_url": "", - "image_tarball_hash": "", - "image_hash": "string", - "additional_files_tarball_url": "string", - "additional_files_tarball_hash": "string", - "metadata": {"key": "value"}, - } - - response = self.client.post( - self.api_prefix + "/mlcubes/", testmlcube, format="json" - ) - self.assertEqual(response.status_code, status.HTTP_201_CREATED) - uid = response.data["id"] - response = self.client.get(self.api_prefix + "/mlcubes/{0}/".format(uid)) - self.assertEqual(response.status_code, status.HTTP_200_OK) - - for k, v in response.data.items(): - if k in testmlcube: - self.assertEqual(testmlcube[k], v) - - response = self.client.get(self.api_prefix + "/mlcubes/") - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(len(response.data["results"]), 1) - - newmlcube = { - "name": "newtestmlcube", - "git_mlcube_url": "newstring", - "git_parameters_url": "newstring", - "tarball_url": "newstring", - "tarball_hash": "newstring", - "image_hash": "string", - "metadata": {"newkey": "newvalue"}, - } - - response = self.client.put( - self.api_prefix + "/mlcubes/{0}/".format(uid), newmlcube, format="json" - ) - self.assertEqual(response.status_code, status.HTTP_200_OK) - response = self.client.get(self.api_prefix + "/mlcubes/{0}/".format(uid)) - self.assertEqual(response.status_code, status.HTTP_200_OK) - - for k, v in response.data.items(): - if k in newmlcube: - self.assertEqual(newmlcube[k], v) - - # TODO Revisit when delete permissions are fixed - # response = self.client.delete(self.api_prefix + "/mlcubes/{0}/".format(uid)) - # self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) - - # response = self.client.get(self.api_prefix + "/mlcubes/{0}/".format(uid)) - # self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) - - def test_invalid_mlcube(self): - invalid_id = 9999 - response = self.client.get(self.api_prefix + "/mlcubes/{0}/".format(invalid_id)) - self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) - - def test_optional_fields(self): - pass diff --git a/server/result/tests.py b/server/mlcube/tests/__init__.py similarity index 100% rename from server/result/tests.py rename to server/mlcube/tests/__init__.py diff --git a/server/mlcube/tests/test_.py b/server/mlcube/tests/test_.py new file mode 100644 index 000000000..311cb640f --- /dev/null +++ b/server/mlcube/tests/test_.py @@ -0,0 +1,351 @@ +from rest_framework import status + +from medperf.tests import MedPerfTest + +from parameterized import parameterized, parameterized_class + + +class MlCubeTest(MedPerfTest): + def generic_setup(self): + # setup users + mlcube_owner = "mlcube_owner" + self.create_user(mlcube_owner) + + # setup globals + self.mlcube_owner = mlcube_owner + self.url = self.api_prefix + "/mlcubes/" + self.set_credentials(None) + + +@parameterized_class( + [ + {"actor": "mlcube_owner"}, + ] +) +class MlCubePostTest(MlCubeTest): + """Test module for POST /mlcubes""" + + def setUp(self): + super(MlCubePostTest, self).setUp() + self.generic_setup() + self.set_credentials(self.actor) + + def test_created_mlcube_fields_are_saved_as_expected(self): + """Testing the valid scenario""" + # Arrange + testmlcube = self.mock_mlcube() + get_mlcube_url = self.api_prefix + "/mlcubes/{0}/" + + # Act + response = self.client.post(self.url, testmlcube, format="json") + + # Assert + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + uid = response.data["id"] + response = self.client.get(get_mlcube_url.format(uid)) + + self.assertEqual( + response.status_code, status.HTTP_200_OK, "mlcube retreival failed" + ) + + for k, v in response.data.items(): + if k in testmlcube: + self.assertEqual(testmlcube[k], v, f"Unexpected value for {k}") + + @parameterized.expand([(True,), (False,)]) + def test_creation_of_duplicate_name_gets_rejected(self, different_name): + """Testing the model fields rules""" + # Arrange + testmlcube = self.mock_mlcube() + self.create_mlcube(testmlcube) + testmlcube["image_hash"] = "new string" + + if different_name: + testmlcube["name"] = "different name" + + # Act + response = self.client.post(self.url, testmlcube, format="json") + + # Assert + if different_name: + exp_status = status.HTTP_201_CREATED + else: + exp_status = status.HTTP_400_BAD_REQUEST + + self.assertEqual(response.status_code, exp_status) + + @parameterized.expand( + [ + ("image_hash",), + ("additional_files_tarball_hash",), + ("mlcube_hash",), + ("parameters_hash",), + (None,), + ] + ) + def test_creation_of_duplicate_mlcubes_with_image_hash(self, field): + """Testing the model unique_together constraint""" + # Arrange + testmlcube = self.mock_mlcube() + self.create_mlcube(testmlcube) + testmlcube["name"] = "new name" + + if field is not None: + testmlcube[field] = "different string" + + # Act + response = self.client.post(self.url, testmlcube, format="json") + + # Assert + if field is not None: + exp_status = status.HTTP_201_CREATED + else: + exp_status = status.HTTP_400_BAD_REQUEST + + self.assertEqual( + response.status_code, exp_status, f"test failed with field {field}" + ) + + @parameterized.expand( + [ + ("image_tarball_hash",), + ("additional_files_tarball_hash",), + ("mlcube_hash",), + ("parameters_hash",), + (None,), + ] + ) + def test_creation_of_duplicate_mlcubes_with_image_tarball(self, field): + """Testing the model unique_together constraint""" + # Arrange + testmlcube = self.mock_mlcube( + image_hash="", image_tarball_url="string", image_tarball_hash="string" + ) + self.create_mlcube(testmlcube) + testmlcube["name"] = "new name" + + if field is not None: + testmlcube[field] = "different string" + + # Act + response = self.client.post(self.url, testmlcube, format="json") + + # Assert + if field is not None: + exp_status = status.HTTP_201_CREATED + else: + exp_status = status.HTTP_400_BAD_REQUEST + + self.assertEqual( + response.status_code, exp_status, f"test failed with field {field}" + ) + + def test_default_values_are_as_expected(self): + """Testing the model fields rules""" + # Arrange + default_values = { + "state": "DEVELOPMENT", + "is_valid": True, + "metadata": {}, + "user_metadata": {}, + "git_parameters_url": "", + "image_tarball_url": "", + "additional_files_tarball_url": "", + } + testmlcube = self.mock_mlcube() + for key in default_values: + if key in testmlcube: + del testmlcube[key] + + # in order to allow empty urls + testmlcube.update( + { + "parameters_hash": "", + "image_tarball_hash": "", + "additional_files_tarball_hash": "", + } + ) + # Act + response = self.client.post(self.url, testmlcube, format="json") + + # Assert + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + for key, val in default_values.items(): + self.assertEqual( + val, response.data[key], f"unexpected default value for {key}" + ) + + def test_readonly_fields(self): + """Testing the serializer rules""" + # Arrange + readonly = { + "owner": 10, + "created_at": "some time", + "modified_at": "some time", + } + testmlcube = self.mock_mlcube() + testmlcube.update(readonly) + + # Act + response = self.client.post(self.url, testmlcube, format="json") + + # Assert + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + for key, val in readonly.items(): + self.assertNotEqual( + val, response.data[key], f"readonly field {key} was modified" + ) + + @parameterized.expand([(True,), (False,)]) + def test_parameters_file_should_have_a_hash(self, url_provided): + """Testing the serializer rules""" + # Arrange + testmlcube = self.mock_mlcube(parameters_hash="") + if not url_provided: + testmlcube["git_parameters_url"] = "" + + # Act + response = self.client.post(self.url, testmlcube, format="json") + + # Assert + if not url_provided: + exp_status = status.HTTP_201_CREATED + else: + exp_status = status.HTTP_400_BAD_REQUEST + self.assertEqual(response.status_code, exp_status) + + @parameterized.expand([(True,), (False,)]) + def test_additional_files_should_have_a_hash(self, url_provided): + """Testing the serializer rules""" + + # Arrange + testmlcube = self.mock_mlcube(additional_files_tarball_hash="") + if not url_provided: + testmlcube["additional_files_tarball_url"] = "" + + # Act + response = self.client.post(self.url, testmlcube, format="json") + + # Assert + if not url_provided: + exp_status = status.HTTP_201_CREATED + else: + exp_status = status.HTTP_400_BAD_REQUEST + self.assertEqual(response.status_code, exp_status) + + @parameterized.expand( + [ + (False, False, False, status.HTTP_400_BAD_REQUEST), + (False, False, True, status.HTTP_400_BAD_REQUEST), + (False, True, False, status.HTTP_400_BAD_REQUEST), + (False, True, True, status.HTTP_201_CREATED), + (True, False, False, status.HTTP_201_CREATED), + (True, False, True, status.HTTP_400_BAD_REQUEST), + (True, True, False, status.HTTP_400_BAD_REQUEST), + (True, True, True, status.HTTP_400_BAD_REQUEST), + ] + ) + def test_image_fields_cases( + self, image_hash, image_tarball_url, image_tarball_hash, exp_status + ): + """Testing the serializer rules + The rules are simply stating that either "image_hash", or both the + "image_tarball_url" and "image_tarball_hash", should be provided.""" + + # Arrange + testmlcube = self.mock_mlcube( + image_hash="string" if image_hash else "", + image_tarball_url="string" if image_tarball_url else "", + image_tarball_hash="string" if image_tarball_hash else "", + ) + # Act + response = self.client.post(self.url, testmlcube, format="json") + + # Assert + self.assertEqual( + response.status_code, + exp_status, + f"test failed with image_hash={image_hash}, " + f"image_tarball_url={image_tarball_url}, " + f"image_tarball_hash={image_tarball_hash}", + ) + + +@parameterized_class( + [ + {"actor": "mlcube_owner"}, + {"actor": "other_user"}, + ] +) +class MlCubeGetListTest(MlCubeTest): + """Test module for GET /mlcubes/ endpoint""" + + def setUp(self): + super(MlCubeGetListTest, self).setUp() + self.generic_setup() + self.set_credentials(self.mlcube_owner) + testmlcube = self.mock_mlcube() + testmlcube = self.create_mlcube(testmlcube).data + + other_user = "other_user" + self.create_user("other_user") + self.other_user = other_user + + self.testmlcube = testmlcube + self.set_credentials(self.actor) + + def test_generic_get_mlcube_list(self): + # Arrange + mlcube_id = self.testmlcube["id"] + + # Act + response = self.client.get(self.url) + + # Assert + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(len(response.data["results"]), 1) + self.assertEqual(response.data["results"][0]["id"], mlcube_id) + + +class PermissionTest(MlCubeTest): + """Test module for permissions of /mlcubes/ endpoint + Non-permitted actions: both GET and POST for unauthenticated users.""" + + def setUp(self): + super(PermissionTest, self).setUp() + self.generic_setup() + self.set_credentials(self.mlcube_owner) + testmlcube = self.mock_mlcube() + self.testmlcube = testmlcube + + @parameterized.expand( + [ + (None, status.HTTP_401_UNAUTHORIZED), + ] + ) + def test_get_permissions(self, user, exp_status): + # Arrange + self.set_credentials(self.mlcube_owner) + self.create_mlcube(self.testmlcube) + self.set_credentials(user) + + # Act + response = self.client.get(self.url) + + # Assert + self.assertEqual(response.status_code, exp_status) + + @parameterized.expand( + [ + (None, status.HTTP_401_UNAUTHORIZED), + ] + ) + def test_post_permissions(self, user, exp_status): + # Arrange + self.set_credentials(user) + + # Act + response = self.client.post(self.url, self.testmlcube, format="json") + + # Assert + self.assertEqual(response.status_code, exp_status) diff --git a/server/mlcube/tests/test_benchmarks.py b/server/mlcube/tests/test_benchmarks.py new file mode 100644 index 000000000..bd07594d4 --- /dev/null +++ b/server/mlcube/tests/test_benchmarks.py @@ -0,0 +1,383 @@ +from rest_framework import status + +from medperf.tests import MedPerfTest + +from parameterized import parameterized, parameterized_class + + +class MlCubeBenchmarksTest(MedPerfTest): + def generic_setup(self): + # setup users + mlcube_owner = "mlcube_owner" + bmk_owner = "bmk_owner" + bmk_prep_mlcube_owner = "bmk_prep_mlcube_owner" + ref_mlcube_owner = "ref_mlcube_owner" + eval_mlcube_owner = "eval_mlcube_owner" + other_user = "other_user" + + self.create_user(mlcube_owner) + self.create_user(bmk_owner) + self.create_user(bmk_prep_mlcube_owner) + self.create_user(ref_mlcube_owner) + self.create_user(eval_mlcube_owner) + self.create_user(other_user) + + # setup globals + self.mlcube_owner = mlcube_owner + self.bmk_owner = bmk_owner + self.bmk_prep_mlcube_owner = bmk_prep_mlcube_owner + self.ref_mlcube_owner = ref_mlcube_owner + self.eval_mlcube_owner = eval_mlcube_owner + self.other_user = other_user + + self.url = self.api_prefix + "/mlcubes/benchmarks/" + self.set_credentials(None) + + +@parameterized_class( + [ + {"actor": "mlcube_owner"}, + {"actor": "bmk_owner"}, + ], +) +class GenericMlCubeBenchmarksPostTest(MlCubeBenchmarksTest): + """Test module for POST /mlcubes/benchmarks""" + + def setUp(self): + super(GenericMlCubeBenchmarksPostTest, self).setUp() + self.generic_setup() + _, _, _, benchmark = self.shortcut_create_benchmark( + self.bmk_prep_mlcube_owner, + self.ref_mlcube_owner, + self.eval_mlcube_owner, + self.bmk_owner, + ) + self.set_credentials(self.mlcube_owner) + mlcube = self.mock_mlcube(state="OPERATION") + mlcube = self.create_mlcube(mlcube).data + + self.bmk_id = benchmark["id"] + self.mlcube_id = mlcube["id"] + self.set_credentials(self.actor) + + def test_created_association_fields_are_saved_as_expected(self): + """Testing the valid scenario""" + # Arrange + testassoc = self.mock_mlcube_association(self.bmk_id, self.mlcube_id) + get_association_url = ( + self.api_prefix + f"/mlcubes/{self.mlcube_id}/benchmarks/{self.bmk_id}/" + ) + + # Act + response = self.client.post(self.url, testassoc, format="json") + + # Assert + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + response = self.client.get(get_association_url) + + self.assertEqual( + response.status_code, + status.HTTP_200_OK, + "association retrieval failed", + ) + + for k, v in response.data[0].items(): + if k in testassoc: + self.assertEqual(testassoc[k], v, f"unexpected value for {k}") + + def test_default_values_are_as_expected(self): + """Testing the model fields rules""" + + # Arrange + default_values = { + "approved_at": None, + "approval_status": "PENDING", + "priority": 0, + } + testassoc = self.mock_mlcube_association(self.bmk_id, self.mlcube_id) + + for key in default_values: + if key in testassoc: + del testassoc[key] + + # Act + response = self.client.post(self.url, testassoc, format="json") + + # Assert + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + for key, val in default_values.items(): + self.assertEqual( + val, response.data[key], f"unexpected default value for {key}" + ) + + def test_readonly_fields(self): + """Testing the serializer rules""" + + # Arrange + readonly = { + "initiated_by": 55, + "created_at": "time", + "modified_at": "time2", + "approved_at": "time3", + "priority": 555, + } + testassoc = self.mock_mlcube_association(self.bmk_id, self.mlcube_id) + + testassoc.update(readonly) + + # Act + response = self.client.post(self.url, testassoc, format="json") + + # Assert + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + for key, val in readonly.items(): + self.assertNotEqual( + val, response.data[key], f"readonly field {key} was modified" + ) + + +@parameterized_class( + [ + {"actor": "mlcube_owner"}, + {"actor": "bmk_owner"}, + ] +) +class SerializersMlCubeBenchmarksPostTest(MlCubeBenchmarksTest): + """Test module for serializers rules of POST /mlcubes/benchmarks""" + + def setUp(self): + super(SerializersMlCubeBenchmarksPostTest, self).setUp() + self.generic_setup() + self.set_credentials(self.actor) + + @parameterized.expand([("DEVELOPMENT",), ("OPERATION",)]) + def test_association_with_unapproved_benchmark(self, state): + # NOTE: the serializer checks also if benchmark is operation + # however, an approved benchmark cannot be in development + # (i.e. there is a redundant check that we can't test) + + # Arrange + _, _, _, benchmark = self.shortcut_create_benchmark( + self.bmk_prep_mlcube_owner, + self.ref_mlcube_owner, + self.eval_mlcube_owner, + self.bmk_owner, + target_approval_status="PENDING", + state=state, + ) + mlcube = self.mock_mlcube(state="OPERATION") + + self.set_credentials(self.mlcube_owner) + mlcube = self.create_mlcube(mlcube).data + self.set_credentials(self.actor) + + testassoc = self.mock_mlcube_association(benchmark["id"], mlcube["id"]) + + # Act + response = self.client.post(self.url, testassoc, format="json") + + # Assert + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + def test_association_failure_with_development_mlcube(self): + # Arrange + _, _, _, benchmark = self.shortcut_create_benchmark( + self.bmk_prep_mlcube_owner, + self.ref_mlcube_owner, + self.eval_mlcube_owner, + self.bmk_owner, + ) + mlcube = self.mock_mlcube(state="DEVELOPMENT") + + self.set_credentials(self.mlcube_owner) + mlcube = self.create_mlcube(mlcube).data + self.set_credentials(self.actor) + + testassoc = self.mock_mlcube_association(benchmark["id"], mlcube["id"]) + + # Act + response = self.client.post(self.url, testassoc, format="json") + + # Assert + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + @parameterized.expand( + [ + ("PENDING", status.HTTP_201_CREATED), + ("APPROVED", status.HTTP_400_BAD_REQUEST), + ("REJECTED", status.HTTP_400_BAD_REQUEST), + ] + ) + def test_specified_association_approval_status_while_not_having_previous_association( + self, approval_status, exp_statuscode + ): + # Arrange + _, _, _, benchmark = self.shortcut_create_benchmark( + self.bmk_prep_mlcube_owner, + self.ref_mlcube_owner, + self.eval_mlcube_owner, + self.bmk_owner, + ) + mlcube = self.mock_mlcube(state="OPERATION") + + self.set_credentials(self.mlcube_owner) + mlcube = self.create_mlcube(mlcube).data + self.set_credentials(self.actor) + + testassoc = self.mock_mlcube_association( + benchmark["id"], mlcube["id"], approval_status=approval_status + ) + + # Act + response = self.client.post(self.url, testassoc, format="json") + + # Assert + self.assertEqual( + response.status_code, + exp_statuscode, + f"test failed for approval_status={approval_status}", + ) + + @parameterized.expand( + [ + ("PENDING", "PENDING", status.HTTP_400_BAD_REQUEST), + ("PENDING", "APPROVED", status.HTTP_400_BAD_REQUEST), + ("PENDING", "REJECTED", status.HTTP_400_BAD_REQUEST), + ("APPROVED", "PENDING", status.HTTP_400_BAD_REQUEST), + ("APPROVED", "APPROVED", status.HTTP_400_BAD_REQUEST), + ("APPROVED", "REJECTED", status.HTTP_201_CREATED), + ("REJECTED", "PENDING", status.HTTP_201_CREATED), + ("REJECTED", "APPROVED", status.HTTP_400_BAD_REQUEST), + ("REJECTED", "REJECTED", status.HTTP_400_BAD_REQUEST), + ] + ) + def test_specified_association_approval_status_while_having_previous_association( + self, prev_approval_status, new_approval_status, exp_statuscode + ): + # Arrange + _, _, _, benchmark = self.shortcut_create_benchmark( + self.bmk_prep_mlcube_owner, + self.ref_mlcube_owner, + self.eval_mlcube_owner, + self.bmk_owner, + ) + mlcube = self.mock_mlcube(state="OPERATION") + + self.set_credentials(self.mlcube_owner) + mlcube = self.create_mlcube(mlcube).data + self.set_credentials(self.actor) + + prev_assoc = self.mock_mlcube_association( + benchmark["id"], mlcube["id"], approval_status=prev_approval_status + ) + self.create_mlcube_association(prev_assoc, self.mlcube_owner, self.bmk_owner) + + new_assoc = self.mock_mlcube_association( + benchmark["id"], mlcube["id"], approval_status=new_approval_status + ) + + # Act + response = self.client.post(self.url, new_assoc, format="json") + + # Assert + self.assertEqual( + response.status_code, + exp_statuscode, + f"test failed when creating {new_approval_status} association " + f"with a previous {prev_approval_status} one", + ) + + def test_creation_of_rejected_association_sets_approval_time(self): + # Arrange + _, _, _, benchmark = self.shortcut_create_benchmark( + self.bmk_prep_mlcube_owner, + self.ref_mlcube_owner, + self.eval_mlcube_owner, + self.bmk_owner, + ) + mlcube = self.mock_mlcube(state="OPERATION") + + self.set_credentials(self.mlcube_owner) + mlcube = self.create_mlcube(mlcube).data + self.set_credentials(self.actor) + + prev_assoc = self.mock_mlcube_association( + benchmark["id"], mlcube["id"], approval_status="APPROVED" + ) + self.create_mlcube_association(prev_assoc, self.mlcube_owner, self.bmk_owner) + + new_assoc = self.mock_mlcube_association( + benchmark["id"], mlcube["id"], approval_status="REJECTED" + ) + + # Act + response = self.client.post(self.url, new_assoc, format="json") + + # Assert + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + self.assertNotEqual(response.data["approved_at"], None) + + def test_creation_of_pending_association_by_same_user_is_auto_approved(self): + # Arrange + _, _, _, benchmark = self.shortcut_create_benchmark( + self.bmk_prep_mlcube_owner, + self.ref_mlcube_owner, + self.eval_mlcube_owner, + self.actor, + ) + mlcube = self.mock_mlcube(state="OPERATION") + mlcube = self.create_mlcube(mlcube).data + + testassoc = self.mock_mlcube_association(benchmark["id"], mlcube["id"]) + + # Act + response = self.client.post(self.url, testassoc, format="json") + + # Assert + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + self.assertEqual(response.data["approval_status"], "APPROVED") + self.assertNotEqual(response.data["approved_at"], None) + + +class PermissionTest(MlCubeBenchmarksTest): + """Test module for permissions of /mlcubes/benchmarks endpoint + Non-permitted actions: POST for all users except mlcube_owner, bmk_owner, and admins + """ + + def setUp(self): + super(PermissionTest, self).setUp() + self.generic_setup() + + _, _, _, benchmark = self.shortcut_create_benchmark( + self.bmk_prep_mlcube_owner, + self.ref_mlcube_owner, + self.eval_mlcube_owner, + self.bmk_owner, + ) + self.set_credentials(self.mlcube_owner) + mlcube = self.mock_mlcube(state="OPERATION") + mlcube = self.create_mlcube(mlcube).data + + self.bmk_id = benchmark["id"] + self.mlcube_id = mlcube["id"] + self.set_credentials(None) + + @parameterized.expand( + [ + ("bmk_prep_mlcube_owner", status.HTTP_403_FORBIDDEN), + ("ref_mlcube_owner", status.HTTP_403_FORBIDDEN), + ("eval_mlcube_owner", status.HTTP_403_FORBIDDEN), + ("other_user", status.HTTP_403_FORBIDDEN), + (None, status.HTTP_401_UNAUTHORIZED), + ] + ) + def test_post_permissions(self, user, exp_status): + # Arrange + self.set_credentials(user) + assoc = self.mock_mlcube_association(self.bmk_id, self.mlcube_id) + + # Act + response = self.client.post(self.url, assoc, format="json") + + # Assert + self.assertEqual(response.status_code, exp_status) diff --git a/server/mlcube/tests/test_pk.py b/server/mlcube/tests/test_pk.py new file mode 100644 index 000000000..5e4b6f080 --- /dev/null +++ b/server/mlcube/tests/test_pk.py @@ -0,0 +1,590 @@ +from rest_framework import status + +from medperf.tests import MedPerfTest + +from parameterized import parameterized, parameterized_class + + +class MlCubeTest(MedPerfTest): + def generic_setup(self): + # setup users + mlcube_owner = "mlcube_owner" + other_user = "other_user" + + self.create_user(mlcube_owner) + self.create_user(other_user) + + # setup globals + self.mlcube_owner = mlcube_owner + self.other_user = other_user + self.url = self.api_prefix + "/mlcubes/{0}/" + self.set_credentials(None) + + +@parameterized_class( + [ + {"actor": "mlcube_owner"}, + {"actor": "other_user"}, + ] +) +class MlCubeGetTest(MlCubeTest): + """Test module for GET /mlcubes/""" + + def setUp(self): + super(MlCubeGetTest, self).setUp() + self.generic_setup() + self.set_credentials(self.mlcube_owner) + testmlcube = self.mock_mlcube() + testmlcube = self.create_mlcube(testmlcube).data + self.testmlcube = testmlcube + self.set_credentials(self.actor) + + def test_generic_get_mlcube(self): + # Arrange + mlcube_id = self.testmlcube["id"] + url = self.url.format(mlcube_id) + + # Act + response = self.client.get(url) + + # Assert + self.assertEqual(response.status_code, status.HTTP_200_OK) + for k, v in response.data.items(): + if k in self.testmlcube: + self.assertEqual(self.testmlcube[k], v, f"Unexpected value for {k}") + + def test_mlcube_not_found(self): + # Arrange + invalid_id = 9999 + url = self.url.format(invalid_id) + + # Act + response = self.client.get(url) + + # Assert + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + + +@parameterized_class( + [ + {"actor": "mlcube_owner"}, + ] +) +class MlCubePutTest(MlCubeTest): + """Test module for PUT /mlcubes/""" + + def setUp(self): + super(MlCubePutTest, self).setUp() + self.generic_setup() + self.set_credentials(self.actor) + + def test_put_modifies_editable_fields_in_development(self): + # Arrange + testmlcube = self.mock_mlcube(state="DEVELOPMENT") + testmlcube = self.create_mlcube(testmlcube).data + + newtestmlcube = { + "name": "newtestmlcube", + "git_mlcube_url": "newstring", + "mlcube_hash": "newstring", + "git_parameters_url": "newstring", + "parameters_hash": "newstring", + "image_tarball_url": "new", + "image_tarball_hash": "new", + "image_hash": "", + "additional_files_tarball_url": "newstring", + "additional_files_tarball_hash": "newstring", + "state": "OPERATION", + "is_valid": False, + "metadata": {"newkey": "newvalue"}, + "user_metadata": {"newkey2": "newvalue2"}, + } + url = self.url.format(testmlcube["id"]) + + # Act + response = self.client.put(url, newtestmlcube, format="json") + + # Assert + self.assertEqual(response.status_code, status.HTTP_200_OK) + response = self.client.get(url) + self.assertEqual(response.status_code, status.HTTP_200_OK) + + for k, v in response.data.items(): + if k in newtestmlcube: + self.assertEqual(newtestmlcube[k], v, f"{k} was not modified") + + def test_put_modifies_editable_fields_in_operation(self): + # Arrange + testmlcube = self.mock_mlcube(state="OPERATION") + testmlcube = self.create_mlcube(testmlcube).data + + newtestmlcube = { + "additional_files_tarball_url": "newurl", + "git_mlcube_url": "newurl", + "git_parameters_url": "newurl", + "is_valid": False, + "user_metadata": {"newkey": "newval"}, + } + url = self.url.format(testmlcube["id"]) + + # Act + response = self.client.put(url, newtestmlcube, format="json") + + # Assert + self.assertEqual(response.status_code, status.HTTP_200_OK) + response = self.client.get(url) + self.assertEqual(response.status_code, status.HTTP_200_OK) + + for k, v in response.data.items(): + if k in newtestmlcube: + self.assertEqual(newtestmlcube[k], v, f"{k} was not modified") + + def test_put_modifies_image_tarball_url_in_operation(self): + # Arrange + testmlcube = self.mock_mlcube( + state="OPERATION", + image_hash="", + image_tarball_url="url", + image_tarball_hash="hash", + ) + testmlcube = self.create_mlcube(testmlcube).data + + newtestmlcube = { + "image_tarball_url": "newurl", + } + url = self.url.format(testmlcube["id"]) + + # Act + response = self.client.put(url, newtestmlcube, format="json") + + # Assert + self.assertEqual(response.status_code, status.HTTP_200_OK) + response = self.client.get(url) + self.assertEqual(response.status_code, status.HTTP_200_OK) + + for k, v in response.data.items(): + if k in newtestmlcube: + self.assertEqual(newtestmlcube[k], v, f"{k} was not modified") + + def test_put_does_not_modify_non_editable_fields_in_operation(self): + # Arrange + testmlcube = self.mock_mlcube(state="OPERATION") + testmlcube = self.create_mlcube(testmlcube).data + + newtestmlcube = { + "name": "newtestmlcube", + "mlcube_hash": "newstring", + "parameters_hash": "newstring", + "image_hash": "newhash", + "additional_files_tarball_hash": "newstring", + "state": "DEVELOPMENT", + "metadata": {"newkey": "newvalue"}, + } + + url = self.url.format(testmlcube["id"]) + + for key in newtestmlcube: + # Act + response = self.client.put(url, {key: newtestmlcube[key]}, format="json") + # Assert + self.assertEqual( + response.status_code, + status.HTTP_400_BAD_REQUEST, + f"{key} was modified", + ) + + def test_put_does_not_modify_non_editable_fields_in_operation_special_case(self): + """This test is the same as the previous one, except that it tries to modify + image_tarball_hash, which should be accompanied by setting the image_hash + to blank and adding image_tarball_url to work successfully in development. + """ + # Arrange + testmlcube = self.mock_mlcube(state="OPERATION") + testmlcube = self.create_mlcube(testmlcube).data + + newtestmlcube = { + "image_tarball_url": "newurl", + "image_tarball_hash": "new", + "image_hash": "", + } + + url = self.url.format(testmlcube["id"]) + + # Act + response = self.client.put(url, newtestmlcube, format="json") + + # Assert + self.assertEqual( + response.status_code, + status.HTTP_400_BAD_REQUEST, + "image_tarball_hash was modified", + ) + + @parameterized.expand([("DEVELOPMENT",), ("OPERATION",)]) + def test_put_does_not_modify_readonly_fields_in_both_states(self, state): + # Arrange + testmlcube = self.mock_mlcube(state=state) + testmlcube = self.create_mlcube(testmlcube).data + + newtestmlcube = {"owner": 5, "created_at": "time", "modified_at": "time"} + url = self.url.format(testmlcube["id"]) + + # Act + response = self.client.put(url, newtestmlcube, format="json") + + # Assert + self.assertEqual(response.status_code, status.HTTP_200_OK) + response = self.client.get(url) + self.assertEqual(response.status_code, status.HTTP_200_OK) + for k, v in newtestmlcube.items(): + self.assertNotEqual(v, response.data[k], f"{k} was modified") + + @parameterized.expand( + [ + ("image_hash",), + ("additional_files_tarball_hash",), + ("mlcube_hash",), + ("parameters_hash",), + ] + ) + def test_put_respects_rules_of_duplicate_mlcubes_with_image_hash(self, field): + """Testing the model unique_together constraint""" + # Arrange + testmlcube = self.mock_mlcube() + testmlcube = self.create_mlcube(testmlcube).data + + newtestmlcube = self.mock_mlcube( + name="newname", state="DEVELOPMENT", **{field: "newvalue"} + ) + newtestmlcube = self.create_mlcube(newtestmlcube).data + + put_body = {field: testmlcube[field]} + url = self.url.format(newtestmlcube["id"]) + + # Act + response = self.client.put(url, put_body, format="json") + + # Assert + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + @parameterized.expand( + [ + ("image_tarball_hash",), + ("additional_files_tarball_hash",), + ("mlcube_hash",), + ("parameters_hash",), + ] + ) + def test_put_respects_rules_of_duplicate_mlcubes_with_image_tarball_hash( + self, field + ): + """Testing the model unique_together constraint""" + # Arrange + testmlcube = self.mock_mlcube( + image_hash="", image_tarball_url="url", image_tarball_hash="hash" + ) + testmlcube = self.create_mlcube(testmlcube).data + + newtestmlcube = self.mock_mlcube( + name="newname", + state="DEVELOPMENT", + image_hash="", + image_tarball_url="url", + **{"image_tarball_hash": "hash", field: "newvalue"}, + ) + newtestmlcube = self.create_mlcube(newtestmlcube).data + + put_body = {field: testmlcube[field]} + url = self.url.format(newtestmlcube["id"]) + + # Act + response = self.client.put(url, put_body, format="json") + + # Assert + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + def test_put_respects_unique_name(self): + # Arrange + testmlcube = self.mock_mlcube() + testmlcube = self.create_mlcube(testmlcube).data + + newtestmlcube = self.mock_mlcube( + state="DEVELOPMENT", name="newname", mlcube_hash="newhash" + ) + newtestmlcube = self.create_mlcube(newtestmlcube).data + + put_body = {"name": testmlcube["name"]} + url = self.url.format(newtestmlcube["id"]) + + # Act + response = self.client.put(url, put_body, format="json") + + # Assert + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + @parameterized.expand([("DEVELOPMENT",), ("OPERATION",)]) + def test_put_doesnot_allow_adding_image_tarball_when_image_hash_is_present( + self, state + ): + # Arrange + testmlcube = self.mock_mlcube(state=state) + testmlcube = self.create_mlcube(testmlcube).data + + put_body = {"image_tarball_url": "url", "image_tarball_hash": "hash"} + url = self.url.format(testmlcube["id"]) + + # Act + response = self.client.put(url, put_body, format="json") + + # Assert + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + @parameterized.expand([("DEVELOPMENT",), ("OPERATION",)]) + def test_put_doesnot_allow_adding_image_hash_when_image_tarball_is_present( + self, state + ): + # Arrange + testmlcube = self.mock_mlcube( + state=state, + image_hash="", + image_tarball_url="url", + image_tarball_hash="hash", + ) + testmlcube = self.create_mlcube(testmlcube).data + + put_body = {"image_hash": "hash"} + url = self.url.format(testmlcube["id"]) + + # Act + response = self.client.put(url, put_body, format="json") + + # Assert + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + @parameterized.expand([("DEVELOPMENT",), ("OPERATION",)]) + def test_put_doesnot_allow_adding_parameters_url_without_hash(self, state): + # Arrange + testmlcube = self.mock_mlcube( + state=state, git_parameters_url="", parameters_hash="" + ) + testmlcube = self.create_mlcube(testmlcube).data + + put_body = {"git_parameters_url": "url"} + url = self.url.format(testmlcube["id"]) + + # Act + response = self.client.put(url, put_body, format="json") + + # Assert + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + @parameterized.expand([("DEVELOPMENT",), ("OPERATION",)]) + def test_put_doesnot_allow_adding_additional_files_url_without_hash(self, state): + # Arrange + testmlcube = self.mock_mlcube( + state=state, + additional_files_tarball_url="", + additional_files_tarball_hash="", + ) + testmlcube = self.create_mlcube(testmlcube).data + + put_body = {"additional_files_tarball_url": "url"} + url = self.url.format(testmlcube["id"]) + + # Act + response = self.client.put(url, put_body, format="json") + + # Assert + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + @parameterized.expand([("DEVELOPMENT",), ("OPERATION",)]) + def test_put_doesnot_allow_adding_image_tarball_url_without_hash(self, state): + # Arrange + testmlcube = self.mock_mlcube( + state=state, + image_tarball_url="", + image_tarball_hash="", + ) + testmlcube = self.create_mlcube(testmlcube).data + + put_body = {"image_hash": "", "image_tarball_url": "url"} + url = self.url.format(testmlcube["id"]) + + # Act + response = self.client.put(url, put_body, format="json") + + # Assert + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + @parameterized.expand([("DEVELOPMENT",), ("OPERATION",)]) + def test_put_doesnot_allow_clearing_parameters_hash_without_url(self, state): + # Arrange + testmlcube = self.mock_mlcube(state=state) + testmlcube = self.create_mlcube(testmlcube).data + + put_body = {"parameters_hash": ""} + url = self.url.format(testmlcube["id"]) + + # Act + response = self.client.put(url, put_body, format="json") + + # Assert + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + @parameterized.expand([("DEVELOPMENT",), ("OPERATION",)]) + def test_put_doesnot_allow_clearning_additional_files_hash_without_url(self, state): + # Arrange + testmlcube = self.mock_mlcube(state=state) + testmlcube = self.create_mlcube(testmlcube).data + + put_body = {"additional_files_tarball_hash": ""} + url = self.url.format(testmlcube["id"]) + + # Act + response = self.client.put(url, put_body, format="json") + + # Assert + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + @parameterized.expand([("DEVELOPMENT",), ("OPERATION",)]) + def test_put_doesnot_allow_clearing_image_tarball_hash_without_url(self, state): + # Arrange + testmlcube = self.mock_mlcube( + state=state, + image_tarball_url="url", + image_tarball_hash="hash", + image_hash="", + ) + testmlcube = self.create_mlcube(testmlcube).data + + put_body = {"image_tarball_hash": ""} + url = self.url.format(testmlcube["id"]) + + # Act + response = self.client.put(url, put_body, format="json") + + # Assert + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + +@parameterized_class( + [ + {"actor": "api_admin"}, + ] +) +class MlCubeDeleteTest(MlCubeTest): + def setUp(self): + super(MlCubeDeleteTest, self).setUp() + self.generic_setup() + self.set_credentials(self.mlcube_owner) + testmlcube = self.mock_mlcube() + testmlcube = self.create_mlcube(testmlcube).data + self.testmlcube = testmlcube + + self.set_credentials(self.actor) + + def test_deletion_works_as_expected(self): + # Arrange + url = self.url.format(self.testmlcube["id"]) + + # Act + response = self.client.delete(url) + + # Assert + self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) + response = self.client.get(url) + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + + +class PermissionTest(MlCubeTest): + """Test module for permissions of /mlcubes/{pk} endpoint + Non-permitted actions: + GET: for unauthenticated users + DELETE: for all users except admin + PUT: for all users except mlcube owner and admin + """ + + def setUp(self): + super(PermissionTest, self).setUp() + self.generic_setup() + + self.set_credentials(self.mlcube_owner) + testmlcube = self.mock_mlcube() + testmlcube = self.create_mlcube(testmlcube).data + + self.testmlcube = testmlcube + self.url = self.url.format(self.testmlcube["id"]) + + @parameterized.expand( + [ + (None, status.HTTP_401_UNAUTHORIZED), + ] + ) + def test_get_permissions(self, user, expected_status): + # Arrange + self.set_credentials(user) + + # Act + response = self.client.get(self.url) + + # Assert + self.assertEqual(response.status_code, expected_status) + + @parameterized.expand( + [ + ("other_user", status.HTTP_403_FORBIDDEN), + (None, status.HTTP_401_UNAUTHORIZED), + ] + ) + def test_put_permissions(self, user, expected_status): + # Arrange + + newtestmlcube = { + "name": "newtestmlcube", + "git_mlcube_url": "newstring", + "mlcube_hash": "newstring", + "git_parameters_url": "newstring", + "parameters_hash": "newstring", + "image_tarball_url": "new", + "image_tarball_hash": "new", + "image_hash": "", + "additional_files_tarball_url": "newstring", + "additional_files_tarball_hash": "newstring", + "state": "OPERATION", + "is_valid": False, + "metadata": {"newkey": "newvalue"}, + "user_metadata": {"newkey2": "newvalue2"}, + "owner": 5, + "created_at": "time", + "modified_at": "time", + } + self.set_credentials(user) + + for key in newtestmlcube: + # Act + response = self.client.put( + self.url, {key: newtestmlcube[key]}, format="json" + ) + # Assert + self.assertEqual( + response.status_code, + expected_status, + f"{key} was modified", + ) + + @parameterized.expand( + [ + ("mlcube_owner", status.HTTP_403_FORBIDDEN), + ("other_user", status.HTTP_403_FORBIDDEN), + (None, status.HTTP_401_UNAUTHORIZED), + ] + ) + def test_delete_permissions(self, user, expected_status): + # Arrange + self.set_credentials(user) + + # Act + response = self.client.delete(self.url) + + # Assert + self.assertEqual(response.status_code, expected_status) diff --git a/server/mlcube/tests/test_pk_benchmarks_bid.py b/server/mlcube/tests/test_pk_benchmarks_bid.py new file mode 100644 index 000000000..bbb14f0a4 --- /dev/null +++ b/server/mlcube/tests/test_pk_benchmarks_bid.py @@ -0,0 +1,483 @@ +from rest_framework import status + +from medperf.tests import MedPerfTest + +from parameterized import parameterized, parameterized_class + + +class MlCubeTest(MedPerfTest): + def generic_setup(self): + # setup users + mlcube_owner = "mlcube_owner" + bmk_owner = "bmk_owner" + bmk_prep_mlcube_owner = "bmk_prep_mlcube_owner" + ref_mlcube_owner = "ref_mlcube_owner" + eval_mlcube_owner = "eval_mlcube_owner" + other_user = "other_user" + + self.create_user(mlcube_owner) + self.create_user(bmk_owner) + self.create_user(bmk_prep_mlcube_owner) + self.create_user(ref_mlcube_owner) + self.create_user(eval_mlcube_owner) + self.create_user(other_user) + + # create benchmark and mlcube + self.set_credentials(bmk_owner) + _, _, _, benchmark = self.shortcut_create_benchmark( + bmk_prep_mlcube_owner, + ref_mlcube_owner, + eval_mlcube_owner, + bmk_owner, + ) + self.set_credentials(mlcube_owner) + mlcube = self.mock_mlcube(state="OPERATION") + mlcube = self.create_mlcube(mlcube).data + + # setup globals + self.mlcube_owner = mlcube_owner + self.bmk_owner = bmk_owner + self.bmk_prep_mlcube_owner = bmk_prep_mlcube_owner + self.ref_mlcube_owner = ref_mlcube_owner + self.eval_mlcube_owner = eval_mlcube_owner + self.other_user = other_user + + self.bmk_id = benchmark["id"] + self.mlcube_id = mlcube["id"] + self.url = self.api_prefix + "/mlcubes/{0}/benchmarks/{1}/" + self.set_credentials(None) + + +@parameterized_class( + [ + {"initiator": "mlcube_owner", "actor": "bmk_owner"}, + {"initiator": "bmk_owner", "actor": "mlcube_owner"}, + ] +) +class BenchmarkMlCubeGetTest(MlCubeTest): + """Test module for GET /mlcubes/""" + + def setUp(self): + super(BenchmarkMlCubeGetTest, self).setUp() + self.generic_setup() + + self.url = self.url.format(self.mlcube_id, self.bmk_id) + self.visible_fields = [ + "approval_status", + "initiated_by", + "approved_at", + "created_at", + "modified_at", + "priority", + ] + + if self.initiator == self.mlcube_owner: + self.approving_user = self.bmk_owner + else: + self.approving_user = self.mlcube_owner + + self.set_credentials(self.actor) + + @parameterized.expand([("PENDING",), ("APPROVED",), ("REJECTED",)]) + def test_generic_get_benchmark_mlcube(self, approval_status): + # Arrange + testassoc = self.mock_mlcube_association( + self.bmk_id, self.mlcube_id, approval_status=approval_status + ) + + testassoc = self.create_mlcube_association( + testassoc, self.initiator, self.approving_user + ).data + if isinstance(testassoc, list): + testassoc = testassoc[0] + + # Act + response = self.client.get(self.url) + + # Assert + self.assertEqual(response.status_code, status.HTTP_200_OK) + for k, v in testassoc.items(): + if k in self.visible_fields: + self.assertIn(k, response.data[0]) + self.assertEqual(response.data[0][k], v, f"Unexpected value for {k}") + else: + self.assertNotIn(k, response.data[0], f"{k} should not be visible") + + def test_benchmark_mlcube_returns_a_list(self): + # Arrange + testassoc = self.mock_mlcube_association( + self.bmk_id, self.mlcube_id, approval_status="REJECTED" + ) + + self.create_mlcube_association(testassoc, self.initiator, self.approving_user) + + testassoc2 = self.mock_mlcube_association( + self.bmk_id, self.mlcube_id, approval_status="PENDING" + ) + + self.create_mlcube_association(testassoc2, self.initiator, self.approving_user) + + # Act + response = self.client.get(self.url) + + # Assert + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(len(response.data), 2) + + +@parameterized_class( + [ + {"initiator": "mlcube_owner", "actor": "bmk_owner"}, + {"initiator": "bmk_owner", "actor": "mlcube_owner"}, + ] +) +class MlCubePutTest(MlCubeTest): + """Test module for PUT /mlcubes/""" + + def setUp(self): + super(MlCubePutTest, self).setUp() + self.generic_setup() + self.url = self.url.format(self.mlcube_id, self.bmk_id) + if self.initiator == self.mlcube_owner: + self.approving_user = self.bmk_owner + else: + self.approving_user = self.mlcube_owner + self.set_credentials(self.actor) + + @parameterized.expand([("PENDING",), ("APPROVED",), ("REJECTED",)]) + def test_put_does_not_modify_readonly_fields(self, approval_status): + # Arrange + testassoc = self.mock_mlcube_association( + self.bmk_id, self.mlcube_id, approval_status=approval_status + ) + + self.create_mlcube_association(testassoc, self.initiator, self.approving_user) + + put_body = { + "initiated_by": 55, + "approved_at": "time", + "created_at": "time", + "modified_at": "time", + } + # Act + response = self.client.put(self.url, put_body, format="json") + + # Assert + self.assertEqual(response.status_code, status.HTTP_200_OK) + response = self.client.get(self.url) + self.assertEqual(response.status_code, status.HTTP_200_OK) + for k, v in put_body.items(): + self.assertNotEqual(v, response.data[0][k], f"{k} was modified") + + @parameterized.expand([("APPROVED",), ("REJECTED",)]) + def test_modifying_approval_status_to_pending_is_not_allowed( + self, original_approval_status + ): + # Arrange + testassoc = self.mock_mlcube_association( + self.bmk_id, self.mlcube_id, approval_status=original_approval_status + ) + + testassoc = self.create_mlcube_association( + testassoc, self.initiator, self.approving_user + ).data + if isinstance(testassoc, list): + testassoc = testassoc[0] + + put_body = { + "approval_status": "PENDING", + } + # Act + response = self.client.put(self.url, put_body, format="json") + + # Assert + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + @parameterized.expand( + [ + ("PENDING", status.HTTP_200_OK), + ("APPROVED", status.HTTP_400_BAD_REQUEST), + ("REJECTED", status.HTTP_400_BAD_REQUEST), + ] + ) + def test_rejecting_an_association_is_allowed_only_if_it_was_pending( + self, original_approval_status, exp_status + ): + # Arrange + testassoc = self.mock_mlcube_association( + self.bmk_id, self.mlcube_id, approval_status=original_approval_status + ) + + testassoc = self.create_mlcube_association( + testassoc, self.initiator, self.approving_user + ).data + if isinstance(testassoc, list): + testassoc = testassoc[0] + + put_body = { + "approval_status": "REJECTED", + } + # Act + response = self.client.put(self.url, put_body, format="json") + + # Assert + self.assertEqual(response.status_code, exp_status) + + @parameterized.expand([("APPROVED",), ("REJECTED",)]) + def test_approving_an_association_is_disallowed_if_it_was_not_pending( + self, original_approval_status + ): + # Arrange + testassoc = self.mock_mlcube_association( + self.bmk_id, self.mlcube_id, approval_status=original_approval_status + ) + + testassoc = self.create_mlcube_association( + testassoc, self.initiator, self.approving_user + ).data + if isinstance(testassoc, list): + testassoc = testassoc[0] + + put_body = { + "approval_status": "APPROVED", + } + # Act + response = self.client.put(self.url, put_body, format="json") + + # Assert + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + def test_approving_an_association_is_allowed_if_user_is_different(self): + """This is also a permission test""" + + # Arrange + testassoc = self.mock_mlcube_association( + self.bmk_id, self.mlcube_id, approval_status="PENDING" + ) + + self.create_mlcube_association(testassoc, self.initiator, None) + + put_body = { + "approval_status": "APPROVED", + } + # Act + response = self.client.put(self.url, put_body, format="json") + + # Assert + if self.initiator == self.actor: + exp_status = status.HTTP_400_BAD_REQUEST + else: + exp_status = status.HTTP_200_OK + + self.assertEqual(response.status_code, exp_status) + + def test_put_works_on_latest_association(self): + # Arrange + testassoc = self.mock_mlcube_association( + self.bmk_id, self.mlcube_id, approval_status="REJECTED" + ) + + self.create_mlcube_association(testassoc, self.initiator, self.approving_user) + + testassoc2 = self.mock_mlcube_association( + self.bmk_id, self.mlcube_id, approval_status="PENDING" + ) + + self.create_mlcube_association(testassoc2, self.initiator, self.approving_user) + + put_body = {"approval_status": "REJECTED"} + # this will fail if latest assoc is not pending. + # so, success of this test implies the PUT acts on testassoc2 (latest) + + # Act + response = self.client.put(self.url, put_body, format="json") + + # Assert + self.assertEqual(response.status_code, status.HTTP_200_OK) + + +@parameterized_class( + [ + {"actor": "api_admin"}, + ] +) +class MlCubeDeleteTest(MlCubeTest): + def setUp(self): + super(MlCubeDeleteTest, self).setUp() + self.generic_setup() + self.url = self.url.format(self.mlcube_id, self.bmk_id) + self.set_credentials(self.actor) + + def test_deletion_works_as_expected_for_single_assoc(self): + # Arrange + testassoc = self.mock_mlcube_association(self.bmk_id, self.mlcube_id) + self.create_mlcube_association(testassoc, self.mlcube_owner, self.bmk_owner) + + # Act + response = self.client.delete(self.url) + + # Assert + self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) + response = self.client.get(self.url) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(len(response.data), 0) + + def test_deletion_works_as_expected_for_multiple_assoc(self): + # Arrange + testassoc = self.mock_mlcube_association( + self.bmk_id, self.mlcube_id, approval_status="REJECTED" + ) + self.create_mlcube_association(testassoc, self.mlcube_owner, self.bmk_owner) + + testassoc2 = self.mock_mlcube_association(self.bmk_id, self.mlcube_id) + self.create_mlcube_association(testassoc2, self.mlcube_owner, self.bmk_owner) + + # Act + response = self.client.delete(self.url) + + # Assert + self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) + response = self.client.get(self.url) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(len(response.data), 0) + + +@parameterized_class( + [ + {"initiator": "mlcube_owner", "association_status": "PENDING"}, + {"initiator": "bmk_owner", "association_status": "PENDING"}, + {"initiator": "mlcube_owner", "association_status": "REJECTED"}, + {"initiator": "bmk_owner", "association_status": "REJECTED"}, + {"initiator": "mlcube_owner", "association_status": "APPROVED"}, + {"initiator": "bmk_owner", "association_status": "APPROVED"}, + ] +) +class PermissionTest(MlCubeTest): + """Test module for permissions of /mlcubes/{pk} endpoint + Non-permitted actions: + GET: for all users except mlcube owner, bmk_owner and admin + DELETE: for all users except admin + PUT: for all users except mlcube owner, bmk_owner and admin + if approval_status == APPROVED, initiated_user is not allowed + if priority exists in PUT body, mlcube_owner is not allowed + """ + + def setUp(self): + super(PermissionTest, self).setUp() + self.generic_setup() + + self.url = self.url.format(self.mlcube_id, self.bmk_id) + + if self.initiator == self.mlcube_owner: + self.approving_user = self.bmk_owner + else: + self.approving_user = self.mlcube_owner + + testassoc = self.mock_mlcube_association( + self.bmk_id, self.mlcube_id, approval_status=self.association_status + ) + + self.create_mlcube_association(testassoc, self.initiator, self.approving_user) + + # TODO: determine for all tests what should be 404 instead of 400 or 403 + @parameterized.expand( + [ + ("bmk_prep_mlcube_owner", status.HTTP_403_FORBIDDEN), + ("ref_mlcube_owner", status.HTTP_403_FORBIDDEN), + ("eval_mlcube_owner", status.HTTP_403_FORBIDDEN), + ("other_user", status.HTTP_403_FORBIDDEN), + (None, status.HTTP_401_UNAUTHORIZED), + ] + ) + def test_get_permissions(self, user, expected_status): + # Arrange + self.set_credentials(user) + + # Act + response = self.client.get(self.url) + + # Assert + self.assertEqual(response.status_code, expected_status) + + @parameterized.expand( + [ + ("bmk_prep_mlcube_owner", status.HTTP_403_FORBIDDEN), + ("ref_mlcube_owner", status.HTTP_403_FORBIDDEN), + ("eval_mlcube_owner", status.HTTP_403_FORBIDDEN), + ("other_user", status.HTTP_403_FORBIDDEN), + (None, status.HTTP_401_UNAUTHORIZED), + ] + ) + def test_put_permissions(self, user, expected_status): + # Arrange + + put_body = { + "approval_status": "APPROVED", + "initiated_by": 55, + "approved_at": "time", + "created_at": "time", + "modified_at": "time", + } + self.set_credentials(user) + + for key in put_body: + # Act + response = self.client.put(self.url, {key: put_body[key]}, format="json") + # Assert + self.assertEqual( + response.status_code, + expected_status, + f"{key} was modified", + ) + + def test_put_permissions_for_approval_status(self): + # Arrange + if self.association_status != "PENDING": + # skip cases that will fail and already tested before + # they are for reasons are other permissions + return + + put_body = { + "approval_status": "APPROVED", + } + self.set_credentials(self.initiator) + + # Act + response = self.client.put(self.url, put_body, format="json") + # Assert + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + # TODO: move this check to permission checks in serializers to return 403 + + def test_put_permissions_for_priority(self): + # Arrange + put_body = { + "priority": 555, + } + self.set_credentials(self.mlcube_owner) + + # Act + response = self.client.put(self.url, put_body, format="json") + + # Assert + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + + @parameterized.expand( + [ + ("mlcube_owner", status.HTTP_403_FORBIDDEN), + ("bmk_owner", status.HTTP_403_FORBIDDEN), + ("bmk_prep_mlcube_owner", status.HTTP_403_FORBIDDEN), + ("ref_mlcube_owner", status.HTTP_403_FORBIDDEN), + ("eval_mlcube_owner", status.HTTP_403_FORBIDDEN), + ("other_user", status.HTTP_403_FORBIDDEN), + (None, status.HTTP_401_UNAUTHORIZED), + ] + ) + def test_delete_permissions(self, user, expected_status): + # Arrange + self.set_credentials(user) + + # Act + response = self.client.delete(self.url) + + # Assert + self.assertEqual(response.status_code, expected_status) diff --git a/server/mlcube/urls.py b/server/mlcube/urls.py index 223c7df7a..4ad14472c 100644 --- a/server/mlcube/urls.py +++ b/server/mlcube/urls.py @@ -8,7 +8,8 @@ path("", views.MlCubeList.as_view()), path("/", views.MlCubeDetail.as_view()), path("benchmarks/", bviews.BenchmarkModelList.as_view()), - path("/benchmarks/", bviews.BenchmarkModelApproval.as_view()), path("/benchmarks//", bviews.ModelApproval.as_view()), path("/reports/", views.MlCubeReportList.as_view()), + # path("/benchmarks/", bviews.ModelBenchmarksList.as_view()), + # NOTE: when activating this endpoint later, check permissions and write tests ] diff --git a/server/result/migrations/0002_auto_20231124_0208.py b/server/result/migrations/0002_auto_20231124_0208.py new file mode 100644 index 000000000..c43f319ff --- /dev/null +++ b/server/result/migrations/0002_auto_20231124_0208.py @@ -0,0 +1,23 @@ +# Generated by Django 3.2.20 on 2023-11-24 02:08 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('result', '0001_initial'), + ] + + operations = [ + migrations.AddField( + model_name='modelresult', + name='is_valid', + field=models.BooleanField(default=True), + ), + migrations.AddField( + model_name='modelresult', + name='user_metadata', + field=models.JSONField(default=dict), + ), + ] diff --git a/server/result/models.py b/server/result/models.py index 4d488e740..3bed25b31 100644 --- a/server/result/models.py +++ b/server/result/models.py @@ -18,9 +18,11 @@ class ModelResult(models.Model): dataset = models.ForeignKey("dataset.Dataset", on_delete=models.PROTECT) results = models.JSONField() metadata = models.JSONField(default=dict) + user_metadata = models.JSONField(default=dict) approval_status = models.CharField( choices=MODEL_RESULT_STATUS, max_length=100, default="PENDING" ) + is_valid = models.BooleanField(default=True) approved_at = models.DateTimeField(null=True, blank=True) created_at = models.DateTimeField(auto_now_add=True) modified_at = models.DateTimeField(auto_now=True) diff --git a/server/result/serializers.py b/server/result/serializers.py index fb193353c..55129c748 100644 --- a/server/result/serializers.py +++ b/server/result/serializers.py @@ -51,3 +51,17 @@ def validate(self, data): "Dataset-Benchmark association must be approved" ) return data + + +class ModelResultDetailSerializer(serializers.ModelSerializer): + class Meta: + model = ModelResult + fields = "__all__" + read_only_fields = [ + "owner", + "approved_at", + "benchmark", + "model", + "dataset", + "results", + ] diff --git a/server/result/tests/__init__.py b/server/result/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/server/result/tests/test_.py b/server/result/tests/test_.py new file mode 100644 index 000000000..79a418bbe --- /dev/null +++ b/server/result/tests/test_.py @@ -0,0 +1,521 @@ +from rest_framework import status + +from medperf.tests import MedPerfTest + +from parameterized import parameterized, parameterized_class + + +class ResultsTest(MedPerfTest): + def generic_setup(self): + # setup users + data_owner = "data_owner" + mlcube_owner = "mlcube_owner" + bmk_owner = "bmk_owner" + bmk_prep_mlcube_owner = "bmk_prep_mlcube_owner" + ref_mlcube_owner = "ref_mlcube_owner" + eval_mlcube_owner = "eval_mlcube_owner" + other_user = "other_user" + + self.create_user(data_owner) + self.create_user(mlcube_owner) + self.create_user(bmk_owner) + self.create_user(bmk_prep_mlcube_owner) + self.create_user(ref_mlcube_owner) + self.create_user(eval_mlcube_owner) + self.create_user(other_user) + + # setup globals + self.data_owner = data_owner + self.mlcube_owner = mlcube_owner + self.bmk_owner = bmk_owner + self.bmk_prep_mlcube_owner = bmk_prep_mlcube_owner + self.ref_mlcube_owner = ref_mlcube_owner + self.eval_mlcube_owner = eval_mlcube_owner + self.other_user = other_user + + self.url = self.api_prefix + "/results/" + self.set_credentials(None) + + +@parameterized_class( + [ + {"actor": "data_owner"}, + ], +) +class GenericResultsPostTest(ResultsTest): + """Test module for POST /results""" + + def setUp(self): + super(GenericResultsPostTest, self).setUp() + self.generic_setup() + + # create benchmark + prep, _, _, benchmark = self.shortcut_create_benchmark( + self.bmk_prep_mlcube_owner, + self.ref_mlcube_owner, + self.eval_mlcube_owner, + self.bmk_owner, + ) + + # create dataset + self.set_credentials(self.data_owner) + dataset = self.mock_dataset( + data_preparation_mlcube=prep["id"], state="OPERATION" + ) + dataset = self.create_dataset(dataset).data + + # create dataset assoc + assoc = self.mock_dataset_association( + benchmark["id"], dataset["id"], approval_status="APPROVED" + ) + self.create_dataset_association(assoc, self.data_owner, self.bmk_owner) + + # create model mlcube + self.set_credentials(self.mlcube_owner) + mlcube = self.mock_mlcube(state="OPERATION") + mlcube = self.create_mlcube(mlcube).data + + # create mlcube assoc + assoc = self.mock_mlcube_association( + benchmark["id"], mlcube["id"], approval_status="APPROVED" + ) + self.create_mlcube_association(assoc, self.mlcube_owner, self.bmk_owner) + + self.bmk_id = benchmark["id"] + self.dataset_id = dataset["id"] + self.mlcube_id = mlcube["id"] + + self.set_credentials(self.actor) + + def test_created_result_fields_are_saved_as_expected(self): + """Testing the valid scenario""" + # Arrange + testresult = self.mock_result( + self.bmk_id, self.mlcube_id, self.dataset_id, results={"r": 1} + ) + get_result_url = self.api_prefix + "/results/{0}/" + + # Act + response = self.client.post(self.url, testresult, format="json") + + # Assert + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + uid = response.data["id"] + response = self.client.get(get_result_url.format(uid)) + + self.assertEqual( + response.status_code, + status.HTTP_200_OK, + "result retrieval failed", + ) + + for k, v in response.data.items(): + if k in testresult: + self.assertEqual(testresult[k], v, f"unexpected value for {k}") + + def test_default_values_are_as_expected(self): + """Testing the model fields rules""" + + # Arrange + default_values = { + "approved_at": None, + "approval_status": "PENDING", + "name": "", + "metadata": {}, + "user_metadata": {}, + "is_valid": True, + } + testresult = self.mock_result( + self.bmk_id, self.mlcube_id, self.dataset_id, results={"r": 1} + ) + + for key in default_values: + if key in testresult: + del testresult[key] + + # Act + response = self.client.post(self.url, testresult, format="json") + + # Assert + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + for key, val in default_values.items(): + self.assertEqual( + val, response.data[key], f"unexpected default value for {key}" + ) + + def test_readonly_fields(self): + """Testing the serializer rules""" + + # Arrange + readonly = { + "owner": 55, + "created_at": "time", + "modified_at": "time2", + "approved_at": "time3", + "approval_status": "APPROVED", + } + testresult = self.mock_result( + self.bmk_id, self.mlcube_id, self.dataset_id, results={"r": 1} + ) + + testresult.update(readonly) + + # Act + response = self.client.post(self.url, testresult, format="json") + + # Assert + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + for key, val in readonly.items(): + self.assertNotEqual( + val, response.data[key], f"readonly field {key} was modified" + ) + + +@parameterized_class( + [ + {"actor": "data_owner"}, + ], +) +class SerializersResultsPostTest(ResultsTest): + """Test module for serializers rules of POST /results""" + + def setUp(self): + super(SerializersResultsPostTest, self).setUp() + self.generic_setup() + + # create benchmark + prep, ref_mlcube, _, benchmark = self.shortcut_create_benchmark( + self.bmk_prep_mlcube_owner, + self.ref_mlcube_owner, + self.eval_mlcube_owner, + self.bmk_owner, + ) + + # create dataset + self.set_credentials(self.data_owner) + dataset = self.mock_dataset( + data_preparation_mlcube=prep["id"], state="OPERATION" + ) + dataset = self.create_dataset(dataset).data + + # create model mlcube + self.set_credentials(self.mlcube_owner) + mlcube = self.mock_mlcube(state="OPERATION") + mlcube = self.create_mlcube(mlcube).data + + self.bmk_id = benchmark["id"] + self.dataset_id = dataset["id"] + self.mlcube_id = mlcube["id"] + self.ref_mlcube_id = ref_mlcube["id"] + + self.set_credentials(self.actor) + + def test_result_creation_with_unassociated_dataset(self): + # Arrange + assoc = self.mock_mlcube_association( + self.bmk_id, self.mlcube_id, approval_status="APPROVED" + ) + self.create_mlcube_association(assoc, self.mlcube_owner, self.bmk_owner) + + testresult = self.mock_result( + self.bmk_id, self.mlcube_id, self.dataset_id, results={"r": 1} + ) + # Act + response = self.client.post(self.url, testresult, format="json") + + # Assert + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + def test_result_creation_with_unassociated_mlcube(self): + # Arrange + assoc = self.mock_dataset_association( + self.bmk_id, self.dataset_id, approval_status="APPROVED" + ) + self.create_dataset_association(assoc, self.data_owner, self.bmk_owner) + + testresult = self.mock_result( + self.bmk_id, self.mlcube_id, self.dataset_id, results={"r": 1} + ) + # Act + response = self.client.post(self.url, testresult, format="json") + + # Assert + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + @parameterized.expand( + [ + ("PENDING", "PENDING"), + ("APPROVED", "PENDING"), + ("REJECTED", "PENDING"), + ("PENDING", "REJECTED"), + ("APPROVED", "REJECTED"), + ("REJECTED", "REJECTED"), + ("PENDING", "APPROVED"), + ("APPROVED", "APPROVED"), + ("REJECTED", "APPROVED"), + ] + ) + def test_result_creation_with_created_associations( + self, dataset_status, mlcube_status + ): + # Arrange + assoc = self.mock_mlcube_association( + self.bmk_id, self.mlcube_id, approval_status=mlcube_status + ) + self.create_mlcube_association(assoc, self.mlcube_owner, self.bmk_owner) + + assoc = self.mock_dataset_association( + self.bmk_id, self.dataset_id, approval_status=dataset_status + ) + self.create_dataset_association(assoc, self.data_owner, self.bmk_owner) + + testresult = self.mock_result( + self.bmk_id, self.mlcube_id, self.dataset_id, results={"r": 1} + ) + # Act + response = self.client.post(self.url, testresult, format="json") + + # Assert + if dataset_status == mlcube_status == "APPROVED": + exp_status = status.HTTP_201_CREATED + else: + exp_status = status.HTTP_400_BAD_REQUEST + + self.assertEqual(response.status_code, exp_status) + + def test_result_creation_looks_for_latest_model_assocs(self): + # Arrange + assoc = self.mock_mlcube_association( + self.bmk_id, self.mlcube_id, approval_status="APPROVED" + ) + self.create_mlcube_association(assoc, self.mlcube_owner, self.bmk_owner) + + assoc = self.mock_dataset_association( + self.bmk_id, self.dataset_id, approval_status="APPROVED" + ) + self.create_dataset_association(assoc, self.data_owner, self.bmk_owner) + + assoc = self.mock_mlcube_association( + self.bmk_id, self.mlcube_id, approval_status="REJECTED" + ) + self.create_mlcube_association( + assoc, self.mlcube_owner, self.bmk_owner, set_status_directly=True + ) + + testresult = self.mock_result( + self.bmk_id, self.mlcube_id, self.dataset_id, results={"r": 1} + ) + # Act + response = self.client.post(self.url, testresult, format="json") + + # Assert + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + def test_result_creation_looks_for_latest_dataset_assocs(self): + # Arrange + assoc = self.mock_mlcube_association( + self.bmk_id, self.mlcube_id, approval_status="APPROVED" + ) + self.create_mlcube_association(assoc, self.mlcube_owner, self.bmk_owner) + + assoc = self.mock_dataset_association( + self.bmk_id, self.dataset_id, approval_status="APPROVED" + ) + self.create_dataset_association(assoc, self.data_owner, self.bmk_owner) + + assoc = self.mock_dataset_association( + self.bmk_id, self.dataset_id, approval_status="REJECTED" + ) + self.create_dataset_association( + assoc, self.data_owner, self.bmk_owner, set_status_directly=True + ) + + testresult = self.mock_result( + self.bmk_id, self.mlcube_id, self.dataset_id, results={"r": 1} + ) + # Act + response = self.client.post(self.url, testresult, format="json") + + # Assert + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + def test_result_creation_with_ref_model(self): + # Arrange + assoc = self.mock_dataset_association( + self.bmk_id, self.dataset_id, approval_status="APPROVED" + ) + self.create_dataset_association(assoc, self.data_owner, self.bmk_owner) + + testresult = self.mock_result( + self.bmk_id, self.ref_mlcube_id, self.dataset_id, results={"r": 1} + ) + # Act + response = self.client.post(self.url, testresult, format="json") + + # Assert + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + + +@parameterized_class( + [ + {"actor": "api_admin"}, + ], +) +class GenericResultsGetListTest(ResultsTest): + """Test module for GET /results""" + + def setUp(self): + super(GenericResultsGetListTest, self).setUp() + self.generic_setup() + + # create benchmark + prep, _, _, benchmark = self.shortcut_create_benchmark( + self.bmk_prep_mlcube_owner, + self.ref_mlcube_owner, + self.eval_mlcube_owner, + self.bmk_owner, + ) + + # create dataset + self.set_credentials(self.data_owner) + dataset = self.mock_dataset( + data_preparation_mlcube=prep["id"], state="OPERATION" + ) + dataset = self.create_dataset(dataset).data + + # create dataset assoc + assoc = self.mock_dataset_association( + benchmark["id"], dataset["id"], approval_status="APPROVED" + ) + self.create_dataset_association(assoc, self.data_owner, self.bmk_owner) + + # create model mlcube + self.set_credentials(self.mlcube_owner) + mlcube = self.mock_mlcube(state="OPERATION") + mlcube = self.create_mlcube(mlcube).data + + # create mlcube assoc + assoc = self.mock_mlcube_association( + benchmark["id"], mlcube["id"], approval_status="APPROVED" + ) + self.create_mlcube_association(assoc, self.mlcube_owner, self.bmk_owner) + + # create result + self.set_credentials(self.data_owner) + result = self.mock_result( + benchmark["id"], mlcube["id"], dataset["id"], results={"r": 1} + ) + result = self.create_result(result).data + self.testresult = result + + self.set_credentials(self.actor) + + def test_generic_get_result_list(self): + # Arrange + result_id = self.testresult["id"] + + # Act + response = self.client.get(self.url) + + # Assert + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(len(response.data["results"]), 1) + self.assertEqual(response.data["results"][0]["id"], result_id) + + +class PermissionTest(ResultsTest): + """Test module for permissions of /results endpoint + Non-permitted actions: + POST: for all users except data_owner, and admins + GET: for all users except admin + """ + + def setUp(self): + super(PermissionTest, self).setUp() + self.generic_setup() + + # create benchmark + prep, _, _, benchmark = self.shortcut_create_benchmark( + self.bmk_prep_mlcube_owner, + self.ref_mlcube_owner, + self.eval_mlcube_owner, + self.bmk_owner, + ) + + # create dataset + self.set_credentials(self.data_owner) + dataset = self.mock_dataset( + data_preparation_mlcube=prep["id"], state="OPERATION" + ) + dataset = self.create_dataset(dataset).data + + # create dataset assoc + assoc = self.mock_dataset_association( + benchmark["id"], dataset["id"], approval_status="APPROVED" + ) + self.create_dataset_association(assoc, self.data_owner, self.bmk_owner) + + # create model mlcube + self.set_credentials(self.mlcube_owner) + mlcube = self.mock_mlcube(state="OPERATION") + mlcube = self.create_mlcube(mlcube).data + + # create mlcube assoc + assoc = self.mock_mlcube_association( + benchmark["id"], mlcube["id"], approval_status="APPROVED" + ) + self.create_mlcube_association(assoc, self.mlcube_owner, self.bmk_owner) + + self.set_credentials(self.data_owner) + result = self.mock_result( + benchmark["id"], mlcube["id"], dataset["id"], results={"r": 1} + ) + + self.testresult = result + + self.set_credentials(None) + + @parameterized.expand( + [ + ("bmk_owner", status.HTTP_403_FORBIDDEN), + ("bmk_prep_mlcube_owner", status.HTTP_403_FORBIDDEN), + ("ref_mlcube_owner", status.HTTP_403_FORBIDDEN), + ("eval_mlcube_owner", status.HTTP_403_FORBIDDEN), + ("mlcube_owner", status.HTTP_403_FORBIDDEN), + ("other_user", status.HTTP_403_FORBIDDEN), + (None, status.HTTP_401_UNAUTHORIZED), + ] + ) + def test_post_permissions(self, user, exp_status): + # Arrange + self.set_credentials(user) + + # Act + response = self.client.post(self.url, self.testresult, format="json") + + # Assert + self.assertEqual(response.status_code, exp_status) + + @parameterized.expand( + [ + ("bmk_owner", status.HTTP_403_FORBIDDEN), + ("bmk_prep_mlcube_owner", status.HTTP_403_FORBIDDEN), + ("ref_mlcube_owner", status.HTTP_403_FORBIDDEN), + ("eval_mlcube_owner", status.HTTP_403_FORBIDDEN), + ("mlcube_owner", status.HTTP_403_FORBIDDEN), + ("data_owner", status.HTTP_403_FORBIDDEN), + ("other_user", status.HTTP_403_FORBIDDEN), + (None, status.HTTP_401_UNAUTHORIZED), + ] + ) + def test_get_permissions(self, user, exp_status): + # Arrange + self.set_credentials(self.data_owner) + self.create_result(self.testresult) + self.set_credentials(user) + + # Act + response = self.client.get(self.url) + + # Assert + self.assertEqual(response.status_code, exp_status) diff --git a/server/result/tests/test_pk.py b/server/result/tests/test_pk.py new file mode 100644 index 000000000..4fcc3fe2e --- /dev/null +++ b/server/result/tests/test_pk.py @@ -0,0 +1,327 @@ +from rest_framework import status + +from medperf.tests import MedPerfTest + +from parameterized import parameterized, parameterized_class + + +class ResultsTest(MedPerfTest): + def generic_setup(self): + # setup users + data_owner = "data_owner" + mlcube_owner = "mlcube_owner" + bmk_owner = "bmk_owner" + bmk_prep_mlcube_owner = "bmk_prep_mlcube_owner" + ref_mlcube_owner = "ref_mlcube_owner" + eval_mlcube_owner = "eval_mlcube_owner" + other_user = "other_user" + + self.create_user(data_owner) + self.create_user(mlcube_owner) + self.create_user(bmk_owner) + self.create_user(bmk_prep_mlcube_owner) + self.create_user(ref_mlcube_owner) + self.create_user(eval_mlcube_owner) + self.create_user(other_user) + + # create benchmark + prep, _, _, benchmark = self.shortcut_create_benchmark( + bmk_prep_mlcube_owner, + ref_mlcube_owner, + eval_mlcube_owner, + bmk_owner, + ) + + # create dataset + self.set_credentials(data_owner) + dataset = self.mock_dataset( + data_preparation_mlcube=prep["id"], state="OPERATION" + ) + dataset = self.create_dataset(dataset).data + + # create dataset assoc + assoc = self.mock_dataset_association( + benchmark["id"], dataset["id"], approval_status="APPROVED" + ) + self.create_dataset_association(assoc, data_owner, bmk_owner) + + # create model mlcube + self.set_credentials(mlcube_owner) + mlcube = self.mock_mlcube(state="OPERATION") + mlcube = self.create_mlcube(mlcube).data + + # create mlcube assoc + assoc = self.mock_mlcube_association( + benchmark["id"], mlcube["id"], approval_status="APPROVED" + ) + self.create_mlcube_association(assoc, mlcube_owner, bmk_owner) + + # setup globals + self.data_owner = data_owner + self.mlcube_owner = mlcube_owner + self.bmk_owner = bmk_owner + self.bmk_prep_mlcube_owner = bmk_prep_mlcube_owner + self.ref_mlcube_owner = ref_mlcube_owner + self.eval_mlcube_owner = eval_mlcube_owner + self.other_user = other_user + + self.bmk_id = benchmark["id"] + self.dataset_id = dataset["id"] + self.mlcube_id = mlcube["id"] + + self.url = self.api_prefix + "/results/{0}/" + self.set_credentials(None) + + +@parameterized_class( + [ + {"actor": "data_owner"}, + {"actor": "bmk_owner"}, + ] +) +class ResultGetTest(ResultsTest): + """Test module for GET /results/""" + + def setUp(self): + super(ResultGetTest, self).setUp() + self.generic_setup() + self.set_credentials(self.actor) + + def test_generic_get_result(self): + # Arrange + result = self.mock_result( + self.bmk_id, self.mlcube_id, self.dataset_id, results={"r": 1} + ) + self.set_credentials(self.data_owner) + result = self.create_result(result).data + self.set_credentials(self.actor) + + url = self.url.format(result["id"]) + + # Act + response = self.client.get(url) + + # Assert + self.assertEqual(response.status_code, status.HTTP_200_OK) + for k, v in response.data.items(): + if k in result: + self.assertEqual(result[k], v, f"Unexpected value for {k}") + + def test_result_not_found(self): + # Arrange + invalid_id = 9999 + url = self.url.format(invalid_id) + + # Act + response = self.client.get(url) + + # Assert + # TODO: fixme after refactoring permissions. should be 404 + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + + +@parameterized_class( + [ + {"actor": "api_admin"}, + ] +) +class ResultPutTest(ResultsTest): + """Test module for PUT /results/""" + + def setUp(self): + super(ResultPutTest, self).setUp() + self.generic_setup() + self.set_credentials(self.actor) + + def test_put_does_not_modify_readonly_fields(self): + # Arrange + result = self.mock_result( + self.bmk_id, self.mlcube_id, self.dataset_id, results={"r": 1} + ) + self.set_credentials(self.data_owner) + result = self.create_result(result).data + self.set_credentials(self.actor) + + newtestresult = { + "owner": 10, + "approved_at": "some time", + "created_at": "some time", + "modified_at": "some time", + "benchmark": 44, + "model": 444, + "dataset": 55, + "results": {"new": 111}, + } + url = self.url.format(result["id"]) + + # Act + response = self.client.put(url, newtestresult, format="json") + + # Assert + self.assertEqual(response.status_code, status.HTTP_200_OK) + response = self.client.get(url) + self.assertEqual(response.status_code, status.HTTP_200_OK) + for k, v in newtestresult.items(): + self.assertNotEqual(v, response.data[k], f"{k} was modified") + + +@parameterized_class( + [ + {"actor": "api_admin"}, + ] +) +class ResultDeleteTest(ResultsTest): + """Test module for DELETE /results/""" + + def setUp(self): + super(ResultDeleteTest, self).setUp() + self.generic_setup() + self.set_credentials(self.actor) + + def test_deletion_works_as_expected(self): + # Arrange + result = self.mock_result( + self.bmk_id, self.mlcube_id, self.dataset_id, results={"r": 1} + ) + self.set_credentials(self.data_owner) + result = self.create_result(result).data + self.set_credentials(self.actor) + + url = self.url.format(result["id"]) + + # Act + response = self.client.delete(url) + + # Assert + self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) + response = self.client.get(url) + + # TODO: fixme after refactoring permissions. should just like this: + # self.assertEqual(response.status_code, status.HTTP_404_FORBIDDEN) + if self.actor == self.data_owner: + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + else: + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + + +class PermissionTest(ResultsTest): + """Test module for permissions of /results/{pk} endpoint + Non-permitted actions: + GET: for all users except bmk_owner, data_owner, and admin + DELETE: for all users except admin + PUT: for all users except admin + """ + + def setUp(self): + super(PermissionTest, self).setUp() + self.generic_setup() + result = self.mock_result( + self.bmk_id, self.mlcube_id, self.dataset_id, results={"r": 1} + ) + self.set_credentials(self.data_owner) + result = self.create_result(result).data + self.url = self.url.format(result["id"]) + + self.result = result + self.set_credentials(None) + + @parameterized.expand( + [ + ("mlcube_owner", status.HTTP_403_FORBIDDEN), + ("bmk_prep_mlcube_owner", status.HTTP_403_FORBIDDEN), + ("ref_mlcube_owner", status.HTTP_403_FORBIDDEN), + ("eval_mlcube_owner", status.HTTP_403_FORBIDDEN), + ("other_user", status.HTTP_403_FORBIDDEN), + (None, status.HTTP_401_UNAUTHORIZED), + ] + ) + def test_get_permissions(self, user, expected_status): + # Arrange + self.set_credentials(user) + + # Act + response = self.client.get(self.url) + + # Assert + self.assertEqual(response.status_code, expected_status) + + @parameterized.expand( + [ + ("bmk_owner", status.HTTP_403_FORBIDDEN), + ("mlcube_owner", status.HTTP_403_FORBIDDEN), + ("data_owner", status.HTTP_403_FORBIDDEN), + ("bmk_prep_mlcube_owner", status.HTTP_403_FORBIDDEN), + ("ref_mlcube_owner", status.HTTP_403_FORBIDDEN), + ("eval_mlcube_owner", status.HTTP_403_FORBIDDEN), + ("other_user", status.HTTP_403_FORBIDDEN), + (None, status.HTTP_401_UNAUTHORIZED), + ] + ) + def test_put_permissions(self, user, expected_status): + # Arrange + + # create new assets to edit with + prep, refmodel, _, newbenchmark = self.shortcut_create_benchmark( + self.bmk_prep_mlcube_owner, + self.ref_mlcube_owner, + self.eval_mlcube_owner, + self.bmk_owner, + prep_mlcube_kwargs={"name": "newprep", "mlcube_hash": "newprephash"}, + ref_mlcube_kwargs={"name": "newref", "mlcube_hash": "newrefhash"}, + eval_mlcube_kwargs={"name": "neweval", "mlcube_hash": "newevalhash"}, + name="newbmk", + ) + self.set_credentials(self.data_owner) + newdataset = self.mock_dataset(prep["id"], generated_uid="newgen") + newdataset = self.create_dataset(newdataset).data + + newtestresult = { + "name": "new", + "owner": 55, + "benchmark": newbenchmark["id"], + "model": refmodel["id"], + "dataset": newdataset["id"], + "results": {"new": "t"}, + "metadata": {"new": "t"}, + "user_metadata": {"new": "t"}, + "approval_status": "APPROVED", + "is_valid": False, + "approved_at": "time", + "created_at": "time", + "modified_at": "time", + } + + self.set_credentials(user) + + for key in newtestresult: + # Act + response = self.client.put( + self.url, {key: newtestresult[key]}, format="json" + ) + + # Assert + self.assertEqual( + response.status_code, expected_status, f"{key} was modified" + ) + + @parameterized.expand( + [ + ("bmk_owner", status.HTTP_403_FORBIDDEN), + ("mlcube_owner", status.HTTP_403_FORBIDDEN), + ("data_owner", status.HTTP_403_FORBIDDEN), + ("bmk_prep_mlcube_owner", status.HTTP_403_FORBIDDEN), + ("ref_mlcube_owner", status.HTTP_403_FORBIDDEN), + ("eval_mlcube_owner", status.HTTP_403_FORBIDDEN), + ("other_user", status.HTTP_403_FORBIDDEN), + (None, status.HTTP_401_UNAUTHORIZED), + ] + ) + def test_delete_permissions(self, user, expected_status): + # Arrange + self.set_credentials(user) + + # Act + response = self.client.delete(self.url) + + # Assert + self.assertEqual(response.status_code, expected_status) diff --git a/server/result/views.py b/server/result/views.py index b650045d8..a8d3debd3 100644 --- a/server/result/views.py +++ b/server/result/views.py @@ -5,8 +5,8 @@ from drf_spectacular.utils import extend_schema from .models import ModelResult -from .serializers import ModelResultSerializer -from .permissions import IsAdmin, IsBenchmarkOwner, IsDatasetOwner, IsResultOwner +from .serializers import ModelResultSerializer, ModelResultDetailSerializer +from .permissions import IsAdmin, IsBenchmarkOwner, IsDatasetOwner class ModelResultList(GenericAPIView): @@ -42,12 +42,12 @@ def post(self, request, format=None): class ModelResultDetail(GenericAPIView): - serializer_class = ModelResultSerializer + serializer_class = ModelResultDetailSerializer queryset = "" def get_permissions(self): if self.request.method == "PUT" or self.request.method == "DELETE": - self.permission_classes = [IsAdmin | IsResultOwner] + self.permission_classes = [IsAdmin] elif self.request.method == "GET": self.permission_classes = [IsAdmin | IsDatasetOwner | IsBenchmarkOwner] return super(self.__class__, self).get_permissions() @@ -63,7 +63,7 @@ def get(self, request, pk, format=None): Retrieve a result instance. """ modelresult = self.get_object(pk) - serializer = ModelResultSerializer(modelresult) + serializer = ModelResultDetailSerializer(modelresult) return Response(serializer.data) def put(self, request, pk, format=None): @@ -71,7 +71,7 @@ def put(self, request, pk, format=None): Update a result instance. """ modelresult = self.get_object(pk) - serializer = ModelResultSerializer(modelresult, data=request.data) + serializer = ModelResultDetailSerializer(modelresult, data=request.data) if serializer.is_valid(): serializer.save() return Response(serializer.data) diff --git a/server/test-requirements.txt b/server/test-requirements.txt index fad08122a..e4065a762 100644 --- a/server/test-requirements.txt +++ b/server/test-requirements.txt @@ -1,2 +1,4 @@ requests curlify +parameterized==0.9.0 +tblib==3.0.0 # to display tracebacks during parallel testing in Django \ No newline at end of file diff --git a/server/user/tests.py b/server/user/tests.py deleted file mode 100644 index 82031f25a..000000000 --- a/server/user/tests.py +++ /dev/null @@ -1,54 +0,0 @@ -from rest_framework.test import APIClient -from rest_framework import status -from django.conf import settings - -from medperf.tests import MedPerfTest - - -class UserTest(MedPerfTest): - """Test module for users APIs""" - - def setUp(self): - super(UserTest, self).setUp() - self.api_prefix = "/api/" + settings.SERVER_API_VERSION - self.client = APIClient() - self.client.credentials(HTTP_AUTHORIZATION="Bearer " + self.admin_token) - - def test_unauthenticated_user(self): - client = APIClient() - response = client.get(self.api_prefix + "/users/1/") - self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) - response = client.delete(self.api_prefix + "/users/1/") - self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) - response = client.put(self.api_prefix + "/users/1/") - self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) - response = client.get(self.api_prefix + "/users/") - self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) - - def test_crud_user(self): - testusername = "testdataowner" - _, userinfo = self.create_user(testusername) - - uid = userinfo["id"] - response = self.client.get(self.api_prefix + "/users/{0}/".format(uid)) - self.assertEqual(response.status_code, status.HTTP_200_OK) - - self.assertEqual(response.data["username"], testusername) - - response = self.client.get(self.api_prefix + "/users/") - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(len(response.data["results"]), 3) - - response = self.client.delete(self.api_prefix + "/users/{0}/".format(uid)) - self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) - - response = self.client.get(self.api_prefix + "/users/{0}/".format(uid)) - self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) - - def test_invalid_user(self): - invalid_uid = 9999 - response = self.client.get(self.api_prefix + "/users/{0}/".format(invalid_uid)) - self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) - - def test_optional_fields(self): - pass diff --git a/server/user/tests/__init__.py b/server/user/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/server/user/tests/test_.py b/server/user/tests/test_.py new file mode 100644 index 000000000..fc7f01233 --- /dev/null +++ b/server/user/tests/test_.py @@ -0,0 +1,47 @@ +from rest_framework import status + +from medperf.tests import MedPerfTest + +from parameterized import parameterized + + +class UserTest(MedPerfTest): + def generic_setup(self): + # setup users + user1 = "user1" + user2 = "user2" + + self.create_user(user1) + self.create_user(user2) + + self.url = self.api_prefix + "/users/" + self.set_credentials(None) + + +class PermissionTest(UserTest): + """Test module for permissions of /users/ endpoint + Non-permitted actions: + GET: for all users except admin + """ + + def setUp(self): + super(PermissionTest, self).setUp() + self.generic_setup() + self.set_credentials(None) + + @parameterized.expand( + [ + ("user1", status.HTTP_403_FORBIDDEN), + ("user2", status.HTTP_403_FORBIDDEN), + (None, status.HTTP_401_UNAUTHORIZED), + ] + ) + def test_get_permissions(self, user, expected_status): + # Arrange + self.set_credentials(user) + + # Act + response = self.client.get(self.url) + + # Assert + self.assertEqual(response.status_code, expected_status) diff --git a/server/user/tests/test_pk.py b/server/user/tests/test_pk.py new file mode 100644 index 000000000..fff6fbee0 --- /dev/null +++ b/server/user/tests/test_pk.py @@ -0,0 +1,99 @@ +from rest_framework import status + +from medperf.tests import MedPerfTest + +from parameterized import parameterized + + +class UserTest(MedPerfTest): + def generic_setup(self): + # setup users + user1 = "user1" + user2 = "user2" + + self.create_user(user1) + self.create_user(user2) + + self.user1 = "user1" + self.user2 = "user2" + + self.url = self.api_prefix + "/users/{0}/" + self.set_credentials(None) + + +class PermissionTest(UserTest): + """Test module for permissions of /users/ endpoint + Non-permitted actions: + GET: for all users except admin and the user themselves + PUT: for all users except admin + DELETE: for all users except admin + + + """ + + def setUp(self): + super(PermissionTest, self).setUp() + self.generic_setup() + self.set_credentials(self.user1) + user1_id = self.client.get(self.api_prefix + "/me/").data["id"] + self.url = self.url.format(user1_id) + self.set_credentials(None) + + @parameterized.expand( + [ + ("user2", status.HTTP_403_FORBIDDEN), + (None, status.HTTP_401_UNAUTHORIZED), + ] + ) + def test_get_permissions(self, user, expected_status): + # Arrange + self.set_credentials(user) + + # Act + response = self.client.get(self.url) + + # Assert + self.assertEqual(response.status_code, expected_status) + + @parameterized.expand( + [ + ("user1", status.HTTP_403_FORBIDDEN), + ("user2", status.HTTP_403_FORBIDDEN), + (None, status.HTTP_401_UNAUTHORIZED), + ] + ) + def test_put_permissions(self, user, expected_status): + # Arrange + self.set_credentials(user) + + fields = { + "username": "new", + "email": "new", + "first_name": "new", + "last_name": "new", + } + for field in fields: + # Act + response = self.client.put(self.url, {field: fields[field]}, format="json") + + # Assert + self.assertEqual( + response.status_code, expected_status, f"{field} was modified" + ) + + @parameterized.expand( + [ + ("user1", status.HTTP_403_FORBIDDEN), + ("user2", status.HTTP_403_FORBIDDEN), + (None, status.HTTP_401_UNAUTHORIZED), + ] + ) + def test_delete_permissions(self, user, expected_status): + # Arrange + self.set_credentials(user) + + # Act + response = self.client.delete(self.url) + + # Assert + self.assertEqual(response.status_code, expected_status) diff --git a/server/user/views.py b/server/user/views.py index 90413c97a..14a8a983b 100644 --- a/server/user/views.py +++ b/server/user/views.py @@ -32,9 +32,9 @@ class UserDetail(GenericAPIView): queryset = "" def get_permissions(self): - if self.request.method == "PUT" or self.request.method == "GET": + if self.request.method == "GET": self.permission_classes = [IsAdmin | IsOwnUser] - elif self.request.method == "DELETE": + elif self.request.method == "DELETE" or self.request.method == "PUT": self.permission_classes = [IsAdmin] return super(self.__class__, self).get_permissions() diff --git a/server/utils/tests.py b/server/utils/tests.py new file mode 100644 index 000000000..c1d17b5ca --- /dev/null +++ b/server/utils/tests.py @@ -0,0 +1,311 @@ +from rest_framework import status + +from medperf.tests import MedPerfTest + + +class UserTest(MedPerfTest): + def test_me_returns_current_user(self): + url = self.api_prefix + "/me/" + + # setup users + user1 = "user1" + user2 = "user2" + self.create_user(user1) + self.create_user(user2) + + # Act + self.set_credentials(user1) + response1 = self.client.get(url) + self.set_credentials(user2) + response2 = self.client.get(url) + + # Assert + self.assertEqual(response1.status_code, status.HTTP_200_OK) + self.assertEqual(response2.status_code, status.HTTP_200_OK) + self.assertEqual(response1.data["username"], user1) + self.assertEqual(response2.data["username"], user2) + + +class BenchmarksTest(MedPerfTest): + def __create_asset(self, user): + _, _, _, benchmark = self.shortcut_create_benchmark( + user, + user, + user, + user, + prep_mlcube_kwargs={"name": f"{user}prep", "mlcube_hash": f"{user}prep"}, + ref_mlcube_kwargs={"name": f"{user}ref", "mlcube_hash": f"{user}ref"}, + eval_mlcube_kwargs={"name": f"{user}eval", "mlcube_hash": f"{user}eval"}, + name=f"{user}name", + ) + return benchmark + + def test_endpoint_returns_current_user_assets(self): + url = self.api_prefix + "/me/benchmarks/" + + # setup users + user1 = "user1" + user2 = "user2" + self.create_user(user1) + self.create_user(user2) + + # create an asset for each user + benchmark1 = self.__create_asset(user1) + benchmark2 = self.__create_asset(user2) + + # Act + self.set_credentials(user1) + response1 = self.client.get(url) + self.set_credentials(user2) + response2 = self.client.get(url) + + # Assert + self.assertEqual(response1.status_code, status.HTTP_200_OK) + self.assertEqual(response2.status_code, status.HTTP_200_OK) + resp1 = response1.data["results"] + resp2 = response2.data["results"] + self.assertEqual(len(resp1), 1) + self.assertEqual(len(resp2), 1) + self.assertEqual(resp1[0]["id"], benchmark1["id"]) + self.assertEqual(resp2[0]["id"], benchmark2["id"]) + + +class DatasetsTest(MedPerfTest): + def __create_asset(self, user): + self.set_credentials(user) + prep = self.mock_mlcube(name=f"{user}name", mlcube_hash=f"{user}hash") + prep = self.create_mlcube(prep).data + dataset = self.mock_dataset(prep["id"], generated_uid=f"{user}genid") + dataset = self.create_dataset(dataset).data + return dataset + + def test_endpoint_returns_current_user_assets(self): + url = self.api_prefix + "/me/datasets/" + + # setup users + user1 = "user1" + user2 = "user2" + self.create_user(user1) + self.create_user(user2) + + # create an asset for each user + dataset1 = self.__create_asset(user1) + dataset2 = self.__create_asset(user2) + + # Act + self.set_credentials(user1) + response1 = self.client.get(url) + self.set_credentials(user2) + response2 = self.client.get(url) + + # Assert + self.assertEqual(response1.status_code, status.HTTP_200_OK) + self.assertEqual(response2.status_code, status.HTTP_200_OK) + resp1 = response1.data["results"] + resp2 = response2.data["results"] + self.assertEqual(len(resp1), 1) + self.assertEqual(len(resp2), 1) + self.assertEqual(resp1[0]["id"], dataset1["id"]) + self.assertEqual(resp2[0]["id"], dataset2["id"]) + + +class MlCubesTest(MedPerfTest): + def __create_asset(self, user): + self.set_credentials(user) + mlcube = self.mock_mlcube(name=f"{user}name", mlcube_hash=f"{user}hash") + mlcube = self.create_mlcube(mlcube).data + return mlcube + + def test_endpoint_returns_current_user_assets(self): + url = self.api_prefix + "/me/mlcubes/" + + # setup users + user1 = "user1" + user2 = "user2" + self.create_user(user1) + self.create_user(user2) + + # create an asset for each user + mlcube1 = self.__create_asset(user1) + mlcube2 = self.__create_asset(user2) + + # Act + self.set_credentials(user1) + response1 = self.client.get(url) + self.set_credentials(user2) + response2 = self.client.get(url) + + # Assert + self.assertEqual(response1.status_code, status.HTTP_200_OK) + self.assertEqual(response2.status_code, status.HTTP_200_OK) + resp1 = response1.data["results"] + resp2 = response2.data["results"] + self.assertEqual(len(resp1), 1) + self.assertEqual(len(resp2), 1) + self.assertEqual(resp1[0]["id"], mlcube1["id"]) + self.assertEqual(resp2[0]["id"], mlcube2["id"]) + + +class ResultsTest(MedPerfTest): + def setUp(self): + super(ResultsTest, self).setUp() + bmk_owner = "bmk_owner" + self.create_user(bmk_owner) + prep, refmodel, _, benchmark = self.shortcut_create_benchmark( + bmk_owner, bmk_owner, bmk_owner, bmk_owner + ) + self.bmk_owner = bmk_owner + self.benchmark = benchmark + self.prep = prep + self.refmodel = refmodel + + def __create_asset(self, user): + self.set_credentials(user) + dataset = self.mock_dataset( + self.prep["id"], generated_uid=f"{user}genid", state="OPERATION" + ) + dataset = self.create_dataset(dataset).data + assoc = self.mock_dataset_association( + self.benchmark["id"], dataset["id"], approval_status="APPROVED" + ) + self.create_dataset_association(assoc, user, self.bmk_owner) + result = self.mock_result( + self.benchmark["id"], self.refmodel["id"], dataset["id"] + ) + result = self.create_result(result).data + return result + + def test_endpoint_returns_current_user_assets(self): + url = self.api_prefix + "/me/results/" + + # setup users + user1 = "user1" + user2 = "user2" + self.create_user(user1) + self.create_user(user2) + + # create an asset for each user + result1 = self.__create_asset(user1) + result2 = self.__create_asset(user2) + + # Act + self.set_credentials(user1) + response1 = self.client.get(url) + self.set_credentials(user2) + response2 = self.client.get(url) + + # Assert + self.assertEqual(response1.status_code, status.HTTP_200_OK) + self.assertEqual(response2.status_code, status.HTTP_200_OK) + resp1 = response1.data["results"] + resp2 = response2.data["results"] + self.assertEqual(len(resp1), 1) + self.assertEqual(len(resp2), 1) + self.assertEqual(resp1[0]["id"], result1["id"]) + self.assertEqual(resp2[0]["id"], result2["id"]) + + +class BenchmarkDatasetTest(MedPerfTest): + def __create_asset(self, user): + prep, _, _, benchmark = self.shortcut_create_benchmark( + user, + user, + user, + user, + prep_mlcube_kwargs={"name": f"{user}prep", "mlcube_hash": f"{user}prep"}, + ref_mlcube_kwargs={"name": f"{user}ref", "mlcube_hash": f"{user}ref"}, + eval_mlcube_kwargs={"name": f"{user}eval", "mlcube_hash": f"{user}eval"}, + name=f"{user}name", + ) + self.set_credentials(user) + dataset = self.mock_dataset( + prep["id"], generated_uid=f"{user}genuid", state="OPERATION" + ) + dataset = self.create_dataset(dataset).data + + assoc = self.mock_dataset_association(benchmark["id"], dataset["id"]) + assoc = self.create_dataset_association(assoc, user, user).data + + return assoc + + def test_endpoint_returns_current_user_assets(self): + url = self.api_prefix + "/me/datasets/associations/" + + # setup users + user1 = "user1" + user2 = "user2" + self.create_user(user1) + self.create_user(user2) + + # create an asset for each user + assoc1 = self.__create_asset(user1) + assoc2 = self.__create_asset(user2) + + # Act + self.set_credentials(user1) + response1 = self.client.get(url) + self.set_credentials(user2) + response2 = self.client.get(url) + + # Assert + self.assertEqual(response1.status_code, status.HTTP_200_OK) + self.assertEqual(response2.status_code, status.HTTP_200_OK) + resp1 = response1.data["results"] + resp2 = response2.data["results"] + self.assertEqual(len(resp1), 1) + self.assertEqual(len(resp2), 1) + self.assertEqual(resp1[0]["id"], assoc1["id"]) + self.assertEqual(resp2[0]["id"], assoc2["id"]) + + +class BenchmarkMlCubeTest(MedPerfTest): + def __create_asset(self, user): + _, _, _, benchmark = self.shortcut_create_benchmark( + user, + user, + user, + user, + prep_mlcube_kwargs={"name": f"{user}prep", "mlcube_hash": f"{user}prep"}, + ref_mlcube_kwargs={"name": f"{user}ref", "mlcube_hash": f"{user}ref"}, + eval_mlcube_kwargs={"name": f"{user}eval", "mlcube_hash": f"{user}eval"}, + name=f"{user}name", + ) + self.set_credentials(user) + mlcube = self.mock_mlcube( + name=f"{user}name", mlcube_hash=f"{user}hash", state="OPERATION" + ) + mlcube = self.create_mlcube(mlcube).data + + assoc = self.mock_mlcube_association(benchmark["id"], mlcube["id"]) + assoc = self.create_mlcube_association(assoc, user, user).data + + return assoc + + def test_endpoint_returns_current_user_assets(self): + url = self.api_prefix + "/me/mlcubes/associations/" + + # setup users + user1 = "user1" + user2 = "user2" + self.create_user(user1) + self.create_user(user2) + + # create an asset for each user + assoc1 = self.__create_asset(user1) + assoc2 = self.__create_asset(user2) + + # Act + self.set_credentials(user1) + response1 = self.client.get(url) + self.set_credentials(user2) + response2 = self.client.get(url) + + # Assert + self.assertEqual(response1.status_code, status.HTTP_200_OK) + self.assertEqual(response2.status_code, status.HTTP_200_OK) + resp1 = response1.data["results"] + resp2 = response2.data["results"] + self.assertEqual(len(resp1), 1) + self.assertEqual(len(resp2), 1) + self.assertEqual(resp1[0]["id"], assoc1["id"]) + self.assertEqual(resp2[0]["id"], assoc2["id"])