Skip to content

Commit

Permalink
Merge pull request godatadriven#1055 from DaanRademaker/change_experi…
Browse files Browse the repository at this point in the history
…ment_to_databricks_sdk

Change experiment to databricks sdk
  • Loading branch information
dan1elt0m authored Nov 10, 2023
2 parents bbf737b + de210ec commit b0d5107
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 310 deletions.
115 changes: 29 additions & 86 deletions aws-lambda/src/databricks_cdk/resources/mlflow/experiment.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from typing import List, Optional
from typing import Optional

from databricks.sdk.service.ml import ExperimentTag
from pydantic import BaseModel

from databricks_cdk.utils import CnfResponse, get_request, post_request
from databricks_cdk.utils import CnfResponse, get_workspace_client


class ExperimentTag(BaseModel):
key: str
value: str
class ExperimentIdNoneError(Exception):
pass


class ExperimentProperties(BaseModel):
Expand All @@ -22,79 +22,6 @@ class ExperimentCreateResponse(CnfResponse):
name: str


class ExperimentExisting(BaseModel):
experiment_id: str
name: str
artifact_location: str
lifecycle_stage: str
last_update_time: int
creation_time: int
tags: List[ExperimentTag]


def get_experiment_url(workspace_url: str) -> str:
"""Get the mlflow experiment url"""
return f"{workspace_url}/api/2.0/mlflow/experiments"


def _create_experiment(experiment_url: str, properties: ExperimentProperties) -> str:
"""Creates a new experiment"""
return post_request(
f"{experiment_url}/create",
{
"name": properties.name,
"artifact_location": properties.artifact_location,
"tags": {"key": "mlflow.note.content", "value": properties.description},
},
)["experiment_id"]


def _update_experiment_name(experiment_url: str, experiment_id: str, new_name: str):
"""Updates the experiment name"""
post_request(
f"{experiment_url}/update",
{"experiment_id": experiment_id, "new_name": new_name},
)


def _update_experiment_description(experiment_url: str, experiment_id: str, new_description: str):
"""Updates the description of the experiment which is a fixed key of the experiment tags (mlflow.note.content)"""
post_request(
f"{experiment_url}/set-experiment-tag",
{
"experiment_id": experiment_id,
"key": "mlflow.note.content",
"value": new_description,
},
)


def _update_experiment(experiment_url: str, properties: ExperimentProperties, existing_experiment: ExperimentExisting):
"""Updates an experiment if there is a new name or a new description"""
if properties.name != existing_experiment.name:
_update_experiment_name(experiment_url, existing_experiment.experiment_id, properties.name)

new_description = ExperimentTag(key="mlflow.note.content", value=properties.description)
if new_description not in existing_experiment.tags:
_update_experiment_description(
experiment_url,
existing_experiment.experiment_id,
new_description.value,
)


def _get_existing_experiment(
experiment_url: str, properties: ExperimentProperties, experiment_id: Optional[str]
) -> Optional[ExperimentExisting]:
"""Gets an existing experiment from mlflow based on experiment_id"""
existing_experiment = get_request(f"{experiment_url}/get?experiment_id={experiment_id}")

if existing_experiment:
return ExperimentExisting.parse_obj(existing_experiment["experiment"])

return None


def create_or_update_experiment(properties: ExperimentProperties, physical_resource_id: Optional[str] = None):
"""
Creates a new experiment if there is no physical_resource_id provided, else it checks if based on the
Expand All @@ -106,22 +33,38 @@ def create_or_update_experiment(properties: ExperimentProperties, physical_resou
defaults to None
:return: Both the physical_resouce_id (experiment_id) and the name of the experiment
"""
experiment_url = get_experiment_url(properties.workspace_url)
workspace_client = get_workspace_client(properties.workspace_url)

description_tag = ExperimentTag(key="mlflow.note.content", value=properties.description)
if not physical_resource_id:
experiment_id = _create_experiment(experiment_url, properties)
return ExperimentCreateResponse(name=properties.name, physical_resource_id=experiment_id)
response = workspace_client.experiments.create_experiment(
name=properties.name, artifact_location=properties.artifact_location, tags=[description_tag]
)
return (
ExperimentCreateResponse(name=properties.name, physical_resource_id=response.experiment_id)
if response.experiment_id is not None
else ExperimentIdNoneError("Experiment ID is None")
)

existing_experiment = _get_existing_experiment(experiment_url, properties, physical_resource_id)
existing_experiment = workspace_client.experiments.get_experiment(experiment_id=physical_resource_id)

if not existing_experiment:
raise ValueError("Existing experiment cannot be found but physical_resouce_id is provided")
if existing_experiment.experiment is None:
raise ExperimentIdNoneError("Existing experiment cannot be found but physical_resouce_id is provided")

if existing_experiment.experiment.name != properties.name:
workspace_client.experiments.update_experiment(experiment_id=physical_resource_id, new_name=properties.name)

existing_tags = existing_experiment.experiment.tags
if existing_tags is not None and properties.description is not None and properties.description not in existing_tags:
workspace_client.experiments.set_experiment_tag(
experiment_id=physical_resource_id, key="mlflow.note.content", value=properties.description
)

_update_experiment(experiment_url, properties, existing_experiment)
return ExperimentCreateResponse(name=properties.name, physical_resource_id=physical_resource_id)


def delete_experiment(properties: ExperimentProperties, physical_resource_id: str) -> CnfResponse:
"""Deletes an existing mlflow experiment"""
post_request(f"{get_experiment_url(properties.workspace_url)}/delete", {"experiment_id": physical_resource_id})
workspace_client = get_workspace_client(properties.workspace_url)
workspace_client.experiments.delete_experiment(experiment_id=physical_resource_id)
return CnfResponse(physical_resource_id=physical_resource_id)
3 changes: 2 additions & 1 deletion aws-lambda/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from unittest.mock import MagicMock

import pytest
from databricks.sdk import ModelRegistryAPI, WorkspaceClient
from databricks.sdk import ExperimentsAPI, ModelRegistryAPI, WorkspaceClient


@pytest.fixture(scope="function", autouse=True)
Expand All @@ -22,5 +22,6 @@ def workspace_client():

# mock all of the underlying service api's
workspace_client.model_registry = MagicMock(spec=ModelRegistryAPI)
workspace_client.experiments = MagicMock(spec=ExperimentsAPI)

return workspace_client
Loading

0 comments on commit b0d5107

Please sign in to comment.