Skip to content

Commit

Permalink
feat: Allow head node resources (#344)
Browse files Browse the repository at this point in the history
# The problem

We want users to be able to select the head node resources on CE.

# This PR's solution

- adds a new kwarg to the workflow decorator
- deprecates `data_aggregation` kwarg to the workflow decorator 

# Checklist

_Check that this PR satisfies the following items:_

- [x] Tests have been added for new features/changed behavior (if no new
features have been added, check the box).
- [x] The [changelog file](CHANGELOG.md) has been updated with a
user-readable description of the changes (if the change isn't visible to
the user in any way, check the box).
- [x] The PR's title is prefixed with
`<feat/fix/chore/imp[rovement]/int[ernal]/docs>[!]:`
- [x] The PR is linked to a JIRA ticket (if there's no suitable ticket,
check the box).

---------

Co-authored-by: Sebastian Morawiec <[email protected]>
  • Loading branch information
jamesclark-Zapata and SebastianMorawiec authored Dec 5, 2023
1 parent 5b643c1 commit f877188
Show file tree
Hide file tree
Showing 10 changed files with 226 additions and 96 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
🔥 *Features*

* Listing workflow runs on CE now supports filtering by state and max age.
* `head_node_resources` keyword argument has been added to the `@sdk.workflow` decorator to control the resources that the head node has available on CE.

🧟 *Deprecations*
* Deprecation of `project_dir` argument for all public API functions.
* `data_aggregation` keyword argument to `@sdk.workflow` decorator is deprecated and will be removed in a future version.

👩‍🔬 *Experimental*

Expand Down
25 changes: 24 additions & 1 deletion src/orquestra/sdk/_base/_driver/_ce_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# © Copyright 2022-2023 Zapata Computing Inc.
################################################################################
"""RuntimeInterface implementation that uses Compute Engine."""
import warnings
from datetime import timedelta
from pathlib import Path
from typing import Dict, List, Optional, Protocol, Sequence, Union
Expand All @@ -12,6 +13,7 @@
from orquestra.sdk._base._logs._interfaces import LogOutput, WorkflowLogs
from orquestra.sdk._base._logs._models import LogAccumulator, LogStreamType
from orquestra.sdk._base.abc import RuntimeInterface
from orquestra.sdk.exceptions import IgnoredFieldWarning
from orquestra.sdk.kubernetes.quantity import parse_quantity
from orquestra.sdk.schema.configs import RuntimeConfiguration
from orquestra.sdk.schema.ir import ArtifactFormat, TaskInvocationId, WorkflowDef
Expand Down Expand Up @@ -156,11 +158,32 @@ def create_workflow_run(

_verify_workflow_resources(resources, max_invocation_resources)

if (
workflow_def.data_aggregation is not None
and workflow_def.data_aggregation.resources is not None
):
head_node_resources = _models.HeadNodeResources(
cpu=workflow_def.data_aggregation.resources.cpu,
memory=workflow_def.data_aggregation.resources.memory,
)
if workflow_def.data_aggregation.resources.gpu:
warnings.warn(
"Head node resources will ignore GPU settings",
category=IgnoredFieldWarning,
)
if workflow_def.data_aggregation.resources.nodes:
warnings.warn(
'Head node resources will ignore "nodes" settings',
category=IgnoredFieldWarning,
)
else:
head_node_resources = None

try:
workflow_def_id = self._client.create_workflow_def(workflow_def, project)

workflow_run_id = self._client.create_workflow_run(
workflow_def_id, resources, dry_run
workflow_def_id, resources, dry_run, head_node_resources
)
except _exceptions.InvalidWorkflowDef as e:
raise exceptions.WorkflowSyntaxError(
Expand Down
3 changes: 3 additions & 0 deletions src/orquestra/sdk/_base/_driver/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,13 +471,15 @@ def create_workflow_run(
workflow_def_id: _models.WorkflowDefID,
resources: _models.Resources,
dry_run: bool,
head_node_resources: Optional[_models.HeadNodeResources],
) -> _models.WorkflowRunID:
"""Submit a workflow def to run in the workflow driver.
Args:
workflow_def_id: ID of the workflow definition to be submitted.
resources: The resources required to execute the workflow.
dry_run: Run the workflow without actually executing any task code.
head_node_resources: the requested resources for the head node
Raises:
orquestra.sdk._base._driver._exceptions.InvalidWorkflowRunRequest: when an
Expand All @@ -497,6 +499,7 @@ def create_workflow_run(
workflowDefinitionID=workflow_def_id,
resources=resources,
dryRun=dry_run,
headNodeResources=head_node_resources,
).dict(),
)

Expand Down
17 changes: 16 additions & 1 deletion src/orquestra/sdk/_base/_driver/_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,15 +287,30 @@ class Resources(pydantic.BaseModel):
gpu: Optional[str] = pydantic.Field(regex="^[01]+$")


class HeadNodeResources(pydantic.BaseModel):
"""
Implements:
https://github.com/zapatacomputing/workflow-driver/blob/ac1e97ea00fc3526c93187a1da02170bff45b74f/openapi/src/schemas/HeadNodeResources.yaml.
""" # noqa: D205, D212

cpu: Optional[str] = pydantic.Field(
regex=r"^([+-]?[0-9.]+)([eEinumkKMGTP]*[-+]?[0-9]*)$"
)
memory: Optional[str] = pydantic.Field(
regex=r"^([+-]?[0-9.]+)([eEinumkKMGTP]*[-+]?[0-9]*)$"
)


class CreateWorkflowRunRequest(pydantic.BaseModel):
"""
Implements:
https://github.com/zapatacomputing/workflow-driver/blob/be7f293b052a0fee7b012badebef72dee02a2ebe/openapi/src/schemas/CreateWorkflowRunRequest.yaml.
https://github.com/zapatacomputing/workflow-driver/blob/ac1e97ea00fc3526c93187a1da02170bff45b74f/openapi/src/schemas/CreateWorkflowRunRequest.yaml.
""" # noqa: D205, D212

workflowDefinitionID: WorkflowDefID
resources: Resources
dryRun: bool
headNodeResources: Optional[Resources]


class CreateWorkflowRunResponse(pydantic.BaseModel):
Expand Down
24 changes: 20 additions & 4 deletions src/orquestra/sdk/_base/_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,7 @@ def workflow(fn: Callable[_P, _R]) -> WorkflowTemplate[_P, _R]:
def workflow(
*,
resources: Optional[_dsl.Resources] = None,
head_node_resources: Optional[_dsl.Resources] = None,
data_aggregation: Optional[Union[DataAggregation, bool]] = None,
custom_name: Optional[str] = None,
default_source_import: Optional[Import] = None,
Expand All @@ -551,6 +552,7 @@ def workflow(
fn: Optional[Callable[_P, _R]] = None,
*,
resources: Optional[_dsl.Resources] = None,
head_node_resources: Optional[_dsl.Resources] = None,
data_aggregation: Optional[Union[DataAggregation, bool]] = None,
custom_name: Optional[str] = None,
default_source_import: Optional[Import] = None,
Expand All @@ -569,9 +571,10 @@ def workflow(
10 nodes with 20 CPUs and a GPU each would be:
resources=sdk.Resources(cpu="20", gpu="1", nodes=10)
If omitted, the cluster's default resources will be used.
data_aggregation: Used to set up resources used during data step. If skipped,
or assigned True default values will be used. If assigned False
data aggregation step will not run.
head_node_resources: !Unstable API! The resources that the head node requires.
Only used on Compute Engine. If omitted, the cluster's default head node
resources will be used.
data_aggregation: Deprecated.
custom_name: custom name for the workflow
default_source_import: Set the default source import for all tasks inside
this workflow.
Expand Down Expand Up @@ -622,6 +625,19 @@ def wf():
elif default_dependency_imports is not None:
workflow_default_dependency_imports = tuple(default_dependency_imports)

if data_aggregation is not None:
warnings.warn(
"data_aggregation argument is deprecated and will be removed "
"in upcoming versions of orquestra-sdk. It will be ignored "
"for this workflow.",
category=FutureWarning,
)

if head_node_resources is not None:
_data_aggregation = DataAggregation(resources=head_node_resources)
else:
_data_aggregation = None

def _inner(fn: Callable[_P, _R]):
signature = inspect.signature(fn)
name = custom_name
Expand All @@ -632,7 +648,7 @@ def _inner(fn: Callable[_P, _R]):
workflow_fn=fn,
fn_ref=fn_ref,
is_parametrized=len(signature.parameters) > 0,
data_aggregation=data_aggregation,
data_aggregation=_data_aggregation,
default_source_import=default_source_import,
default_dependency_imports=workflow_default_dependency_imports,
)
Expand Down
6 changes: 6 additions & 0 deletions src/orquestra/sdk/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,3 +335,9 @@ class WorkspacesNotSupportedError(BaseRuntimeError):
"""When a non-workspaces supporting runtime gets a workspaces-related request."""

pass


class IgnoredFieldWarning(Warning):
"""Raised when a requested feature is not supported on the selected runtime."""

pass
92 changes: 91 additions & 1 deletion tests/sdk/driver/test_ce_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from orquestra.sdk._base._driver import _ce_runtime, _client, _exceptions, _models
from orquestra.sdk._base._spaces._structs import ProjectRef
from orquestra.sdk._base._testing._example_wfs import (
add,
my_workflow,
workflow_parametrised_with_resources,
workflow_with_different_resources,
Expand Down Expand Up @@ -94,6 +95,7 @@ def test_happy_path(
workflow_def_id,
_models.Resources(cpu=None, memory=None, gpu=None, nodes=None),
False,
None,
)
assert isinstance(wf_run_id, WorkflowRunId)
assert (
Expand Down Expand Up @@ -124,6 +126,7 @@ def test_with_memory(
workflow_def_id,
_models.Resources(cpu=None, memory="10Gi", gpu=None, nodes=None),
False,
None,
)

def test_with_cpu(
Expand All @@ -149,6 +152,7 @@ def test_with_cpu(
workflow_def_id,
_models.Resources(cpu="1000m", memory=None, gpu=None, nodes=None),
False,
None,
)

def test_with_gpu(
Expand All @@ -172,6 +176,7 @@ def test_with_gpu(
workflow_def_id,
_models.Resources(cpu=None, memory=None, gpu="1", nodes=None),
False,
None,
)

def test_maximum_resource(
Expand All @@ -195,6 +200,7 @@ def test_maximum_resource(
workflow_def_id,
_models.Resources(cpu="5000m", memory="3G", gpu="1", nodes=None),
False,
None,
)

def test_resources_from_workflow(
Expand Down Expand Up @@ -222,8 +228,89 @@ def test_resources_from_workflow(
workflow_def_id,
_models.Resources(cpu="1", memory="1.5G", gpu="1", nodes=20),
False,
None,
)

@pytest.mark.parametrize(
"head_node_resources",
(
sdk.Resources(cpu=None, memory=None),
None,
),
)
def test_with_no_head_node_resources(
self,
mocked_client: MagicMock,
runtime: _ce_runtime.CERuntime,
workflow_def_id: str,
workflow_run_id: str,
head_node_resources: sdk.Resources,
):
# Given
@sdk.workflow(
head_node_resources=head_node_resources,
)
def wf():
return add(1, 2)

mocked_client.create_workflow_def.return_value = workflow_def_id
mocked_client.create_workflow_run.return_value = workflow_run_id

# When
_ = runtime.create_workflow_run(
wf().model,
None,
dry_run=False,
)

# Then
mocked_client.create_workflow_run.assert_called_once_with(
workflow_def_id,
_models.Resources(cpu=None, memory=None, nodes=None, gpu=None),
False,
None,
)

@pytest.mark.parametrize(
"cpu, memory,",
(
("1", None),
(None, "10Gi"),
("2", "20Gi"),
),
)
def test_with_head_node_resources(
self,
mocked_client: MagicMock,
runtime: _ce_runtime.CERuntime,
workflow_def_id: str,
workflow_run_id: str,
cpu: str,
memory: str,
):
# Given
@sdk.workflow(head_node_resources=sdk.Resources(cpu=cpu, memory=memory))
def wf():
return add(1, 2)

mocked_client.create_workflow_def.return_value = workflow_def_id
mocked_client.create_workflow_run.return_value = workflow_run_id

# When
_ = runtime.create_workflow_run(
wf().model,
None,
dry_run=False,
)

# Then
mocked_client.create_workflow_run.assert_called_once_with(
workflow_def_id,
_models.Resources(cpu=None, memory=None, nodes=None, gpu=None),
False,
_models.HeadNodeResources(cpu=cpu, memory=memory),
)

class TestWorkflowDefFailure:
def test_invalid_wf_def(
self, mocked_client: MagicMock, runtime: _ce_runtime.CERuntime
Expand Down Expand Up @@ -1988,5 +2075,8 @@ def wf():
assert all([telltale in str(exec_info) for telltale in telltales])
else:
mocked_client.create_workflow_run.assert_called_once_with(
workflow_def_id, expected_resources, False
workflow_def_id,
expected_resources,
False,
None,
)
Loading

0 comments on commit f877188

Please sign in to comment.