Skip to content

Commit

Permalink
feat: MLFlow temp artifact utility (#223)
Browse files Browse the repository at this point in the history
# The problem

Uploading artifacts to MLFlow requires writing them to file first.
Therefore we need a runtime-agnostic command for getting an appropriate
temp direstory.

# This PR's solution

Implements `sdk.mlflow.get_temp_artifacts_dir()` as a user-facing
function.

Example usage:

```python
from orquestra import sdk
import mlflow

artifacts_dir: Path = sdk.mlflow.get_temp_artifacts_dir()

artifact_path = artifacts_dir / "final_state_dict.pickle"
with artifact_path.open("wb") as f:
    pickle.dumps(my_model.state_dict())
mlflow.log_artifact(artifact_path)
```

Under the hood, this looks for the `ORQ_MLFLOW_ARTIFACTS_DIR` that
exists in Studio and CE. If it can't find it, it returns a
`mllflow/artifacts` subdir under the base orquestra directory.

This approach means that local users can control the artifacts dir
location by setting `ORQ_MLFLOW_ARTIFACTS_DIR`.

# Checklist

_Check that this PR satisfies the following items:_

- [X] Tests have been added for new features/changed behavior (if no new
features have been added, check the box).
- [x] The [changelog file](CHANGELOG.md) has been updated with a
user-readable description of the changes (if the change isn't visible to
the user in any way, check the box).
- [X] The PR's title is prefixed with
`<feat/fix/chore/imp[rovement]/int[ernal]/docs>[!]:`
- [X] The PR is linked to a JIRA ticket (if there's no suitable ticket,
check the box).

[ORQSDK-880]

[ORQSDK-880]:
https://zapatacomputing.atlassian.net/browse/ORQSDK-880?atlOrigin=eyJpIjoiNWRkNTljNzYxNjVmNDY3MDlhMDU5Y2ZhYzA5YTRkZjUiLCJwIjoiZ2l0aHViLWNvbS1KU1cifQ

---------

Co-authored-by: James Clark <[email protected]>
  • Loading branch information
BenjaminMummery and jamesclark-Zapata authored Jul 5, 2023
1 parent 1afbee2 commit 4cc2ca5
Show file tree
Hide file tree
Showing 8 changed files with 187 additions and 1 deletion.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
🔥 *Features*
* Adding `FutureWarning` when accessing CE Secrets without specifying Workspace.
* Users can use `ORQ_CURRENT_PROJECT` and `ORQ_CURRENT_WORKSPACE` env variables to set default workspace and project for their interactions with CE.
* In any execution environment, users can use `sdk.mlflow.get_temp_artifacts_dir()` to get the path to a temporary directory for writing artifacts prior to uploading to MLFlow.
* Add `--list` option to `orq login` that displays the stored remote logins, which runtimes they are using, and whether their access tokens are up to date.
* Local runtime now captures any logs printed to standard output and error streams when a task is running. In particular, this means plain `print()`s will be captured and reported back with `orq wf logs` or `orq task logs`.

Expand Down
53 changes: 53 additions & 0 deletions docs/examples/tests/test_mlflow_utilities.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
################################################################################
# © Copyright 2023 Zapata Computing Inc.
################################################################################

"""
Snippets and tests used in the "MLFlow Utilities" tutorial.
"""

import pickle
from pathlib import Path
from unittest.mock import Mock

from pytest import MonkeyPatch

from orquestra import sdk

mlflow = Mock()
my_model = Mock()
my_model.state_dict.return_value = {"state": "awesome"}


class Snippets:
@staticmethod
def artifacts_dir_snippet():
artifacts_dir: Path = sdk.mlflow.get_temp_artifacts_dir()

artifact_path = artifacts_dir / "final_state_dict.pickle"
with artifact_path.open("wb"):
pickle.dumps(my_model.state_dict())
mlflow.log_artifact(artifact_path)

# </snippet>
return artifacts_dir


class TestSnippets:
@staticmethod
def test_artifact_dir_local(monkeypatch: MonkeyPatch, tmp_path):
monkeypatch.delenv("ORQ_MLFLOW_ARTIFACTS_DIR", raising=False)
monkeypatch.setattr(
sdk.mlflow._connection_utils, "DEFAULT_TEMP_ARTIFACTS_DIR", tmp_path
)

assert Snippets.artifacts_dir_snippet() == tmp_path

@staticmethod
def test_artifact_dir_env_var_override(monkeypatch: MonkeyPatch, tmp_path):
monkeypatch.setenv("ORQ_MLFLOW_ARTIFACTS_DIR", str(tmp_path))
monkeypatch.setattr(
sdk.mlflow._connection_utils, "DEFAULT_TEMP_ARTIFACTS_DIR", "NOT/THIS/PATH"
)

assert Snippets.artifacts_dir_snippet() == tmp_path
1 change: 1 addition & 0 deletions docs/guides/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ Guides
resource-management
custom-images
version-compatibility
mlflow
18 changes: 18 additions & 0 deletions docs/guides/mlflow.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
.. © Copyright 2023 Zapata Computing Inc.
MLflow Utilities
================

Uploading Artifacts
-------------------

Artifacts must be written to a file before being uploaded to MLflow. Rather than manually configuring an appropriate temporary directory for creating these files, we provide the ``get_temp_artifacts_dir()`` utility. This function can be used in any runtime and will return an appropriate directory.

.. literalinclude:: ../examples/tests/test_mlflow_utilities.py
:start-after: def artifacts_dir_snippet():
:end-before: </snippet>
:language: python
:dedent: 8


Compute Engine and Studio both configure temporary directories automatically. For local workflows, the default location is `~/.orquestra/mlflow/artifacts`, however this can be overridden by setting the ``ORQ_MLFLOW_ARTIFACTS_DIR`` environment variable.
3 changes: 2 additions & 1 deletion src/orquestra/sdk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
################################################################################
"""Orquestra SDK allows to define computational workflows using Python DSL."""

from . import secrets
from . import mlflow, secrets
from ._base._api import (
RuntimeConfig,
TaskRun,
Expand Down Expand Up @@ -45,6 +45,7 @@
"Import",
"InlineImport",
"LocalImport",
"mlflow",
"NotATaskWarning",
"PythonImports",
"Resources",
Expand Down
9 changes: 9 additions & 0 deletions src/orquestra/sdk/mlflow/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
################################################################################
# © Copyright 2023 Zapata Computing Inc.
################################################################################

"""A set of Orquestra utilities relating to interacting with MLFlow."""

from ._connection_utils import get_temp_artifacts_dir

__all__ = ["get_temp_artifacts_dir"]
36 changes: 36 additions & 0 deletions src/orquestra/sdk/mlflow/_connection_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
################################################################################
# © Copyright 2023 Zapata Computing Inc.
################################################################################

"""Utilities for communicating with mlflow."""

import os
from pathlib import Path

from orquestra.sdk._base._services import ORQUESTRA_BASE_PATH

DEFAULT_TEMP_ARTIFACTS_DIR: Path = ORQUESTRA_BASE_PATH / "mlflow" / "artifacts"


def get_temp_artifacts_dir() -> Path:
"""
Return a path to a temp directory that can be used to temporarily store artifacts.
Uploading artifacts to MLflow requires them to be written locally first. Finding an
appropriate directory vary significantly between a workflow running locally and one
running on a remote cluster. This function handles that complexity so that workflows
do not need adjusting between runtimes.
"""

path: Path
if "ORQ_MLFLOW_ARTIFACTS_DIR" in os.environ:
# In Studio and CE there is an environment variable that points to the artifact
# directory.
path = Path(os.environ["ORQ_MLFLOW_ARTIFACTS_DIR"])
else:
# If the artifact dir envvar doesn't exist, we're probably executing locally.
path = DEFAULT_TEMP_ARTIFACTS_DIR

path.mkdir(parents=True, exist_ok=True)

return path
67 changes: 67 additions & 0 deletions tests/sdk/v2/mlflow/test_connection_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
################################################################################
# © Copyright 2023 Zapata Computing Inc.
################################################################################

import pathlib

from orquestra import sdk


class TestGetTempArtifactsDir:
class TestWithRemote:
"""
'Remote' here covers studio and CE as these are treated identically in the code.
"""

@staticmethod
def test_happy_path(tmp_path: pathlib.Path, monkeypatch):
# Given
monkeypatch.setenv("ORQ_MLFLOW_ARTIFACTS_DIR", str(tmp_path))

# When
path = sdk.mlflow.get_temp_artifacts_dir()

# Then
assert path == tmp_path

@staticmethod
def test_creates_dir(tmp_path: pathlib.Path, monkeypatch):
# Given
dir = tmp_path / "dir_that_does_not_exist"
assert not dir.exists()
monkeypatch.setenv("ORQ_MLFLOW_ARTIFACTS_DIR", str(dir))

# When
_ = sdk.mlflow.get_temp_artifacts_dir()

# Then
assert dir.exists()

class TestWithLocal:
@staticmethod
def test_happy_path(tmp_path: pathlib.Path, monkeypatch):
# Given
monkeypatch.setattr(
sdk.mlflow._connection_utils, "DEFAULT_TEMP_ARTIFACTS_DIR", tmp_path
)

# When
path = sdk.mlflow.get_temp_artifacts_dir()

# Then
assert path == tmp_path

@staticmethod
def test_creates_dir(tmp_path: pathlib.Path, monkeypatch):
# Given
dir = tmp_path / "dir_that_does_not_exist"
assert not dir.exists()
monkeypatch.setattr(
sdk.mlflow._connection_utils, "DEFAULT_TEMP_ARTIFACTS_DIR", dir
)

# When
_ = sdk.mlflow.get_temp_artifacts_dir()

# Then
assert dir.exists()

0 comments on commit 4cc2ca5

Please sign in to comment.