From 871fcc85e3571f1dc73f190ed9a7b19aa484e4c1 Mon Sep 17 00:00:00 2001 From: Daan Rademaker Date: Fri, 10 Nov 2023 16:48:40 +0100 Subject: [PATCH 1/2] change api experiment to databricks sdk --- .../resources/mlflow/experiment.py | 115 ++------ aws-lambda/tests/conftest.py | 3 +- .../tests/resources/mlflow/test_experiment.py | 277 +++--------------- 3 files changed, 77 insertions(+), 318 deletions(-) diff --git a/aws-lambda/src/databricks_cdk/resources/mlflow/experiment.py b/aws-lambda/src/databricks_cdk/resources/mlflow/experiment.py index 9a9d29d1..10073404 100644 --- a/aws-lambda/src/databricks_cdk/resources/mlflow/experiment.py +++ b/aws-lambda/src/databricks_cdk/resources/mlflow/experiment.py @@ -1,13 +1,13 @@ -from typing import List, Optional +from typing import Optional +from databricks.sdk.service.ml import ExperimentTag from pydantic import BaseModel -from databricks_cdk.utils import CnfResponse, get_request, post_request +from databricks_cdk.utils import CnfResponse, get_workspace_client -class ExperimentTag(BaseModel): - key: str - value: str +class ExperimentIdNoneError(Exception): + pass class ExperimentProperties(BaseModel): @@ -22,79 +22,6 @@ class ExperimentCreateResponse(CnfResponse): name: str -class ExperimentExisting(BaseModel): - experiment_id: str - name: str - artifact_location: str - lifecycle_stage: str - last_update_time: int - creation_time: int - tags: List[ExperimentTag] - - -def get_experiment_url(workspace_url: str) -> str: - """Get the mlflow experiment url""" - return f"{workspace_url}/api/2.0/mlflow/experiments" - - -def _create_experiment(experiment_url: str, properties: ExperimentProperties) -> str: - """Creates a new experiment""" - return post_request( - f"{experiment_url}/create", - { - "name": properties.name, - "artifact_location": properties.artifact_location, - "tags": {"key": "mlflow.note.content", "value": properties.description}, - }, - )["experiment_id"] - - -def _update_experiment_name(experiment_url: str, experiment_id: str, new_name: str): - """Updates the experiment name""" - post_request( - f"{experiment_url}/update", - {"experiment_id": experiment_id, "new_name": new_name}, - ) - - -def _update_experiment_description(experiment_url: str, experiment_id: str, new_description: str): - """Updates the description of the experiment which is a fixed key of the experiment tags (mlflow.note.content)""" - post_request( - f"{experiment_url}/set-experiment-tag", - { - "experiment_id": experiment_id, - "key": "mlflow.note.content", - "value": new_description, - }, - ) - - -def _update_experiment(experiment_url: str, properties: ExperimentProperties, existing_experiment: ExperimentExisting): - """Updates an experiment if there is a new name or a new description""" - if properties.name != existing_experiment.name: - _update_experiment_name(experiment_url, existing_experiment.experiment_id, properties.name) - - new_description = ExperimentTag(key="mlflow.note.content", value=properties.description) - if new_description not in existing_experiment.tags: - _update_experiment_description( - experiment_url, - existing_experiment.experiment_id, - new_description.value, - ) - - -def _get_existing_experiment( - experiment_url: str, properties: ExperimentProperties, experiment_id: Optional[str] -) -> Optional[ExperimentExisting]: - """Gets an existing experiment from mlflow based on experiment_id""" - existing_experiment = get_request(f"{experiment_url}/get?experiment_id={experiment_id}") - - if existing_experiment: - return ExperimentExisting.parse_obj(existing_experiment["experiment"]) - - return None - - def create_or_update_experiment(properties: ExperimentProperties, physical_resource_id: Optional[str] = None): """ Creates a new experiment if there is no physical_resource_id provided, else it checks if based on the @@ -106,22 +33,38 @@ def create_or_update_experiment(properties: ExperimentProperties, physical_resou defaults to None :return: Both the physical_resouce_id (experiment_id) and the name of the experiment """ - experiment_url = get_experiment_url(properties.workspace_url) + workspace_client = get_workspace_client(properties.workspace_url) + description_tag = ExperimentTag(key="mlflow.note.content", value=properties.description) if not physical_resource_id: - experiment_id = _create_experiment(experiment_url, properties) - return ExperimentCreateResponse(name=properties.name, physical_resource_id=experiment_id) + response = workspace_client.experiments.create_experiment( + name=properties.name, artifact_location=properties.artifact_location, tags=[description_tag] + ) + return ( + ExperimentCreateResponse(name=properties.name, physical_resource_id=response.experiment_id) + if response.experiment_id is not None + else ExperimentIdNoneError("Experiment ID is None") + ) - existing_experiment = _get_existing_experiment(experiment_url, properties, physical_resource_id) + existing_experiment = workspace_client.experiments.get_experiment(experiment_id=physical_resource_id) - if not existing_experiment: - raise ValueError("Existing experiment cannot be found but physical_resouce_id is provided") + if existing_experiment.experiment is None: + raise ExperimentIdNoneError("Existing experiment cannot be found but physical_resouce_id is provided") + + if existing_experiment.experiment.name != properties.name: + workspace_client.experiments.update_experiment(experiment_id=physical_resource_id, new_name=properties.name) + + existing_tags = existing_experiment.experiment.tags + if existing_tags is not None and properties.description is not None and properties.description not in existing_tags: + workspace_client.experiments.set_experiment_tag( + experiment_id=physical_resource_id, key="mlflow.note.content", value=properties.description + ) - _update_experiment(experiment_url, properties, existing_experiment) return ExperimentCreateResponse(name=properties.name, physical_resource_id=physical_resource_id) def delete_experiment(properties: ExperimentProperties, physical_resource_id: str) -> CnfResponse: """Deletes an existing mlflow experiment""" - post_request(f"{get_experiment_url(properties.workspace_url)}/delete", {"experiment_id": physical_resource_id}) + workspace_client = get_workspace_client(properties.workspace_url) + workspace_client.experiments.delete_experiment(experiment_id=physical_resource_id) return CnfResponse(physical_resource_id=physical_resource_id) diff --git a/aws-lambda/tests/conftest.py b/aws-lambda/tests/conftest.py index 588f7492..a6a4005e 100644 --- a/aws-lambda/tests/conftest.py +++ b/aws-lambda/tests/conftest.py @@ -2,7 +2,7 @@ from unittest.mock import MagicMock import pytest -from databricks.sdk import ModelRegistryAPI, WorkspaceClient +from databricks.sdk import ExperimentsAPI, ModelRegistryAPI, WorkspaceClient @pytest.fixture(scope="function", autouse=True) @@ -22,5 +22,6 @@ def workspace_client(): # mock all of the underlying service api's workspace_client.model_registry = MagicMock(spec=ModelRegistryAPI) + workspace_client.experiments = MagicMock(spec=ExperimentsAPI) return workspace_client diff --git a/aws-lambda/tests/resources/mlflow/test_experiment.py b/aws-lambda/tests/resources/mlflow/test_experiment.py index 57d5568c..3cdb45bc 100644 --- a/aws-lambda/tests/resources/mlflow/test_experiment.py +++ b/aws-lambda/tests/resources/mlflow/test_experiment.py @@ -1,197 +1,24 @@ from unittest.mock import patch import pytest +from databricks.sdk.service.ml import CreateExperimentResponse, Experiment, ExperimentTag, GetExperimentResponse from databricks_cdk.resources.mlflow.experiment import ( + CnfResponse, ExperimentCreateResponse, - ExperimentExisting, + ExperimentIdNoneError, ExperimentProperties, - ExperimentTag, - _create_experiment, - _get_existing_experiment, - _update_experiment, - _update_experiment_description, - _update_experiment_name, create_or_update_experiment, delete_experiment, - get_experiment_url, ) -def test_get_experiment_url(): - workspace_url = "https://test.cloud.databricks.com" - assert get_experiment_url(workspace_url) == "https://test.cloud.databricks.com/api/2.0/mlflow/experiments" - - -@patch("databricks_cdk.resources.mlflow.experiment.post_request") -def test__create_experiment(patched_post_request): - props = ExperimentProperties( - name="test", - artifact_location="s3://test", - workspace_url="https://test.cloud.databricks.com", - description="some description", - ) - patched_post_request.return_value = {"experiment_id": "test_id"} - experiment_id = _create_experiment("https://test.cloud.databricks.com/api/2.0/mlflow/experiments", props) - - assert experiment_id == "test_id" - assert patched_post_request.call_count == 1 - - assert patched_post_request.call_args.args == ( - "https://test.cloud.databricks.com/api/2.0/mlflow/experiments/create", - { - "name": "test", - "artifact_location": "s3://test", - "tags": {"key": "mlflow.note.content", "value": "some description"}, - }, - ) - - -@patch("databricks_cdk.resources.mlflow.experiment.post_request") -def test__update_experiment_name(patched_post_request): - _update_experiment_name( - experiment_url="https://test.cloud.databricks.com/api/2.0/mlflow/experiments", - experiment_id="test_id", - new_name="new_name", - ) - - assert patched_post_request.call_args.args == ( - "https://test.cloud.databricks.com/api/2.0/mlflow/experiments/update", - {"experiment_id": "test_id", "new_name": "new_name"}, - ) - - -@patch("databricks_cdk.resources.mlflow.experiment.post_request") -def test__update_experiment_description(patched_post_request): - _update_experiment_description( - experiment_url="https://test.cloud.databricks.com/api/2.0/mlflow/experiments", - experiment_id="test_id", - new_description="new_description", - ) - - assert patched_post_request.call_args.args == ( - "https://test.cloud.databricks.com/api/2.0/mlflow/experiments/set-experiment-tag", - {"experiment_id": "test_id", "key": "mlflow.note.content", "value": "new_description"}, - ) - - -@patch("databricks_cdk.resources.mlflow.experiment._update_experiment_name") -@patch("databricks_cdk.resources.mlflow.experiment._update_experiment_description") -def test__update_experiment_changes( - patched__updated_experiment_description, - patched__update_experiment_name, -): - props = ExperimentProperties( - name="new_name", - artifact_location="s3://test", - workspace_url="https://test.cloud.databricks.com", - description="some new description", - ) - - existing_experiment = ExperimentExisting( - experiment_id="test_id", - name="old_name", - artifact_location="s3://test", - lifecycle_stage="blah", - last_update_time=1234, - creation_time=1234, - tags=[ExperimentTag(key="mlflow.note.content", value="some old description")], - ) - - # both name update and description update should be called - _update_experiment( - experiment_url="https://test.cloud.databricks.com/api/2.0/mlflow/experiments", - properties=props, - existing_experiment=existing_experiment, +@patch("databricks_cdk.resources.mlflow.experiment.get_workspace_client") +def test_create_or_update_experiment_new(patched_get_workspace_client, workspace_client): + patched_get_workspace_client.return_value = workspace_client + workspace_client.experiments.create_experiment.return_value = CreateExperimentResponse( + experiment_id="some_experiment_id" ) - - assert patched__update_experiment_name.call_args.args == ( - "https://test.cloud.databricks.com/api/2.0/mlflow/experiments", - "test_id", - "new_name", - ) - - assert patched__updated_experiment_description.call_args.args == ( - "https://test.cloud.databricks.com/api/2.0/mlflow/experiments", - "test_id", - "some new description", - ) - - -@patch("databricks_cdk.resources.mlflow.experiment._update_experiment_name") -@patch("databricks_cdk.resources.mlflow.experiment._update_experiment_description") -def test__update_experiment_no_changes( - patched__updated_experiment_description, - patched__update_experiment_name, -): - props = ExperimentProperties( - name="same_name", - artifact_location="s3://test", - workspace_url="https://test.cloud.databricks.com", - description="same description", - ) - - existing_experiment = ExperimentExisting( - experiment_id="test_id", - name="same_name", - artifact_location="s3://test", - lifecycle_stage="blah", - last_update_time=1234, - creation_time=1234, - tags=[ExperimentTag(key="mlflow.note.content", value="same description")], - ) - - # both name update and description update should be called - _update_experiment( - experiment_url="https://test.cloud.databricks.com/api/2.0/mlflow/experiments", - properties=props, - existing_experiment=existing_experiment, - ) - - # update functions should not have been called - patched__updated_experiment_description.call_count == 0 - patched__update_experiment_name.call_count == 0 - - -@patch("databricks_cdk.resources.mlflow.experiment.get_request") -def test__get_existing_experiment(patched__get_request): - props = ExperimentProperties( - name="same_name", - artifact_location="s3://test", - workspace_url="https://test.cloud.databricks.com", - description="same description", - ) - - patched__get_request.return_value = { - "experiment": { - "experiment_id": "1234", - "name": "same_name", - "artifact_location": "s3://test", - "lifecycle_stage": "blah", - "last_update_time": 1234, - "creation_time": 1234, - "tags": [{"key": "mlflow.note.content", "value": "same description"}], - } - } - - existing_experiment = _get_existing_experiment( - experiment_url="https://test.cloud.databricks.com/api/2.0/mlflow/experiments", - properties=props, - experiment_id="1234", - ) - - assert isinstance(existing_experiment, ExperimentExisting) - - patched__get_request.return_value = None - assert not _get_existing_experiment( - experiment_url="https://test.cloud.databricks.com/api/2.0/mlflow/experiments", - properties=props, - experiment_id="1234", - ) - - -@patch("databricks_cdk.resources.mlflow.experiment._create_experiment") -def test_create_or_update_experiment_new(patched__create_experiment): props = ExperimentProperties( name="name", artifact_location="s3://test", @@ -199,45 +26,27 @@ def test_create_or_update_experiment_new(patched__create_experiment): description="same description", ) - patched__create_experiment.return_value = "some_experiment_id" - # completely new experiment response = create_or_update_experiment(props, physical_resource_id=None) assert response == ExperimentCreateResponse(physical_resource_id="some_experiment_id", name="name") + workspace_client.experiments.create_experiment.assert_called_once_with( + name="name", + artifact_location="s3://test", + tags=[ExperimentTag(key="mlflow.note.content", value="same description")], + ) -@patch("databricks_cdk.resources.mlflow.experiment._get_existing_experiment") -@patch("databricks_cdk.resources.mlflow.experiment._update_experiment") -def test_create_or_update_experiment_existing(patched__update_experiment, patched__get_existing_experiment): +@patch("databricks_cdk.resources.mlflow.experiment.get_workspace_client") +def test_create_or_update_experiment_existing(patched_get_workspace_client, workspace_client): props = ExperimentProperties( name="name", artifact_location="s3://test", workspace_url="https://test.cloud.databricks.com", description="same description", ) - - patched__get_existing_experiment.return_value = ExperimentExisting( - experiment_id="test_id", - name="old_name", - artifact_location="s3://test", - lifecycle_stage="blah", - last_update_time=1234, - creation_time=1234, - tags=[ExperimentTag(key="mlflow.note.content", value="some old description")], - ) - - # already existing experiment - response = create_or_update_experiment(props, physical_resource_id="test_id") - - assert patched__update_experiment.call_args.args == ( - "https://test.cloud.databricks.com/api/2.0/mlflow/experiments", - ExperimentProperties( - name="name", - artifact_location="s3://test", - workspace_url="https://test.cloud.databricks.com", - description="same description", - ), - ExperimentExisting( + patched_get_workspace_client.return_value = workspace_client + workspace_client.experiments.get_experiment.return_value = GetExperimentResponse( + experiment=Experiment( experiment_id="test_id", name="old_name", artifact_location="s3://test", @@ -245,38 +54,44 @@ def test_create_or_update_experiment_existing(patched__update_experiment, patche last_update_time=1234, creation_time=1234, tags=[ExperimentTag(key="mlflow.note.content", value="some old description")], - ), + ) ) - assert patched__get_existing_experiment.call_args.args == ( - "https://test.cloud.databricks.com/api/2.0/mlflow/experiments", - ExperimentProperties( - name="name", - artifact_location="s3://test", - workspace_url="https://test.cloud.databricks.com", - description="same description", - ), - "test_id", - ) + # already existing experiment + response = create_or_update_experiment(props, physical_resource_id="test_id") assert response == ExperimentCreateResponse(physical_resource_id="test_id", name="name") - - # this is invalid and should raise error - patched__get_existing_experiment.return_value = None - with pytest.raises(ValueError): - response = create_or_update_experiment(props, physical_resource_id="test_id") + workspace_client.experiments.update_experiment.assert_called_once_with(experiment_id="test_id", new_name="name") -@patch("databricks_cdk.resources.mlflow.experiment.post_request") -def test_delete_experiment(patched_post_request): +@patch("databricks_cdk.resources.mlflow.experiment.get_workspace_client") +def test_create_or_update_experiment_existing_invalid_id(patched_get_workspace_client, workspace_client): props = ExperimentProperties( name="name", artifact_location="s3://test", workspace_url="https://test.cloud.databricks.com", description="same description", ) - delete_experiment(props, "some_id") - assert patched_post_request.call_args.args == ( - "https://test.cloud.databricks.com/api/2.0/mlflow/experiments/delete", - {"experiment_id": "some_id"}, + patched_get_workspace_client.return_value = workspace_client + workspace_client.experiments.get_experiment.return_value = GetExperimentResponse(experiment=None) + + # already existing experiment + with pytest.raises(ExperimentIdNoneError): + create_or_update_experiment(props, physical_resource_id="test_id") + + +@patch("databricks_cdk.resources.mlflow.experiment.get_workspace_client") +def test_delete_experiment(patched_get_workspace_client, workspace_client): + patched_get_workspace_client.return_value = workspace_client + props = ExperimentProperties( + name="name", + artifact_location="s3://test", + workspace_url="https://test.cloud.databricks.com", + description="same description", ) + response = delete_experiment(props, "some_id") + assert response == CnfResponse(physical_resource_id="some_id") + workspace_client.experiments.delete_experiment.assert_called_once_with(experiment_id="some_id") + + +# TODO: Add test for tag updates From de210ec823b9cf3d89b18a49bc320846f13723fb Mon Sep 17 00:00:00 2001 From: Daan Rademaker Date: Fri, 10 Nov 2023 17:17:20 +0100 Subject: [PATCH 2/2] add new test for tags --- .../tests/resources/mlflow/test_experiment.py | 31 +++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/aws-lambda/tests/resources/mlflow/test_experiment.py b/aws-lambda/tests/resources/mlflow/test_experiment.py index 3cdb45bc..a969e662 100644 --- a/aws-lambda/tests/resources/mlflow/test_experiment.py +++ b/aws-lambda/tests/resources/mlflow/test_experiment.py @@ -64,6 +64,37 @@ def test_create_or_update_experiment_existing(patched_get_workspace_client, work workspace_client.experiments.update_experiment.assert_called_once_with(experiment_id="test_id", new_name="name") +@patch("databricks_cdk.resources.mlflow.experiment.get_workspace_client") +def test_create_or_update_experiment_existing_new_description(patched_get_workspace_client, workspace_client): + props = ExperimentProperties( + name="name", + artifact_location="s3://test", + workspace_url="https://test.cloud.databricks.com", + description="new description", + ) + patched_get_workspace_client.return_value = workspace_client + workspace_client.experiments.get_experiment.return_value = GetExperimentResponse( + experiment=Experiment( + experiment_id="test_id", + name="name", + artifact_location="s3://test", + lifecycle_stage="blah", + last_update_time=1234, + creation_time=1234, + tags=[ExperimentTag(key="mlflow.note.content", value="some old description")], + ) + ) + + # already existing experiment + response = create_or_update_experiment(props, physical_resource_id="test_id") + + assert response == ExperimentCreateResponse(physical_resource_id="test_id", name="name") + assert workspace_client.experiments.update_experiment.call_count == 0 + workspace_client.experiments.set_experiment_tag.assert_called_once_with( + experiment_id="test_id", key="mlflow.note.content", value="new description" + ) + + @patch("databricks_cdk.resources.mlflow.experiment.get_workspace_client") def test_create_or_update_experiment_existing_invalid_id(patched_get_workspace_client, workspace_client): props = ExperimentProperties(