-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: MLFlow temp artifact utility (#223)
# The problem Uploading artifacts to MLFlow requires writing them to file first. Therefore we need a runtime-agnostic command for getting an appropriate temp direstory. # This PR's solution Implements `sdk.mlflow.get_temp_artifacts_dir()` as a user-facing function. Example usage: ```python from orquestra import sdk import mlflow artifacts_dir: Path = sdk.mlflow.get_temp_artifacts_dir() artifact_path = artifacts_dir / "final_state_dict.pickle" with artifact_path.open("wb") as f: pickle.dumps(my_model.state_dict()) mlflow.log_artifact(artifact_path) ``` Under the hood, this looks for the `ORQ_MLFLOW_ARTIFACTS_DIR` that exists in Studio and CE. If it can't find it, it returns a `mllflow/artifacts` subdir under the base orquestra directory. This approach means that local users can control the artifacts dir location by setting `ORQ_MLFLOW_ARTIFACTS_DIR`. # Checklist _Check that this PR satisfies the following items:_ - [X] Tests have been added for new features/changed behavior (if no new features have been added, check the box). - [x] The [changelog file](CHANGELOG.md) has been updated with a user-readable description of the changes (if the change isn't visible to the user in any way, check the box). - [X] The PR's title is prefixed with `<feat/fix/chore/imp[rovement]/int[ernal]/docs>[!]:` - [X] The PR is linked to a JIRA ticket (if there's no suitable ticket, check the box). [ORQSDK-880] [ORQSDK-880]: https://zapatacomputing.atlassian.net/browse/ORQSDK-880?atlOrigin=eyJpIjoiNWRkNTljNzYxNjVmNDY3MDlhMDU5Y2ZhYzA5YTRkZjUiLCJwIjoiZ2l0aHViLWNvbS1KU1cifQ --------- Co-authored-by: James Clark <[email protected]>
- Loading branch information
1 parent
1afbee2
commit 4cc2ca5
Showing
8 changed files
with
187 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
################################################################################ | ||
# © Copyright 2023 Zapata Computing Inc. | ||
################################################################################ | ||
|
||
""" | ||
Snippets and tests used in the "MLFlow Utilities" tutorial. | ||
""" | ||
|
||
import pickle | ||
from pathlib import Path | ||
from unittest.mock import Mock | ||
|
||
from pytest import MonkeyPatch | ||
|
||
from orquestra import sdk | ||
|
||
mlflow = Mock() | ||
my_model = Mock() | ||
my_model.state_dict.return_value = {"state": "awesome"} | ||
|
||
|
||
class Snippets: | ||
@staticmethod | ||
def artifacts_dir_snippet(): | ||
artifacts_dir: Path = sdk.mlflow.get_temp_artifacts_dir() | ||
|
||
artifact_path = artifacts_dir / "final_state_dict.pickle" | ||
with artifact_path.open("wb"): | ||
pickle.dumps(my_model.state_dict()) | ||
mlflow.log_artifact(artifact_path) | ||
|
||
# </snippet> | ||
return artifacts_dir | ||
|
||
|
||
class TestSnippets: | ||
@staticmethod | ||
def test_artifact_dir_local(monkeypatch: MonkeyPatch, tmp_path): | ||
monkeypatch.delenv("ORQ_MLFLOW_ARTIFACTS_DIR", raising=False) | ||
monkeypatch.setattr( | ||
sdk.mlflow._connection_utils, "DEFAULT_TEMP_ARTIFACTS_DIR", tmp_path | ||
) | ||
|
||
assert Snippets.artifacts_dir_snippet() == tmp_path | ||
|
||
@staticmethod | ||
def test_artifact_dir_env_var_override(monkeypatch: MonkeyPatch, tmp_path): | ||
monkeypatch.setenv("ORQ_MLFLOW_ARTIFACTS_DIR", str(tmp_path)) | ||
monkeypatch.setattr( | ||
sdk.mlflow._connection_utils, "DEFAULT_TEMP_ARTIFACTS_DIR", "NOT/THIS/PATH" | ||
) | ||
|
||
assert Snippets.artifacts_dir_snippet() == tmp_path |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,3 +13,4 @@ Guides | |
resource-management | ||
custom-images | ||
version-compatibility | ||
mlflow |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
.. © Copyright 2023 Zapata Computing Inc. | ||
MLflow Utilities | ||
================ | ||
|
||
Uploading Artifacts | ||
------------------- | ||
|
||
Artifacts must be written to a file before being uploaded to MLflow. Rather than manually configuring an appropriate temporary directory for creating these files, we provide the ``get_temp_artifacts_dir()`` utility. This function can be used in any runtime and will return an appropriate directory. | ||
|
||
.. literalinclude:: ../examples/tests/test_mlflow_utilities.py | ||
:start-after: def artifacts_dir_snippet(): | ||
:end-before: </snippet> | ||
:language: python | ||
:dedent: 8 | ||
|
||
|
||
Compute Engine and Studio both configure temporary directories automatically. For local workflows, the default location is `~/.orquestra/mlflow/artifacts`, however this can be overridden by setting the ``ORQ_MLFLOW_ARTIFACTS_DIR`` environment variable. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
################################################################################ | ||
# © Copyright 2023 Zapata Computing Inc. | ||
################################################################################ | ||
|
||
"""A set of Orquestra utilities relating to interacting with MLFlow.""" | ||
|
||
from ._connection_utils import get_temp_artifacts_dir | ||
|
||
__all__ = ["get_temp_artifacts_dir"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
################################################################################ | ||
# © Copyright 2023 Zapata Computing Inc. | ||
################################################################################ | ||
|
||
"""Utilities for communicating with mlflow.""" | ||
|
||
import os | ||
from pathlib import Path | ||
|
||
from orquestra.sdk._base._services import ORQUESTRA_BASE_PATH | ||
|
||
DEFAULT_TEMP_ARTIFACTS_DIR: Path = ORQUESTRA_BASE_PATH / "mlflow" / "artifacts" | ||
|
||
|
||
def get_temp_artifacts_dir() -> Path: | ||
""" | ||
Return a path to a temp directory that can be used to temporarily store artifacts. | ||
Uploading artifacts to MLflow requires them to be written locally first. Finding an | ||
appropriate directory vary significantly between a workflow running locally and one | ||
running on a remote cluster. This function handles that complexity so that workflows | ||
do not need adjusting between runtimes. | ||
""" | ||
|
||
path: Path | ||
if "ORQ_MLFLOW_ARTIFACTS_DIR" in os.environ: | ||
# In Studio and CE there is an environment variable that points to the artifact | ||
# directory. | ||
path = Path(os.environ["ORQ_MLFLOW_ARTIFACTS_DIR"]) | ||
else: | ||
# If the artifact dir envvar doesn't exist, we're probably executing locally. | ||
path = DEFAULT_TEMP_ARTIFACTS_DIR | ||
|
||
path.mkdir(parents=True, exist_ok=True) | ||
|
||
return path |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
################################################################################ | ||
# © Copyright 2023 Zapata Computing Inc. | ||
################################################################################ | ||
|
||
import pathlib | ||
|
||
from orquestra import sdk | ||
|
||
|
||
class TestGetTempArtifactsDir: | ||
class TestWithRemote: | ||
""" | ||
'Remote' here covers studio and CE as these are treated identically in the code. | ||
""" | ||
|
||
@staticmethod | ||
def test_happy_path(tmp_path: pathlib.Path, monkeypatch): | ||
# Given | ||
monkeypatch.setenv("ORQ_MLFLOW_ARTIFACTS_DIR", str(tmp_path)) | ||
|
||
# When | ||
path = sdk.mlflow.get_temp_artifacts_dir() | ||
|
||
# Then | ||
assert path == tmp_path | ||
|
||
@staticmethod | ||
def test_creates_dir(tmp_path: pathlib.Path, monkeypatch): | ||
# Given | ||
dir = tmp_path / "dir_that_does_not_exist" | ||
assert not dir.exists() | ||
monkeypatch.setenv("ORQ_MLFLOW_ARTIFACTS_DIR", str(dir)) | ||
|
||
# When | ||
_ = sdk.mlflow.get_temp_artifacts_dir() | ||
|
||
# Then | ||
assert dir.exists() | ||
|
||
class TestWithLocal: | ||
@staticmethod | ||
def test_happy_path(tmp_path: pathlib.Path, monkeypatch): | ||
# Given | ||
monkeypatch.setattr( | ||
sdk.mlflow._connection_utils, "DEFAULT_TEMP_ARTIFACTS_DIR", tmp_path | ||
) | ||
|
||
# When | ||
path = sdk.mlflow.get_temp_artifacts_dir() | ||
|
||
# Then | ||
assert path == tmp_path | ||
|
||
@staticmethod | ||
def test_creates_dir(tmp_path: pathlib.Path, monkeypatch): | ||
# Given | ||
dir = tmp_path / "dir_that_does_not_exist" | ||
assert not dir.exists() | ||
monkeypatch.setattr( | ||
sdk.mlflow._connection_utils, "DEFAULT_TEMP_ARTIFACTS_DIR", dir | ||
) | ||
|
||
# When | ||
_ = sdk.mlflow.get_temp_artifacts_dir() | ||
|
||
# Then | ||
assert dir.exists() |